#!/usr/bin/python
__DOC__ = """\
# Copyright 2011, Bjarni R. Einarsson <http://bre.klaki.net/>
# License: AGPLv3
#
# lapcat: Location Aware Proxy Chooser And Tunneler
#         a.k.a. Netcat for your Laptop.
#
# This is a netcat-like tool for opening up a TCP connection to some port
# on some host, where the connection strategy depends on where you are.
#
# Requirements:
#   Python 2.x or 3.x
#   PySocksipyChain, <https://github.com/pagekite/PySocksipyChain/>
#
##############################################################################
#
# For example, say we want 'ssh homeserver' to behave like so:
#
#   - When at home, connect directly (fast!)
#   - At work, use the local HTTP Proxy and PageKite (fast!)
#   - From anywhere else, use a Tor hidden service (private!)
#
# With lapcat, this is possible by defining the following rules in a file
# named ~/.lapcat/homeserver (use lapcat -N to generate network IDs).
#
#   [home]
#   if network = 10.1.2.254/aa:bb:cc:dd:ee:ff
#   host = homeserver.local
#   chain = none
#   priority = 1
#
#   [work]
#   if network = 192.168.55.254/gw:ma:ca:dd:re:ss
#   host = homeserver.pagekite.me
#   chain = http:proxy.corp:8080, http:homeserver.pagekite.me:443
#   priority = 1
#
#   [default]
#   host = 12345123451234512345.onion
#   chain = socks5:localhost:9050
#   priority = 100
#
# Then add the following to ~/.ssh/config
#
#   Host homeserver homeserver.pagekite.me
#     CheckHostIP no
#     ProxyCommand /path/to/lapcat homeserver 22
#
"""
import getopt, os, select, socket, subprocess, sys
import sockschain as socks

def DebugPrint(text):
  sys.stderr.write(text+'\n')
  sys.stderr.flush()

global TRACE
global DEBUG
TRACE = DEBUG = False
TRACE = False

SYS_CONF_DIR = '/etc/lapcat'
USER_CONF_DIR = '~/.lapcat'
IMPORT_KEYWORD = 'import'
DEFAULT_RULE = 'default'
DEFAULT_CHAIN = 'default'

V_ACTIVE = 'active'
V_CHAIN = 'chain'
V_DEFAULT_CHAIN = 'default chain'
V_FINAL = 'final'
V_HOST = 'host'
V_PORT = 'port'
V_PRIORITY = 'priority'
V_TEST_COMMAND = 'test command'
V_TEST_HOST = 'if host'
V_TEST_PORT = 'if port'
V_TEST_NETWORK = 'if network'
VARIABLE_DEFAULTS = {
  V_ACTIVE: True,
  V_CHAIN: DEFAULT_CHAIN,
  V_DEFAULT_CHAIN: None,
  V_HOST: '%h',
  V_PORT: '%p',
  V_FINAL: False,
  V_TEST_COMMAND: None,
  V_TEST_HOST: None,
  V_TEST_PORT: None,
  V_TEST_NETWORK: None,
  V_PRIORITY: 100
}


def Run(argv):
  return subprocess.Popen(argv, stdout=subprocess.PIPE
                          ).communicate()[0].decode().splitlines()

def RunTest(command):
  try:
    if DEBUG: DEBUG("Running: %s" % command)
    retcode = subprocess.call(command, shell=True)
    if DEBUG:
      if retcode < 0:
        DEBUG("Child was terminated by signal: %s" % -retcode)
      else:
        DEBUG("Child returned: %s" % retcode)
    return (retcode == 0)
  except OSError:
    if DEBUG: DEBUG("Execution failed: %s" % (sys.exc_info(), ))
    return False


def GetNetworkId():
  # FIXME: This probably only works on Linux/IPv4 !

  gateway = 'unknown'
  for line in Run(['netstat', '-rn']):
    if line.startswith('0.0.0.0'):
      gateway = line.split()[1].lower()

  network = 'unknown'
  if gateway != 'unknown':
    for line in Run(['arp', '-n', gateway]):
      if line.lower().startswith(gateway):
        network = line.split()[2].lower()

  if DEBUG: DEBUG("Network is: %s/%s" % (gateway, network))
  return '%s/%s' % (gateway, network)


