socketbase.cpp

jwatte's picture

#include "sockimpl.h"
 
using namespace etwork;
using namespace etwork::impl;
 
 
SocketManager::SocketManager()
{
  listening_ = INVALID_SOCKET;
  maxNumSocks_ = FD_SETSIZE;
  numSocks_ = 0;
  maxSock_ = 0;
  allSet_ = (fd_set *)::operator new( sizeof(fd_set) );
  FD_ZERO( allSet_ );
  readSet_ = (fd_set *)::operator new( sizeof(fd_set) );
  FD_ZERO( readSet_ );
  writeSet_ = (fd_set *)::operator new( sizeof(fd_set) );
  FD_ZERO( writeSet_ );
  writeTempSet_ = (fd_set *)::operator new( sizeof(fd_set) );
  FD_ZERO( writeTempSet_ );
  exceptSet_ = (fd_set *)::operator new( sizeof(fd_set) );
  FD_ZERO( exceptSet_ );
  nextSocket_ = 1;
  tmpBuffer_ = 0;
  curQueueSpace_ = 0;
  curTime_ = time_.seconds();
}
 
SocketManager::~SocketManager()
{
  if( listening_ != INVALID_SOCKET ) {
    ::closesocket( listening_ );
    listening_ = INVALID_SOCKET;
  }
  ::operator delete( allSet_ );
  ::operator delete( readSet_ );
  ::operator delete( writeSet_ );
  ::operator delete( writeTempSet_ );
  ::operator delete( exceptSet_ );
  delete[] tmpBuffer_;
}
 
bool SocketManager::open( EtworkSettings * settings )
{
  settings_ = *settings;
 
  if( settings_.accepting && !settings_.port ) {
    ErrorInfo ei;
    ei.error = EtworkError( ES_error, EA_init, EO_invalid_parameters );
    ei.error.setText( "Port may not be 0 when accepting in EtworkSettings." );
    ei.osError = 0;
    ei.socket = 0;
    etwork_info_from( 0, ei );
    return false;
  }
  //  Unreliable socket managers always get one socket,
  //  even if not "accepting" connections.
  if( settings->accepting || !settings->reliable ) {
    int type = SOCK_STREAM;
    int proto = IPPROTO_TCP;
    if( !settings->reliable ) {
      type = SOCK_DGRAM;
      proto = IPPROTO_UDP;
    }
    listening_ = ::socket( AF_INET, type, proto );
    if( IS_SOCKET_ERROR( listening_ ) ) {
      debug_sock_error( 0, ::WSAGetLastError(), EA_init, "::socket(AF_INET)" );
      return false;
    }
    int one = 1;
    int r = ::setsockopt( listening_, SOL_SOCKET, SO_REUSEADDR, (char const *)&one, sizeof(one) );
    if( r < 0 ) {
      debug_sock_error( 0, ::WSAGetLastError(), EA_init, "::setsockopt(SO_REUSEADDR)" );
      goto failure;
    }
    sockaddr_in addr;
    memset( &addr, 0, sizeof( addr ) );
    addr.sin_family = AF_INET;
    addr.sin_port = htons( settings->port );
    r = ::bind( listening_, (sockaddr const *)&addr, (int)sizeof(addr) );
    if( r < 0 ) {
      debug_sock_error( 0, ::WSAGetLastError(), EA_init, "::bind()" );
      goto failure;
    }
    if( settings->reliable ) {
      r = ::listen( listening_, 100 );
      if( r < 0 ) {
        debug_sock_error( 0, ::WSAGetLastError(), EA_init, "::listen(100)" );
        goto failure;
      }
    }
    else {
      //  make the socket non-blocking
      u_long nonblock = 1;
      r = ::ioctlsocket( listening_, FIONBIO, &nonblock );
      if( r < 0 ) {
        debug_sock_error( 0, ::WSAGetLastError(), EA_init, "::ioctlsocket(FIONBIO)" );
        goto failure;
      }
      //  make sure there's enough queuing space
      change_queuing_space();
    }
  }
  tmpBuffer_ = new char[ settings_.maxMessageSize ];
  if( settings->debug ) {
    OutputDebugString( "Etwork: SocketManager::open() was succesful.\n" );
  }
  regenerate_sets();
  return true;
 
failure:
  ::closesocket( listening_ );
  listening_ = INVALID_SOCKET;
  return false;
}
 
