/* IMSpector - Instant Messenger Transparent Proxy Service
 * http://www.imspector.org/
 * (c) Lawrence Manning <lawrence@aslak.net>, 2006
 *          
 * Released under the GPL v2. */

#include "imspector.h"

#define CERT_SOCKET "/tmp/.imspectorcert"

#define VERIFY_OFF 0
#define VERIFY_SELFSIGNED 1
#define VERIFY_BLOCK 2

SSLState::SSLState(void)
{
	method = NULL;
	connectctx = NULL;
	ctx = NULL;
}

/* Loads the certs, both the default one and a whole load of "per CN" ones. */
bool SSLState::init(class Options &options, bool debugmode)
{
	localdebugmode = debugmode;
		
	/* Init SSL. */
	SSL_library_init();
	SSL_load_error_strings();
	
	method = SSLv23_method();
	
	if (!method)
	{
		syslog(LOG_ERR, "Error: Couldn't set SSL method: %s", ERR_error_string(ERR_get_error(), NULL));
		return false;
	}
	
	connectctx = SSL_CTX_new(method);
	ctx = SSL_CTX_new(method);

	if (!connectctx || !ctx)
	{
		syslog(LOG_ERR, "Error: Couldn't create SSL contexts: %s", ERR_error_string(ERR_get_error(), NULL));
		return false;
	}
	
	if (SSL_CTX_use_PrivateKey_file(ctx, options["ssl_key"].c_str(), SSL_FILETYPE_PEM) <= 0)
	{
		syslog(LOG_ERR, "Error: Couldnt open server private key: %s",
			ERR_error_string(ERR_get_error(), NULL));
		return false;
	}
	
	/* These are private copies of these option. */
	sslcertdir = options["ssl_cert_dir"];
	sslverifydir = options["ssl_verify_dir"];
	
	/* Deal with IM server verification options. */
	if (!sslverifydir.empty())
	{
		/* Our client connection will be verified by the CA certs in this dir. */
		if (SSL_CTX_load_verify_locations(connectctx, NULL, sslverifydir.c_str()) <= 0)
		{
			syslog(LOG_ERR, "Error: Couldn't set verify location: %s",
				ERR_error_string(ERR_get_error(), NULL));
			return false;
		}
	}

	/* If set to block, then connections that fail validation checks are
	 * dropped. If set to selfsigned, then certs will still be passed, but the cert
	 * given to the IM client will always be self-signed. */
	if (options["ssl_verify"] == "block")
		sslverify = VERIFY_BLOCK;
	else if (options["ssl_verify"] == "selfsigned")
		sslverify = VERIFY_SELFSIGNED;
	else
		sslverify = VERIFY_OFF;

	if (sslcertdir.empty())
	{
		/* We are not creating cert on demand, so fall back to loading a static
		 * cert into the CTX. */
		if (!loadcert(ctx, "default", options["ssl_cert"]))
		{
			syslog(LOG_ERR, "Error: Unable to set connection with certificate");
			return false;
		}
	}
	else
	{
		/* Woot, on demand certs! */
		syslog(LOG_INFO, "Creating certs on demand into: %s", sslcertdir.c_str());
	
		FILE *hfile = NULL;
		
		/* First up, load in the CA cert. */
		if (!(hfile = fopen(options["ssl_ca_cert"].c_str(), "r")))
		{
			syslog(LOG_ERR, "Error: Unable to open CA cert");
			return false;
		}
		if (!(cacert = PEM_read_X509(hfile, NULL, NULL, NULL)))
		{
			syslog(LOG_ERR, "Error: Couldn't read CA cert: %s", ERR_error_string(ERR_get_error(), NULL));
			return false;
		}	
		fclose(hfile);
		
		/* Now the CA private key. */
		if (!(hfile = fopen(options["ssl_ca_key"].c_str(), "r")))
		{
			syslog(LOG_ERR, "Error: Unable to open CA key");
			return false;
		}
		if (!(cakey = PEM_read_PrivateKey(hfile, NULL, NULL, NULL)))
		{
			syslog(LOG_ERR, "Error: Couldn't read CA key: %s", ERR_error_string(ERR_get_error(), NULL));
			return false;
		}	
		fclose(hfile);
		
		/* Finally, the server key that will be used for all connections. */
		if (!(hfile = fopen(options["ssl_key"].c_str(), "r")))
		{
			syslog(LOG_ERR, "Error: Unable to open server key");
			return false;
		}
		if (!(serverkey = PEM_read_PrivateKey(hfile, NULL, NULL, NULL)))
		{
			syslog(LOG_ERR, "Error: Couldn't read server key: %s", ERR_error_string(ERR_get_error(), NULL));
			return false;
		}	
		fclose(hfile);

		/* Fork off the server process. */
		switch (fork())
		{
			/* An error occured. */
			case -1:
				syslog(LOG_ERR, "Error: Fork failed: %s", strerror(errno));
				return false;
			
			/* In the child. */
			case 0:
				sslcertserver(sslcertdir);
				debugprint(localdebugmode, "Error: We should not come here");
				exit(0);
		
			/* In the parent. */
			default:
				break;
		}	
	}
	
	return true;
}

