#include "tcp_server.h"

#include <fstream.h>
#include <iostream.h>

#define USE_REAPER // if you need to uncomment this, be warned! The other method is 
// not well tested and is knowned to leave zombies on Linux

#ifdef USE_REAPER
static void reaper(int sig_num);  // this one has to be outside of class


static void reaper(int sig_num)
 {
   int   status;

    while (waitpid(-1, &status, WNOHANG) > 0)
       /*  fprintf(stderr, "waiting...\n")  */     ;
    signal(SIGCHLD, reaper);
 }

#else

static int fork2()
 {
         pid_t pid;
         int rc;
         int status;
     
         if (!(pid = fork()))
         {
             switch (fork())
             {
               case 0:  return 0;
               case -1: _exit(errno);    
               default: _exit(0);
             }
         }
     
         if (pid < 0 || waitpid(pid,&status,0) < 0)
           return -1;
     
         if (WIFEXITED(status))
           if (WEXITSTATUS(status) == 0)
             return 1;
           else
             errno = WEXITSTATUS(status);
         else
           errno = EINTR;  /* well, sort of :-) */
     
         return -1;
 }

#endif 


void TcpServer::start_server(int port, int listen_queue_size, char* log_file)
 {
  if(!log_file)
   log = stderr;
  else if(!strcmp("-", log_file))
    log = stderr;
  else
   {
    log = fopen(log_file, "a");
    if(!log)
     {
      fprintf(stderr, "Tcp Server: could not open log file %s : %s\n", 
        log_file, strerror(errno));
      exit(1);
     }
   }
  
  setbuf(log, NULL);

  struct protoent *tcp_prot = getprotobyname("tcp");
  if(!tcp_prot)
   fatal_error("TcpServer: tcp does not seem to be supported");

  socket_fd = socket(AF_INET,SOCK_STREAM, tcp_prot->p_proto);

  if(socket_fd == -1)
   fatal_error("TcpServer: error creating the socket: %s", strerror(errno));
  
  memset(&serv_addr, 0, sizeof(serv_addr));
  serv_addr.sin_addr.s_addr = htonl(INADDR_ANY);
  serv_addr.sin_port = htons(port);
  serv_addr.sin_family = AF_INET;
  
  int sock_len = sizeof(serv_addr) ; 
 
  if(bind(socket_fd, (struct sockaddr*)&serv_addr, sizeof(serv_addr)) 
                        == -1)
    {
     fatal_error("TcpServer: Could not bind: %s", strerror(errno));
    }

   if(listen(socket_fd, listen_queue_size) == -1)
    fatal_error("TcpServer: could not listen: %s", strerror(errno));



 }

void TcpServer::run_server()
 {
   
   #ifdef USE_REAPER
    pid_t fork_pid = fork();
   #else
    pid_t fork_pid = fork2();

   #endif

   if(fork_pid == -1)
    fatal_error("TcpServer: could not fork into background: %s", strerror(errno));

   #ifdef USE_REAPER
    signal(SIGCHLD, reaper);
   #endif

   if(fork_pid > 0) 
    {
     close(socket_fd);
     return;
    }

   close(0);
   close(1);
   close(2);

   int serv_addr_size = sizeof(serv_addr);

   while(1)
    {
     int client_socket_fd = accept(socket_fd, (struct sockaddr*)&serv_addr, &serv_addr_size);
     
     if(client_socket_fd == -1)
      {
       if(errno != EINTR)
        warn( "TcpServer: error on accept : %s\n", strerror(errno));
      }
     else
      {
       int client_fork_pid = fork();
       if(client_fork_pid == -1)
        warn("TcpServer: could not fork off a client handler : %s\n",
	 strerror(errno));
       else
        if(client_fork_pid == 0)
	 {
	  close(socket_fd);
	
	  if(access_control)
	   if(!access_allowed(client_socket_fd))
	    {
	     close(client_socket_fd);
	     return;
	    }

	  handle_client(client_socket_fd);
	  //fprintf(stderr, "Closing the client socket\n");
	  if(close(client_socket_fd) == -1); 
	   //warn( "Failed to close the client socked %s\n",
	   // strerror(errno));
	  return;
	 }
	else
	 close(client_socket_fd);
      }
    }
 }
   
TcpServer::TcpServer(int port,int listen_queue_size, char* log_file): 
 debug_level(0), access_control(0)
 {
  start_server(port, listen_queue_size, log_file);
 }

TcpServer::~TcpServer()
 {
 }

int TcpServer::access_allowed(int client_socket_fd)
 {
  struct sockaddr_in peername;
  int peername_len = sizeof(sockaddr_in);
  if(getpeername(client_socket_fd, (struct sockaddr*)&peername, &peername_len) == -1)
   {
    warn("Cannot get peer - access denied");
    return 0;
   }
  else
   {
    if(ip_rules.ok(peername.sin_addr))
     return 1;
    char* src_addr = inet_ntoa(peername.sin_addr);
    log_access_denied(src_addr);
   }
 
  return 0;
 }

char* TcpServer::get_peer_addr(int client_socket_fd)
 {
  struct sockaddr_in peername;
  int peername_len = sizeof(sockaddr_in);
  if(getpeername(client_socket_fd, (struct sockaddr*)&peername, &peername_len) == -1)
   warn("Cannot get peer");
  else
   {
    char* src_addr = inet_ntoa(peername.sin_addr);
    return src_addr;
   }
  return "unknown";
 }

void TcpServer::cleanup()
 {
  close(socket_fd);
 }

#ifdef DEBUG

class TestServer: public TcpServer
 {
  protected:

    void handle_client(int client_socket_fd);

  public:

   TestServer(int port, int queue_size, char* log_file) :  
    TcpServer(port, queue_size, log_file)
    {
    }
   
 };



void TestServer::handle_client(int client_socket_fd)
 {
  char buf[512];

  info("connect from %s ", get_peer_addr(client_socket_fd));

  FILE *in  = fdopen(client_socket_fd, "r");
  FILE *out  = fdopen(client_socket_fd, "w");

  setbuf(in, NULL);
  setbuf(out, NULL);

  if(!in)
   {
    warn("fdopen failed on 'in': %s\n", strerror(errno));
    return;
   }

  if(!out)
   {
    warn("fdopen failed on 'out': %s\n", strerror(errno));
    return;
   }

  fprintf(out, "Please enter your name:");
  fgets( buf, sizeof(buf), in);
  fprintf(out, "Good-bye, %s\n", buf);
 }

int main(int c, char** argv)
 {
  int port = 1234;
  char* log_file = "-";
  char* ip_rules_file = NULL;
  if(c > 1) port = atoi(argv[1]);
  if(c > 2) log_file = argv[2];
  if(c > 3) ip_rules_file = argv[3];
   

  TestServer s(port, 50, log_file);
 
  if(ip_rules_file)
   {
    ifstream rules(ip_rules_file);
    if(!rules)
    {
     cerr << "Could not open rules file" << endl;
     exit(1);
    } 
    
    while(1)
    {
     char netnum[128], netmask[128];
     rules >> netnum >> netmask;
     cout << "reading: " << netnum << " " << netmask << endl;
     if(!rules) break;
     s.add_ip_rule(netnum, netmask);
    }
    s.enable_access_control();
   }
     
  s.run_server();
 }

#endif 