namespace {
  struct NotifyActive {
    std::set< ISocket * > & s_;
    NotifyActive( std::set< ISocket * > & s ) : s_( s ) {
    }
    ~NotifyActive() {
      std::set< ISocket * > tmp;
      tmp.swap( s_ );
      std::for_each( tmp.begin(), tmp.end(), notify );
    }
    static void notify( ISocket * s ) {
      Socket * ss = static_cast< Socket * >( s );
      if( !ss->notify_ ) {
        //  The socket notify was removed from this socket while in-flight!
        //  User won't really hear about this through the regular channel.
        etwork_error_from( ss, ss->mgr_, EtworkError( ES_warning, EA_session, EO_internal_error ) );
      }
      else {
        ss->notify_->onNotify();
      }
    }
  };
}
 
void SocketManager::timeout_sockets()
{
  for( SocketMap::iterator ptr = sockets_.begin(), end = sockets_.end(); ptr != end; ) {
    Socket *cl = (*ptr).second;
    bool remove = false;
    if( settings_.timeout > 0 && cl->lastActive_ + settings_.timeout < curTime_ ) {
      etwork_error_from( cl, this, EtworkError( ES_note, EA_session, EO_peer_timeout ) );
      remove = true;
    }
    else if( settings_.keepalive > 0 && cl->lastKeepalive_ + settings_.keepalive < curTime_ ) {
      //  send keepalive message
      cl->write( "", 0 );
    }
    ++ptr;
    if( remove ) {
      cl->close_socket();
      if( cl->accepted_ ) {
        if( cl->notify_ ) {
          notify_.insert( cl );
        }
        else {
          active_.insert( cl );
        }
      }
    }
  }
}
 
