
/*
 * ptytelnet - connect two processes over telnet
 *
 * usage: ptytelnet hostname[:port] prog ... [';' prog ...]
 *
 * David Leonard, 1997
 * <david.leonard@it.uq.edu.au>
 *
 * This program telnets to the given host/port and runs the sequence of 
 * programs given. These programs think that their stdin/stdout is a tty
 * but they are really talking through the TELNET protocol decoder to the 
 * host at the other end.
 *
 * Special 'prog's are:
 *	user	- the real stdin/stdout is used so that you can interactively
 *		  do stuff
 *	daemon  - the process backgrounds itself
 *	binary  - the telnet connection is put into 8-bit clean mode
 *      text    - the telnet connection is put into interactive character mode 
 *
 * An example is:
 *	ptytelnet gateway:23  text ';' chat -f login.script ';' binary ';' pppd
 *
 * Compile with: -DTELOPTS -DTELCMDS -lutil
 *
 * $Id$
 */

#include <stdio.h>		/* printf */
#include <util.h>		/* openpty */
#include <string.h>		/* strtok, memset */
#include <unistd.h>		/* abort */
#include <stdlib.h>		/* atexit */
#include <termios.h>		/* tcgetattr, tcsetattr */
#include <signal.h>		/* signal, kill */
#include <netdb.h>		/* gethostbyname */
#include <err.h>		/* err,... */
#include <sys/types.h>		/* pid_t */
#include <sys/wait.h>		/* wait3 */
#include <sys/ioctl.h>		/* ioctl */
#include <sys/socket.h>		/* socket */
#include <netinet/in.h>		/* AF_INET,... */
#include <arpa/telnet.h>	/* IAC,... */
#include <arpa/inet.h>		/* inet_ntoa */

/*
 * mediated stream types (bitmasks)
 * - a mediated stream is the full duplex combination of two simplex streams.
 *   each of the two simplex streams has an associated 'type'.
 */

#define ENCTEL	2		/* encode data into telnet protocol */
#define DECTEL	4		/* decode data from telnet protocol */
#define ISUSER	8		/* interpret chars from user (^Z/+++) */
#define DBGW   16		/* show what is written */
#define DBGR   32		/* who what is read */

/* the special escape sequence to exit a mediated 'ISUSER' stream */

char           *user_esc_sequence = "+++";

static const int debug_tn = 1;
static int exiting = 0;

/*
 * protocol_state - structure representing the protocol state of a mediated
 * stream
 */

typedef struct {
	int             esc_seq_matched;	/* position in
						 * user_esc_sequence */
	int             tn_command;	/* last telnet cmd recieved */
	unsigned char   out_pend_buf[1024];	/* telnet protocol output */
	unsigned char  *out_pend;	/* telnet protocol outputtail */
	char            expecting_d_ack[64];	/* DO/DONT options expected */
	char            expecting_w_ack[64];	/* WILL/WONT options expected */
} protocol_state;

void
dbg_io(msg, fd, buf, len)
	char           *msg;
	int             fd;
	void           *buf;
	size_t          len;
{
	int             i;

	if (debug_tn < 3)
		return;

	printf("[%s(%d): ", msg, fd);
	for (i = 0; i < len; i++) {
		unsigned char   c = ((char *) buf)[i];
		if (c == '\\')
			putchar('\\');
		if (c >= ' ' && c <= 0x7f)
			putchar(c);
		else
			printf("\\x%02x", c);
	}
	printf("]\n");
}

/*
 * init_protocol_state
 * - initialises a protocol state structure
 */

void
init_protocol_state(protocol_state * state)
{
	memset(state, '\0', sizeof *state);
	state->out_pend = state->out_pend_buf;
}

/* tty parameters are saved here for later restoration */

static struct termios tty_save;
static int      tty_needs_restoring = 0;

/*
 * raw_tty
 * - force the stdin terminal into 'raw' mode
 */

void
raw_tty()
{
	struct termios  t;
	int             fd = 0;

	if (!tty_needs_restoring) {
		tcgetattr(fd, &tty_save);
		tcgetattr(fd, &t);
		cfmakeraw(&t);
		tcsetattr(fd, TCSANOW, &t);
		tty_needs_restoring = 1;
	}
}

