/*
 * $Id$ 
 * David Leonard <d@openbsd.org>, 1999. Public domain.
 *
 * tunnel - the poor man's VPN.
 *
 * This program uses the tun(4) device to set up a tunnel endpoint,
 * use an rsh-like program to connect to another host and start the
 * other tunnel endpoint.
 *
 * Some evironment variables are used to tell what programs to use
 * and run on the other end:
 *
 *  Example usage:
 *
 *	 TUNNEL_RSH=ssh
 *	 TUNNEL_RSH_USER=leonard
 *	 TUNNEL_COMMAND=/home/leonard/bin/tunnel
 *	 export TUNNEL_COMMAND TUNNEL_RSH TUNNEL_RSH_USER
 *	 tunnel nearby-host 10.1.128.9:10.1.128.10
 *	 route add that-remote-host 10.1.128.10
 */

#include <err.h>
#include <errno.h>
#include <fcntl.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <signal.h>
#include <unistd.h>
#include <util.h>
#include <sys/types.h>
#include <sys/wait.h>
#include <sys/socket.h>
#include <sys/ioctl.h>
#include <sys/param.h>
#include <netinet/in.h>
#include <net/if.h>
#include <net/if_tun.h>
#include <arpa/inet.h>
#include <stdarg.h>

/*
 * Find our local host's name
 */
char *
localhost()
{
	static char hostname[MAXHOSTNAMELEN] = "";

	if (hostname[0] == '\0') 
		if (gethostname(hostname, sizeof hostname) == -1) {
			warn("gethostname");
			strcpy(hostname, "?");
		}
	return hostname;
}

/*
 * Set a sockaddr structure from a string
 */
int
stosa(char *name, struct sockaddr*sa)
{
	struct sockaddr_in *sin = (struct sockaddr_in *)sa;

	if (inet_pton(AF_INET, name, &sin->sin_addr) == 1) {
		sin->sin_len = sizeof(*sin);
		sin->sin_family = AF_INET;
		sin->sin_port = 0;
		memset(&sin->sin_zero, '\0', sizeof sin->sin_zero);
		return (0);
	} else {
		errno = EINVAL;
		return (-1);
	}
}

/*
 * Open a tunnel device.
 */
int
opentun(hostpair)
	char *hostpair;
{
	int tunfd;
	int i;
	char tunname[] = "/dev/tun0";
	char *ifname = tunname + 5;
	struct ifreq	  ifr;
	struct ifreq	  addr_ifr;
	struct ifreq	  dstaddr_ifr;
	char *colon;
	char *remotename, *localname;

	colon = strchr(hostpair, ':');
	if (colon == NULL)
		err(1, "%s: invalid address pair \"%s\"", localhost(), hostpair);
	*colon = '\0';
	localname = strdup(hostpair);
	remotename = strdup(colon+1);
	*colon = ':';

	if (stosa(localname, &addr_ifr.ifr_addr) < 0)
		err(1, "%s: %s", localhost(), localname);
	if (stosa(remotename, &dstaddr_ifr.ifr_dstaddr) < 0)
		err(1, "%s: %s", localhost(), remotename);


	for (i = 0; i < 10; i++) {
		tunname[8] = '0'+i;
		tunfd = opendev(tunname, O_RDWR | O_NONBLOCK, OPENDEV_DRCT,
		    NULL);
		if (tunfd >= 0)
			break;
		if (errno == ENXIO)
			err(1, "%s: out of tunnels", localhost());
		if (errno == EBUSY)
			continue;
		err(1, "%s: open %s", localhost(), tunname);
	}

#if 0
	/* Set the endpoint addresses */
	if (ioctl(tunfd, SIOCSIFADDR, &addr_ifr) < 0)
		err(1, "%s: %s: SIOCSIFADDR", localhost(), ifname);

	if (ioctl(tunfd, SIOCSIFDSTADDR, &dstaddr_ifr) < 0)
		err(1, "%s: %s: SIOCSIFDSTADDR", localhost(), ifname);

	/* Mark the interface up */
	if (ioctl(tunfd, SIOCGIFFLAGS, &ifr) < 0)
		err(1, "%s: %s: SIOCGIFFLAGS", localhost(), ifname);
	if (!(ifr.ifr_flags & IFF_UP)) {
		ifr.ifr_flags |= IFF_UP;
		if (ioctl(tunfd, SIOCSIFFLAGS, &ifr) < 0)
			err(1, "%s: %s: SIOCSIFFLAGS", localhost(), ifname);
	}
#else
	{
		char cmdbuf[1024];
		int status;

		sprintf(cmdbuf, 
		    "/sbin/ifconfig %s %s %s netmask 255.255.255.255",
		    ifname, localname, remotename);
		status = system(cmdbuf);
		if (!(WIFEXITED(status) && WEXITSTATUS(status) == 0)) {
			errx(1, "%s: ifconfig", localhost());
		}
	}
#endif

	free(localname);
	free(remotename);

	return (tunfd);
}