int SocketManager::poll( double seconds, ISocket ** outActive, int maxActive )
{
  if( maxActive < 1 || !outActive ) {
    debug_sock_error( 0, WSAEINVAL, EA_session, "SocketManager::poll() maxActive" );
    return -1;
  }
 
  memset( outActive, 0, sizeof( *outActive )*maxActive );
  active_.clear();
  NotifyActive na( notify_ );   //  Make sure they get notified on all exit paths.
                                //  Also, NotifyActive clears the set after notification.
 
  //  handle timeouts
  double now = time_.seconds();
  curTime_ = now;
  timeout_sockets();
 
  if( seconds < 0 ) {
    seconds = 0;
  }
 
  //  Only write to sockets the first time I poll in a given timeout period,
  //  unless they show to have activity (making progress).
  memcpy( writeSet_, allSet_, sizeof(fd_set)+sizeof(SOCKET)*(numSocks_-FD_SETSIZE) );
 
again:
#if defined( WIN32 )
  if( numSocks_ == 0 ) {
    //  no sockets to poll anymore -- return what we have
    std::copy( active_.begin(), active_.end(), outActive );
    return (int)active_.size();
  }
  //  I do this copying around of socket sets to avoid having to
  //  walk each socket struct for each select() -- doing that
  //  would take a lot of cahce misses, and thus be slow.
  //  (I'm kidding myself -- I end up walking a lot of them later
  //  in this function anyway...)
  memcpy( readSet_, allSet_, sizeof(fd_set)+sizeof(SOCKET)*(numSocks_-FD_SETSIZE) );
  memcpy( exceptSet_, allSet_, sizeof(fd_set)+sizeof(SOCKET)*(numSocks_-FD_SETSIZE) );
#else
#error "implement me!"
#endif
 
  //  Calculate timeout
  curTime_ = time_.seconds();
  double then = seconds - curTime_ + now;
  if( then < 0 ) {
    then = 0;
  }
  timeval timo;
  timo.tv_sec = (int)floor( then );
  timo.tv_usec = (int)((seconds-timo.tv_sec)*1000000);
  int r = ::select( (int)maxSock_, readSet_, writeSet_, exceptSet_, &timo );
  // Handle error case
  if( r < 0 ) {
    debug_sock_error( 0, WSAGetLastError(), EA_session, "::select()" );
    if( active_.size() ) {
      std::copy( active_.begin(), active_.end(), outActive );
      if( settings_.debug ) {
        OutputDebugString( "select() failed but returning active sockets\n" );
      }
      return (int)active_.size();
    }
    return -1;
  }
 
  //  Start out assuming no sockets will be writing the next time around.
  FD_ZERO( writeTempSet_ );
  bool progress = true;
#if defined( WIN32 )
  for( size_t i = 0; i < readSet_->fd_count; ++i ) {
    //  service read
    SOCKET s = readSet_->fd_array[i];
#else
#error "implement me"
#endif
    if( s == listening_ ) {
      //  Listening read will also do unreliable sockets.
      //  Return false if it's time to exit out of the read loop.
      if( !handle_listening_read( maxActive ) ) {
        if( settings_.debug ) {
          OutputDebugString( "handle_listening_read() failed.\n" );
        }
        progress = false;
      }
    }
    else {
      //  look up the socket handler for this fd
      SocketMap::iterator ptr = sockets_.find( s );
      ASSERT( ptr != sockets_.end() );
      if( ptr != sockets_.end() ) {
        Socket * s = (*ptr).second;
        //  Tell the socket to put data into its message queue.
        //  Return false if it's time to exit out of the read loop.
        if( s->wants_to_read() ) {
          if( !s->do_read() ) {
            if( settings_.debug ) {
              OutputDebugString( "Socket->do_read() failed.\n" );
            }
            progress = false;
          }
          if( s->notify_ ) {
            notify_.insert( s );
          }
          else {
            active_.insert( s );
          }
        }
        //  this socket has activity -- give it another chance at writing
        if( progress && s->bufOut_.space_used() > 0 && !s->closed() ) {
          FD_SET( (*ptr).first, writeTempSet_ );
        }
      }
    }
    if( active_.size() == maxActive ) {
      if( settings_.debug ) {
        OutputDebugString( "Etwork: SocketManager::poll() filled up the active socket array in read.\n" );
      }
      goto no_more_actives;
    }
#if defined( WIN32 )
  }
  for( size_t i = 0; i < writeSet_->fd_count; ++i ) {
    //  service write
    SOCKET s = writeSet_->fd_array[i];
#else
#error "implement me"
#endif
    if( s == listening_ ) {
      if( !handle_listening_write( maxActive ) ) {
        if( settings_.debug ) {
          OutputDebugString( "handle_listening_write() failed.\n" );
        }
        progress = false;
      }
    }
    else {
      SocketMap::iterator ptr = sockets_.find( s );
      //  socket may close inside read()
      if( ptr != sockets_.end() ) {
        if( (*ptr).second->wants_to_write() ) {
          if( !(*ptr).second->do_write() ) {
            if( settings_.debug ) {
              OutputDebugString( "Socket->do_write() failed.\n" );
            }
            progress = false;
          }
          else {
            if( (*ptr).second->notify_ ) {
              notify_.insert( (*ptr).second );
            }
            else {
              active_.insert( (*ptr).second );
            }
          }
          //  this socket has activity -- give it another chance at writing
          if( progress && (*ptr).second->wants_to_write() ) {
            FD_SET( (*ptr).first, writeTempSet_ );
          }
        }
        //  I don't add back sockets that have writebufData_, because those
        //  mean that they are behind on their window size, so I should
        //  wait trying to ram more data down their throat anyway.
      }
    }
    if( active_.size() == maxActive ) {
      if( settings_.debug ) {
        OutputDebugString( "Etwork: SocketManager::poll() filled up the active socket array in write.\n" );
      }
      goto no_more_actives;
    }
#if defined( WIN32 )
  }
  for( size_t i = 0; i < exceptSet_->fd_count; ++i ) {
    //  service except
    SOCKET s = exceptSet_->fd_array[i];
#else
#error "implement me"
#endif
    if( s == listening_ ) {
      if( !handle_listening_except( maxActive ) ) {
        if( settings_.debug ) {
          OutputDebugString( "handle_listening_except() failed.\n" );
        }
        progress = false;
      }
    }
    else {
      SocketMap::iterator ptr = sockets_.find( s );
      //  sockets may close within read()
      if( ptr != sockets_.end() ) {
        if( !(*ptr).second->do_except() ) {
          if( settings_.debug ) {
            OutputDebugString( "Socket->do_except() failed.\n" );
          }
          progress = false;
        }
        if( (*ptr).second->notify_ ) {
          notify_.insert( (*ptr).second );
        }
        else {
          active_.insert( (*ptr).second );
        }
      }
    }
    if( active_.size() == maxActive ) {
      if( settings_.debug ) {
        OutputDebugString( "Etwork: SocketManager::poll() filled up the active socket array in except.\n" );
      }
      goto no_more_actives;
    }
#if defined( WIN32 )
  }
#else
#error "implement me"
#endif
  then = time_.seconds();
  //  If I made progress (i e, no-one spin-reading on a full buffer), then
  //  I may consider going back for seconds until the timeout runs out.
  if( then-now < seconds && progress ) {
#if defined( WIN32 )
    memcpy( writeSet_, writeTempSet_, sizeof(((fd_set *)0)->fd_count)+sizeof(SOCKET)*numSocks_ );
#else
#error "implement me"
#endif
    goto again;
  }
no_more_actives:
  std::copy( active_.begin(), active_.end(), outActive );
  return (int)active_.size();
}
 
