diff --git a/CMakeLists.txt b/CMakeLists.txt index 0685acb9a..60c0eacfd 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -144,6 +144,9 @@ set(STUB_SOURCES Stub/dvlnet/loopback.cpp Stub/dvlnet/packet.cpp Stub/dvlnet/base.cpp + Stub/dvlnet/frame_queue.cpp + Stub/dvlnet/tcp_client.cpp + Stub/dvlnet/tcp_server.cpp Stub/dvlnet/udp_p2p.cpp Stub/DiabloUI/credits.cpp diff --git a/Stub/dvlnet/base.cpp b/Stub/dvlnet/base.cpp index 2dad0818f..7a2532db5 100644 --- a/Stub/dvlnet/base.cpp +++ b/Stub/dvlnet/base.cpp @@ -2,19 +2,14 @@ using namespace dvlnet; +base::base(buffer_t info) +{ + game_init_info = std::move(info); +} + void base::setup_password(std::string pw) { - //pw.resize(std::min(pw.size(), crypto_pwhash_PASSWD_MAX)); - //pw.resize(std::max(pw.size(), crypto_pwhash_PASSWD_MIN), 0); - std::string salt("devilution-salt"); - salt.resize(crypto_pwhash_SALTBYTES, 0); - if (crypto_pwhash(key.data(), crypto_secretbox_KEYBYTES, - pw.data(), pw.size(), - reinterpret_cast(salt.data()), - crypto_pwhash_OPSLIMIT_INTERACTIVE, - crypto_pwhash_MEMLIMIT_INTERACTIVE, - crypto_pwhash_ALG_DEFAULT)) - ABORT(); + pktfty = std::make_unique(pw); } void base::run_event_handler(_SNETEVENT& ev) @@ -25,17 +20,53 @@ void base::run_event_handler(_SNETEVENT& ev) } } -void base::recv_local(upacket& pkt) +void base::handle_accept(packet& pkt) +{ + if (plr_self != PLR_BROADCAST) + return; // already have player id + if (pkt.cookie() == cookie_self) + plr_self = pkt.newplr(); + if (game_init_info != pkt.info()) { + // we joined and did not create + _SNETEVENT ev; + ev.eventid = EVENT_TYPE_PLAYER_CREATE_GAME; + ev.playerid = plr_self; + ev.data = const_cast(pkt.info().data()); + ev.databytes = pkt.info().size(); + run_event_handler(ev); + } +} + +void base::recv_local(packet& pkt) { - switch (pkt->type()) { + switch (pkt.type()) { case PT_MESSAGE: - message_queue.push(message_t(pkt->src(), pkt->message())); + message_queue.push(message_t(pkt.src(), pkt.message())); break; case PT_TURN: - turn_queue[pkt->src()].push(pkt->turn()); + turn_queue[pkt.src()].push(pkt.turn()); break; - case PT_LEAVE_GAME: - // todo + case PT_JOIN_ACCEPT: + handle_accept(pkt); + break; + case PT_CONNECT: + connected_table[pkt.newplr()] = true; + active_table[pkt.newplr()] = true; + break; + case PT_DISCONNECT: + if (pkt.newplr() != plr_self) { + auto leaveinfo = pkt.leaveinfo(); + _SNETEVENT ev; + ev.eventid = EVENT_TYPE_PLAYER_LEAVE_GAME; + ev.playerid = pkt.newplr(); + ev.data = reinterpret_cast(&leaveinfo); + ev.databytes = sizeof(leaveinfo_t); + run_event_handler(ev); + } else { + // problem + } + connected_table[pkt.newplr()] = false; + active_table[pkt.newplr()] = false; break; // otherwise drop } @@ -50,13 +81,14 @@ bool base::SNetReceiveMessage(int* sender, char** data, int* size) message_queue.pop(); *sender = message_last.sender; *size = message_last.payload.size(); - *data = reinterpret_cast(message_last.payload.data()); + *data = reinterpret_cast(message_last.payload.data()); return true; } bool base::SNetSendMessage(int playerID, void* data, unsigned int size) { - if (playerID != SNPLAYER_ALL && playerID != SNPLAYER_OTHERS && (playerID < 0 || playerID >= MAX_PLRS)) + if (playerID != SNPLAYER_ALL && playerID != SNPLAYER_OTHERS + && (playerID < 0 || playerID >= MAX_PLRS)) abort(); auto raw_message = reinterpret_cast(data); buffer_t message(raw_message, raw_message + size); @@ -64,23 +96,25 @@ bool base::SNetSendMessage(int playerID, void* data, unsigned int size) message_queue.push(message_t(plr_self, message)); plr_t dest; if (playerID == SNPLAYER_ALL || playerID == SNPLAYER_OTHERS) - dest = ADDR_BROADCAST; + dest = PLR_BROADCAST; else dest = playerID; - upacket pkt = make_packet(PT_MESSAGE, plr_self, dest, message); - send(pkt); + if(dest != plr_self) { + auto pkt = pktfty->make_packet(plr_self, dest, message); + send(*pkt); + } return true; } -bool base::SNetReceiveTurns(char **data, unsigned int *size, DWORD *status) +bool base::SNetReceiveTurns(char** data, unsigned int* size, DWORD* status) { poll(); for (auto i = 0; i < MAX_PLRS; ++i) { status[i] = 0; - if (i == plr_self || active(i)) { + if (active_table[i] || i == plr_self) { status[i] |= PS_ACTIVE; } - if (i == plr_self || connected(i)) { + if (connected_table[i] || i == plr_self) { status[i] |= PS_CONNECTED; } if (!turn_queue[i].empty()) { @@ -88,7 +122,7 @@ bool base::SNetReceiveTurns(char **data, unsigned int *size, DWORD *status) status[i] |= PS_TURN_ARRIVED; turn_last[i] = turn_queue[i].front(); turn_queue[i].pop(); - data[i] = reinterpret_cast(&turn_last[i]); + data[i] = reinterpret_cast(&turn_last[i]); } } return true; @@ -98,8 +132,9 @@ bool base::SNetSendTurn(char* data, unsigned int size) { if (size != sizeof(turn_t)) ABORT(); - upacket pkt = make_packet(PT_TURN, plr_self, ADDR_BROADCAST, *reinterpret_cast(data)); - send(pkt); + auto pkt = pktfty->make_packet(plr_self, PLR_BROADCAST, + *reinterpret_cast(data)); + send(*pkt); return true; } @@ -113,23 +148,35 @@ int base::SNetGetProviderCaps(struct _SNETCAPS* caps) caps->bytessec = 1000000; // ? caps->latencyms = 0; // unused caps->defaultturnssec = 10; // ? - caps->defaultturnsintransit = 1; // maximum acceptable number of turns in queue? + caps->defaultturnsintransit = 1; // maximum acceptable number + // of turns in queue? return 1; } -void* base::SNetUnregisterEventHandler(event_type evtype, void(__stdcall* func)(struct _SNETEVENT *)) +void* base::SNetUnregisterEventHandler(event_type evtype, snet_event_func func) { registered_handlers.erase(evtype); return (void*)func; } -void* base::SNetRegisterEventHandler(event_type evtype, void(__stdcall* func)(struct _SNETEVENT *)) +void* base::SNetRegisterEventHandler(event_type evtype, snet_event_func func) { + /* + engine registers handler for: + EVENT_TYPE_PLAYER_LEAVE_GAME + EVENT_TYPE_PLAYER_CREATE_GAME (should be raised during SNetCreateGame + for non-creating player) + EVENT_TYPE_PLAYER_MESSAGE (for bnet? not implemented) + (engine uses same function for all three) + */ registered_handlers[evtype] = func; return (void*)func; - // need to handle: - // EVENT_TYPE_PLAYER_LEAVE_GAME - // EVENT_TYPE_PLAYER_CREATE_GAME (raised during SNetCreateGame?) - // EVENT_TYPE_PLAYER_MESSAGE - // all by the same function +} + +bool base::SNetLeaveGame(int type) +{ + auto pkt = pktfty->make_packet(plr_self, PLR_BROADCAST, + plr_self, type); + send(*pkt); + return true; } diff --git a/Stub/dvlnet/base.h b/Stub/dvlnet/base.h index b5f6b65d0..079f51e08 100644 --- a/Stub/dvlnet/base.h +++ b/Stub/dvlnet/base.h @@ -1,36 +1,37 @@ -// exact meaning yet to be worked out +#pragma once + #define PS_CONNECTED 0x10000 #define PS_TURN_ARRIVED 0x20000 #define PS_ACTIVE 0x40000 +#define LEAVE_NORMAL 3 +#define LEAVE_ENDING 0x40000004 + namespace dvlnet { class base : public dvlnet { public: + base(buffer_t info); + virtual int create(std::string addrstr, std::string passwd) = 0; virtual int join(std::string addrstr, std::string passwd) = 0; virtual bool SNetReceiveMessage(int* sender, char** data, int* size); virtual bool SNetSendMessage(int dest, void* data, unsigned int size); - virtual bool SNetReceiveTurns(char** data, unsigned int* size, DWORD* status); + virtual bool SNetReceiveTurns(char** data, unsigned int* size, + DWORD* status); virtual bool SNetSendTurn(char* data, unsigned int size); virtual int SNetGetProviderCaps(struct _SNETCAPS* caps); - virtual void* SNetRegisterEventHandler(event_type evtype, void(__stdcall* func)(struct _SNETEVENT*)); - virtual void* SNetUnregisterEventHandler(event_type evtype, void(__stdcall* func)(struct _SNETEVENT*)); + virtual void* SNetRegisterEventHandler(event_type evtype, + snet_event_func func); + virtual void* SNetUnregisterEventHandler(event_type evtype, + snet_event_func func); + virtual bool SNetLeaveGame(int type); virtual void poll() = 0; - virtual void send(upacket& pkt) = 0; - virtual bool connected(plr_t p) = 0; - virtual bool active(plr_t p) = 0; - - static constexpr unsigned short max_packet_size = 0xFFFF; - upacket make_packet(buffer_t buf); - template upacket make_packet(T t, Args... args); + virtual void send(packet& pkt) = 0; protected: - static constexpr daddr_t ADDR_BROADCAST = 0xFF; - static constexpr daddr_t ADDR_MASTER = 0xFE; - - std::map registered_handlers; + std::map registered_handlers; buffer_t game_init_info; struct message_t { @@ -42,32 +43,19 @@ namespace dvlnet { message_t message_last; std::queue message_queue; - std::array turn_last = { 0 }; + std::array turn_last = {}; std::array, MAX_PLRS> turn_queue; + std::array active_table = {}; + std::array connected_table = {}; - plr_t plr_self = ADDR_BROADCAST; + plr_t plr_self = PLR_BROADCAST; cookie_t cookie_self = 0; - key_t key = { 0 }; + std::unique_ptr pktfty; void setup_password(std::string pw); - void recv_local(upacket &pkt); - void run_event_handler(_SNETEVENT &ev); + void handle_accept(packet& pkt); + void recv_local(packet& pkt); + void run_event_handler(_SNETEVENT& ev); }; - - inline upacket base::make_packet(buffer_t buf) - { - auto ret = std::make_unique(key); - ret->create(std::move(buf)); - ret->decrypt(); - return ret; - } - - template upacket base::make_packet(T t, Args... args) - { - auto ret = std::make_unique(key); - ret->create(t, args...); - ret->encrypt(); - return ret; - } } diff --git a/Stub/dvlnet/dvlnet.h b/Stub/dvlnet/dvlnet.h index 95f7c7fcd..352049723 100644 --- a/Stub/dvlnet/dvlnet.h +++ b/Stub/dvlnet/dvlnet.h @@ -1,27 +1,39 @@ #pragma once -#include "sodium.h" +#include #include #include #include #include namespace dvlnet { + typedef void(__stdcall *snet_event_func)(struct _SNETEVENT*); + class dvlnet_exception : public std::exception {}; + class dvlnet { public: virtual int create(std::string addrstr, std::string passwd) = 0; virtual int join(std::string addrstr, std::string passwd) = 0; - virtual bool SNetReceiveMessage(int *sender, char **data, int *size) = 0; - virtual bool SNetSendMessage(int dest, void *data, unsigned int size) = 0; - virtual bool SNetReceiveTurns(char **data, unsigned int *size, DWORD *status) = 0; - virtual bool SNetSendTurn(char *data, unsigned int size) = 0; - virtual int SNetGetProviderCaps(struct _SNETCAPS *caps) = 0; - virtual void *SNetRegisterEventHandler(event_type evtype, void(__stdcall *func)(struct _SNETEVENT *)) = 0; - virtual void *SNetUnregisterEventHandler(event_type evtype, void(__stdcall *func)(struct _SNETEVENT *)) = 0; + virtual bool SNetReceiveMessage(int* sender, char** data, + int* size) = 0; + virtual bool SNetSendMessage(int dest, void* data, + unsigned int size) = 0; + virtual bool SNetReceiveTurns(char** data, unsigned int* size, + DWORD* status) = 0; + virtual bool SNetSendTurn(char* data, unsigned int size) = 0; + virtual int SNetGetProviderCaps(struct _SNETCAPS* caps) = 0; + virtual void* SNetRegisterEventHandler(event_type evtype, + snet_event_func func) = 0; + virtual void* SNetUnregisterEventHandler(event_type evtype, + snet_event_func func) = 0; + virtual bool SNetLeaveGame(int type) = 0; virtual ~dvlnet() {} }; } #include "dvlnet/packet.h" +#include "dvlnet/frame_queue.h" #include "dvlnet/loopback.h" #include "dvlnet/base.h" +#include "dvlnet/tcp_server.h" +#include "dvlnet/tcp_client.h" #include "dvlnet/udp_p2p.h" diff --git a/Stub/dvlnet/frame_queue.cpp b/Stub/dvlnet/frame_queue.cpp new file mode 100644 index 000000000..ff09a977d --- /dev/null +++ b/Stub/dvlnet/frame_queue.cpp @@ -0,0 +1,74 @@ +#include "../types.h" + +using namespace dvlnet; + +size_t frame_queue::size() +{ + return current_size; +} + +buffer_t frame_queue::read(size_t s) +{ + if(current_size < s) + throw frame_queue_exception(); + buffer_t ret; + while(s > 0 && s >= buffer_deque.front().size()) { + s -= buffer_deque.front().size(); + current_size -= buffer_deque.front().size(); + ret.insert(ret.end(), + buffer_deque.front().begin(), + buffer_deque.front().end()); + buffer_deque.pop_front(); + } + if(s > 0) { + ret.insert(ret.end(), + buffer_deque.front().begin(), + buffer_deque.front().begin()+s); + buffer_deque.front().erase(buffer_deque.front().begin(), + buffer_deque.front().begin()+s); + current_size -= s; + } + return std::move(ret); +} + +void frame_queue::write(buffer_t buf) +{ + current_size += buf.size(); + buffer_deque.push_back(std::move(buf)); +} + +bool frame_queue::packet_ready() +{ + if(!nextsize) { + if(size() < sizeof(framesize_t)) + return false; + auto szbuf = read(sizeof(framesize_t)); + nextsize = *(reinterpret_cast(&szbuf[0])); + if(!nextsize) + throw frame_queue_exception(); + } + if(size() >= nextsize) + return true; + else + return false; +} + +buffer_t frame_queue::read_packet() +{ + if(!nextsize || (size() < nextsize)) + throw frame_queue_exception(); + auto ret = std::move(read(nextsize)); + nextsize = 0; + return std::move(ret); +} + +buffer_t frame_queue::make_frame(buffer_t packetbuf) +{ + buffer_t ret; + if(packetbuf.size() > max_frame_size) + ABORT(); + frame_queue::framesize_t size = packetbuf.size(); + ret.insert(ret.end(), packet_out::begin(size), packet_out::end(size)); + ret.insert(ret.end(), packetbuf.begin(), packetbuf.end()); + return std::move(ret); +} diff --git a/Stub/dvlnet/frame_queue.h b/Stub/dvlnet/frame_queue.h new file mode 100644 index 000000000..73c91a305 --- /dev/null +++ b/Stub/dvlnet/frame_queue.h @@ -0,0 +1,24 @@ +#pragma once + +namespace dvlnet { + class frame_queue_exception : public dvlnet_exception {}; + + class frame_queue { + public: + typedef uint32_t framesize_t; + constexpr static framesize_t max_frame_size = 0xFFFF; + private: + size_t current_size = 0; + std::deque buffer_deque; + size_t nextsize = 0; + + size_t size(); + buffer_t read(size_t s); + public: + bool packet_ready(); + buffer_t read_packet(); + void write(buffer_t buf); + + static buffer_t make_frame(buffer_t packetbuf); + }; +} diff --git a/Stub/dvlnet/loopback.cpp b/Stub/dvlnet/loopback.cpp index f8e862100..e6873ca11 100644 --- a/Stub/dvlnet/loopback.cpp +++ b/Stub/dvlnet/loopback.cpp @@ -12,7 +12,7 @@ int loopback::join(std::string addrstr, std::string passwd) ABORT(); } -bool loopback::SNetReceiveMessage(int *sender, char **data, int *size) +bool loopback::SNetReceiveMessage(int* sender, char** data, int* size) { if (message_queue.empty()) return false; @@ -24,29 +24,29 @@ bool loopback::SNetReceiveMessage(int *sender, char **data, int *size) return true; } -bool loopback::SNetSendMessage(int dest, void *data, unsigned int size) +bool loopback::SNetSendMessage(int dest, void* data, unsigned int size) { if (dest == plr_single || dest == SNPLAYER_ALL) { - auto raw_message = reinterpret_cast(data); + auto raw_message = reinterpret_cast(data); buffer_t message(raw_message, raw_message + size); message_queue.push(message); } return true; } -bool loopback::SNetReceiveTurns(char **data, unsigned int *size, DWORD *status) +bool loopback::SNetReceiveTurns(char** data, unsigned int* size, DWORD* status) { // todo: check that this is safe return true; } -bool loopback::SNetSendTurn(char *data, unsigned int size) +bool loopback::SNetSendTurn(char* data, unsigned int size) { // todo: check that this is safe return true; } -int loopback::SNetGetProviderCaps(struct _SNETCAPS *caps) +int loopback::SNetGetProviderCaps(struct _SNETCAPS* caps) { caps->size = 0; // engine writes only ?!? caps->flags = 0; // unused @@ -56,20 +56,28 @@ int loopback::SNetGetProviderCaps(struct _SNETCAPS *caps) caps->bytessec = 1000000; // ? caps->latencyms = 0; // unused caps->defaultturnssec = 10; // ? - caps->defaultturnsintransit = 1; // maximum acceptable number of turns in queue? + caps->defaultturnsintransit = 1; // maximum acceptable number + // of turns in queue? return 1; } -void *loopback::SNetRegisterEventHandler(event_type evtype, void(__stdcall *func)(struct _SNETEVENT *)) +void* loopback::SNetRegisterEventHandler(event_type evtype, + snet_event_func func) { // not called in real singleplayer mode // not needed in pseudo multiplayer mode (?) return this; } -void *loopback::SNetUnregisterEventHandler(event_type evtype, void(__stdcall *func)(struct _SNETEVENT *)) +void* loopback::SNetUnregisterEventHandler(event_type evtype, + snet_event_func func) { // not called in real singleplayer mode // not needed in pseudo multiplayer mode (?) return this; } + +bool loopback::SNetLeaveGame(int type) +{ + return true; +} diff --git a/Stub/dvlnet/loopback.h b/Stub/dvlnet/loopback.h index 67378723a..69b95aa5b 100644 --- a/Stub/dvlnet/loopback.h +++ b/Stub/dvlnet/loopback.h @@ -10,12 +10,16 @@ namespace dvlnet { public: virtual int create(std::string addrstr, std::string passwd); virtual int join(std::string addrstr, std::string passwd); - virtual bool SNetReceiveMessage(int *sender, char **data, int *size); - virtual bool SNetSendMessage(int dest, void *data, unsigned int size); - virtual bool SNetReceiveTurns(char **data, unsigned int *size, DWORD *status); - virtual bool SNetSendTurn(char *data, unsigned int size); - virtual int SNetGetProviderCaps(struct _SNETCAPS *caps); - virtual void *SNetRegisterEventHandler(event_type evtype, void(__stdcall *func)(struct _SNETEVENT *)); - virtual void *SNetUnregisterEventHandler(event_type evtype, void(__stdcall *func)(struct _SNETEVENT *)); + virtual bool SNetReceiveMessage(int* sender, char** data, int* size); + virtual bool SNetSendMessage(int dest, void* data, unsigned int size); + virtual bool SNetReceiveTurns(char** data, unsigned int* size, + DWORD* status); + virtual bool SNetSendTurn(char* data, unsigned int size); + virtual int SNetGetProviderCaps(struct _SNETCAPS* caps); + virtual void *SNetRegisterEventHandler(event_type evtype, + snet_event_func func); + virtual void *SNetUnregisterEventHandler(event_type evtype, + snet_event_func func); + virtual bool SNetLeaveGame(int type); }; } diff --git a/Stub/dvlnet/packet.cpp b/Stub/dvlnet/packet.cpp index dc1b119f9..b552738da 100644 --- a/Stub/dvlnet/packet.cpp +++ b/Stub/dvlnet/packet.cpp @@ -32,7 +32,7 @@ plr_t packet::dest() return m_dest; } -const buffer_t &packet::message() +const buffer_t& packet::message() { if (!have_decrypted) ABORT(); @@ -63,27 +63,28 @@ plr_t packet::newplr() { if (!have_decrypted) ABORT(); - if (m_type != PT_JOIN_ACCEPT) + if (m_type != PT_JOIN_ACCEPT && m_type != PT_CONNECT + && m_type != PT_DISCONNECT) throw packet_exception(); return m_newplr; } -plr_t packet::oldplr() +const buffer_t& packet::info() { if (!have_decrypted) ABORT(); - if (m_type != PT_LEAVE_GAME) + if (m_type != PT_JOIN_REQUEST && m_type != PT_JOIN_ACCEPT) throw packet_exception(); - return m_oldplr; + return m_info; } -const buffer_t &packet::info() +leaveinfo_t packet::leaveinfo() { if (!have_decrypted) ABORT(); - if (m_type != PT_JOIN_ACCEPT) + if (m_type != PT_DISCONNECT) throw packet_exception(); - return m_info; + return m_leaveinfo; } void packet_in::create(buffer_t buf) @@ -104,13 +105,17 @@ void packet_in::decrypt() if (encrypted_buffer.size() < crypto_secretbox_NONCEBYTES + crypto_secretbox_MACBYTES + sizeof(packet_type) + 2 * sizeof(plr_t)) throw packet_exception(); - auto pktlen = encrypted_buffer.size() - crypto_secretbox_NONCEBYTES - crypto_secretbox_MACBYTES; + auto pktlen = (encrypted_buffer.size() + - crypto_secretbox_NONCEBYTES + - crypto_secretbox_MACBYTES); decrypted_buffer.resize(pktlen); if (crypto_secretbox_open_easy(decrypted_buffer.data(), - encrypted_buffer.data() + crypto_secretbox_NONCEBYTES, - encrypted_buffer.size() - crypto_secretbox_NONCEBYTES, - encrypted_buffer.data(), - key.data())) + encrypted_buffer.data() + + crypto_secretbox_NONCEBYTES, + encrypted_buffer.size() + - crypto_secretbox_NONCEBYTES, + encrypted_buffer.data(), + key.data())) throw packet_exception(); } else { if (encrypted_buffer.size() < sizeof(packet_type) + 2 * sizeof(plr_t)) @@ -123,90 +128,6 @@ void packet_in::decrypt() have_decrypted = true; } -void packet_out::create(packet_type t, - plr_t s, - plr_t d, - buffer_t m) -{ - if (have_encrypted || have_decrypted) - ABORT(); - if (t != PT_MESSAGE) - ABORT(); - have_decrypted = true; - m_type = t; - m_src = s; - m_dest = d; - m_message = std::move(m); -} - -void packet_out::create(packet_type t, - plr_t s, - plr_t d, - turn_t u) -{ - if (have_encrypted || have_decrypted) - ABORT(); - if (t != PT_TURN) - ABORT(); - have_decrypted = true; - m_type = t; - m_src = s; - m_dest = d; - m_turn = u; -} - -void packet_out::create(packet_type t, - plr_t s, - plr_t d, - cookie_t c) -{ - if (have_encrypted || have_decrypted) - ABORT(); - if (t != PT_JOIN_REQUEST) - ABORT(); - have_decrypted = true; - m_type = t; - m_src = s; - m_dest = d; - m_cookie = c; -} - -void packet_out::create(packet_type t, - plr_t s, - plr_t d, - cookie_t c, - plr_t n, - buffer_t i) -{ - if (have_encrypted || have_decrypted) - ABORT(); - if (t != PT_JOIN_ACCEPT) - ABORT(); - have_decrypted = true; - m_type = t; - m_src = s; - m_dest = d; - m_cookie = c; - m_newplr = n; - m_info = i; -} - -void packet_out::create(packet_type t, - plr_t s, - plr_t d, - plr_t o) -{ - if (have_encrypted || have_decrypted) - ABORT(); - if (t != PT_LEAVE_GAME) - ABORT(); - have_decrypted = true; - m_type = t; - m_src = s; - m_dest = d; - m_oldplr = o; -} - void packet_out::encrypt() { if (!have_decrypted) @@ -218,15 +139,36 @@ void packet_out::encrypt() if (!disable_encryption) { auto len_cleartext = encrypted_buffer.size(); - encrypted_buffer.insert(encrypted_buffer.begin(), crypto_secretbox_NONCEBYTES, 0); - encrypted_buffer.insert(encrypted_buffer.end(), crypto_secretbox_MACBYTES, 0); + encrypted_buffer.insert(encrypted_buffer.begin(), + crypto_secretbox_NONCEBYTES, 0); + encrypted_buffer.insert(encrypted_buffer.end(), + crypto_secretbox_MACBYTES, 0); randombytes_buf(encrypted_buffer.data(), crypto_secretbox_NONCEBYTES); - if (crypto_secretbox_easy(encrypted_buffer.data() + crypto_secretbox_NONCEBYTES, - encrypted_buffer.data() + crypto_secretbox_NONCEBYTES, - len_cleartext, - encrypted_buffer.data(), - key.data())) + if (crypto_secretbox_easy(encrypted_buffer.data() + + crypto_secretbox_NONCEBYTES, + encrypted_buffer.data() + + crypto_secretbox_NONCEBYTES, + len_cleartext, + encrypted_buffer.data(), + key.data())) ABORT(); } have_encrypted = true; } + +packet_factory::packet_factory(std::string pw) +{ + if (sodium_init() < 0) + ABORT(); + //pw.resize(std::min(pw.size(), crypto_pwhash_PASSWD_MAX)); + //pw.resize(std::max(pw.size(), crypto_pwhash_PASSWD_MIN), 0); + std::string salt("devilution-salt"); + salt.resize(crypto_pwhash_SALTBYTES, 0); + if (crypto_pwhash(key.data(), crypto_secretbox_KEYBYTES, + pw.data(), pw.size(), + reinterpret_cast(salt.data()), + crypto_pwhash_OPSLIMIT_INTERACTIVE, + crypto_pwhash_MEMLIMIT_INTERACTIVE, + crypto_pwhash_ALG_DEFAULT)) + ABORT(); +} diff --git a/Stub/dvlnet/packet.h b/Stub/dvlnet/packet.h index 0fa97074e..2035c2a37 100644 --- a/Stub/dvlnet/packet.h +++ b/Stub/dvlnet/packet.h @@ -6,16 +6,21 @@ namespace dvlnet { PT_TURN = 0x02, PT_JOIN_REQUEST = 0x11, PT_JOIN_ACCEPT = 0x12, - PT_LEAVE_GAME = 0x13, + PT_CONNECT = 0x13, + PT_DISCONNECT = 0x14, }; typedef uint8_t plr_t; typedef uint32_t cookie_t; typedef int turn_t; // change int to something else in devilution code later + typedef int leaveinfo_t; // also change later typedef std::vector buffer_t; typedef std::array key_t; - class packet_exception : public std::exception {}; + static constexpr plr_t PLR_MASTER = 0xFE; + static constexpr plr_t PLR_BROADCAST = 0xFF; + + class packet_exception : public dvlnet_exception {}; class packet { protected: @@ -26,29 +31,29 @@ namespace dvlnet { turn_t m_turn; cookie_t m_cookie; plr_t m_newplr; - plr_t m_oldplr; buffer_t m_info; + leaveinfo_t m_leaveinfo; - const key_t &key; + const key_t& key; bool have_encrypted = false; bool have_decrypted = false; buffer_t encrypted_buffer; buffer_t decrypted_buffer; public: - packet(const key_t &k) : key(k) {}; + packet(const key_t& k) : key(k) {}; - const buffer_t &data(); + const buffer_t& data(); packet_type type(); plr_t src(); plr_t dest(); - const buffer_t &message(); + const buffer_t& message(); turn_t turn(); cookie_t cookie(); plr_t newplr(); - plr_t oldplr(); - const buffer_t &info(); + const buffer_t& info(); + leaveinfo_t leaveinfo(); }; template class packet_proc : public packet { @@ -57,35 +62,32 @@ namespace dvlnet { void process_data(); }; - typedef std::unique_ptr upacket; - class packet_in : public packet_proc { public: using packet_proc::packet_proc; void create(buffer_t buf); - void process_element(buffer_t &x); - template void process_element(T &x); + void process_element(buffer_t& x); + template void process_element(T& x); void decrypt(); }; class packet_out : public packet_proc { public: using packet_proc::packet_proc; - void create(packet_type t, plr_t s, plr_t d, buffer_t m); - void create(packet_type t, plr_t s, plr_t d, turn_t u); - void create(packet_type t, plr_t s, plr_t d, cookie_t c); - void create(packet_type t, plr_t s, plr_t d, cookie_t c, plr_t n, buffer_t i); - void create(packet_type t, plr_t s, plr_t d, plr_t o); - void process_element(buffer_t &x); - template void process_element(T &x); - template static const unsigned char *begin(const T &x); - template static const unsigned char *end(const T &x); + + template + void create(Args... args); + + void process_element(buffer_t& x); + template void process_element(T& x); + template static const unsigned char* begin(const T& x); + template static const unsigned char* end(const T& x); void encrypt(); }; template void packet_proc

::process_data() { - P &self = static_cast(*this); + P& self = static_cast(*this); self.process_element(m_type); self.process_element(m_src); self.process_element(m_dest); @@ -98,48 +100,163 @@ namespace dvlnet { break; case PT_JOIN_REQUEST: self.process_element(m_cookie); + self.process_element(m_info); break; case PT_JOIN_ACCEPT: self.process_element(m_cookie); self.process_element(m_newplr); self.process_element(m_info); break; - case PT_LEAVE_GAME: + case PT_CONNECT: + self.process_element(m_newplr); + break; + case PT_DISCONNECT: + self.process_element(m_newplr); + self.process_element(m_leaveinfo); break; } } - inline void packet_in::process_element(buffer_t &x) + inline void packet_in::process_element(buffer_t& x) { x.insert(x.begin(), decrypted_buffer.begin(), decrypted_buffer.end()); decrypted_buffer.resize(0); } - template void packet_in::process_element(T &x) + template void packet_in::process_element(T& x) { if (decrypted_buffer.size() < sizeof(T)) throw packet_exception(); x = *reinterpret_cast(decrypted_buffer.data()); - decrypted_buffer.erase(decrypted_buffer.begin(), decrypted_buffer.begin() + sizeof(T)); + decrypted_buffer.erase(decrypted_buffer.begin(), + decrypted_buffer.begin() + sizeof(T)); } - inline void packet_out::process_element(buffer_t &x) + template<> + inline void packet_out::create(plr_t s, plr_t d, buffer_t m) + { + if (have_encrypted || have_decrypted) + ABORT(); + have_decrypted = true; + m_type = PT_MESSAGE; + m_src = s; + m_dest = d; + m_message = std::move(m); + } + + template<> + inline void packet_out::create(plr_t s, plr_t d, turn_t u) + { + if (have_encrypted || have_decrypted) + ABORT(); + have_decrypted = true; + m_type = PT_TURN; + m_src = s; + m_dest = d; + m_turn = u; + } + + template<> + inline void packet_out::create(plr_t s, plr_t d, + cookie_t c, buffer_t i) + { + if (have_encrypted || have_decrypted) + ABORT(); + have_decrypted = true; + m_type = PT_JOIN_REQUEST; + m_src = s; + m_dest = d; + m_cookie = c; + m_info = i; + } + + template<> + inline void packet_out::create(plr_t s, plr_t d, cookie_t c, + plr_t n, buffer_t i) + { + if (have_encrypted || have_decrypted) + ABORT(); + have_decrypted = true; + m_type = PT_JOIN_ACCEPT; + m_src = s; + m_dest = d; + m_cookie = c; + m_newplr = n; + m_info = i; + } + + template<> + inline void packet_out::create(plr_t s, plr_t d, plr_t n) + { + if (have_encrypted || have_decrypted) + ABORT(); + have_decrypted = true; + m_type = PT_CONNECT; + m_src = s; + m_dest = d; + m_newplr = n; + } + + template<> + inline void packet_out::create(plr_t s, plr_t d, plr_t n, + leaveinfo_t l) + { + if (have_encrypted || have_decrypted) + ABORT(); + have_decrypted = true; + m_type = PT_DISCONNECT; + m_src = s; + m_dest = d; + m_newplr = n; + m_leaveinfo = l; + } + + inline void packet_out::process_element(buffer_t& x) { encrypted_buffer.insert(encrypted_buffer.end(), x.begin(), x.end()); } - template void packet_out::process_element(T &x) + template void packet_out::process_element(T& x) { encrypted_buffer.insert(encrypted_buffer.end(), begin(x), end(x)); } - template const unsigned char* packet_out::begin(const T &x) + template const unsigned char* packet_out::begin(const T& x) { return reinterpret_cast(&x); } - template const unsigned char* packet_out::end(const T &x) + template const unsigned char* packet_out::end(const T& x) { return reinterpret_cast(&x) + sizeof(T); } + + class packet_factory { + key_t key = {}; + + public: + static constexpr unsigned short max_packet_size = 0xFFFF; + + packet_factory(std::string pw); + std::unique_ptr make_packet(buffer_t buf); + template + std::unique_ptr make_packet(Args... args); + }; + + inline std::unique_ptr packet_factory::make_packet(buffer_t buf) + { + auto ret = std::make_unique(key); + ret->create(std::move(buf)); + ret->decrypt(); + return ret; + } + + template + std::unique_ptr packet_factory::make_packet(Args... args) + { + auto ret = std::make_unique(key); + ret->create(args...); + ret->encrypt(); + return ret; + } } diff --git a/Stub/dvlnet/tcp_client.cpp b/Stub/dvlnet/tcp_client.cpp new file mode 100644 index 000000000..e0728424e --- /dev/null +++ b/Stub/dvlnet/tcp_client.cpp @@ -0,0 +1,79 @@ +#include "../types.h" + +using namespace dvlnet; + +tcp_client::tcp_client(buffer_t info) : + base(std::move(info)) +{ +} + +int tcp_client::create(std::string addrstr, std::string passwd) +{ + local_server = std::make_unique(ioc, addrstr, + default_port, passwd); + return join(local_server->localhost_self(), passwd); +} + +int tcp_client::join(std::string addrstr, std::string passwd) +{ + setup_password(passwd); + auto ipaddr = asio::ip::make_address(addrstr); + sock.connect(asio::ip::tcp::endpoint(ipaddr, default_port)); + start_recv(); + { // hack: try to join for 5 seconds + randombytes_buf(reinterpret_cast(&cookie_self), + sizeof(cookie_t)); + auto pkt = pktfty->make_packet(PLR_BROADCAST, + PLR_MASTER, cookie_self, + game_init_info); + send(*pkt); + for (auto i = 0; i < 5; ++i) { + poll(); + if (plr_self != PLR_BROADCAST) + break; // join successful + sleep(1); + } + } + return (plr_self == PLR_BROADCAST ? -1 : plr_self); +} + +void tcp_client::poll() +{ + ioc.poll(); +} + +void tcp_client::handle_recv(const asio::error_code& error, size_t bytes_read) +{ + if(error) + throw std::runtime_error(""); + if(bytes_read == 0) + throw std::runtime_error(""); + recv_buffer.resize(bytes_read); + recv_queue.write(std::move(recv_buffer)); + recv_buffer.resize(frame_queue::max_frame_size); + while(recv_queue.packet_ready()) { + auto pkt = pktfty->make_packet(recv_queue.read_packet()); + recv_local(*pkt); + } + start_recv(); +} + +void tcp_client::start_recv() +{ + sock.async_receive(asio::buffer(recv_buffer), + std::bind(&tcp_client::handle_recv, this, + std::placeholders::_1, std::placeholders::_2)); +} + +void tcp_client::handle_send(const asio::error_code& error, size_t bytes_sent) +{ + // empty for now +} + +void tcp_client::send(packet& pkt) +{ + auto frame = frame_queue::make_frame(pkt.data()); + asio::async_write(sock, asio::buffer(frame), + std::bind(&tcp_client::handle_send, this, + std::placeholders::_1, std::placeholders::_2)); +} diff --git a/Stub/dvlnet/tcp_client.h b/Stub/dvlnet/tcp_client.h new file mode 100644 index 000000000..88e65cf0b --- /dev/null +++ b/Stub/dvlnet/tcp_client.h @@ -0,0 +1,27 @@ +#pragma once + +namespace dvlnet { + class tcp_client : public base { + public: + tcp_client(buffer_t info); + int create(std::string addrstr, std::string passwd); + int join(std::string addrstr, std::string passwd); + + static constexpr unsigned short default_port = 6112; + + virtual void poll(); + virtual void send(packet& pkt); + private: + std::unique_ptr local_server; + + frame_queue recv_queue; + buffer_t recv_buffer = buffer_t(frame_queue::max_frame_size); + + asio::io_context ioc; + asio::ip::tcp::socket sock = asio::ip::tcp::socket(ioc); + + void handle_recv(const asio::error_code& error, size_t bytes_read); + void start_recv(); + void handle_send(const asio::error_code& error, size_t bytes_sent); + }; +} diff --git a/Stub/dvlnet/tcp_server.cpp b/Stub/dvlnet/tcp_server.cpp new file mode 100644 index 000000000..e2522d8e2 --- /dev/null +++ b/Stub/dvlnet/tcp_server.cpp @@ -0,0 +1,200 @@ +#include "../types.h" + +using namespace dvlnet; + +tcp_server::tcp_server(asio::io_context& ioc, std::string bindaddr, + unsigned short port, std::string pw) : + ioc(ioc), pktfty(pw) +{ + auto addr = asio::ip::address::from_string(bindaddr); + auto ep = asio::ip::tcp::endpoint(addr, port); + acceptor = std::make_unique(ioc, ep); + start_accept(); +} + +std::string tcp_server::localhost_self() +{ + auto addr = acceptor->local_endpoint().address(); + if(addr.is_unspecified()) { + if(addr.is_v4()) { + return asio::ip::address_v4::loopback().to_string(); + } else if(addr.is_v6()) { + return asio::ip::address_v6::loopback().to_string(); + } else { + ABORT(); + } + } else { + return addr.to_string(); + } +} + +tcp_server::scc tcp_server::make_connection() +{ + return std::make_shared(ioc); +} + +plr_t tcp_server::next_free() +{ + for(plr_t i = 0; i < MAX_PLRS; ++i) + if(!connections[i]) + return i; + return PLR_BROADCAST; +} + +bool tcp_server::empty() +{ + for(plr_t i = 0; i < MAX_PLRS; ++i) + if(connections[i]) + return false; + return true; +} + +void tcp_server::start_recv(scc con) +{ + con->socket.async_receive(asio::buffer(con->recv_buffer), + std::bind(&tcp_server::handle_recv, this, con, + std::placeholders::_1, + std::placeholders::_2)); +} + +void tcp_server::handle_recv(scc con, const asio::error_code& ec, + size_t bytes_read) +{ + if(ec || bytes_read == 0) { + drop_connection(con); + return; + } + con->recv_buffer.resize(bytes_read); + con->recv_queue.write(std::move(con->recv_buffer)); + con->recv_buffer.resize(frame_queue::max_frame_size); + while(con->recv_queue.packet_ready()) { + try { + auto pkt = pktfty.make_packet(con->recv_queue.read_packet()); + if(con->plr == PLR_BROADCAST) { + handle_recv_newplr(con, *pkt); + } else { + con->timeout = timeout_active; + handle_recv_packet(*pkt); + } + } catch (dvlnet_exception e) { + drop_connection(con); + return; + } + } + start_recv(con); +} + +void tcp_server::send_connect(scc con) +{ + auto pkt = pktfty.make_packet(PLR_MASTER, PLR_BROADCAST, + con->plr); + send_packet(*pkt); +} + +void tcp_server::handle_recv_newplr(scc con, packet& pkt) +{ + auto newplr = next_free(); + if(newplr == PLR_BROADCAST) + throw server_exception(); + if(empty()) + game_init_info = pkt.info(); + auto reply = pktfty.make_packet(PLR_MASTER, PLR_BROADCAST, + pkt.cookie(), newplr, + game_init_info); + start_send(con, *reply); + con->plr = newplr; + connections[newplr] = con; + con->timeout = timeout_active; + send_connect(con); +} + +void tcp_server::handle_recv_packet(packet& pkt) +{ + send_packet(pkt); +} + +void tcp_server::send_packet(packet& pkt) +{ + if(pkt.dest() == PLR_BROADCAST) { + for(auto i = 0; i < MAX_PLRS; ++i) + if(i != pkt.src() && connections[i]) + start_send(connections[i], pkt); + } else { + if(pkt.dest() < 0 || pkt.dest() >= MAX_PLRS) + throw server_exception(); + if((pkt.dest() != pkt.src()) && connections[pkt.dest()]) + start_send(connections[pkt.dest()], pkt); + } +} + +void tcp_server::start_send(scc con, packet& pkt) +{ + auto frame = frame_queue::make_frame(pkt.data()); + asio::async_write(con->socket, asio::buffer(frame), + std::bind(&tcp_server::handle_send, this, con, + std::placeholders::_1, std::placeholders::_2)); + +} + +void tcp_server::handle_send(scc con, const asio::error_code& ec, + size_t bytes_sent) +{ + // empty for now +} + +void tcp_server::start_accept() +{ + auto nextcon = make_connection(); + acceptor->async_accept(nextcon->socket, + std::bind(&tcp_server::handle_accept, + this, nextcon, + std::placeholders::_1)); +} + +void tcp_server::handle_accept(scc con, const asio::error_code& ec) +{ + if(next_free() == PLR_BROADCAST) { + drop_connection(con); + } else { + con->timeout = timeout_connect; + start_recv(con); + start_timeout(con); + } + start_accept(); +} + +void tcp_server::start_timeout(scc con) +{ + con->timer.expires_after(std::chrono::seconds(1)); + con->timer.async_wait(std::bind(&tcp_server::handle_timeout, this, con, + std::placeholders::_1)); +} + +void tcp_server::handle_timeout(scc con, const asio::error_code& ec) +{ + if(ec) { + drop_connection(con); + return; + } + if(con->timeout > 0) + con->timeout -= 1; + if(con->timeout < 0) + con->timeout = 0; + if(!con->timeout) { + drop_connection(con); + return; + } + start_timeout(con); +} + +void tcp_server::drop_connection(scc con) +{ + if(con->plr != PLR_BROADCAST) { + auto pkt = pktfty.make_packet(PLR_MASTER, PLR_BROADCAST, + con->plr, 0); + connections[con->plr] = nullptr; + send_packet(*pkt); + } + con->timer.cancel(); + con->socket.close(); +} diff --git a/Stub/dvlnet/tcp_server.h b/Stub/dvlnet/tcp_server.h new file mode 100644 index 000000000..08f2e250c --- /dev/null +++ b/Stub/dvlnet/tcp_server.h @@ -0,0 +1,52 @@ +#pragma once + +namespace dvlnet { + class server_exception : public dvlnet_exception {}; + + class tcp_server { + public: + tcp_server(asio::io_context& ioc, std::string bindaddr, + unsigned short port, std::string pw); + std::string localhost_self(); + + private: + static constexpr int timeout_connect = 30; + static constexpr int timeout_active = 60; + + struct client_connection { + frame_queue recv_queue; + buffer_t recv_buffer = buffer_t(frame_queue::max_frame_size); + plr_t plr = PLR_BROADCAST; + asio::ip::tcp::socket socket; + asio::steady_timer timer; + int timeout; + client_connection(asio::io_context& ioc) : + socket(ioc), timer(ioc) {} + }; + + typedef std::shared_ptr scc; + + asio::io_context& ioc; + packet_factory pktfty; + std::unique_ptr acceptor; + std::array connections; + buffer_t game_init_info; + + scc make_connection(); + plr_t next_free(); + bool empty(); + void start_accept(); + void handle_accept(scc con, const asio::error_code& ec); + void start_recv(scc con); + void handle_recv(scc con, const asio::error_code& ec, size_t bytes_read); + void handle_recv_newplr(scc con, packet& pkt); + void handle_recv_packet(packet& pkt); + void send_connect(scc con); + void send_packet(packet& pkt); + void start_send(scc con, packet& pkt); + void handle_send(scc con, const asio::error_code& ec, size_t bytes_sent); + void start_timeout(scc con); + void handle_timeout(scc con, const asio::error_code& ec); + void drop_connection(scc con); + }; +} diff --git a/Stub/dvlnet/udp_p2p.cpp b/Stub/dvlnet/udp_p2p.cpp index 5ef7a63df..fc66b9315 100644 --- a/Stub/dvlnet/udp_p2p.cpp +++ b/Stub/dvlnet/udp_p2p.cpp @@ -4,17 +4,14 @@ using namespace dvlnet; const udp_p2p::endpoint udp_p2p::none; -udp_p2p::udp_p2p(buffer_t info) +udp_p2p::udp_p2p(buffer_t info) : + base(info) { - if (sodium_init() < 0) - abort(); - game_init_info = std::move(info); } - int udp_p2p::create(std::string addrstr, std::string passwd) { - sock = asio::ip::udp::socket(io_context); // to be removed later + sock = asio::ip::udp::socket(io_context);// to be removed later setup_password(passwd); auto ipaddr = asio::ip::make_address(addrstr); if (ipaddr.is_v4()) @@ -28,7 +25,8 @@ int udp_p2p::create(std::string addrstr, std::string passwd) try { sock.bind(asio::ip::udp::endpoint(asio::ip::address_v6(), port)); } catch (std::exception e) { - eprintf("bind: %s, %s\n", asio::ip::address_v6().to_string(), e.what()); + eprintf("bind: %s, %s\n", asio::ip::address_v6().to_string(), + e.what()); } ++port; } @@ -58,16 +56,18 @@ int udp_p2p::join(std::string addrstr, std::string passwd) { // hack: try to join for 5 seconds randombytes_buf(reinterpret_cast(&cookie_self), sizeof(cookie_t)); - upacket pkt = make_packet(PT_JOIN_REQUEST, ADDR_BROADCAST, ADDR_MASTER, cookie_self); - send(pkt); + auto pkt = pktfty->make_packet(PLR_BROADCAST, + PLR_MASTER, cookie_self, + game_init_info); + send(*pkt); for (auto i = 0; i < 5; ++i) { recv(); - if (plr_self != ADDR_BROADCAST) + if (plr_self != PLR_BROADCAST) break; // join successful sleep(1); } } - return (plr_self == ADDR_BROADCAST ? 4 : plr_self); + return (plr_self == PLR_BROADCAST ? 4 : plr_self); } void udp_p2p::poll() @@ -75,33 +75,23 @@ void udp_p2p::poll() recv(); } -void udp_p2p::send(upacket& pkt) +void udp_p2p::send(packet& pkt) { send_internal(pkt, none); } -bool udp_p2p::connected(plr_t p) -{ - return active_table[p]; -} - -bool udp_p2p::active(plr_t p) -{ - return active_table[p]; -} - void udp_p2p::recv() { try { while (1) { // read until kernel buffer is empty? try { endpoint sender; - buffer_t pkt_buf(max_packet_size); + buffer_t pkt_buf(packet_factory::max_packet_size); size_t pkt_len; pkt_len = sock.receive_from(asio::buffer(pkt_buf), sender); pkt_buf.resize(pkt_len); - upacket pkt = make_packet(pkt_buf); - recv_decrypted(pkt, sender); + auto pkt = pktfty->make_packet(pkt_buf); + recv_decrypted(*pkt, sender); } catch (packet_exception e) { // drop packet } @@ -111,10 +101,10 @@ void udp_p2p::recv() } } -void udp_p2p::send_internal(upacket &pkt, endpoint sender) +void udp_p2p::send_internal(packet& pkt, endpoint sender) { - for (auto &dest : dests_for_addr(pkt->dest(), sender)) { - sock.send_to(asio::buffer(pkt->data()), dest); + for (auto &dest : dests_for_addr(pkt.dest(), sender)) { + sock.send_to(asio::buffer(pkt.data()), dest); } } @@ -127,12 +117,13 @@ std::set udp_p2p::dests_for_addr(plr_t dest, endpoint sender) if (0 <= dest && dest < MAX_PLRS) { if (active_table[dest]) ret.insert(nexthop_table[dest]); - } else if (dest == ADDR_BROADCAST) { + } else if (dest == PLR_BROADCAST) { for (auto i = 0; i < MAX_PLRS; ++i) if (i != plr_self && active_table[i]) ret.insert(nexthop_table[i]); - ret.insert(connection_requests_pending.begin(), connection_requests_pending.end()); - } else if (dest == ADDR_MASTER) { + ret.insert(connection_requests_pending.begin(), + connection_requests_pending.end()); + } else if (dest == PLR_MASTER) { if (master != none) ret.insert(master); } @@ -140,7 +131,7 @@ std::set udp_p2p::dests_for_addr(plr_t dest, endpoint sender) return ret; } -void udp_p2p::handle_join_request(upacket &pkt, endpoint sender) +void udp_p2p::handle_join_request(packet& pkt, endpoint sender) { plr_t i; for (i = 0; i < MAX_PLRS; ++i) { @@ -149,49 +140,35 @@ void udp_p2p::handle_join_request(upacket &pkt, endpoint sender) break; } } - upacket reply = make_packet(PT_JOIN_ACCEPT, plr_self, ADDR_BROADCAST, pkt->cookie(), i, game_init_info); - send(reply); + auto reply = pktfty->make_packet(plr_self, PLR_BROADCAST, + pkt.cookie(), i, + game_init_info); + send(*reply); } -void udp_p2p::recv_decrypted(upacket &pkt, endpoint sender) +void udp_p2p::recv_decrypted(packet& pkt, endpoint sender) { // 1. route send_internal(pkt, sender); // 2. handle local - if (pkt->src() == ADDR_BROADCAST && pkt->dest() == ADDR_MASTER) { + if (pkt.src() == PLR_BROADCAST && pkt.dest() == PLR_MASTER) { connection_requests_pending.insert(sender); if (master == none) { handle_join_request(pkt, sender); } } // normal packets - if (pkt->src() < 0 || pkt->src() >= MAX_PLRS) + if (pkt.src() < 0 || pkt.src() >= MAX_PLRS) return; //drop packet - if (active_table[pkt->src()]) { //WRONG?!? - if (sender != nexthop_table[pkt->src()]) + if (active_table[pkt.src()]) { //WRONG?!? + if (sender != nexthop_table[pkt.src()]) return; //rpfilter fail: drop packet } else { - nexthop_table[pkt->src()] = sender; // new connection: accept + nexthop_table[pkt.src()] = sender; // new connection: accept } - active_table[pkt->src()] = ACTIVE; - if (pkt->dest() != plr_self && pkt->dest() != ADDR_BROADCAST) + active_table[pkt.src()] = true; + connected_table[pkt.src()] = true; + if (pkt.dest() != plr_self && pkt.dest() != PLR_BROADCAST) return; //packet not for us, drop - if(pkt->type() == PT_JOIN_ACCEPT) - handle_accept(pkt); - else - recv_local(pkt); -} - -void udp_p2p::handle_accept(upacket &pkt) -{ - if (plr_self != ADDR_BROADCAST) - return; // already have player id - if (pkt->cookie() == cookie_self) - plr_self = pkt->newplr(); - _SNETEVENT ev; - ev.eventid = EVENT_TYPE_PLAYER_CREATE_GAME; - ev.playerid = plr_self; - ev.data = const_cast(pkt->info().data()); - ev.databytes = pkt->info().size(); - run_event_handler(ev); + recv_local(pkt); } diff --git a/Stub/dvlnet/udp_p2p.h b/Stub/dvlnet/udp_p2p.h index c484d6ba7..c06c06219 100644 --- a/Stub/dvlnet/udp_p2p.h +++ b/Stub/dvlnet/udp_p2p.h @@ -5,9 +5,7 @@ namespace dvlnet { virtual int create(std::string addrstr, std::string passwd); virtual int join(std::string addrstr, std::string passwd); virtual void poll(); - virtual void send(upacket& pkt); - virtual bool connected(plr_t p); - virtual bool active(plr_t p); + virtual void send(packet& pkt); private: typedef asio::ip::udp::endpoint endpoint; @@ -24,15 +22,13 @@ namespace dvlnet { std::set connection_requests_pending; std::array nexthop_table; - std::array active_table = { 0 }; asio::ip::udp::socket sock = asio::ip::udp::socket(io_context); void recv(); - void handle_join_request(upacket &pkt, endpoint sender); - void handle_accept(upacket &pkt); - void send_internal(upacket& pkt, endpoint sender = none); + void handle_join_request(packet& pkt, endpoint sender); + void send_internal(packet& pkt, endpoint sender = none); std::set dests_for_addr(plr_t dest, endpoint sender); - void recv_decrypted(upacket &pkt, endpoint sender); + void recv_decrypted(packet& pkt, endpoint sender); }; } diff --git a/Stub/storm_net.cpp b/Stub/storm_net.cpp index 9fd2207c6..eed502c3b 100644 --- a/Stub/storm_net.cpp +++ b/Stub/storm_net.cpp @@ -54,7 +54,8 @@ BOOL STORMAPI SNetDestroy() BOOL STORMAPI SNetDropPlayer(int playerid, DWORD flags) { - UNIMPLEMENTED(); + DUMMY(); + return TRUE; } BOOL STORMAPI SNetGetGameInfo(int type, void *dst, unsigned int length, unsigned int *byteswritten) @@ -65,8 +66,7 @@ BOOL STORMAPI SNetGetGameInfo(int type, void *dst, unsigned int length, unsigned BOOL STORMAPI SNetLeaveGame(int type) { - DUMMY(); - return TRUE; + return dvlnet_inst->SNetLeaveGame(type); } BOOL STORMAPI SNetSendServerChatCommand(const char *command) @@ -86,8 +86,8 @@ int __stdcall SNetInitializeProvider(unsigned long provider, struct _SNETPROGRAM { if (provider == 'UDPN') { dvlnet::buffer_t game_init_info((char*)client_info->initdata, - (char*)client_info->initdata + client_info->initdatabytes); - dvlnet_inst = std::make_unique(std::move(game_init_info)); + (char*)client_info->initdata + client_info->initdatabytes); + dvlnet_inst = std::make_unique(std::move(game_init_info)); } else if (provider == 'SCBL' || provider == 0) { dvlnet_inst = std::make_unique(); } else {