/*
 * restore_tty
 * - restore the terminal to its previous mode
 */

void
restore_tty()
{
	if (tty_needs_restoring) {
		tcsetattr(0, TCSANOW, &tty_save);
		tty_needs_restoring = 0;
	}
}

/*
 * suspend - causes the process to be suspended
 */

void
suspend()
{
	int             needs_restore = tty_needs_restoring;

	if (needs_restore)
		restore_tty();
	printf("[Suspended.]\n");
	kill(getpid(), SIGSTOP);
	printf("[Resumed.]\n");
	if (needs_restore)
		raw_tty();
}

/*
 * tn_send
 * - enqueue a telnet protocol command to send to the remote server
 */

void
tn_send(int expectack, protocol_state * state, unsigned char cmd,
	unsigned char op)
{
	if (debug_tn >= 1)
		fprintf(stderr, "[send %s %s]\r\n", TELCMD(cmd), TELOPT(op));
	*(state->out_pend)++ = IAC;
	*(state->out_pend)++ = cmd;
	*(state->out_pend)++ = op;
	if (expectack && (cmd == WILL || cmd == WONT))
		state->expecting_d_ack[op] = 1;
	if (expectack && (cmd == DO || cmd == DONT))
		state->expecting_w_ack[op] = 1;
}

/*
 * tn_recv
 * - process a recieved telnet protocol command
 */

void
tn_recv(protocol_state * state, unsigned char cmd, unsigned char op)
{
	if (debug_tn > 2)
		fprintf(stderr, "[recv %s %s]\r\n", TELCMD(cmd), TELOPT(op));
	if ((cmd == WILL || cmd == WONT) && state->expecting_w_ack[op]) {
		state->expecting_w_ack[op] = 0;
		return;
	}
	if ((cmd == DO || cmd == DONT) && state->expecting_d_ack[op]) {
		state->expecting_w_ack[op] = 0;
		return;
	}
	/* if the server is instigating something - then we have no idea */
	if (cmd == WILL || cmd == WONT) {
		int             okcmd = cmd + 2;

		if (op == TELOPT_BINARY || op == TELOPT_LINEMODE || 
		    op == TELOPT_ECHO)
			tn_send(0, state, okcmd, op);
		else if (op == TELOPT_SGA)
			tn_send(0, state, DO, op);
		else
			tn_send(0, state, DONT, op);
	}
	if (cmd == DO || cmd == DONT) {
		int             okcmd = cmd - 2;

		if (op == TELOPT_SGA || op == TELOPT_BINARY)
			tn_send(0, state, okcmd, op);
		else
			tn_send(0, state, WONT, op);
	}
}

/*
 * proto_flush - flush the protocol output buffer to 'to'
 */

void
proto_flush(to, type, state)
	int             to;
	int             type;
	protocol_state *state;
{
	if (type & ENCTEL) {
		if (state->out_pend != state->out_pend_buf) {
			write(to, state->out_pend_buf,
			      state->out_pend - state->out_pend_buf);
			if (type & DBGW)
				dbg_io("write", to, state->out_pend_buf,
				     state->out_pend - state->out_pend_buf);
			state->out_pend = state->out_pend_buf;
		}
	}
}

/*
 * mediate2
 * - copy one "burst" of data from 'from' to 'to', interpreting the
 *   data wrt 'type'.
 * - returns true if the session should end
 */

