/*
 * Simplified example SSL client/server program.
 * Used as a template.
 */

#include <stdio.h>
#include <unistd.h>
#include <err.h>
#include <poll.h>
#include <netdb.h>

#include <sys/types.h>
#include <sys/socket.h>

#include <netinet/in.h>

#include <arpa/inet.h>

#include <openssl/rsa.h>
#include <openssl/crypto.h>
#include <openssl/x509.h>
#include <openssl/pem.h>
#include <openssl/ssl.h>
#include <openssl/err.h>

/* Lines starting with this enter SSL mode */
#define TRIGGER '!'

int s = -1;
int isserver;
SSL_CTX *ctx;
SSL *ssl = NULL;

static char *
sslerr(void)
{
	static char buf[1024];
	ERR_error_string(ERR_get_error(), buf);
	return buf;
}

int
endpoint(host, port)
	char *host, *port;
{
	struct addrinfo hints, *res, *res0;
	int error;
	int s;
	const char *cause = NULL;

	memset(&hints, 0, sizeof(hints));
	hints.ai_family = PF_INET;
	hints.ai_socktype = SOCK_STREAM;
	hints.ai_protocol = IPPROTO_TCP;
	if (isserver)
		hints.ai_flags = AI_PASSIVE;
	hints.ai_socktype = SOCK_STREAM;
	error = getaddrinfo(host, port, &hints, &res0);
	if (error)
		err(1, "%s", gai_strerror(error));
	s = -1;
	for (res = res0; res; res = res->ai_next) {
		s = socket(res->ai_family, res->ai_socktype,
		    res->ai_protocol);
		if (s == -1) {
			cause = "socket";
			continue;
		}
		if (isserver) {
		    if (bind(s, res->ai_addr, res->ai_addrlen) == -1) {
			cause = "bind";
			close(s);
			s = -1;
			continue;
		    }
		    if (listen(s, 1) == -1) {
			cause = "listen";
			close(s);
			s = -1;
			continue;
		    }
		} else
		    if (connect(s, res->ai_addr, res->ai_addrlen) == -1) {
			cause = "connect";
			close(s);
			s = -1;
			continue;
		    }
		break;
	}
	if (s < 0)
		err(1, cause);
	freeaddrinfo(res0);
	return s;
}

void
startssl()
{
	printf("[establishing ssl connection...]");

	if (!isserver) {
		ctx = SSL_CTX_new(SSLv23_client_method());
		SSL_CTX_set_verify(ctx, SSL_VERIFY_NONE, NULL);
		ssl = SSL_new(ctx);
		SSL_set_fd(ssl, s);
		SSL_set_connect_state(ssl);
		if (SSL_connect(ssl) <= 0)
			errx(1, "SSL_connect: %s", sslerr());
		ssl->debug = 1;
		printf("[connected]\n");
	} else {
		ctx = SSL_CTX_new(SSLv23_server_method());
		SSL_CTX_set_verify(ctx, SSL_VERIFY_NONE, NULL);
		SSL_CTX_set_options(ctx, SSL_OP_ALL);
		if (!SSL_CTX_use_certificate_file(ctx, "../s.crt", 
		    SSL_FILETYPE_PEM))
			errx(1, "SSL_CTX_use_certificate_file: %s", sslerr());
		if (!SSL_CTX_use_PrivateKey_file(ctx, "../s.key", 
		    SSL_FILETYPE_PEM))
			errx(1, "SSL_CTX_use_PrivateKey_file: %s", sslerr());
		if (!SSL_CTX_check_private_key(ctx))
			errx(1, "certificate and key mismatch: %s", sslerr());
		ssl = SSL_new(ctx);
		SSL_set_fd(ssl, s);
		SSL_set_accept_state(ssl);
		if (SSL_accept(ssl) <= 0)
			errx(1, "SSL_accept: %s", sslerr());
		ssl->debug = 1;
		printf("[accepted]\n");
	}
}