class LapCatConfig(object):
  def __init__(self, hostname, portnum, network):
    self.hostname = hostname
    self.portnum = str(int(portnum))
    self.network = network
    self.rules = {DEFAULT_RULE: {}}
    self.rules[DEFAULT_RULE].update(VARIABLE_DEFAULTS)

  def sysConfig(self, name=None):
    return os.path.join(SYS_CONF_DIR, name or self.hostname)

  def userConfig(self, name=None):
    return os.path.join(os.path.expanduser(USER_CONF_DIR),
                        name or self.hostname)

  def globalConfigs(self):
    """List all global configuration files, in order of preference."""
    configs = []
    for order, dirn in ( ('0', SYS_CONF_DIR),
                         ('1', os.path.expanduser(USER_CONF_DIR)) ):
      try:
        for fn in os.listdir(dirn):
          try:
            pri, rest = fn.split('-', 1)
            pri = '%3.3d-%s' % (int(pri), order)
            configs.append((pri, os.path.join(dirn, fn)))
          except ValueError:
            pass
      except:
        if DEBUG: DEBUG("%s: %s" % (dirn, sys.exc_info()))

    configs.sort(key=lambda k: k[0])
    if DEBUG: DEBUG('Configs are: %s' % configs)
    return [cfg[1] for cfg in configs]

  def load(self, filename=None, require=False, wildcards=False):
    """Load and parse a rule configuration file."""
    filename = filename or self.userConfig()
    try:
      fd = open(filename, 'r')
      if DEBUG: DEBUG("Loading: %s" % filename)
    except:
      fd = None
      if wildcards:
        filedir = os.path.dirname(filename)
        parts = os.path.basename(filename).split('.')
        while len(parts) > 0:
          parts[0] = '_ANY_'
          try:
            filename = os.path.join(filedir, '.'.join(parts))
            fd = open(filename, 'r')
            if DEBUG: DEBUG("Loading: %s" % filename)
            break
          except:
            parts.pop(0)

      if not fd:
        if not require: return self
        raise

    section = self.rules[DEFAULT_RULE]
    count = 0
    for line in fd:
      count += 1
      line = line.strip()

      if line == '' or line.startswith('#'):
        pass

      elif line.startswith('[') and line.endswith(']'):
        secname = line[1:-1]
        if secname == '':
          raise ValueError(('%s(line=%s): Null section') % (filename, count))
        elif secname not in self.rules:
          self.rules[secname] = {}
        section = self.rules[secname]

      elif line.startswith(IMPORT_KEYWORD):
        files = [self.sysConfig(name=line[len(IMPORT_KEYWORD)+1:]),
                 self.userConfig(name=line[len(IMPORT_KEYWORD)+1:])]
        loaded = False
        for fn in files:
          try:
            self.load(filename=fn, require=True)
            loaded = True
          except IOError:
            pass
        if not loaded:
          raise ValueError(('%s(line=%s): File not found, tried: %s'
                            ) % (filename, count, files))

      elif '=' in line:
        var, value = line.split('=')

        var = var.strip().lower()
        if var not in VARIABLE_DEFAULTS:
          raise ValueError(('%s(line=%s): Unknown variable: %s'
                            ) % (filename, count, var))

        value = value.strip()
        if value.lower() in ('true', 'yes'): value = True
        elif value.lower() in ('false', 'no'): value = False
        section[var] = value

      else:
        raise ValueError(('%s(line=%s): Invalid line') % (filename, count))

    return self

  def configure(self):
    """Load all the rules pertaining to this host:port."""
    for config in self.globalConfigs():
      self.load(filename=config, require=True)
    self.load(filename=self.sysConfig(),  require=False, wildcards=True)
    self.load(filename=self.userConfig(), require=False, wildcards=True)
    return self

  def ruleOrder(self):
    """Calculate the order in which to evaluate our rules."""
    keys = [r for r in self.rules]
    keys.sort(key=lambda rule: int(self.rules[rule].get(V_PRIORITY, 999)))
    if DEBUG: DEBUG('Rule order: %s' % keys)
    return keys

  def test(self, rule):
    """Test whether a particular rule matches."""
    if not (rule.get(V_ACTIVE, True) or rule.get(V_DEFAULT_CHAIN, False)):
      return False

    try:
      hosts = (rule.get(V_TEST_HOST, '') or self.hostname).lower().split(', ')
      if self.hostname.lower() not in hosts: return False

      ports = (rule.get(V_TEST_PORT, '') or self.portnum).lower().split(', ')
      if self.portnum not in ports: return False

      ntwks = (rule.get(V_TEST_NETWORK, '') or self.network).split(', ')
      if self.network not in ntwks: return False

      if rule.get(V_TEST_COMMAND, False):
        return RunTest(rule[V_TEST_COMMAND])
      else:
        return True
    except:
      return False

  def connect(self):
    """Connect to the host:port."""
    rules = self.ruleOrder()

    for ruleName in rules:
      rule = self.rules[ruleName]
      if self.test(rule):

        if rule.get(V_DEFAULT_CHAIN, False):
          if DEBUG: DEBUG("Configuring default proxy chain: %s" % rule)
          socks.setdefaultproxy()
          for proxy in rule[V_DEFAULT_CHAIN].split(', '):
            socks.adddefaultproxy(*socks.parseproxy(proxy))

        if rule.get(V_CHAIN, False) and rule.get(V_ACTIVE, True):
          try:
            host = (rule.get(V_HOST, '') or self.hostname
                    ).replace('%h', self.hostname)
            port = (rule.get(V_PORT, '') or self.portnum
                    ).replace('%p', self.portnum)

            sock = socks.socksocket(socket.AF_INET, socket.SOCK_STREAM)
            for proxy in rule.get(V_CHAIN, DEFAULT_CHAIN).split(', '):
              sock.addproxy(*socks.parseproxy(proxy.strip()
                                              .replace('%h', host)
                                              .replace('%p', port)))
            sock.connect((host, int(port)))
            if DEBUG: DEBUG('Connected! [%s]' % ruleName)
            return sock

          except:
            if DEBUG: DEBUG('connect(%s) failed: %s' % (ruleName,
                                                        sys.exc_info()))
            if rule.get(V_FINAL, False):
              raise IOError("Connect failed at: %s" % ruleName)

    raise IOError("Connect failed, tried: %s" % rules)