int
mediate2(from, to, type, state)
	int             from;
	int             to;
	int             type;
	protocol_state *state;
{
	static unsigned char buffer[8192];
	static unsigned char outbuffer[16384];
	unsigned char  *out, *in;
	int             len;

	/* read from 'from' into the input buffer */

	len = read(from, &buffer, sizeof buffer);

	if (len < 0) {
		warn("fd %d: read", from);
		return -1;
	}
	if (len == 0) {
		if (debug_tn > 1)
			fprintf(stderr, "fd %d closed\n", from);
		return 1;
	}

	if (type & DBGR)
		dbg_io("read", from, buffer, len);

	/* copy the input buffer to the output buffer */

	for (out = outbuffer, in = buffer; in - buffer < len; in++) {

		/* perform telnet encodings  - see also flush_proto() */
		if (type & ENCTEL) {

			if (*in == IAC)
				*out++ = IAC;	/* when encoding, double up
						 * on IACs */
		}
		/* decode telnet stream before passing through */
		if (type & DECTEL) {
			switch (state->tn_command) {
			case 0:
				if (*in == IAC) {
					state->tn_command = IAC;
					continue;
				}
				break;

			case IAC:
				switch (*in) {
				case IAC:
					*out++ = IAC;
					state->tn_command = 0;
					break;
				case WILL:
				case WONT:
				case DO:
				case DONT:
					state->tn_command = *in;
					break;
				default:
					warn("unknown telnet command %d\n", *in);
					state->tn_command = 0;
					break;
				}
				continue;

			case WILL:
			case WONT:
			case DO:
			case DONT:
				tn_recv(state, state->tn_command, *in);
				state->tn_command = 0;
				continue;

			default:
				warn("unknown telnet cmd %d\n", state->tn_command);
				state->tn_command = 0;
				continue;
			}
		}
		if ((type & ISUSER) && *in == ('Z' & 0x3f)) {
			suspend();
		}
		*out++ = *in;

		/* check for the escape message */
		if (type & ISUSER) {
			if (*in == user_esc_sequence[state->esc_seq_matched]) {
				state->esc_seq_matched++;
				if (user_esc_sequence[state->esc_seq_matched] == '\0')
					/* bug? stops before sending! */
					return 1;
			} else {
				state->esc_seq_matched = 0;
			}
		}
	}

	/* blast the converted packet to the other end of the simplex stream */

	if (type & DBGW)
		dbg_io("write", to, &outbuffer, out - outbuffer);
	len = write(to, &outbuffer, out - outbuffer);
	if (len < 0) {
		warn("fd %d: write", to);
		return -1;
	}

	return 0;
}

/*
 * mediate
 * - mediate data transfer between two pair of file descriptors.
 */

int
mediate(state, from1, to1, type1, from2, to2, type2)
	protocol_state *state;
	int             from1;
	int             to1;
	int             type1;
	int             from2;
	int             to2;
	int             type2;
{
	fd_set          reads;
	fd_set          writes;
	int             res;
	struct timeval  to_poll = {0, 0};

	if (debug_tn > 1)
		fprintf(stderr, "[mediate %d->%d and %d->%d]\n",
		       from1, to1, from2, to2);

	FD_ZERO(&reads);
	FD_ZERO(&writes);
	while (1) {

		/* set up the read set */

		FD_SET(from1, &reads);
		FD_SET(from2, &reads);
		FD_SET(to1, &writes);
		FD_SET(to2, &writes);

		/* wait for data */

		res = select(FD_SETSIZE, &reads, NULL, NULL, NULL);
		if (exiting) {
			if (debug_tn > 3)
				fprintf(stderr, "exiting mediate loop\n");
			return (exiting);
		}

		if (res < 0)
			err(1, "select");

		/* poll the write sets: */
		res = select(FD_SETSIZE, NULL, &writes, NULL, &to_poll);
		if (res < 0)
			err(1, "select");

		/* check from1->to1 */
		if (FD_ISSET(from1, &reads) && FD_ISSET(to1, &writes)) {
			if (debug_tn > 3)
				fputc('>', stderr);
			res = mediate2(from1, to1, type1, state);
			if (res)
				return (res);
		}
		/* check from2->to2 */
		if (FD_ISSET(from2, &reads) && FD_ISSET(to2, &writes)) {
			if (debug_tn > 3)
				fputc('<', stderr);
			res = mediate2(from2, to2, type2, state);
			if (res)
				return (res);
		}
		/* flush the protocol output buffer */
		proto_flush(to1, type1, state);
		proto_flush(to2, type2, state);
	}

	return (0);
}

/*
 * open_telnet
 * - open a socket to a remote host's telnet server
 * - hostname is of the form "hostname[:port]"
 */