int
main(argc, argv)
	int argc;
	char *argv[];
{
	int l, i;
	char buf[256];
	int len;
	struct pollfd pfd[2];
	struct sockaddr_in addr;
	socklen_t addrlen;

	setbuf(stdout, NULL);
	setbuf(stdin, NULL);

	SSL_library_init();
	SSL_load_error_strings();

	/* Make a TCP client or server connection */
	switch (argc) {
	case 3:
		isserver = 0;
		s = endpoint(argv[1], argv[2]);
		break;
	case 2:
		isserver = 0;
		s = endpoint("localhost", argv[1]);
		break;
	case 1:
		isserver = 1;
		l = endpoint(NULL, "0");
		addrlen = sizeof addr;
		if (getsockname(l, (struct sockaddr *)&addr, &addrlen) == -1)
			err(1, "getsockname");
		printf("waiting on port %u\n", ntohs(addr.sin_port));
		addrlen = sizeof addr;
		if ((s = accept(l, (struct sockaddr *)&addr, &addrlen)) == -1)
			err(1, "accept");
		close(l);
		break;
	default:
		fprintf(stderr,  "usage: %s [[host] port]\n", argv[0]);
		exit(1);
	}

	printf("connected\n");
	/* startssl(); */

	/* Loop copying characters from stdin */
	for (;;) {
		pfd[0].fd = STDIN_FILENO;
		pfd[0].events = POLLIN;
		pfd[0].revents = 0;

		if (ssl)
			pfd[1].fd = SSL_get_fd(ssl);
		else
			pfd[1].fd = s;
		pfd[1].events = POLLIN;
		pfd[1].revents = 0;

		if (poll(pfd, 2, -1) == -1)
			err(1, "poll");

		if (pfd[0].revents & POLLIN) {
			len = read(STDIN_FILENO, buf, sizeof buf - 1);
			if (len == -1)
				err(1, "read");
			if (len == 0)
				break;
			/* Convert LF to CRLF */
			if (len > 1 && buf[len - 1] == '\n') {
				buf[len - 1] = '\r';
				buf[len] = '\n';
				len++;
			}
			if (ssl) {
				len = SSL_write(ssl, buf, len);
				switch (SSL_get_error(ssl, len)) {
				case SSL_ERROR_NONE: break;
				default: err(1, "SSL_write: %s", sslerr());
				}
			} else {
				if (buf[0] == TRIGGER)
					startssl();
					/* rest of buf is lost! */
				else if (write(s, buf, len) == -1)
					err(1, "write");
			}
		}

		if (pfd[1].revents & POLLIN) {
			if (ssl) {
				len = SSL_read(ssl, buf, sizeof buf);
				switch (SSL_get_error(ssl, len)) {
				case SSL_ERROR_NONE: break;
				case SSL_ERROR_ZERO_RETURN: len = 0; break;
				default: err(1, "SSL_read: %s", sslerr());
				}
			} else {
				len = read(s, buf, sizeof buf);
				if (len == -1)
					err(1, "read");
			}
			if (len == 0)
				break;
			/* if (write(STDOUT_FILENO, buf, len) == -1)
				err(1, "write"); */
			for (i = 0; i < len; i++) {
				char c = buf[i];
				if ((c >= ' ' && c <= '~') || c == '\t') 
					putchar(c);
				else if (c == '\n')
					printf("\\n\n");
				else if (c == '\r')
					printf("\\r");
				else if (c == '\0')
					printf("\\0");
				else
					printf("\\%03o", c);
			}
		}
	}

	if (ssl) {
		SSL_set_shutdown(ssl, SSL_SENT_SHUTDOWN|SSL_RECEIVED_SHUTDOWN);
		SSL_free(ssl);
		SSL_CTX_free(ctx);
		ERR_remove_state(0);
	}
	close(s);

	printf("disconnected\n");
	exit(0);
}