def NetCat(host, port, input_fd, output_fd):
  try:
    network = GetNetworkId()
    socks.netcat(LapCatConfig(host, port, network).configure().connect(),
                 input_fd, output_fd)
  except IOError:
    DebugPrint('%s' % (sys.exc_info(), ))
    sys.exit(1)

def SetProcTitle(title):
  try:
    import setproctitle
    setproctitle.setproctitle(title)
  except:
    pass

def HttpProxy(input_fd, output_fd):
  try:
    # Get the initial request
    request = ''
    loops = 1024
    while not (loops < 1 or
               request.endswith('\n\n') or
               request.endswith('\r\n\r\n')):
      request += os.read(input_fd.fileno(), 1)
      loops -= 1

    if TRACE: TRACE('<<< Got request (l:%s):\n%s<<<\n' % (1024-loops, request))

    # If it is a HTTP CONNECT, we connect directly.
    words = request.split()
    if (len(words) >= 3 and
        words[0].upper() == 'CONNECT' and
        words[2].upper().startswith('HTTP/')):
      output_fd.write('HTTP/1.1 200 Tunnel established\r\n\r\n')
      output_fd.flush()
      host, port = words[1].split(':')
      if DEBUG: DEBUG('Using native lapcat connection to %s:%s' % (host, port))
      SetProcTitle('lapcat: %s:%s' % (host, port))
      NetCat(host, port, input_fd, output_fd)

    # Otherwise, forward this to a real HTTP Proxy for processing.
    elif len(words) > 2:
      if DEBUG: DEBUG('Connecting via. lapcat-http-proxy')
      host = 'lapcat-http-proxy'
      network = GetNetworkId()
      SetProcTitle('lapcat: %s' % words[1])
      conn = LapCatConfig(host, 0, network).configure().connect()
      conn.sendall(request)
      socks.netcat(conn, input_fd, output_fd)

  except (ValueError, IOError):
    DebugPrint('%s' % (sys.exc_info(), ))
    sys.exit(1)


class FileWrapper(object):
  def __init__(self, sock):
    self.sock = sock
  def flush(self): pass
  def close(self): return self.sock.close()
  def write(self, data): return self.sock.send(data)
  def fileno(self): return self.sock.fileno()


def ForkAndListen(outfmt, baseport=0, tries=1, loop=False, relative=False):
  for t in range(0, tries):
    try:
      try:
        srv = socket.socket(socket.AF_INET6, socket.SOCK_STREAM)
      except:
        srv = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
      srv.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
      srv.bind(('', baseport+t))
      break
    except:
      srv = None

  srv.listen(3)

  if relative:
    sys.stdout.write((outfmt+'\n') % (srv.getsockname()[1]-baseport))
  else:
    sys.stdout.write((outfmt+'\n') % srv.getsockname()[1])

  sys.stdout.flush()
  os.close(sys.stdout.fileno())
  os.close(sys.stdin.fileno())
  if not loop and os.fork() != 0: os._exit(0)

  while True:
    # Wait for a connection...
    i, o, e = select.select([srv], [], [], 15)
    if srv in i:
      client, address = srv.accept()
      if DEBUG: DEBUG('Accepted: %s' % (address, ))
      if not (loop and (os.fork() != 0)):
        srv.close()
        fw = FileWrapper(client)
        return fw, fw
      client.close()
    elif not loop:
      # Or die?
      os._exit(0)