int
open_telnet(hostname)
	char           *hostname;
{
	int             sockfd;
	struct sockaddr_in from, to;
	struct hostent *he;
	int             port;
	char           *portstr;

	/* determine the telnet server's port number */

	port = 23;		/* well-known port */
	strtok(hostname, ":");
	portstr = strtok(NULL, ":");
	if (portstr)
		port = atoi(portstr);

	/* lookup the IP address of the remote host */

	he = gethostbyname(hostname);
	if (!he) {
		herror(hostname);
		exit(1);
	}
	/* set up where we are connecting to */

	to.sin_len = sizeof(to);
	to.sin_family = AF_INET;
	to.sin_port = htons(port);
	memcpy((void *) &to.sin_addr, (void *) he->h_addr, sizeof to.sin_addr);

	if (debug_tn >= 1)
		fprintf(stderr, "[Connecting to %s:%d...]\n", 
		    inet_ntoa(to.sin_addr), port);

	/* set up where are we connecting from */

	from.sin_len = sizeof(from);
	from.sin_family = AF_INET;
	from.sin_port = htons(0);	/* any */
	from.sin_addr.s_addr = INADDR_ANY;	/* any interface */

	/* allocate a network socket */

	sockfd = socket(AF_INET, SOCK_STREAM, 0);
	if (sockfd < 0)
		err(1, "socket");

	/* bind one end of the socket to the local address */

	if (bind(sockfd, (struct sockaddr *) & from, sizeof from) < 0)
		err(1, "bind");

	/* connect the other end of the socket to the remote address */

	if (connect(sockfd, (struct sockaddr *) & to, sizeof to) < 0)
		err(1, "connect");

	if (debug_tn >= 1)
		fprintf(stderr, "[Connected.]\n");

	return (sockfd);
}
/*
 * killed
 * - handle an INT, KILL or HUP to this process.. SIGHUPs the child
 */

void
killed(sig)
	int             sig;
{
	if (debug_tn >= 1)
		fprintf(stderr, "[Got signal %d.]\n", sig);
	exiting = 1;
	/* Should cause EINTR somewhere */
}

void
negotiate_text(state)
	protocol_state *state;
{
	/* negotiate text transfer, with suppressed go-aheads */

	tn_send(1, state, WONT, TELOPT_LINEMODE);
	tn_send(1, state, DONT, TELOPT_BINARY);
	tn_send(1, state, WONT, TELOPT_BINARY);
	tn_send(1, state, DO, TELOPT_ECHO);
}

void
negotiate_binary(state)
	protocol_state *state;
{
	/* negotiate binary transfer, with suppressed go-aheads */

	tn_send(1, state, DO, TELOPT_BINARY);
	tn_send(1, state, WILL, TELOPT_BINARY);
	tn_send(1, state, DONT, TELOPT_ECHO);
	tn_send(1, state, DO, TELOPT_SGA);
}

void
user_cmd(sockfd, state)
	int             sockfd;
	protocol_state *state;
{

	/* tell them we won't do linemode */

	negotiate_text(state);

	/* allow the user to log in, start slirp, etc */

	printf("[Type %s to continue, or ^Z to suspend.]\n",
	       user_esc_sequence);

	raw_tty();
	atexit(restore_tty);
	mediate(state,
		STDIN_FILENO, sockfd, ISUSER | ENCTEL,
		sockfd, STDOUT_FILENO, DECTEL
		);
	restore_tty();
}