int SocketManager::accept( ISocket ** outAccepted, int maxAccepted )
{
  memset( outAccepted, 0, sizeof(*outAccepted)*maxAccepted );
  int i = 0;
  bool changed = false;
  for( ; i < maxAccepted; ++i ) {
    if( !accepted_.size() ) {
      break;
    }
    Socket * s = accepted_.front();
    accepted_.pop_front();
    // note that s_ is a "socket id" for unreliable sockets
    sockets_[s->s_] = s;
    changed = true;
    outAccepted[i] = s;
    s->accepted_ = true;
  }
  if( changed ) {
    change_queuing_space();
    regenerate_sets();
  }
  return i;
}
 
int SocketManager::connect( char const * address, unsigned short port, ISocket ** outConnected )
{
  *outConnected = 0;
  sockaddr_in addr;
  memset( &addr, 0, sizeof( addr ) );
  addr.sin_family = AF_INET;
  addr.sin_port = htons( port );
  {
    //  gethostbyname() is not thread-safe, even across instances
    Locker ghl( gethostLock );
    hostent * ent = gethostbyname( address );
    if( !ent ) {
      debug_sock_error( 0, WSAGetLastError(), EA_address, "::gethostbyname()" );
      return -1;
    }
    memcpy( &addr.sin_addr, ent->h_addr_list[0], sizeof(addr.sin_addr) );
  }
 
  SOCKET s;
  if( settings_.reliable ) {
    s = ::socket( AF_INET, SOCK_STREAM, IPPROTO_TCP );
    if( IS_SOCKET_ERROR( s ) ) {
      debug_sock_error( 0, WSAGetLastError(), EA_connect, "::socket(AF_INET)" );
      return -1;
    }
    int r = ::connect( s, (sockaddr const *)&addr, (int)sizeof( addr ) );
    if( r < 0 ) {
      debug_sock_error( 0, WSAGetLastError(), EA_connect, "::connect()" );
      ::closesocket( s );
      return -1;
    }
    int one = 1;
    r = ::setsockopt( s, IPPROTO_TCP, TCP_NODELAY, (char const *)&one, sizeof( one ) );
    if( r < 0 ) {
      debug_sock_error( 0, WSAGetLastError(), EA_connect, "::setsockopt(TCP_NODELAY)" );
    }
  }
  else {
    s = listening_;
  }
  //  @TODO: There is a potential race here, where we may have a socket 
  //  waiting inside accepting_ but not yet accepted, yet the client 
  //  calls connect() on that address. Figure out what to do: return the 
  //  accepted socket? Delete the accepted socket? (that might lead to 
  //  live-lock if we're unlucky)
  Socket * so = new Socket( this, s, addr );
  //  note that s_ is a "socket id" for unreliable sockets
  sockets_[so->s_] = so;
  *outConnected = so;
  so->accepted_ = true;
  regenerate_sets();
  if( !settings_.reliable ) {
    //  gotta make sure that we can find this socket again
    socketAddrs_[addr] = so;
    so->write( "", 0 ); //  send an empty packet to establish a connection
  }
  return 1;
}
 
void SocketManager::dispose()
{
  if( sockets_.size() ) {
    char buf[2048];
    _snprintf( buf, 2048, "Etwork: SocketManager::dispose() sees %d active sockets.\n", sockets_.size() );
    buf[2047] = 0;
    OutputDebugString( buf );
    if( settings_.debug ) {
      __asm { int 3 }
    }
  }
  if( accepted_.size() ) {
    char buf[2048];
    _snprintf( buf, 2048, "Etwork: SocketManager::dispose() sees %d sockets pending accept.\n", accepted_.size() );
    buf[2047] = 0;
    OutputDebugString( buf );
    if( settings_.debug ) {
      __asm { int 3 }
    }
  }
  delete this;
}
 
void SocketManager::debug_sock_error( ISocket * sock, int err, ErrorArea area, char const * func )
{
  wsa_error_from( sock, this, err, area );
}
 
void SocketManager::remove_socket( Socket * s )
{
  socketAddrs_.erase( s->addr_ );
  SocketMap::iterator ptr = sockets_.find( s->s_ );
  if( ptr != sockets_.end() ) {
    sockets_.erase( ptr );
    if( settings_.reliable ) {
      regenerate_sets();
    }
    return;
  }
  std::deque< Socket * >::iterator q = std::find( accepted_.begin(), accepted_.end(), s );
  if( q != accepted_.end() ) {
    accepted_.erase( q );
    return;
  }
  ASSERT( !"Socket not found in SocketManager::remove_socket()" );
}