network_worker.cpp

Go to the documentation of this file.
00001 /* $Id: network_worker.cpp 52533 2012-01-07 02:35:17Z shadowmaster $ */
00002 /*
00003    Copyright (C) 2003 - 2012 by David White <dave@whitevine.net>
00004    Part of the Battle for Wesnoth Project http://www.wesnoth.org/
00005 
00006    This program is free software; you can redistribute it and/or modify
00007    it under the terms of the GNU General Public License as published by
00008    the Free Software Foundation; either version 2 of the License, or
00009    (at your option) any later version.
00010    This program is distributed in the hope that it will be useful,
00011    but WITHOUT ANY WARRANTY.
00012 
00013    See the COPYING file for more details.
00014 */
00015 
00016 /**
00017  * Network worker handles data transfers in threads
00018  * Remember to use mutexs as little as possible
00019  * All global vars should be used in mutex
00020  * FIXME: @todo All code which holds a mutex should run O(1) time
00021  * for scalability. Implement read/write locks.
00022  *  (postponed for 1.5)
00023  */
00024 
00025 #include "global.hpp"
00026 
00027 #include "scoped_resource.hpp"
00028 #include "log.hpp"
00029 #include "network_worker.hpp"
00030 #include "filesystem.hpp"
00031 #include "thread.hpp"
00032 #include "serialization/binary_or_text.hpp"
00033 #include "serialization/parser.hpp"
00034 #include "wesconfig.h"
00035 
00036 #include <cerrno>
00037 #include <deque>
00038 #include <sstream>
00039 
00040 #ifdef HAVE_SENDFILE
00041 #include <sys/sendfile.h>
00042 #include <netinet/in.h>
00043 #include <netinet/tcp.h>
00044 #endif
00045 
00046 
00047 #ifdef __AMIGAOS4__
00048 #include <unistd.h>
00049 //#include <sys/clib2_net.h>
00050 #endif
00051 
00052 #if defined(_WIN32) || defined(__WIN32__) || defined (WIN32)
00053 #  undef INADDR_ANY
00054 #  undef INADDR_BROADCAST
00055 #  undef INADDR_NONE
00056 #  ifndef NOMINMAX
00057 #   define NOMINMAX
00058 #  endif
00059 #  include <windows.h>
00060 #  define USE_SELECT 1
00061 typedef int socklen_t;
00062 #else
00063 #  include <sys/types.h>
00064 #  include <sys/socket.h>
00065 #  ifdef __BEOS__
00066 #    include <socket.h>
00067 #  else
00068 #    include <fcntl.h>
00069 #  endif
00070 #  define SOCKET int
00071 #  ifdef HAVE_POLL_H
00072 #    define USE_POLL 1
00073 #    include <poll.h>
00074 #  elif defined(HAVE_SYS_POLL_H)
00075 #    define USE_POLL 1
00076 #    include <sys/poll.h>
00077 #  endif
00078 #  ifndef USE_POLL
00079 #    define USE_SELECT 1
00080 #    ifdef HAVE_SYS_SELECT_H
00081 #      include <sys/select.h>
00082 #    else
00083 #      include <sys/time.h>
00084 #      include <sys/types.h>
00085 #      include <unistd.h>
00086 #    endif
00087 #  endif
00088 #endif
00089 
00090 static lg::log_domain log_network("network");
00091 #define DBG_NW LOG_STREAM(debug, log_network)
00092 #define LOG_NW LOG_STREAM(info, log_network)
00093 #define ERR_NW LOG_STREAM(err, log_network)
00094 
00095 namespace {
00096 struct _TCPsocket {
00097     int ready;
00098     SOCKET channel;
00099     IPaddress remoteAddress;
00100     IPaddress localAddress;
00101     int sflag;
00102 };
00103 
00104 #ifndef NUM_SHARDS
00105 #define NUM_SHARDS 1
00106 #endif
00107 
00108 unsigned int waiting_threads[NUM_SHARDS];
00109 size_t min_threads = 0;
00110 size_t max_threads = 0;
00111 
00112 size_t get_shard(TCPsocket sock) { return reinterpret_cast<uintptr_t>(sock)%NUM_SHARDS; }
00113 
00114 struct buffer {
00115     explicit buffer(TCPsocket sock) :
00116         sock(sock),
00117         config_buf(),
00118         config_error(""),
00119         stream(),
00120         raw_buffer()
00121         {}
00122 
00123     TCPsocket sock;
00124     mutable config config_buf;
00125     std::string config_error;
00126     std::ostringstream stream;
00127 
00128     /**
00129      * This field is used if we're sending a raw buffer instead of through a
00130      * config object. It will contain the entire contents of the buffer being
00131      * sent.
00132      */
00133     std::vector<char> raw_buffer;
00134 };
00135 
00136 
00137 bool managed = false, raw_data_only = false;
00138 typedef std::vector< buffer* > buffer_set;
00139 buffer_set outgoing_bufs[NUM_SHARDS];
00140 
00141 /** a queue of sockets that we are waiting to receive on */
00142 typedef std::vector<TCPsocket> receive_list;
00143 receive_list pending_receives[NUM_SHARDS];
00144 
00145 typedef std::deque<buffer*> received_queue;
00146 received_queue received_data_queue;  // receive_mutex
00147 
00148 enum SOCKET_STATE { SOCKET_READY, SOCKET_LOCKED, SOCKET_ERRORED, SOCKET_INTERRUPT };
00149 typedef std::map<TCPsocket,SOCKET_STATE> socket_state_map;
00150 typedef std::map<TCPsocket, std::pair<network::statistics,network::statistics> > socket_stats_map;
00151 
00152 socket_state_map sockets_locked[NUM_SHARDS];
00153 socket_stats_map transfer_stats; // stats_mutex
00154 
00155 int socket_errors[NUM_SHARDS];
00156 threading::mutex* shard_mutexes[NUM_SHARDS];
00157 threading::mutex* stats_mutex = NULL;
00158 threading::mutex* received_mutex = NULL;
00159 threading::condition* cond[NUM_SHARDS];
00160 
00161 std::map<Uint32,threading::thread*> threads[NUM_SHARDS];
00162 std::vector<Uint32> to_clear[NUM_SHARDS];
00163 int system_send_buffer_size = 0;
00164 bool network_use_system_sendfile = false;
00165 
00166 int receive_bytes(TCPsocket s, char* buf, size_t nbytes)
00167 {
00168 #ifdef NETWORK_USE_RAW_SOCKETS
00169     _TCPsocket* sock = reinterpret_cast<_TCPsocket*>(s);
00170     int res = 0;
00171     do {
00172         errno = 0;
00173         res = recv(sock->channel, buf, nbytes, MSG_DONTWAIT);
00174     } while(errno == EINTR);
00175     sock->ready = 0;
00176     return res;
00177 #else
00178     return SDLNet_TCP_Recv(s, buf, nbytes);
00179 #endif
00180 }
00181 
00182 
00183 void check_send_buffer_size(TCPsocket& s)
00184 {
00185     if (system_send_buffer_size)
00186         return;
00187     _TCPsocket* sock = reinterpret_cast<_TCPsocket*>(s);
00188     socklen_t len = sizeof(system_send_buffer_size);
00189 #ifdef _WIN32
00190     getsockopt(sock->channel, SOL_SOCKET, SO_RCVBUF,reinterpret_cast<char*>(&system_send_buffer_size), &len);
00191 #else
00192     getsockopt(sock->channel, SOL_SOCKET, SO_RCVBUF,&system_send_buffer_size, &len);
00193 #endif
00194     --system_send_buffer_size;
00195     DBG_NW << "send buffer size: " << system_send_buffer_size << "\n";
00196 }
00197 
00198 bool receive_with_timeout(TCPsocket s, char* buf, size_t nbytes,
00199         bool update_stats=false, int idle_timeout_ms=30000,
00200         int total_timeout_ms=300000)
00201 {
00202 #if !defined(USE_POLL) && !defined(USE_SELECT)
00203     int startTicks = SDL_GetTicks();
00204     int time_used = 0;
00205 #endif
00206     int timeout_ms = idle_timeout_ms;
00207     while(nbytes > 0) {
00208         const int bytes_read = receive_bytes(s, buf, nbytes);
00209         if(bytes_read == 0) {
00210             return false;
00211         } else if(bytes_read < 0) {
00212 #if defined(EAGAIN) && !defined(__BEOS__) && !defined(_WIN32)
00213             if(errno == EAGAIN)
00214 #elif defined(EWOULDBLOCK)
00215             if(errno == EWOULDBLOCK)
00216 #elif defined(_WIN32) && defined(WSAEWOULDBLOCK)
00217             if(WSAGetLastError() == WSAEWOULDBLOCK)
00218 #else
00219             // assume non-recoverable error.
00220             if(false)
00221 #endif
00222             {
00223 #ifdef USE_POLL
00224                 struct pollfd fd = { reinterpret_cast<_TCPsocket*>(s)->channel, POLLIN, 0 };
00225                 int poll_res;
00226 
00227                 //we timeout of the poll every 100ms. This lets us check to
00228                 //see if we have been disconnected, in which case we should
00229                 //abort the receive.
00230                 const int poll_timeout = std::min(timeout_ms, 100);
00231                 do {
00232                     poll_res = poll(&fd, 1, poll_timeout);
00233 
00234                     if(poll_res == 0) {
00235                         timeout_ms -= poll_timeout;
00236                         total_timeout_ms -= poll_timeout;
00237                         if(timeout_ms <= 0 || total_timeout_ms <= 0) {
00238                             //we've been waiting too long; abort the receive
00239                             //as having failed due to timeout.
00240                             return false;
00241                         }
00242 
00243                         //check to see if we've been interrupted
00244                         const size_t shard = get_shard(s);
00245                         const threading::lock lock(*shard_mutexes[shard]);
00246                         socket_state_map::iterator lock_it = sockets_locked[shard].find(s);
00247                         assert(lock_it != sockets_locked[shard].end());
00248                         if(lock_it->second == SOCKET_INTERRUPT) {
00249                             return false;
00250                         }
00251                     }
00252 
00253                 } while(poll_res == 0 || (poll_res == -1 && errno == EINTR));
00254 
00255                 if (poll_res < 1)
00256                     return false;
00257 #elif defined(USE_SELECT)
00258                 int retval;
00259                 const int select_timeout = std::min(timeout_ms, 100);
00260                 do {
00261                     fd_set readfds;
00262                     FD_ZERO(&readfds);
00263                     FD_SET(((_TCPsocket*)s)->channel, &readfds);
00264                     struct timeval tv;
00265                     tv.tv_sec = select_timeout/1000;
00266                     tv.tv_usec = select_timeout % 1000 * 1000;
00267                     retval = select(((_TCPsocket*)s)->channel + 1, &readfds, NULL, NULL, &tv);
00268                     DBG_NW << "select retval: " << retval << ", timeout idle " << timeout_ms
00269                         << " total " << total_timeout_ms << " (ms)\n";
00270                     if(retval == 0) {
00271                         timeout_ms -= select_timeout;
00272                         total_timeout_ms -= select_timeout;
00273                         if(timeout_ms <= 0 || total_timeout_ms <= 0) {
00274                             //we've been waiting too long; abort the receive
00275                             //as having failed due to timeout.
00276                             return false;
00277                         }
00278 
00279                         //check to see if we've been interrupted
00280                         const size_t shard = get_shard(s);
00281                         const threading::lock lock(*shard_mutexes[shard]);
00282                         socket_state_map::iterator lock_it = sockets_locked[shard].find(s);
00283                         assert(lock_it != sockets_locked[shard].end());
00284                         if(lock_it->second == SOCKET_INTERRUPT) {
00285                             return false;
00286                         }
00287                     }
00288                 } while(retval == 0 || (retval == -1 && errno == EINTR));
00289 
00290                 if (retval < 1) {
00291                     return false;
00292                 }
00293 #else
00294                 //TODO: consider replacing this with a select call
00295                 time_used = SDL_GetTicks() - startTicks;
00296                 if(time_used >= timeout_ms) {
00297                     return false;
00298                 }
00299                 SDL_Delay(20);
00300 #endif
00301             } else {
00302                 return false;
00303             }
00304         } else {
00305             timeout_ms = idle_timeout_ms;
00306             buf += bytes_read;
00307             if(update_stats && !raw_data_only) {
00308                 const threading::lock lock(*stats_mutex);
00309                 transfer_stats[s].second.transfer(static_cast<size_t>(bytes_read));
00310             }
00311 
00312             if(bytes_read > static_cast<int>(nbytes)) {
00313                 return false;
00314             }
00315             nbytes -= bytes_read;
00316             // We got some data from server so reset start time so slow conenction won't timeout.
00317 #if !defined(USE_POLL) && !defined(USE_SELECT)
00318             startTicks = SDL_GetTicks();
00319 #endif
00320         }
00321         {
00322             const size_t shard = get_shard(s);
00323             const threading::lock lock(*shard_mutexes[shard]);
00324             socket_state_map::iterator lock_it = sockets_locked[shard].find(s);
00325             assert(lock_it != sockets_locked[shard].end());
00326             if(lock_it->second == SOCKET_INTERRUPT) {
00327                 return false;
00328             }
00329         }
00330     }
00331 
00332     return true;
00333 }
00334 
00335 /**
00336  * @todo See if the TCPsocket argument should be removed.
00337  */
00338 static void output_to_buffer(TCPsocket /*sock*/, const config& cfg, std::ostringstream& compressor)
00339 {
00340     config_writer writer(compressor, true);
00341     writer.write(cfg);
00342 }
00343 
00344 static void make_network_buffer(const char* input, int len, std::vector<char>& buf)
00345 {
00346     buf.resize(4 + len);
00347     SDLNet_Write32(len, &buf[0]);
00348     memcpy(&buf[4], input, len);
00349 }
00350 
00351 static SOCKET_STATE send_buffer(TCPsocket sock, std::vector<char>& buf, int in_size = -1)
00352 {
00353 #ifdef __BEOS__
00354     int timeout = 60000;
00355 #endif
00356 //  check_send_buffer_size(sock);
00357     size_t upto = 0;
00358     size_t size = buf.size();
00359     if (in_size != -1)
00360         size = in_size;
00361     int send_len = 0;
00362 
00363     if (!raw_data_only)
00364     {
00365         const threading::lock lock(*stats_mutex);
00366         transfer_stats[sock].first.fresh_current(size);
00367     }
00368 #ifdef __BEOS__
00369     while(upto < size && timeout > 0) {
00370 #else
00371     while(true) {
00372 #endif
00373         {
00374             const size_t shard = get_shard(sock);
00375             // check if the socket is still locked
00376             const threading::lock lock(*shard_mutexes[shard]);
00377             if(sockets_locked[shard][sock] != SOCKET_LOCKED)
00378             {
00379                 return SOCKET_ERRORED;
00380             }
00381         }
00382         send_len = static_cast<int>(size - upto);
00383         const int res = SDLNet_TCP_Send(sock, &buf[upto],send_len);
00384 
00385 
00386         if( res == send_len) {
00387             if (!raw_data_only)
00388             {
00389                 const threading::lock lock(*stats_mutex);
00390                 transfer_stats[sock].first.transfer(static_cast<size_t>(res));
00391             }
00392             return SOCKET_READY;
00393         }
00394 #if defined(_WIN32)
00395         if(WSAGetLastError() == WSAEWOULDBLOCK)
00396 #elif defined(EAGAIN) && !defined(__BEOS__)
00397         if(errno == EAGAIN)
00398 #elif defined(EWOULDBLOCK)
00399         if(errno == EWOULDBLOCK)
00400 #endif
00401         {
00402             // update how far we are
00403             upto += static_cast<size_t>(res);
00404             if (!raw_data_only)
00405             {
00406                 const threading::lock lock(*stats_mutex);
00407                 transfer_stats[sock].first.transfer(static_cast<size_t>(res));
00408             }
00409 
00410 #ifdef USE_POLL
00411             struct pollfd fd = { ((_TCPsocket*)sock)->channel, POLLOUT, 0 };
00412             int poll_res;
00413             do {
00414                 poll_res = poll(&fd, 1, 60000);
00415             } while(poll_res == -1 && errno == EINTR);
00416 
00417 
00418             if(poll_res > 0)
00419                 continue;
00420 #elif defined(USE_SELECT) && !defined(__BEOS__)
00421             fd_set writefds;
00422             FD_ZERO(&writefds);
00423             FD_SET(((_TCPsocket*)sock)->channel, &writefds);
00424             int retval;
00425             struct timeval tv;
00426             tv.tv_sec = 60;
00427             tv.tv_usec = 0;
00428 
00429             do {
00430                 retval = select(((_TCPsocket*)sock)->channel + 1, NULL, &writefds, NULL, &tv);
00431             } while(retval == -1 && errno == EINTR);
00432 
00433             if(retval > 0)
00434                 continue;
00435 #elif defined(__BEOS__)
00436             if(res > 0) {
00437                 // some data was sent, reset timeout
00438                 timeout = 60000;
00439             } else {
00440                 // sleep for 100 milliseconds
00441                 SDL_Delay(100);
00442                 timeout -= 100;
00443             }
00444             continue;
00445 #endif
00446         }
00447 
00448         return SOCKET_ERRORED;
00449     }
00450 }
00451 
00452 #ifdef HAVE_SENDFILE
00453 
00454 #ifdef TCP_CORK
00455     struct cork_setter {
00456         cork_setter(int socket) : cork_(1), socket_(socket)
00457         {
00458             setsockopt(socket_, IPPROTO_TCP, TCP_CORK, &cork_, sizeof(cork_));;
00459         }
00460         ~cork_setter()
00461         {
00462             cork_ = 0;
00463             setsockopt(socket_, IPPROTO_TCP, TCP_CORK, &cork_, sizeof(cork_));
00464         }
00465         private:
00466         int cork_;
00467         int socket_;
00468     };
00469 #else
00470     struct cork_setter
00471     {
00472         cork_setter(int) {}
00473     };
00474 #endif
00475 
00476 struct close_fd {
00477         void operator()(int fd) const { close(fd); }
00478 };
00479 typedef util::scoped_resource<int, close_fd> scoped_fd;
00480 #endif
00481 
00482 static SOCKET_STATE send_file(buffer* buf)
00483 {
00484     size_t upto = 0;
00485     size_t filesize = file_size(buf->config_error);
00486 #ifdef HAVE_SENDFILE
00487     // implements linux sendfile support
00488     LOG_NW << "send_file use system sendfile: " << (network_use_system_sendfile?"yes":"no") << "\n";
00489     if (network_use_system_sendfile)
00490     {
00491         std::vector<char> buffer;
00492         buffer.resize(4);
00493         SDLNet_Write32(filesize,&buffer[0]);
00494         int socket = reinterpret_cast<_TCPsocket*>(buf->sock)->channel;
00495         const scoped_fd in_file(open(buf->config_error.c_str(), O_RDONLY));
00496         cork_setter set_socket_cork(socket);
00497         int poll_res;
00498         struct pollfd fd = {socket, POLLOUT, 0 };
00499         do {
00500             poll_res = poll(&fd, 1, 600000);
00501         } while(poll_res == -1 && errno == EINTR);
00502 
00503         SOCKET_STATE result;
00504         if (poll_res > 0)
00505             result = send_buffer(buf->sock, buffer, 4);
00506         else
00507             result = SOCKET_ERRORED;
00508 
00509 
00510         if (result != SOCKET_READY)
00511         {
00512             return result;
00513         }
00514         result = SOCKET_READY;
00515 
00516         while (true)
00517         {
00518 
00519             do {
00520                 poll_res = poll(&fd, 1, 600000);
00521             } while(poll_res == -1 && errno == EINTR);
00522 
00523             if (poll_res <= 0 )
00524             {
00525                 result = SOCKET_ERRORED;
00526                 break;
00527             }
00528 
00529 
00530             int bytes = ::sendfile(socket, in_file, 0, filesize);
00531 
00532             if (bytes == -1)
00533             {
00534                 if (errno == EAGAIN)
00535                     continue;
00536                 result = SOCKET_ERRORED;
00537                 break;
00538             }
00539             upto += bytes;
00540 
00541 
00542             if (upto == filesize)
00543             {
00544                 break;
00545             }
00546         }
00547 
00548         return result;
00549     }
00550 #endif
00551     // default sendfile implementation
00552     // if no system implementation is enabled
00553     int send_size = 0;
00554     // reserve 1024*8 bytes buffer
00555     buf->raw_buffer.resize(std::min<size_t>(1024*8, filesize));
00556     SDLNet_Write32(filesize,&buf->raw_buffer[0]);
00557     scoped_istream file_stream = istream_file(buf->config_error);
00558     SOCKET_STATE result = send_buffer(buf->sock, buf->raw_buffer, 4);
00559 
00560     if (!file_stream->good()) {
00561         ERR_NW << "send_file: Couldn't open file " << buf->config_error << "\n";
00562     }
00563     if (result != SOCKET_READY)
00564     {
00565         return result;
00566     }
00567     while (file_stream->good())
00568     {
00569         // read data
00570         file_stream->read(&buf->raw_buffer[0], buf->raw_buffer.size());
00571         send_size = file_stream->gcount();
00572         upto += send_size;
00573         // send data to socket
00574         result = send_buffer(buf->sock, buf->raw_buffer, send_size);
00575         if (result != SOCKET_READY)
00576         {
00577             break;
00578         }
00579         if (upto == filesize)
00580         {
00581             break;
00582         }
00583 
00584     }
00585     if (upto != filesize && !file_stream->good()) {
00586         ERR_NW << "send_file failed because the stream from file '"
00587             << buf->config_error << "' is not good. Sent up to: " << upto
00588             << " of file size: " << filesize << "\n";
00589     }
00590     return result;
00591 }
00592 
00593 static SOCKET_STATE receive_buf(TCPsocket sock, std::vector<char>& buf)
00594 {
00595     union {
00596     char buf[4] ALIGN_4;
00597     Uint32 num;
00598     } num_buf;
00599     bool res = receive_with_timeout(sock,num_buf.buf,4,false);
00600 
00601     if(!res) {
00602         return SOCKET_ERRORED;
00603     }
00604 
00605     const int len = SDLNet_Read32(&num_buf);
00606 
00607     if(len < 1 || len > 100000000) {
00608         return SOCKET_ERRORED;
00609     }
00610 
00611     buf.resize(len);
00612     char* beg = &buf[0];
00613     const char* const end = beg + len;
00614 
00615     if (!raw_data_only)
00616     {
00617         const threading::lock lock(*stats_mutex);
00618         transfer_stats[sock].second.fresh_current(len);
00619     }
00620 
00621     res = receive_with_timeout(sock, beg, end - beg, true);
00622     if(!res) {
00623         return SOCKET_ERRORED;
00624     }
00625 
00626     return SOCKET_READY;
00627 }
00628 
00629 inline void check_socket_result(TCPsocket& sock, SOCKET_STATE& result)
00630 {
00631     const size_t shard = get_shard(sock);
00632     const threading::lock lock(*shard_mutexes[shard]);
00633     socket_state_map::iterator lock_it = sockets_locked[shard].find(sock);
00634     assert(lock_it != sockets_locked[shard].end());
00635     lock_it->second = result;
00636     if(result == SOCKET_ERRORED) {
00637         ++socket_errors[shard];
00638     }
00639 }
00640 
00641 static int process_queue(void* shard_num)
00642 {
00643     size_t shard = static_cast<size_t>(reinterpret_cast<uintptr_t>(shard_num));
00644     DBG_NW << "thread started...\n";
00645     for(;;) {
00646 
00647         //if we find a socket to send data to, sent_buf will be non-NULL. If we find a socket
00648         //to receive data from, sent_buf will be NULL. 'sock' will always refer to the socket
00649         //that data is being sent to/received from
00650         TCPsocket sock = NULL;
00651         buffer* sent_buf = 0;
00652 
00653         {
00654             const threading::lock lock(*shard_mutexes[shard]);
00655             while(managed && !to_clear[shard].empty()) {
00656                 Uint32 tmp = to_clear[shard].back();
00657                 to_clear[shard].pop_back();
00658                 threading::thread *zombie = threads[shard][tmp];
00659                 threads[shard].erase(tmp);
00660                 delete zombie;
00661 
00662             }
00663             if(min_threads && waiting_threads[shard] >= min_threads) {
00664                     DBG_NW << "worker thread exiting... not enough jobs\n";
00665                     to_clear[shard].push_back(threading::get_current_thread_id());
00666                     return 0;
00667             }
00668             waiting_threads[shard]++;
00669             for(;;) {
00670 
00671                 buffer_set::iterator itor = outgoing_bufs[shard].begin(), itor_end = outgoing_bufs[shard].end();
00672                 for(; itor != itor_end; ++itor) {
00673                     socket_state_map::iterator lock_it = sockets_locked[shard].find((*itor)->sock);
00674                     assert(lock_it != sockets_locked[shard].end());
00675                     if(lock_it->second == SOCKET_READY) {
00676                         lock_it->second = SOCKET_LOCKED;
00677                         sent_buf = *itor;
00678                         sock = sent_buf->sock;
00679                         outgoing_bufs[shard].erase(itor);
00680                         break;
00681                     }
00682                 }
00683 
00684                 if(sock == NULL) {
00685                     receive_list::iterator itor = pending_receives[shard].begin(), itor_end = pending_receives[shard].end();
00686                     for(; itor != itor_end; ++itor) {
00687                         socket_state_map::iterator lock_it = sockets_locked[shard].find(*itor);
00688                         assert(lock_it != sockets_locked[shard].end());
00689                         if(lock_it->second == SOCKET_READY) {
00690                             lock_it->second = SOCKET_LOCKED;
00691                             sock = *itor;
00692                             pending_receives[shard].erase(itor);
00693                             break;
00694                         }
00695                     }
00696                 }
00697 
00698                 if(sock != NULL) {
00699                     break;
00700                 }
00701 
00702                 if(managed == false) {
00703                     DBG_NW << "worker thread exiting...\n";
00704                     waiting_threads[shard]--;
00705                     to_clear[shard].push_back(threading::get_current_thread_id());
00706                     return 0;
00707                 }
00708 
00709                 cond[shard]->wait(*shard_mutexes[shard]); // temporarily release the mutex and wait for a buffer
00710             }
00711             waiting_threads[shard]--;
00712             // if we are the last thread in the pool, create a new one
00713             if(!waiting_threads[shard] && managed == true) {
00714                 // max_threads of 0 is unlimited
00715                 if(!max_threads || max_threads >threads[shard].size()) {
00716                     threading::thread * tmp = new threading::thread(process_queue,shard_num);
00717                     threads[shard][tmp->get_id()] =tmp;
00718                 }
00719             }
00720         }
00721 
00722         assert(sock);
00723 
00724         DBG_NW << "thread found a buffer...\n";
00725 
00726         SOCKET_STATE result = SOCKET_READY;
00727         std::vector<char> buf;
00728 
00729         if(sent_buf) {
00730 
00731             if(!sent_buf->config_error.empty())
00732             {
00733                 // We have file to send over net
00734                 result = send_file(sent_buf);
00735             } else {
00736                 if(sent_buf->raw_buffer.empty()) {
00737                     const std::string &value = sent_buf->stream.str();
00738                     make_network_buffer(value.c_str(), value.size(), sent_buf->raw_buffer);
00739                 }
00740 
00741                 result = send_buffer(sent_buf->sock, sent_buf->raw_buffer);
00742             }
00743             delete sent_buf;
00744         } else {
00745             result = receive_buf(sock,buf);
00746         }
00747 
00748 
00749         if(result != SOCKET_READY || buf.empty())
00750         {
00751             check_socket_result(sock,result);
00752                 continue;
00753         }
00754         //if we received data, add it to the queue
00755         buffer* received_data = new buffer(sock);
00756 
00757         if(raw_data_only) {
00758             received_data->raw_buffer.swap(buf);
00759         } else {
00760             std::string buffer(buf.begin(), buf.end());
00761             std::istringstream stream(buffer);
00762             try {
00763                 read_gz(received_data->config_buf, stream);
00764             } catch(config::error &e)
00765             {
00766                 received_data->config_error = e.message;
00767             }
00768         }
00769 
00770         {
00771             // Now add data
00772             const threading::lock lock_received(*received_mutex);
00773             received_data_queue.push_back(received_data);
00774         }
00775         check_socket_result(sock,result);
00776     }
00777     // unreachable
00778 }
00779 
00780 } //anonymous namespace
00781 
00782 namespace network_worker_pool
00783 {
00784 
00785 manager::manager(size_t p_min_threads,size_t p_max_threads) : active_(!managed)
00786 {
00787     if(active_) {
00788         managed = true;
00789         for(int i = 0; i != NUM_SHARDS; ++i) {
00790             shard_mutexes[i] = new threading::mutex();
00791             cond[i] = new threading::condition();
00792         }
00793         stats_mutex = new threading::mutex();
00794         received_mutex = new threading::mutex();
00795 
00796         min_threads = p_min_threads;
00797         max_threads = p_max_threads;
00798 
00799         for(size_t shard = 0; shard != NUM_SHARDS; ++shard) {
00800             const threading::lock lock(*shard_mutexes[shard]);
00801             for(size_t n = 0; n != p_min_threads; ++n) {
00802                 threading::thread * tmp = new threading::thread(process_queue,(void*)uintptr_t(shard));
00803                 threads[shard][tmp->get_id()] = tmp;
00804             }
00805         }
00806     }
00807 }
00808 
00809 manager::~manager()
00810 {
00811     if(active_) {
00812         managed = false;
00813 
00814         for(size_t shard = 0; shard != NUM_SHARDS; ++shard) {
00815             {
00816                 const threading::lock lock(*shard_mutexes[shard]);
00817                 socket_errors[shard] = 0;
00818             }
00819 
00820             cond[shard]->notify_all();
00821 
00822             for(std::map<Uint32,threading::thread*>::const_iterator i = threads[shard].begin(); i != threads[shard].end(); ++i) {
00823 
00824                 DBG_NW << "waiting for thread " << i->first << " to exit...\n";
00825                 delete i->second;
00826                 DBG_NW << "thread exited...\n";
00827             }
00828 
00829             // Condition variables must be deleted first as
00830             // they make reference to mutexs. If the mutexs
00831             // are destroyed first, the condition variables
00832             // will access memory already freed by way of
00833             // stale mutex. Bad things will follow. ;)
00834             threads[shard].clear();
00835             // Have to clean up to_clear so no bogus clearing of threads
00836             to_clear[shard].clear();
00837             delete cond[shard];
00838             cond[shard] = NULL;
00839             delete shard_mutexes[shard];
00840             shard_mutexes[shard] = NULL;
00841         }
00842 
00843         delete stats_mutex;
00844         delete received_mutex;
00845         stats_mutex = 0;
00846         received_mutex = 0;
00847 
00848         for(int i = 0; i != NUM_SHARDS; ++i) {
00849             sockets_locked[i].clear();
00850         }
00851         transfer_stats.clear();
00852 
00853         DBG_NW << "exiting manager::~manager()\n";
00854     }
00855 }
00856 
00857 network::pending_statistics get_pending_stats()
00858 {
00859     network::pending_statistics stats;
00860     stats.npending_sends = 0;
00861     stats.nbytes_pending_sends = 0;
00862     for(size_t shard = 0; shard != NUM_SHARDS; ++shard) {
00863         const threading::lock lock(*shard_mutexes[shard]);
00864         stats.npending_sends += outgoing_bufs[shard].size();
00865         for(buffer_set::const_iterator i = outgoing_bufs[shard].begin(); i != outgoing_bufs[shard].end(); ++i) {
00866             stats.nbytes_pending_sends += (*i)->raw_buffer.size();
00867         }
00868     }
00869 
00870     return stats;
00871 }
00872 
00873 void set_raw_data_only()
00874 {
00875     raw_data_only = true;
00876 }
00877 
00878 void set_use_system_sendfile(bool use)
00879 {
00880     network_use_system_sendfile = use;
00881 }
00882 
00883 void receive_data(TCPsocket sock)
00884 {
00885     {
00886         const size_t shard = get_shard(sock);
00887         const threading::lock lock(*shard_mutexes[shard]);
00888         pending_receives[shard].push_back(sock);
00889 
00890         socket_state_map::const_iterator i = sockets_locked[shard].insert(std::pair<TCPsocket,SOCKET_STATE>(sock,SOCKET_READY)).first;
00891         if(i->second == SOCKET_READY || i->second == SOCKET_ERRORED) {
00892             cond[shard]->notify_one();
00893         }
00894     }
00895 }
00896 
00897 TCPsocket get_received_data(TCPsocket sock, config& cfg, network::bandwidth_in_ptr& bandwidth_in)
00898 {
00899     assert(!raw_data_only);
00900     const threading::lock lock_received(*received_mutex);
00901     received_queue::iterator itor = received_data_queue.begin();
00902     if(sock != NULL) {
00903         for(; itor != received_data_queue.end(); ++itor) {
00904             if((*itor)->sock == sock) {
00905                 break;
00906             }
00907         }
00908     }
00909 
00910     if(itor == received_data_queue.end()) {
00911         return NULL;
00912     } else if (!(*itor)->config_error.empty()){
00913         // throw the error in parent thread
00914         std::string error = (*itor)->config_error;
00915         buffer* buf = *itor;
00916         received_data_queue.erase(itor);
00917         delete buf;
00918         throw config::error(error);
00919     } else {
00920         cfg.swap((*itor)->config_buf);
00921         const TCPsocket res = (*itor)->sock;
00922         buffer* buf = *itor;
00923         bandwidth_in.reset(new network::bandwidth_in((*itor)->raw_buffer.size()));
00924         received_data_queue.erase(itor);
00925         delete buf;
00926         return res;
00927     }
00928 }
00929 
00930 TCPsocket get_received_data(std::vector<char>& out)
00931 {
00932     assert(raw_data_only);
00933     const threading::lock lock_received(*received_mutex);
00934     if(received_data_queue.empty()) {
00935         return NULL;
00936     }
00937 
00938     buffer* buf = received_data_queue.front();
00939     received_data_queue.pop_front();
00940     out.swap(buf->raw_buffer);
00941     const TCPsocket res = buf->sock;
00942     delete buf;
00943     return res;
00944 }
00945 
00946 static void queue_buffer(TCPsocket sock, buffer* queued_buf)
00947 {
00948     const size_t shard = get_shard(sock);
00949     const threading::lock lock(*shard_mutexes[shard]);
00950     outgoing_bufs[shard].push_back(queued_buf);
00951     socket_state_map::const_iterator i = sockets_locked[shard].insert(std::pair<TCPsocket,SOCKET_STATE>(sock,SOCKET_READY)).first;
00952     if(i->second == SOCKET_READY || i->second == SOCKET_ERRORED) {
00953         cond[shard]->notify_one();
00954     }
00955 
00956 }
00957 
00958 void queue_raw_data(TCPsocket sock, const char* buf, int len)
00959 {
00960     buffer* queued_buf = new buffer(sock);
00961     assert(*buf == 31);
00962     make_network_buffer(buf, len, queued_buf->raw_buffer);
00963     queue_buffer(sock, queued_buf);
00964 }
00965 
00966 
00967 void queue_file(TCPsocket sock, const std::string& filename)
00968 {
00969     buffer* queued_buf = new buffer(sock);
00970     queued_buf->config_error = filename;
00971     queue_buffer(sock, queued_buf);
00972 }
00973 
00974 size_t queue_data(TCPsocket sock,const config& buf, const std::string& packet_type)
00975 {
00976     DBG_NW << "queuing data...\n";
00977 
00978     buffer* queued_buf = new buffer(sock);
00979     output_to_buffer(sock, buf, queued_buf->stream);
00980     const size_t size = queued_buf->stream.str().size();
00981 
00982     network::add_bandwidth_out(packet_type, size);
00983     queue_buffer(sock, queued_buf);
00984     return size;
00985 }
00986 
00987 namespace
00988 {
00989 
00990 /** Caller has to make sure to own the mutex for this shard */
00991 void remove_buffers(TCPsocket sock)
00992 {
00993     {
00994         const size_t shard = get_shard(sock);
00995         for(buffer_set::iterator i = outgoing_bufs[shard].begin(); i != outgoing_bufs[shard].end();) {
00996             if ((*i)->sock == sock)
00997             {
00998                 buffer* buf = *i;
00999                 i = outgoing_bufs[shard].erase(i);
01000                 delete buf;
01001             }
01002             else
01003             {
01004                 ++i;
01005             }
01006         }
01007     }
01008 
01009     {
01010         const threading::lock lock_receive(*received_mutex);
01011 
01012         for(received_queue::iterator j = received_data_queue.begin(); j != received_data_queue.end(); ) {
01013             if((*j)->sock == sock) {
01014                 buffer *buf = *j;
01015                 j = received_data_queue.erase(j);
01016                 delete buf;
01017             } else {
01018                 ++j;
01019             }
01020         }
01021     }
01022 }
01023 
01024 } // anonymous namespace
01025 
01026 bool is_locked(const TCPsocket sock) {
01027     const size_t shard = get_shard(sock);
01028     const threading::lock lock(*shard_mutexes[shard]);
01029     const socket_state_map::iterator lock_it = sockets_locked[shard].find(sock);
01030     if (lock_it == sockets_locked[shard].end()) return false;
01031     return (lock_it->second == SOCKET_LOCKED);
01032 }
01033 
01034 bool close_socket(TCPsocket sock)
01035 {
01036     {
01037         const size_t shard = get_shard(sock);
01038         const threading::lock lock(*shard_mutexes[shard]);
01039 
01040         pending_receives[shard].erase(std::remove(pending_receives[shard].begin(),pending_receives[shard].end(),sock),pending_receives[shard].end());
01041 
01042         const socket_state_map::iterator lock_it = sockets_locked[shard].find(sock);
01043         if(lock_it == sockets_locked[shard].end()) {
01044             remove_buffers(sock);
01045             return true;
01046         }
01047         if (!(lock_it->second == SOCKET_LOCKED || lock_it->second == SOCKET_INTERRUPT)) {
01048             sockets_locked[shard].erase(lock_it);
01049             remove_buffers(sock);
01050             return true;
01051         } else {
01052             lock_it->second = SOCKET_INTERRUPT;
01053             return false;
01054         }
01055 
01056     }
01057 
01058 
01059 }
01060 
01061 TCPsocket detect_error()
01062 {
01063     for(size_t shard = 0; shard != NUM_SHARDS; ++shard) {
01064         const threading::lock lock(*shard_mutexes[shard]);
01065         if(socket_errors[shard] > 0) {
01066             for(socket_state_map::iterator i = sockets_locked[shard].begin(); i != sockets_locked[shard].end();) {
01067                 if(i->second == SOCKET_ERRORED) {
01068                     --socket_errors[shard];
01069                     const TCPsocket sock = i->first;
01070                     sockets_locked[shard].erase(i++);
01071                     pending_receives[shard].erase(std::remove(pending_receives[shard].begin(),pending_receives[shard].end(),sock),pending_receives[shard].end());
01072                     remove_buffers(sock);
01073                     return sock;
01074                 }
01075                 else
01076                 {
01077                     ++i;
01078                 }
01079             }
01080         }
01081 
01082         socket_errors[shard] = 0;
01083     }
01084 
01085     return 0;
01086 }
01087 
01088 std::pair<network::statistics,network::statistics> get_current_transfer_stats(TCPsocket sock)
01089 {
01090     const threading::lock lock(*stats_mutex);
01091     return transfer_stats[sock];
01092 }
01093 
01094 } // network_worker_pool namespace
01095 
01096 
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines

Generated by doxygen 1.7.1 on Fri May 25 2012 01:03:07 for The Battle for Wesnoth
Gna! | Forum | Wiki | CIA | devdocs