void SSLState::free(void)
{
	if (ctx)
	{
		SSL_CTX_free(ctx);
		ctx = NULL;
	}
	if (connectctx)
	{
		SSL_CTX_free(connectctx);
		connectctx = NULL;
	}
}

/* Convert the IM server socket (our "client")into an SSL socket, setting up
 * a CTX for the IM client (our "server"). */
bool SSLState::imserversocktossl(class Socket &imserversock)
{
	/* First, upgrade the connection to the IM server; the connected socket. */
	imserversock.enablessl(connectctx);

	/* We could do SSL cert validation here, but for now we are not. */
	if (!imserversock.sslconnect())
		return false;

	/* Determine the commonname of the im server's cert, so we can associate
	 * a particular local cert to the client. */
	std::string commonname = imserversock.getpeercommonname();
	
	debugprint(localdebugmode, "Switching to SSL mode for Common Name: %s", commonname.c_str());

	int result = imserversock.getvalidatecertresult();
	
	debugprint(localdebugmode, "Valdiation result: %d", result);

	bool selfsigned = false;
	
	/* See what it should do with failed validations. */
	if (result != X509_V_OK)
	{
		switch (sslverify)
		{
			case VERIFY_BLOCK:
				debugprint(localdebugmode, "Blocking connection because of validation error: %d",
					result);
				return false;

			case VERIFY_SELFSIGNED:
				selfsigned = true;
				break;

			default:
				break;
		}
	}		

	/* See if we are doing on demand certs. */	
	if (!sslcertdir.empty())
	{
		debugprint(localdebugmode, "Requesting cert");
		
		/* Ask the helper process to make us a cert. */
		if (!sslcertclient(commonname, selfsigned))
		{
			syslog(LOG_ERR, "Error: Couldn't create and sign new certificate");
			return false;
		}

		/* And load the cert in. */	
		if (!loadcert(ctx, commonname, formatcertfilename(commonname)))
		{
			syslog(LOG_ERR, "Error: Unable to set connection with new certificate");
			return false;
		}
		
		debugprint(localdebugmode, "CTX loaded with cert");
	}
	
	return true;
}

/* Make the connection back to the IM client SSL, using the previously
 * configured CTX. */
bool SSLState::clientsocktossl(class Socket &clientsock)
{
	/* Enable SSL on the CTX. */
	clientsock.enablessl(ctx);
	
	if (!clientsock.sslaccept())
		return false;

	return true;
}

/* Private stuff here. */

/* A "client" for the cert process. Returns true if the server responded with "OK".
 * Called from the protocol handler process. */
bool SSLState::sslcertclient(std::string commonname, bool selfsigned)
{
	class Socket certsock(AF_UNIX, SOCK_STREAM);
	
	/* Complete the connection. */
	if (!(certsock.connectsocket(CERT_SOCKET, ""))) return -1;
	
	/* Add on a CR as the server needs these for end of line. */
	std::string commandlinecr = stringprintf("%s %s\n", commonname.c_str(),
		selfsigned ? "TRUE" : "FALSE");
	
	if (!certsock.sendalldata(commandlinecr.c_str(), commandlinecr.length())) return -1;
	
	char buffer[BUFFER_SIZE];
	
	memset(buffer, 0, BUFFER_SIZE);
	
	if (certsock.recvline(buffer, BUFFER_SIZE) < 0)
	{
		syslog(LOG_ERR, "Error: Couldn't get result from cert server");
		return false;
	}
		
	stripnewline(buffer);
	
	certsock.closesocket();
	
	return (strcmp(buffer, "OK") == 0);
}

/* Loads a cert (the key is already present) into the CTX.  The domain (CN) of the cert
 * is passed in only so we can put it in error messages so the user can see which cert is
 * giving trouble.  This is called both for the "fixed" cert refered to by the ssl_cert
 * option at startup, and for the on-demand certs, from within the protocol process. */
bool SSLState::loadcert(SSL_CTX *ctx, std::string domain, std::string certfilename)
{
	if (SSL_CTX_use_certificate_file(ctx, certfilename.c_str(), SSL_FILETYPE_PEM) <= 0)
	{
		syslog(LOG_ERR, "Error: Couldn't open certificate for %s: %s", domain.c_str(), 
			ERR_error_string(ERR_get_error(), NULL));
		return false;
	}
	
	if (!SSL_CTX_check_private_key(ctx))
	{
		syslog(LOG_ERR, "Error: Private key and certificate do not match for %s: %s",
			domain.c_str(), ERR_error_string(ERR_get_error(), NULL));
		return false;
	}

	return true;
}

/* A simple, single process "server" that is used to create certificates as
 * requested. */