void
prog_cmd(sockfd, state, argc, argv)
	int             sockfd;
	protocol_state *state;
	int             argc;
	char          **argv;
{
	int             ptyslave, ptymaster;
	char            slavename[1024];
	char          **arglist;
	int             i;
	pid_t           child;
	int		status;

	/* copy the arguments into an array we can null-terminate */
	arglist = malloc((argc + 1) * sizeof(char *));
	if (arglist == NULL)
		err(1, "malloc");
	for (i = 0; i < argc; i++)
		arglist[i] = argv[i];
	arglist[argc] = NULL;

	/* allocate a pty pair */

	openpty(&ptymaster, &ptyslave, slavename, NULL, NULL);

	/* on interrupts, kill the child */

	if (signal(SIGINT, killed) == SIG_ERR)
		err(1, "signal SIGINT");
	if (signal(SIGTERM, killed) == SIG_ERR)
		err(1, "signal SIGTERM");
	if (signal(SIGHUP, killed) == SIG_ERR)
		err(1, "signal SIGHUP");

	if (signal(SIGPIPE, SIG_IGN) == SIG_ERR)
		err(1, "signal SIGHUP");

	child = fork();
	if (child == -1)
		err(1, "fork");
	if (child == 0) {

		/* This block only executed in the child process */

		/* clean up file descriptors the child doesn't use */
		close(0);	/* stdin */
		close(1);	/* stdout */
		/* close(2); *//* stderr */
		close(ptymaster);
		close(sockfd);

		/* bind standard input/output to the pty device */
		dup2(ptyslave, 0);
		dup2(ptyslave, 1);
		close(ptyslave);

		/* start own session (loses controlling tty) */
		if (setsid() == -1)
			err(1, "setsid");

		/*
		 * set the pty to be this process's controlling tty
		 * (/dev/tty)
		 */
		if (ioctl(0, TIOCSCTTY) == -1)
			err(1, "ioctl TIOCSCTTY");

		/* execute the requested program (usually pppd) */
		execvp(arglist[0], arglist);

		warn("%s", arglist[0]);
		_exit(1);
	}
	free(arglist);

	/* parent doesn't use the slave side of the pty */
	close(ptyslave);

	/* transfer data (forever) between the telnet connection and the pty */

	mediate(state,
		ptymaster, sockfd, ENCTEL | DBGR | DBGW,
		sockfd, ptymaster, DECTEL | DBGR | DBGW);

	/* no need to talk to pty any more */
	close(ptymaster);

	printf("[Waiting for child process %d to finish.]\n", child);
	if (waitpid(child, &status, WNOHANG) == -1) {
		if (-1 == kill(child, SIGKILL))
			warn("pid %d", child);
		if (-1 == waitpid(child, &status, WNOHANG))
			warn("waitpid %d", child);
	}

	if (WIFEXITED(status) && WEXITSTATUS(status) != 0)
		errx(WEXITSTATUS(status), "%s: child exited abnormally",
			arglist[0]);
	if (WIFSIGNALED(status))
		err(1, "%s: %s", arglist[0], strsignal(WTERMSIG(status)));
}

/*
 * main
 */

int
main(argc, argv)
	int             argc;
	char          **argv;
{
	int             sockfd;
	protocol_state  state;
	int             newargc;

	/* check command line arguments usage */
	if (argc < 3) {
		fprintf(stderr, "usage: %s host cmd...\n", argv[0]);
		exit(2);
	}
	/* connect to the host using telnet */
	sockfd = open_telnet(argv[1]);

	/* skip the host:port argument */
	argc -= 2;
	argv += 2;

	init_protocol_state(&state);

	while (argc != 0) {

		/* skip to the end of command or a ';' */
		newargc = 0;
		while (newargc < argc && strcmp(argv[newargc], ";") != 0)
			newargc++;

		if (debug_tn >= 1) {
			int i;
			fprintf(stderr, "[processing command:");
			for (i = 0; i < newargc; i++)
				fprintf(stderr, " %s", argv[i]);
			fprintf(stderr, "]\n");
		}

		if (newargc != 0) {
			if (strcmp(argv[0], "daemon") == 0) {
				daemon(1, 0);
			} else if (strcmp(argv[0], "user") == 0) {
				user_cmd(sockfd, &state);
			} else if (strcmp(argv[0], "text") == 0) {
				negotiate_text(&state);
			} else if (strcmp(argv[0], "binary") == 0) {
				negotiate_binary(&state);
			} else {
				/* other command */
				prog_cmd(sockfd, &state, newargc, argv);
			}
		}

		/* skip the command words */
		argc -= newargc;
		argv += newargc;

		/* skip the semicolon */
		if (argc != 0) {
			argc--;
			argv++;
		}

		if (exiting)
			exit(exiting);
	}
	exit(0);
}