if __name__ == '__main__':
  opts, args = getopt.getopt(sys.argv[1:], 'hl:NPRtvV:',
                             ['listen=', 'tc', 'tp', 'tf=', 'vnc', 'rdp'])

  if len(args) == 1 and ':' in args[0]:
    args = args[0].split(':')

  use_sysdefaults = True
  mode, portadd, inlinefmt, inlineargs = 'netcat', 0, '', {}

  for opt, arg in opts:
    if '-V' == opt:
      opt = '-v'
      sys.stderr = open(arg, 'a')
    if '-v' == opt:
      if DEBUG and socks.DEBUG: TRACE = DebugPrint
      if DEBUG: socks.DEBUG = DebugPrint
      DEBUG = DebugPrint

    if '-N' == opt: mode = 'networkid'
    elif '-P' == opt: mode = 'httpproxy'
    elif '-R' == opt: use_sysdefaults = False
    else:
      if mode not in ('netcat', 'httpproxy'):
        mode = 'invalid'
        break
      elif '-t' == opt:   inlinefmt = '127.0.0.1 %d'
      elif '--tc' == opt: inlinefmt = '127.0.0.1:%d'
      elif '--tp' == opt: inlinefmt = '%d'
      elif '--tf' == opt: inlinefmt = arg
      elif '--rdp' == opt:
        inlinefmt = '127.0.0.1:%d'
        if len(args) == 1: args.append('3389')
      elif '--vnc' == opt:
        inlinefmt, portadd = '127.0.0.1:%d', 5900
        inlineargs = {'baseport': 5900, 'tries': 20, 'relative': True}
        if len(args) == 1: args.append('0')
      elif opt in ('-l', '--listen'):
        inlinefmt = '127.0.0.1:%d'
        inlineargs = {'baseport': int(arg), 'tries': 1, 'loop': True}

  if use_sysdefaults:
    socks.usesystemdefaults()

  # Set up the listener, if necessary...
  if inlinefmt and ((mode == 'netcat' and len(args) == 2) or
                    (mode == 'httpproxy')):
    fin, fout = ForkAndListen(inlinefmt, **inlineargs)
  else:
    fin, fout = sys.stdin, sys.stdout

  # Do proxy stuff!
  if mode == 'netcat' and len(args) == 2:
    NetCat(args[0], portadd + int(args[1].replace(':', '')), fin, fout)
  elif mode == 'httpproxy' and len(args) == 0:
    HttpProxy(fin, fout)

  # Or print information!
  elif mode == 'networkid' and len(args) == 0:
    print('%s' % GetNetworkId())
  elif len(args) == 1 and args[0] in ('-h', '--help'):
    DebugPrint(__DOC__)
  else:
    print((
      '%(p)s: Location Aware Proxy Chooser And Tunneler / NetCat for Laptops\n'
      '\n'
      'Usage: %(p)s [-v [-v]]      host port     # Connect to host:port\n'
      '       %(p)s <-t|--tc|--tp> host port     # Inline proxy mode\n'
      '       %(p)s --tf=<fmt>     host port     # Inline proxy mode\n'
      '       %(p)s --rdp          host [port]   # Inline RDP proxy mode\n'
      '       %(p)s --vnc          host [screen] # Inline VNC proxy mode\n'
      '       %(p)s -l port        host port     # Local port <=> host proxy\n'
      '       %(p)s -N                           # Show current network ID\n'
      '       %(p)s -P                           # Behave like an HTTP Proxy\n'
      '       %(p)s -h                           # Print instructions\n'
      '\n'
      'To use with ssh, add to ~/.ssh/config:\n'
      '    ProxyCommand %(fp)s %%h %%p\n'
      '    CheckHostIP no\n'
      '\n'
      'Inline use examples:\n'
      '    $ vncviewer `%(p)s --vnc hostname`\n'
      '    $ rdesktop `%(p)s --rdp homebox.pagekite.me`\n'
      '    $ irssi -c localhost -p `%(p)s --tp irc.freenode.net 6667`\n'
    ) % {'fp': os.path.abspath(sys.argv[0]),
         'p': os.path.basename(sys.argv[0])})
    sys.exit(100)