bool SSLState::sslcertserver(std::string sslcertdir)
{
	class Socket certsock(AF_UNIX, SOCK_STREAM);
	
	if (!certsock.listensocket(CERT_SOCKET))
	{
		syslog(LOG_ERR, "Error: Couldn't bind to cert socket");
		return false;
	}

	/* This loop has no exit, except when the parent kills it off. */
	while (true)
	{
		std::string clientaddress;
		class Socket clientsock(AF_UNIX, SOCK_STREAM);
		char buffer[BUFFER_SIZE];
		
		if (!certsock.awaitconnection(clientsock, clientaddress)) continue;

		memset(buffer, 0, BUFFER_SIZE);
		if (clientsock.recvline(buffer, BUFFER_SIZE) < 0)
		{
			syslog(LOG_ERR, "Error: Couldn't get Common Name from cert client");
			continue;
		}

		/* Passed is simply the CN. */
		stripnewline(buffer);

		std::string command; std::vector<std::string> args; int argc;
		
		/* A typical comandline will be: bob.com TRUE\n */
		chopline(buffer, command, args, argc);
		
		std::string resultstring = "FAIL";

		if (argc > 0)
		{
			if (signcert(command, args[0] == "TRUE"))
				resultstring = "OK";
		}

		resultstring += '\n';
		
		if (clientsock.sendline(resultstring.c_str(), resultstring.length()) < 0)
		{
			syslog(LOG_ERR, "Error: Couldn't send result to cert client");
			continue;
		}

		clientsock.closesocket();
	}
	
	return true;
}

/* This does the actual work of signing a cert. Called from within the certserver
 * process. */
bool SSLState::signcert(std::string commonname, bool selfsigned)
{
	struct stat statbuf;
	FILE *hfile = NULL;
	X509 *servercert = NULL;
	X509_NAME *servercertname = NULL;
	std::string certfilename;
	
	memset(&statbuf, 0, sizeof(struct stat));

	/* Get the final filename for the cert. */
	certfilename = formatcertfilename(commonname.c_str());
	
	debugprint(localdebugmode, "Cert %s has filename: %s", commonname.c_str(),
		certfilename.c_str(), selfsigned);

	/* See if the cert already exists, leaving us nothing to do. */	
	if (stat(certfilename.c_str(), &statbuf) == 0)
	{
		debugprint(localdebugmode, "Cert %s already exists", commonname.c_str());
		return true;
	}
		
	debugprint(localdebugmode, "Cert %s needs to be created and signed", commonname.c_str());

	if (!(servercert = X509_new()))
	{
		syslog(LOG_ERR, "Error: Couldn't create new cert object: %s", ERR_error_string(ERR_get_error(), NULL));
		return false;
	}
	
	X509_set_version(servercert, 2);
	
	/* Set the serial number, which we will just hash from the commonname. */
	ASN1_INTEGER_set(X509_get_serialNumber(servercert), (int) hash(commonname.c_str()));
	
	/* Certificate will be valid from an hour ago, incase the clocks on the IMSpector
	 * machine and the client do not quite match, and for 5 years. */
	X509_gmtime_adj(X509_get_notBefore(servercert), - 60 * 60);
	X509_gmtime_adj(X509_get_notAfter(servercert), 60 * 60 * 24 * 365 * 5);

	/* The public key comes from the global server key. */
	X509_set_pubkey(servercert, serverkey);
	
	if (!(servercertname = X509_get_subject_name(servercert)))
	{
		syslog(LOG_ERR, "Error: Couldn't create new cert name object: %s", ERR_error_string(ERR_get_error(), NULL));
		return false;
	}
	
	/* The created cert will have only a CN field, which of course is duplicated
	 * from the IMSpector->IM server connection. */
	X509_NAME_add_entry_by_txt(servercertname, "CN", MBSTRING_ASC,
		(unsigned char *) commonname.c_str(), -1, -1, 0);
	
	/* By default we sign with our CA. */
	X509 *cert = cacert;
	EVP_PKEY *key = cakey;

	/* But we can make a self signed cert, if required. */
	if (selfsigned)
	{
		debugprint(localdebugmode, "Generating a self-signed cert");

		cert = servercert;
		key = serverkey;
	}

	/* The issuer of the cert is our CA, unless its self-signed. */
	X509_set_issuer_name(servercert, X509_get_subject_name(cert));

	/* The big call: sign the cert! */		
	if (!X509_sign(servercert, key, EVP_sha1()))
	{
		syslog(LOG_ERR, "Error: Couldn't sign cert: %s", ERR_error_string(ERR_get_error(), NULL));
		return false;
	}

	/* Finally write out the cert to disk.  We never diretly load a CTX with a
	 * cert from memory, it always come from a file on disk. */
	if (!(hfile = fopen(certfilename.c_str(), "w")))
	{
		syslog(LOG_ERR, "Error: Unable to create certificate file");
		return false;
	}
	if (!(PEM_write_X509(hfile, servercert)))
	{
		syslog(LOG_ERR, "Error: Couldn't write signed cert: %s", ERR_error_string(ERR_get_error(), NULL));
		return false;
	}
	fclose(hfile);
	
	X509_free(servercert);
	
	syslog(LOG_INFO, "%s cert for %s created and signed", selfsigned ? "Self-signed" : "Regular",
		commonname.c_str());
	
	return true;
}

/* Gets the hashed filename from a commonname. */
std::string SSLState::formatcertfilename(std::string commonname)
{
	return stringprintf("%s/%08x.pem", sslcertdir.c_str(), hash(commonname.c_str()));
}