/*
 * Handle the death of the transport child.
 */
void
sigchld(sig)
	int sig;
{
	int status;
	pid_t child;
	char *rsh_command;

	if ((rsh_command = getenv("TUNNEL_RSH")) == NULL)
		rsh_command = "ssh";

	child = wait(&status);
	if (child < 0)
		err(1, "%s: wait", localhost());
	if (WIFEXITED(status))
		errx(WEXITSTATUS(status), "%s: %s exited", localhost(),
		    rsh_command);
	else if (WIFSIGNALED(status))
		errx(1, "%s: %s: %s", localhost(), rsh_command,
		    strsignal(WTERMSIG(status)));
	else if (WIFSTOPPED(status))
		warnx("%s: %s: %s", localhost(), rsh_command,
		    strsignal(WSTOPSIG(status)));
	else
		warnx("%s: pid %d: %s (unhandled)", localhost(), child,
		    strsignal(sig));
}

/*
 * Start a process like ssh going as the tunnel transport.
 */
pid_t
build_tunnel(host, pifd, pofd, addrpair)
	char *host;
	int *pifd;
	int *pofd;
	char *addrpair;
{
	int topipe[2];
	int frompipe[2];
	pid_t child;
	char *tunnel_command;
	char *rsh_command;
	char *rsh_user;
	char *addrpair_rev;
	char *s;

	/* Reverse the local:remote string */
	s = strchr(addrpair, ':');
	if (s == NULL)
		errx(1, "%s: invalid address pair \"%s\"", localhost(), 
		    addrpair);
	addrpair_rev = malloc(strlen(addrpair)+1);
	sprintf(addrpair_rev, "%.*s:%.*s",
		(int)strlen(s+1), s+1,
		s-addrpair, addrpair);

	if (pipe(topipe) < 0 || pipe(frompipe) < 0)
		err(1, "%s: pipe", localhost());

	if ((tunnel_command = getenv("TUNNEL_COMMAND")) == NULL)
		tunnel_command = "tunnel";
	if ((rsh_command = getenv("TUNNEL_RSH")) == NULL)
		rsh_command = "ssh";
	if ((rsh_user = getenv("TUNNEL_RSH_USER")) == NULL)
		rsh_user = NULL;

	signal(SIGCHLD, sigchld);
	child = fork();
	if (child < 0)
		err(1, "%s: fork", localhost());
	if (child == 0) {
		/* In child. */
		if (dup2(topipe[0], STDIN_FILENO) < 0)
			err(1, "%s: dup2", localhost());
		if (dup2(frompipe[1], STDOUT_FILENO) < 0)
			err(1, "%s: dup2", localhost());
		close(topipe[0]);
		close(topipe[1]);
		close(frompipe[0]);
		close(frompipe[1]);
		if (seteuid(getuid()) == -1)
			err(1, "%s: seteuid", localhost());
		if (setegid(getgid()) == -1)
			err(1, "%s: setegid", localhost());
		if (rsh_user)
		{
			execlp(rsh_command,
			    rsh_command,
			    "-l", rsh_user, host,
			    tunnel_command, "-remote", addrpair_rev,
			    NULL);
		}
		else
		{
			execlp(rsh_command,
			    rsh_command, 
			    host,
			    tunnel_command, "-remote", addrpair_rev,
			    NULL);
		}
		err(1, "%s: %s", localhost(), rsh_command);
	}

	close(topipe[0]);
	close(frompipe[1]);

	*pifd = frompipe[0];
	*pofd = topipe[1];

	free(addrpair_rev);

	return (child);
}

