#!/usr/bin/env ruby
#
#  equity 0.4 - queueing software load balancer with transmission inspection
#               and simple HTTP header rewriting
#
#  Usage: equity [-dp] [-h "Header: Value"] <listen-port> [<node-address>:]<node-port> ...
#  Run equity with no arguments for detailed usage information.
#

require 'socket'
require 'equity/node'
require 'equity/manager'
require 'getoptlong'

# Returns a string that can be used to identify a particular node in debugging
# messages. At the moment this returns the address and port of the client being
# served by the node.
def client_id(node)
  if node.connected?
    "#{node.client_address}:#{node.client_port}"
  else
    "not connected"
  end
end

# Prints a debug message if debugging is enabled.
def debug(socket, message)
  return unless $debugging  
  if socket
    if node = Equity::Node.with_socket(socket)
      client_id = client_id(node)
    else
      client_id = nil
      begin
        nameinfo = Socket.getnameinfo(socket.getpeername,
          Socket::NI_NUMERICHOST | Socket::NI_NUMERICSERV)
        client_id = "#{nameinfo[0]}:#{nameinfo[1]}"
      rescue
        client_id = "???:???"
      end
    end
    print "[#{client_id}] "
  end
  puts message
end

# Prints data in hex and ASCII if packet dumping is enabled.
BYTES_PER_PACKET_LINE = 16
SAFE_PACKET_RX = (32..126).to_a.collect! {|n| n.chr}.join("").sub!(/[\[\]]/, '\\\1')
def dump_packet(data, src, dst)
  return unless $dump_packets
  
  node = Equity::Node.with_socket(src) || return
  to_server = (src == node.sockets.first)
  arrow = (to_server ? "--->" : "<---")
  print "[#{client_id(node)}] #{arrow} #{node.address}:#{node.port}\n\n"
  
  data = data.dup
  until data.empty?
    bytes = data.slice!(0, BYTES_PER_PACKET_LINE)
    hex = bytes.unpack("H*").first + ("  " * (BYTES_PER_PACKET_LINE - bytes.length))
    hex1 = hex.slice!(0, BYTES_PER_PACKET_LINE)
    hex2 = hex
    hex1.gsub!(/(....)/, '\1 ')
    hex2.gsub!(/(....)/, '\1 ')
    bytes.gsub!(/[^#{SAFE_PACKET_RX}]/, ".")
    puts "  #{hex1}  #{hex2}  #{bytes}"
  end
  print "\n"
end

# Handles rewriting of HTTP request headers. This method isn't very smart about
# how it detects HTTP requests, but should work for most situations.
def rewrite_headers!(data, src)
  return if $header_rewrites.empty?
  return unless http_slash_idx = data.index(" HTTP/")
  node = Equity::Node.with_socket(src) || return
  return unless (src == node.sockets.first)
  
  debug(src, "rewriting HTTP request headers")
  
  before_headers = data.slice!(0, data.index("\r\n", http_slash_idx))
  headers = data.slice!(0, data.index("\r\n\r\n"))
  after_headers = data
  
  $header_rewrites.each do |header, value|
    if headers.sub!(/\r\n#{header}: (.+?)\r\n/m, "\r\n#{header}: #{value}\r\n")
      debug(src, "  #{header}: #{$1} => #{value}")
    end
  end
  data.replace(before_headers + headers + after_headers)
rescue
  # If an exception occurs, ignore it. Rewriting the headers is far less
  # important than keeping the program alive.
end


# Process arguments.
$debugging       = !!ARGV.delete('-d')
$dump_packets    = !!ARGV.delete('-p')
$header_rewrites = {}

options = GetoptLong.new(
  ["-d", "--debug",   GetoptLong::NO_ARGUMENT],
  ["-p", "--packets", GetoptLong::NO_ARGUMENT],
  ["-h", "--header",  GetoptLong::REQUIRED_ARGUMENT]
)
options.each do |option, argument|
  case option
  when "-d"
    $debugging = true
  when "-p"
    $dump_packets = true
  when "-h"
    header = argument.split(/: +/)
    unless header.length == 2
      STDERR.puts "I don't understand this header rewrite:"
      STDERR.puts "  #{argument}"
      STDERR.puts "Run equity with no arguments for help."
      exit(64) # EX_USAGE
    end
    $header_rewrites[header.first] = header.last
  end
end

if ARGV.length < 2
  STDERR.puts 'equity 0.4'
  STDERR.puts 'Usage: equity [<options>] <listen-port> [<node-address>:]<node-port> ...'
  STDERR.print "\n"
  STDERR.puts 'Node addresses default to localhost.'
  STDERR.print "\n"
  STDERR.puts 'The following options can be used:'
  STDERR.puts '-d Debug mode. Stays in the foreground and prints status messages.'
  STDERR.puts '-p Packet dumper. Stays in the foreground and displays the raw data being'
  STDERR.puts '   transferred.'
  STDERR.puts '-h "<header>: <value>" Rewrite an HTTP request header before retransmitting.'
  STDERR.puts '   Can be used more than once to rewrite several headers.'
  STDERR.print "\n"
  exit(64) # EX_USAGE
end

# Instantiate nodes.
$listen_port = ARGV.shift.to_i

$nodes = []
ARGV.each do |node_spec|
  address, port = node_spec.split(':', 2)
  if port.nil?
    port = address
    address = 'localhost'
  end
  $nodes << Equity::Node.new(address, port)
end

# Daemonize.
unless $debugging || $dump_packets
  fork && exit
  Process.setsid
  trap 'SIGHUP', 'IGNORE'
  fork && exit
  Dir.chdir '/tmp'
  File.umask 0000
  ObjectSpace.each_object(IO) do |io|
    unless [STDIN, STDOUT, STDERR].include?(io)
      io.close rescue nil
    end
  end
  STDIN.reopen "/dev/null"
  STDOUT.reopen "/dev/null", "a"
  STDERR.reopen STDOUT
end

# Print debugging info for header rewrites.
$header_rewrites.each do |header, value|
  debug(nil, "Will rewrite value of HTTP request header `#{header}' to `#{value}'")
end

# Set up SIGINT handler to print node counters.
trap 'SIGINT', Proc.new {
  debug(nil, "\nNode Counters")
  $nodes.each do |node|
    debug(nil, "#{node} - #{node.counter} connections")
  end
  exit
}

# Start up control server.
$control_server = Equity::Manager.new($nodes, $listen_port)

# Start up server.
$client_queue = []

$server = TCPServer.new(nil, $listen_port)
$server.listen(5)

debug(nil, "Ready on port #{$listen_port}")

# Loop forever.
while true
  # Wait for one or more sockets to be readable.
  sockets = [$server, $nodes.collect {|n| n.sockets}]
  sockets.flatten!
  selected = select(sockets, nil, nil, 5)
  if selected
    selected[0].each do |socket|
      if socket == $server
        # Incoming connection. Accept it and queue it.
        client = $server.accept
        $client_queue << client
        debug(client, 'connection accepted')
      else
        # Data received. Transfer it to the socket's mate.
        node = $nodes.find {|n| n.owns_socket?(socket)}
        next if node.nil?
        begin
          data = socket.recvfrom(65536)[0]
          raise Errno::EPIPE if data.empty?
          mate = socket.mate
          rewrite_headers!(data, socket)
          dump_packet(data, socket, mate)
          mate.write(data)
        rescue Errno::EPIPE, Errno::ECONNRESET
          # Connection closed. Disconnect this node.
          node.disconnect
          debug(socket, 'disconnected')
        end
      end
    end
  end
  
  # Dequeue as many clients as possible, showing favor to nodes that have
  # handled the fewest clients.
  while !$client_queue.empty?
    client = $client_queue.shift
    found_a_node = false
    failure_count = 0
    sorted_notes = $nodes.sort {|a,b| a.counter <=> b.counter}
    sorted_notes.each do |node|
      next if node.connected?
      begin
        node.connect(client)
      rescue Exception => e
        debug(client, "connection to #{node} failed: #{e.message}")
        failure_count += 1
        next
      end
      debug(client, "assigned to #{node}")
      found_a_node = true
      break
    end
    
    $client_queue.unshift(client) unless found_a_node
    
    if failure_count == $nodes.length
      # All nodes failed. Disconnect all queued clients.
      debug(nil, 'All nodes failed')
      $client_queue.each {|client| client.close}
      $client_queue.replace []
    end
    
    break unless found_a_node
  end
  
  # Check for control packets.
  $control_server.process_commands
end