void
transport(ifd, ofd, tunfd)
	int ifd;
	int ofd;
	int tunfd;
{
	fd_set rfds, wfds;
	int maxfd;
	char *obuffer, *ibuffer;
	char *obuf, *ibuf;
	size_t obufmaxsiz, ibufmaxsiz;
	int ibuflen, obuflen;
	u_int32_t ipktlen;
	int len;

	FD_ZERO(&rfds);
	FD_ZERO(&wfds);

	maxfd = ifd < ofd ? ofd : ifd;
	maxfd = maxfd < tunfd ? tunfd : maxfd;

	ibufmaxsiz = obufmaxsiz = 8192;

	if ((ibuffer = malloc(ibufmaxsiz)) == NULL)
		err(1, "%s: malloc", localhost());
	if ((obuffer = malloc(obufmaxsiz)) == NULL)
		err(1, "%s: malloc", localhost());

	ipktlen = 0;
	ibuflen = 0;
	ibuf = ibuffer;
	obuflen = 0;
	obuf = NULL;

	while(1) {
		if (obuflen) {
			FD_SET(ofd, &wfds);
			FD_CLR(tunfd, &rfds);
		} else {
			FD_CLR(ofd, &wfds);
			FD_SET(tunfd, &rfds);
		}

		if (ipktlen && (ibuflen == ipktlen)) {
			FD_SET(tunfd, &wfds);
			FD_CLR(ifd, &rfds);
		} else {
			FD_CLR(tunfd, &wfds);
			FD_SET(ifd, &rfds);
		}

#ifdef DEBUG
		fprintf(stderr, 
			"%s: obuflen=%-6d ibuflen=%-6d ipktlen=%-6d  "
			    "r:%c%c%c w:%c%c%c\n",
			localhost(),
			obuflen, ibuflen, ipktlen,
			(FD_ISSET(ifd, &rfds)   ? 'i' : '-'),
			(FD_ISSET(ofd, &rfds)   ? 'o' : '-'),
			(FD_ISSET(tunfd, &rfds) ? 't' : '-'),
			(FD_ISSET(ifd, &wfds)   ? 'i' : '-'),
			(FD_ISSET(ofd, &wfds)   ? 'o' : '-'),
			(FD_ISSET(tunfd, &wfds) ? 't' : '-'));
#endif

		if (select(maxfd + 1, &rfds, &wfds, NULL, NULL) < 0) {
			if (errno != EINTR)
				err(1, "%s: select", localhost());
			else
				continue;
		}

		if (FD_ISSET(tunfd, &rfds)) {
			/*
			 * Read an entire packet from the
			 * tunnel device and store in the 
			 * outgoing buffer.
			 */
			len = read(tunfd, obuffer + sizeof(u_int32_t),
				obufmaxsiz - sizeof(u_int32_t));
			if (len < 0)
				err(1, "%s: tun read", localhost());
			if (len == 0) {
				warnx("%s: tunnel closed", localhost());
				return;
			}
			*(u_int32_t*)obuffer = htonl(len + sizeof(u_int32_t));

			obuf = obuffer;
			obuflen = len + sizeof(u_int32_t);
		}

		if (FD_ISSET(ofd, &wfds)) {
			/*
			 * Write some of the outgoing buffer out.
			 */
			len = write(ofd, obuf, obuflen);
			if (len < 0)
				err(1, "%s: output write", localhost());
			obuflen -= len;
			obuf += len;
		}

		if (FD_ISSET(tunfd, &wfds)) {
			/*
			 * Write a fully received packet 
			 * onto the tunnel device
			 */
			len = write(tunfd, ibuffer + sizeof(u_int32_t),
				ibuflen - sizeof(u_int32_t));
			if (len < 0)
				err(1, "%s: tun write", localhost());

			ibuflen = 0;
			ibuf = ibuffer;
			ipktlen = 0;
		}

		if (FD_ISSET(ifd, &rfds)) {
			if (ipktlen == 0) {
				len = read(ifd, ibuf, 
				    sizeof(u_int32_t) - ibuflen);
				if (len + ibuflen == sizeof(u_int32_t))
					ipktlen = ntohl(*(u_int32_t*)ibuffer);
			} else {
				len = read(ifd, ibuf, ipktlen - ibuflen);
			}
			if (len < 0)
				err(1, "%s: input read", localhost());
			if (len == 0) {
				warnx("%s: tunnel closed", localhost());
				return;
			}
			ibuflen += len;
			ibuf += len;
		}
	}
}

void
usage(argv0)
	char *argv0;
{

	fprintf(stderr, "usage: %s host local-addr:remote-addr\n", argv0);
	exit(2);
}

int
main(argc, argv)
	int argc;
	char **argv;
{
	int ifd;
	int ofd;
	int tunfd;
	pid_t child;
	char *host;
	char *addrpair;

	if (argc != 3) 
		usage(argv[0]);

	host = argv[1];
	addrpair = argv[2];

	if (strcmp(host, "-remote") == 0) {
		/* remote end */
		ifd = STDIN_FILENO;
		ofd = STDOUT_FILENO;
		child = 0;
		tunfd = opentun(addrpair);
	} else {
		/* local end */
		daemon(0, 1);
		tunfd = opentun(addrpair);
		child = build_tunnel(host, &ifd, &ofd, addrpair);
	}

	transport(ifd, ofd, tunfd);

	if (child && kill(SIGINT, child) < 0)
			err(1, "%s: kill %d", localhost(), child);

	exit(0);
}

