From 24f746ae6a3ec571d40b7cc5fb82a26639ef5b9b Mon Sep 17 00:00:00 2001 From: Xadhoom <> Date: Sun, 17 Mar 2019 18:54:05 +0000 Subject: [PATCH] Fix strict aliasing violations in netcode --- SourceX/dvlnet/abstract_net.cpp | 6 +- SourceX/dvlnet/abstract_net.h | 61 ++-- SourceX/dvlnet/base.cpp | 10 +- SourceX/dvlnet/base.h | 116 ++++---- SourceX/dvlnet/frame_queue.cpp | 10 +- SourceX/dvlnet/frame_queue.h | 48 +-- SourceX/dvlnet/loopback.cpp | 6 +- SourceX/dvlnet/loopback.h | 56 ++-- SourceX/dvlnet/packet.cpp | 6 +- SourceX/dvlnet/packet.h | 505 ++++++++++++++++---------------- SourceX/dvlnet/tcp_client.cpp | 6 +- SourceX/dvlnet/tcp_client.h | 50 ++-- SourceX/dvlnet/tcp_server.cpp | 6 +- SourceX/dvlnet/tcp_server.h | 92 +++--- SourceX/dvlnet/udp_p2p.cpp | 8 +- SourceX/dvlnet/udp_p2p.h | 70 ++--- 16 files changed, 554 insertions(+), 502 deletions(-) diff --git a/SourceX/dvlnet/abstract_net.cpp b/SourceX/dvlnet/abstract_net.cpp index 6789b6b9b..afb601b20 100644 --- a/SourceX/dvlnet/abstract_net.cpp +++ b/SourceX/dvlnet/abstract_net.cpp @@ -5,7 +5,8 @@ #include "dvlnet/udp_p2p.h" #include "dvlnet/loopback.h" -namespace dvl { namespace net { +namespace dvl { +namespace net { abstract_net::~abstract_net() { @@ -24,4 +25,5 @@ std::unique_ptr abstract_net::make_net(provider_t provider) } } -}} +} // namespace net +} // namespace dvl diff --git a/SourceX/dvlnet/abstract_net.h b/SourceX/dvlnet/abstract_net.h index 36c7f96dc..a69725a5b 100644 --- a/SourceX/dvlnet/abstract_net.h +++ b/SourceX/dvlnet/abstract_net.h @@ -7,36 +7,39 @@ #include "devilution.h" -namespace dvl { namespace net { - typedef std::vector buffer_t; - typedef void(*snet_event_func)(struct _SNETEVENT*); - typedef unsigned long provider_t; - class dvlnet_exception : public std::exception {}; +namespace dvl { +namespace net { - class abstract_net { - public: - virtual int create(std::string addrstr, std::string passwd) = 0; - virtual int join(std::string addrstr, std::string passwd) = 0; - virtual bool SNetReceiveMessage(int* sender, char** data, - int* size) = 0; - virtual bool SNetSendMessage(int dest, void* data, - unsigned int size) = 0; - virtual bool SNetReceiveTurns(char** data, unsigned int* size, - DWORD* status) = 0; - virtual bool SNetSendTurn(char* data, unsigned int size) = 0; - virtual int SNetGetProviderCaps(struct _SNETCAPS* caps) = 0; - virtual void* SNetRegisterEventHandler(event_type evtype, - snet_event_func func) = 0; - virtual void* SNetUnregisterEventHandler(event_type evtype, - snet_event_func func) = 0; - virtual bool SNetLeaveGame(int type) = 0; - virtual bool SNetDropPlayer(int playerid, DWORD flags) = 0; - virtual bool SNetGetOwnerTurnsWaiting(DWORD *turns) = 0; - virtual bool SNetGetTurnsInTransit(int *turns) = 0; - virtual void setup_gameinfo(buffer_t info) = 0; - virtual ~abstract_net(); +typedef std::vector buffer_t; +typedef void(*snet_event_func)(struct _SNETEVENT*); +typedef unsigned long provider_t; +class dvlnet_exception : public std::exception {}; - static std::unique_ptr make_net(provider_t provider); +class abstract_net { +public: + virtual int create(std::string addrstr, std::string passwd) = 0; + virtual int join(std::string addrstr, std::string passwd) = 0; + virtual bool SNetReceiveMessage(int* sender, char** data, + int* size) = 0; + virtual bool SNetSendMessage(int dest, void* data, + unsigned int size) = 0; + virtual bool SNetReceiveTurns(char** data, unsigned int* size, + DWORD* status) = 0; + virtual bool SNetSendTurn(char* data, unsigned int size) = 0; + virtual int SNetGetProviderCaps(struct _SNETCAPS* caps) = 0; + virtual void* SNetRegisterEventHandler(event_type evtype, + snet_event_func func) = 0; + virtual void* SNetUnregisterEventHandler(event_type evtype, + snet_event_func func) = 0; + virtual bool SNetLeaveGame(int type) = 0; + virtual bool SNetDropPlayer(int playerid, DWORD flags) = 0; + virtual bool SNetGetOwnerTurnsWaiting(DWORD *turns) = 0; + virtual bool SNetGetTurnsInTransit(int *turns) = 0; + virtual void setup_gameinfo(buffer_t info) = 0; + virtual ~abstract_net(); + + static std::unique_ptr make_net(provider_t provider); }; -}} +} // namespace net +} // namespace dvl diff --git a/SourceX/dvlnet/base.cpp b/SourceX/dvlnet/base.cpp index f78507588..0860a4628 100644 --- a/SourceX/dvlnet/base.cpp +++ b/SourceX/dvlnet/base.cpp @@ -1,8 +1,10 @@ #include "dvlnet/base.h" #include +#include -namespace dvl { namespace net { +namespace dvl { +namespace net { void base::setup_gameinfo(buffer_t info) { @@ -170,7 +172,8 @@ bool base::SNetSendTurn(char* data, unsigned int size) { if (size != sizeof(turn_t)) ABORT(); - turn_t turn = *reinterpret_cast(data); + turn_t turn; + std::memcpy(&turn, data, sizeof(turn)); auto pkt = pktfty->make_packet(plr_self, PLR_BROADCAST, turn); send(*pkt); turn_queue[plr_self].push_back(pkt->turn()); @@ -253,4 +256,5 @@ bool base::SNetGetTurnsInTransit(int *turns) return true; } -}} +} // namespace net +} // namespace dvl diff --git a/SourceX/dvlnet/base.h b/SourceX/dvlnet/base.h index bc4decb34..56be6c3db 100644 --- a/SourceX/dvlnet/base.h +++ b/SourceX/dvlnet/base.h @@ -18,60 +18,64 @@ #define LEAVE_ENDING 0x40000004 #define LEAVE_DROP 0x40000006 -namespace dvl { namespace net { - class base : public abstract_net { - public: - virtual int create(std::string addrstr, std::string passwd) = 0; - virtual int join(std::string addrstr, std::string passwd) = 0; - - virtual bool SNetReceiveMessage(int* sender, char** data, int* size); - virtual bool SNetSendMessage(int dest, void* data, unsigned int size); - virtual bool SNetReceiveTurns(char** data, unsigned int* size, - DWORD* status); - virtual bool SNetSendTurn(char* data, unsigned int size); - virtual int SNetGetProviderCaps(struct _SNETCAPS* caps); - virtual void* SNetRegisterEventHandler(event_type evtype, - snet_event_func func); - virtual void* SNetUnregisterEventHandler(event_type evtype, - snet_event_func func); - virtual bool SNetLeaveGame(int type); - virtual bool SNetDropPlayer(int playerid, DWORD flags); - virtual bool SNetGetOwnerTurnsWaiting(DWORD *turns); - virtual bool SNetGetTurnsInTransit(int *turns); - - virtual void poll() = 0; - virtual void send(packet& pkt) = 0; - - void setup_gameinfo(buffer_t info); - protected: - std::map registered_handlers; - buffer_t game_init_info; - - struct message_t { - int sender; // change int to something else in devilution code later - buffer_t payload; - message_t() : sender(-1), payload({}) {} - message_t(int s, buffer_t p) : sender(s), payload(p) {} - }; - - message_t message_last; - std::deque message_queue; - std::array turn_last = {}; - std::array, MAX_PLRS> turn_queue; - std::array connected_table = {}; - - plr_t plr_self = PLR_BROADCAST; - cookie_t cookie_self = 0; - - std::unique_ptr pktfty; - - void setup_password(std::string pw); - void handle_accept(packet& pkt); - void recv_local(packet& pkt); - void run_event_handler(_SNETEVENT& ev); - - private: - plr_t get_owner(); - void clear_msg(plr_t plr); +namespace dvl { +namespace net { + +class base : public abstract_net { +public: + virtual int create(std::string addrstr, std::string passwd) = 0; + virtual int join(std::string addrstr, std::string passwd) = 0; + + virtual bool SNetReceiveMessage(int* sender, char** data, int* size); + virtual bool SNetSendMessage(int dest, void* data, unsigned int size); + virtual bool SNetReceiveTurns(char** data, unsigned int* size, + DWORD* status); + virtual bool SNetSendTurn(char* data, unsigned int size); + virtual int SNetGetProviderCaps(struct _SNETCAPS* caps); + virtual void* SNetRegisterEventHandler(event_type evtype, + snet_event_func func); + virtual void* SNetUnregisterEventHandler(event_type evtype, + snet_event_func func); + virtual bool SNetLeaveGame(int type); + virtual bool SNetDropPlayer(int playerid, DWORD flags); + virtual bool SNetGetOwnerTurnsWaiting(DWORD *turns); + virtual bool SNetGetTurnsInTransit(int *turns); + + virtual void poll() = 0; + virtual void send(packet& pkt) = 0; + + void setup_gameinfo(buffer_t info); +protected: + std::map registered_handlers; + buffer_t game_init_info; + + struct message_t { + int sender; // change int to something else in devilution code later + buffer_t payload; + message_t() : sender(-1), payload({}) {} + message_t(int s, buffer_t p) : sender(s), payload(p) {} }; -}} + + message_t message_last; + std::deque message_queue; + std::array turn_last = {}; + std::array, MAX_PLRS> turn_queue; + std::array connected_table = {}; + + plr_t plr_self = PLR_BROADCAST; + cookie_t cookie_self = 0; + + std::unique_ptr pktfty; + + void setup_password(std::string pw); + void handle_accept(packet& pkt); + void recv_local(packet& pkt); + void run_event_handler(_SNETEVENT& ev); + +private: + plr_t get_owner(); + void clear_msg(plr_t plr); +}; + +} // namespace net +} // namespace dvl diff --git a/SourceX/dvlnet/frame_queue.cpp b/SourceX/dvlnet/frame_queue.cpp index 5374995cf..a6211e4fd 100644 --- a/SourceX/dvlnet/frame_queue.cpp +++ b/SourceX/dvlnet/frame_queue.cpp @@ -1,8 +1,11 @@ #include "dvlnet/frame_queue.h" +#include + #include "dvlnet/packet.h" -namespace dvl { namespace net { +namespace dvl { +namespace net { size_t frame_queue::size() { @@ -45,7 +48,7 @@ bool frame_queue::packet_ready() if(size() < sizeof(framesize_t)) return false; auto szbuf = read(sizeof(framesize_t)); - nextsize = *(reinterpret_cast(&szbuf[0])); + std::memcpy(&nextsize, &szbuf[0], sizeof(nextsize)); if(!nextsize) throw frame_queue_exception(); } @@ -75,4 +78,5 @@ buffer_t frame_queue::make_frame(buffer_t packetbuf) return std::move(ret); } -}} +} // namespace net +} // namespace dvl diff --git a/SourceX/dvlnet/frame_queue.h b/SourceX/dvlnet/frame_queue.h index 9aed69b94..666f2d21a 100644 --- a/SourceX/dvlnet/frame_queue.h +++ b/SourceX/dvlnet/frame_queue.h @@ -4,25 +4,29 @@ #include "dvlnet/abstract_net.h" -namespace dvl { namespace net { - class frame_queue_exception : public dvlnet_exception {}; - - class frame_queue { - public: - typedef uint32_t framesize_t; - constexpr static framesize_t max_frame_size = 0xFFFF; - private: - size_t current_size = 0; - std::deque buffer_deque; - size_t nextsize = 0; - - size_t size(); - buffer_t read(size_t s); - public: - bool packet_ready(); - buffer_t read_packet(); - void write(buffer_t buf); - - static buffer_t make_frame(buffer_t packetbuf); - }; -}} +namespace dvl { +namespace net { + +class frame_queue_exception : public dvlnet_exception {}; + +class frame_queue { +public: + typedef uint32_t framesize_t; + constexpr static framesize_t max_frame_size = 0xFFFF; +private: + size_t current_size = 0; + std::deque buffer_deque; + size_t nextsize = 0; + + size_t size(); + buffer_t read(size_t s); +public: + bool packet_ready(); + buffer_t read_packet(); + void write(buffer_t buf); + + static buffer_t make_frame(buffer_t packetbuf); +}; + +} // namespace net +} // namespace dvl diff --git a/SourceX/dvlnet/loopback.cpp b/SourceX/dvlnet/loopback.cpp index 332504fd0..0cf0c380a 100644 --- a/SourceX/dvlnet/loopback.cpp +++ b/SourceX/dvlnet/loopback.cpp @@ -1,7 +1,8 @@ #include "dvlnet/loopback.h" #include "stubs.h" -namespace dvl { namespace net { +namespace dvl { +namespace net { int loopback::create(std::string addrstr, std::string passwd) { @@ -104,4 +105,5 @@ bool loopback::SNetGetTurnsInTransit(int *turns) return true; } -}} +} // namespace net +} // namespace dvl diff --git a/SourceX/dvlnet/loopback.h b/SourceX/dvlnet/loopback.h index 4e45376e8..f3a5034c1 100644 --- a/SourceX/dvlnet/loopback.h +++ b/SourceX/dvlnet/loopback.h @@ -6,30 +6,34 @@ #include "devilution.h" #include "dvlnet/abstract_net.h" -namespace dvl { namespace net { - class loopback : public abstract_net { - private: - std::queue message_queue; - buffer_t message_last; - const int plr_single = 0; +namespace dvl { +namespace net { - public: - virtual int create(std::string addrstr, std::string passwd); - virtual int join(std::string addrstr, std::string passwd); - virtual bool SNetReceiveMessage(int* sender, char** data, int* size); - virtual bool SNetSendMessage(int dest, void* data, unsigned int size); - virtual bool SNetReceiveTurns(char** data, unsigned int* size, - DWORD* status); - virtual bool SNetSendTurn(char* data, unsigned int size); - virtual int SNetGetProviderCaps(struct _SNETCAPS* caps); - virtual void *SNetRegisterEventHandler(event_type evtype, - snet_event_func func); - virtual void *SNetUnregisterEventHandler(event_type evtype, - snet_event_func func); - virtual bool SNetLeaveGame(int type); - virtual bool SNetDropPlayer(int playerid, DWORD flags); - virtual bool SNetGetOwnerTurnsWaiting(DWORD *turns); - virtual bool SNetGetTurnsInTransit(int *turns); - virtual void setup_gameinfo(buffer_t info); - }; -}} +class loopback : public abstract_net { +private: + std::queue message_queue; + buffer_t message_last; + const int plr_single = 0; + +public: + virtual int create(std::string addrstr, std::string passwd); + virtual int join(std::string addrstr, std::string passwd); + virtual bool SNetReceiveMessage(int* sender, char** data, int* size); + virtual bool SNetSendMessage(int dest, void* data, unsigned int size); + virtual bool SNetReceiveTurns(char** data, unsigned int* size, + DWORD* status); + virtual bool SNetSendTurn(char* data, unsigned int size); + virtual int SNetGetProviderCaps(struct _SNETCAPS* caps); + virtual void *SNetRegisterEventHandler(event_type evtype, + snet_event_func func); + virtual void *SNetUnregisterEventHandler(event_type evtype, + snet_event_func func); + virtual bool SNetLeaveGame(int type); + virtual bool SNetDropPlayer(int playerid, DWORD flags); + virtual bool SNetGetOwnerTurnsWaiting(DWORD *turns); + virtual bool SNetGetTurnsInTransit(int *turns); + virtual void setup_gameinfo(buffer_t info); +}; + +} // namespace net +} // namespace dvl diff --git a/SourceX/dvlnet/packet.cpp b/SourceX/dvlnet/packet.cpp index e6dfd5b26..fca7f2f80 100644 --- a/SourceX/dvlnet/packet.cpp +++ b/SourceX/dvlnet/packet.cpp @@ -1,6 +1,7 @@ #include "dvlnet/packet.h" -namespace dvl { namespace net { +namespace dvl { +namespace net { static constexpr bool disable_encryption = false; @@ -174,4 +175,5 @@ packet_factory::packet_factory(std::string pw) ABORT(); } -}} +} // namespace net +} // namespace dvl diff --git a/SourceX/dvlnet/packet.h b/SourceX/dvlnet/packet.h index 40e0b462f..dcca79ca0 100644 --- a/SourceX/dvlnet/packet.h +++ b/SourceX/dvlnet/packet.h @@ -3,267 +3,272 @@ #include #include #include +#include #include #include "dvlnet/abstract_net.h" #include "stubs.h" -namespace dvl { namespace net { - enum packet_type : uint8_t { +namespace dvl { +namespace net { + +enum packet_type : uint8_t { PT_MESSAGE = 0x01, PT_TURN = 0x02, PT_JOIN_REQUEST = 0x11, PT_JOIN_ACCEPT = 0x12, PT_CONNECT = 0x13, PT_DISCONNECT = 0x14, - }; - - typedef uint8_t plr_t; - typedef uint32_t cookie_t; - typedef int turn_t; // change int to something else in devilution code later - typedef int leaveinfo_t; // also change later - typedef std::array key_t; - - static constexpr plr_t PLR_MASTER = 0xFE; - static constexpr plr_t PLR_BROADCAST = 0xFF; - - class packet_exception : public dvlnet_exception {}; - - class packet { - protected: - packet_type m_type; - plr_t m_src; - plr_t m_dest; - buffer_t m_message; - turn_t m_turn; - cookie_t m_cookie; - plr_t m_newplr; - buffer_t m_info; - leaveinfo_t m_leaveinfo; - - const key_t& key; - bool have_encrypted = false; - bool have_decrypted = false; - buffer_t encrypted_buffer; - buffer_t decrypted_buffer; - - public: - packet(const key_t& k) : key(k) {}; - - const buffer_t& data(); - - packet_type type(); - plr_t src(); - plr_t dest(); - const buffer_t& message(); - turn_t turn(); - cookie_t cookie(); - plr_t newplr(); - const buffer_t& info(); - leaveinfo_t leaveinfo(); - }; - - template class packet_proc : public packet { - public: - using packet::packet; - void process_data(); - }; - - class packet_in : public packet_proc { - public: - using packet_proc::packet_proc; - void create(buffer_t buf); - void process_element(buffer_t& x); - template void process_element(T& x); - void decrypt(); - }; - - class packet_out : public packet_proc { - public: - using packet_proc::packet_proc; - - template - void create(Args... args); - - void process_element(buffer_t& x); - template void process_element(T& x); - template static const unsigned char* begin(const T& x); - template static const unsigned char* end(const T& x); - void encrypt(); - }; - - template void packet_proc

::process_data() - { - P& self = static_cast(*this); - self.process_element(m_type); - self.process_element(m_src); - self.process_element(m_dest); - switch (m_type) { - case PT_MESSAGE: - self.process_element(m_message); - break; - case PT_TURN: - self.process_element(m_turn); - break; - case PT_JOIN_REQUEST: - self.process_element(m_cookie); - self.process_element(m_info); - break; - case PT_JOIN_ACCEPT: - self.process_element(m_cookie); - self.process_element(m_newplr); - self.process_element(m_info); - break; - case PT_CONNECT: - self.process_element(m_newplr); - break; - case PT_DISCONNECT: - self.process_element(m_newplr); - self.process_element(m_leaveinfo); - break; - } - } - - inline void packet_in::process_element(buffer_t& x) - { - x.insert(x.begin(), decrypted_buffer.begin(), decrypted_buffer.end()); - decrypted_buffer.resize(0); - } - - template void packet_in::process_element(T& x) - { - if (decrypted_buffer.size() < sizeof(T)) - throw packet_exception(); - x = *reinterpret_cast(decrypted_buffer.data()); - decrypted_buffer.erase(decrypted_buffer.begin(), - decrypted_buffer.begin() + sizeof(T)); - } - - template<> - inline void packet_out::create(plr_t s, plr_t d, buffer_t m) - { - if (have_encrypted || have_decrypted) - ABORT(); - have_decrypted = true; - m_type = PT_MESSAGE; - m_src = s; - m_dest = d; - m_message = std::move(m); - } - - template<> - inline void packet_out::create(plr_t s, plr_t d, turn_t u) - { - if (have_encrypted || have_decrypted) - ABORT(); - have_decrypted = true; - m_type = PT_TURN; - m_src = s; - m_dest = d; - m_turn = u; - } - - template<> - inline void packet_out::create(plr_t s, plr_t d, - cookie_t c, buffer_t i) - { - if (have_encrypted || have_decrypted) - ABORT(); - have_decrypted = true; - m_type = PT_JOIN_REQUEST; - m_src = s; - m_dest = d; - m_cookie = c; - m_info = i; - } - - template<> - inline void packet_out::create(plr_t s, plr_t d, cookie_t c, - plr_t n, buffer_t i) - { - if (have_encrypted || have_decrypted) - ABORT(); - have_decrypted = true; - m_type = PT_JOIN_ACCEPT; - m_src = s; - m_dest = d; - m_cookie = c; - m_newplr = n; - m_info = i; - } - - template<> - inline void packet_out::create(plr_t s, plr_t d, plr_t n) - { - if (have_encrypted || have_decrypted) - ABORT(); - have_decrypted = true; - m_type = PT_CONNECT; - m_src = s; - m_dest = d; - m_newplr = n; - } - - template<> - inline void packet_out::create(plr_t s, plr_t d, plr_t n, - leaveinfo_t l) - { - if (have_encrypted || have_decrypted) - ABORT(); - have_decrypted = true; - m_type = PT_DISCONNECT; - m_src = s; - m_dest = d; - m_newplr = n; - m_leaveinfo = l; - } - - inline void packet_out::process_element(buffer_t& x) - { - encrypted_buffer.insert(encrypted_buffer.end(), x.begin(), x.end()); - } - - template void packet_out::process_element(T& x) - { - encrypted_buffer.insert(encrypted_buffer.end(), begin(x), end(x)); - } - - template const unsigned char* packet_out::begin(const T& x) - { - return reinterpret_cast(&x); - } - - template const unsigned char* packet_out::end(const T& x) - { - return reinterpret_cast(&x) + sizeof(T); - } - - class packet_factory { - key_t key = {}; - - public: - static constexpr unsigned short max_packet_size = 0xFFFF; - - packet_factory(std::string pw = ""); - std::unique_ptr make_packet(buffer_t buf); - template - std::unique_ptr make_packet(Args... args); - }; - - inline std::unique_ptr packet_factory::make_packet(buffer_t buf) - { - auto ret = std::make_unique(key); - ret->create(std::move(buf)); - ret->decrypt(); - return ret; - } +}; + +typedef uint8_t plr_t; +typedef uint32_t cookie_t; +typedef int turn_t; // change int to something else in devilution code later +typedef int leaveinfo_t; // also change later +typedef std::array key_t; + +static constexpr plr_t PLR_MASTER = 0xFE; +static constexpr plr_t PLR_BROADCAST = 0xFF; + +class packet_exception : public dvlnet_exception {}; + +class packet { +protected: + packet_type m_type; + plr_t m_src; + plr_t m_dest; + buffer_t m_message; + turn_t m_turn; + cookie_t m_cookie; + plr_t m_newplr; + buffer_t m_info; + leaveinfo_t m_leaveinfo; + + const key_t& key; + bool have_encrypted = false; + bool have_decrypted = false; + buffer_t encrypted_buffer; + buffer_t decrypted_buffer; + +public: + packet(const key_t& k) : key(k) {}; + + const buffer_t& data(); + + packet_type type(); + plr_t src(); + plr_t dest(); + const buffer_t& message(); + turn_t turn(); + cookie_t cookie(); + plr_t newplr(); + const buffer_t& info(); + leaveinfo_t leaveinfo(); +}; + +template class packet_proc : public packet { +public: + using packet::packet; + void process_data(); +}; + +class packet_in : public packet_proc { +public: + using packet_proc::packet_proc; + void create(buffer_t buf); + void process_element(buffer_t& x); + template void process_element(T& x); + void decrypt(); +}; + +class packet_out : public packet_proc { +public: + using packet_proc::packet_proc; template - std::unique_ptr packet_factory::make_packet(Args... args) - { - auto ret = std::make_unique(key); - ret->create(args...); - ret->encrypt(); - return ret; + void create(Args... args); + + void process_element(buffer_t& x); + template void process_element(T& x); + template static const unsigned char* begin(const T& x); + template static const unsigned char* end(const T& x); + void encrypt(); +}; + +template void packet_proc

::process_data() +{ + P& self = static_cast(*this); + self.process_element(m_type); + self.process_element(m_src); + self.process_element(m_dest); + switch (m_type) { + case PT_MESSAGE: + self.process_element(m_message); + break; + case PT_TURN: + self.process_element(m_turn); + break; + case PT_JOIN_REQUEST: + self.process_element(m_cookie); + self.process_element(m_info); + break; + case PT_JOIN_ACCEPT: + self.process_element(m_cookie); + self.process_element(m_newplr); + self.process_element(m_info); + break; + case PT_CONNECT: + self.process_element(m_newplr); + break; + case PT_DISCONNECT: + self.process_element(m_newplr); + self.process_element(m_leaveinfo); + break; } -}} +} + +inline void packet_in::process_element(buffer_t& x) +{ + x.insert(x.begin(), decrypted_buffer.begin(), decrypted_buffer.end()); + decrypted_buffer.resize(0); +} + +template void packet_in::process_element(T& x) +{ + if (decrypted_buffer.size() < sizeof(T)) + throw packet_exception(); + std::memcpy(&x, decrypted_buffer.data(), sizeof(T)); + decrypted_buffer.erase(decrypted_buffer.begin(), + decrypted_buffer.begin() + sizeof(T)); +} + +template<> +inline void packet_out::create(plr_t s, plr_t d, buffer_t m) +{ + if (have_encrypted || have_decrypted) + ABORT(); + have_decrypted = true; + m_type = PT_MESSAGE; + m_src = s; + m_dest = d; + m_message = std::move(m); +} + +template<> +inline void packet_out::create(plr_t s, plr_t d, turn_t u) +{ + if (have_encrypted || have_decrypted) + ABORT(); + have_decrypted = true; + m_type = PT_TURN; + m_src = s; + m_dest = d; + m_turn = u; +} + +template<> +inline void packet_out::create(plr_t s, plr_t d, + cookie_t c, buffer_t i) +{ + if (have_encrypted || have_decrypted) + ABORT(); + have_decrypted = true; + m_type = PT_JOIN_REQUEST; + m_src = s; + m_dest = d; + m_cookie = c; + m_info = i; +} + +template<> +inline void packet_out::create(plr_t s, plr_t d, cookie_t c, + plr_t n, buffer_t i) +{ + if (have_encrypted || have_decrypted) + ABORT(); + have_decrypted = true; + m_type = PT_JOIN_ACCEPT; + m_src = s; + m_dest = d; + m_cookie = c; + m_newplr = n; + m_info = i; +} + +template<> +inline void packet_out::create(plr_t s, plr_t d, plr_t n) +{ + if (have_encrypted || have_decrypted) + ABORT(); + have_decrypted = true; + m_type = PT_CONNECT; + m_src = s; + m_dest = d; + m_newplr = n; +} + +template<> +inline void packet_out::create(plr_t s, plr_t d, plr_t n, + leaveinfo_t l) +{ + if (have_encrypted || have_decrypted) + ABORT(); + have_decrypted = true; + m_type = PT_DISCONNECT; + m_src = s; + m_dest = d; + m_newplr = n; + m_leaveinfo = l; +} + +inline void packet_out::process_element(buffer_t& x) +{ + encrypted_buffer.insert(encrypted_buffer.end(), x.begin(), x.end()); +} + +template void packet_out::process_element(T& x) +{ + encrypted_buffer.insert(encrypted_buffer.end(), begin(x), end(x)); +} + +template const unsigned char* packet_out::begin(const T& x) +{ + return reinterpret_cast(&x); +} + +template const unsigned char* packet_out::end(const T& x) +{ + return reinterpret_cast(&x) + sizeof(T); +} + +class packet_factory { + key_t key = {}; + +public: + static constexpr unsigned short max_packet_size = 0xFFFF; + + packet_factory(std::string pw = ""); + std::unique_ptr make_packet(buffer_t buf); + template + std::unique_ptr make_packet(Args... args); +}; + +inline std::unique_ptr packet_factory::make_packet(buffer_t buf) +{ + auto ret = std::make_unique(key); + ret->create(std::move(buf)); + ret->decrypt(); + return ret; +} + +template +std::unique_ptr packet_factory::make_packet(Args... args) +{ + auto ret = std::make_unique(key); + ret->create(args...); + ret->encrypt(); + return ret; +} + +} // namespace net +} // namespace dvl diff --git a/SourceX/dvlnet/tcp_client.cpp b/SourceX/dvlnet/tcp_client.cpp index ed683d12d..51e2e4ab0 100644 --- a/SourceX/dvlnet/tcp_client.cpp +++ b/SourceX/dvlnet/tcp_client.cpp @@ -7,7 +7,8 @@ #include #include -namespace dvl { namespace net { +namespace dvl { +namespace net { int tcp_client::create(std::string addrstr, std::string passwd) { @@ -102,4 +103,5 @@ void tcp_client::send(packet& pkt) std::placeholders::_1, std::placeholders::_2)); } -}} +} // namespace net +} // namespace dvl diff --git a/SourceX/dvlnet/tcp_client.h b/SourceX/dvlnet/tcp_client.h index bdb8cbf40..fd3db20e4 100644 --- a/SourceX/dvlnet/tcp_client.h +++ b/SourceX/dvlnet/tcp_client.h @@ -12,26 +12,30 @@ #include "dvlnet/base.h" #include "dvlnet/tcp_server.h" -namespace dvl { namespace net { - class tcp_client : public base { - public: - int create(std::string addrstr, std::string passwd); - int join(std::string addrstr, std::string passwd); - - constexpr static unsigned short default_port = 6112; - - virtual void poll(); - virtual void send(packet& pkt); - private: - frame_queue recv_queue; - buffer_t recv_buffer = buffer_t(frame_queue::max_frame_size); - - asio::io_context ioc; - asio::ip::tcp::socket sock = asio::ip::tcp::socket(ioc); - std::unique_ptr local_server; // must be declared *after* ioc - - void handle_recv(const asio::error_code& error, size_t bytes_read); - void start_recv(); - void handle_send(const asio::error_code& error, size_t bytes_sent); - }; -}} +namespace dvl { +namespace net { + +class tcp_client : public base { +public: + int create(std::string addrstr, std::string passwd); + int join(std::string addrstr, std::string passwd); + + constexpr static unsigned short default_port = 6112; + + virtual void poll(); + virtual void send(packet& pkt); +private: + frame_queue recv_queue; + buffer_t recv_buffer = buffer_t(frame_queue::max_frame_size); + + asio::io_context ioc; + asio::ip::tcp::socket sock = asio::ip::tcp::socket(ioc); + std::unique_ptr local_server; // must be declared *after* ioc + + void handle_recv(const asio::error_code& error, size_t bytes_read); + void start_recv(); + void handle_send(const asio::error_code& error, size_t bytes_sent); +}; + +} // namespace net +} // namespace dvl diff --git a/SourceX/dvlnet/tcp_server.cpp b/SourceX/dvlnet/tcp_server.cpp index 8ecbde56c..2fe1a4f2c 100644 --- a/SourceX/dvlnet/tcp_server.cpp +++ b/SourceX/dvlnet/tcp_server.cpp @@ -5,7 +5,8 @@ #include "dvlnet/base.h" -namespace dvl { namespace net { +namespace dvl { +namespace net { tcp_server::tcp_server(asio::io_context& ioc, std::string bindaddr, unsigned short port, std::string pw) : @@ -208,4 +209,5 @@ void tcp_server::drop_connection(scc con) con->socket.close(); } -}} +} // namespace net +} // namespace dvl diff --git a/SourceX/dvlnet/tcp_server.h b/SourceX/dvlnet/tcp_server.h index bd4918f6d..1c5271b13 100644 --- a/SourceX/dvlnet/tcp_server.h +++ b/SourceX/dvlnet/tcp_server.h @@ -12,53 +12,57 @@ #include "dvlnet/abstract_net.h" #include "dvlnet/frame_queue.h" -namespace dvl { namespace net { - class server_exception : public dvlnet_exception {}; +namespace dvl { +namespace net { - class tcp_server { - public: - tcp_server(asio::io_context& ioc, std::string bindaddr, - unsigned short port, std::string pw); - std::string localhost_self(); +class server_exception : public dvlnet_exception {}; - private: - static constexpr int timeout_connect = 30; - static constexpr int timeout_active = 60; +class tcp_server { +public: + tcp_server(asio::io_context& ioc, std::string bindaddr, + unsigned short port, std::string pw); + std::string localhost_self(); - struct client_connection { - frame_queue recv_queue; - buffer_t recv_buffer = buffer_t(frame_queue::max_frame_size); - plr_t plr = PLR_BROADCAST; - asio::ip::tcp::socket socket; - asio::steady_timer timer; - int timeout; - client_connection(asio::io_context& ioc) : - socket(ioc), timer(ioc) {} - }; +private: + static constexpr int timeout_connect = 30; + static constexpr int timeout_active = 60; - typedef std::shared_ptr scc; + struct client_connection { + frame_queue recv_queue; + buffer_t recv_buffer = buffer_t(frame_queue::max_frame_size); + plr_t plr = PLR_BROADCAST; + asio::ip::tcp::socket socket; + asio::steady_timer timer; + int timeout; + client_connection(asio::io_context& ioc) : + socket(ioc), timer(ioc) {} + }; - asio::io_context& ioc; - packet_factory pktfty; - std::unique_ptr acceptor; - std::array connections; - buffer_t game_init_info; + typedef std::shared_ptr scc; - scc make_connection(); - plr_t next_free(); - bool empty(); - void start_accept(); - void handle_accept(scc con, const asio::error_code& ec); - void start_recv(scc con); - void handle_recv(scc con, const asio::error_code& ec, size_t bytes_read); - void handle_recv_newplr(scc con, packet& pkt); - void handle_recv_packet(packet& pkt); - void send_connect(scc con); - void send_packet(packet& pkt); - void start_send(scc con, packet& pkt); - void handle_send(scc con, const asio::error_code& ec, size_t bytes_sent); - void start_timeout(scc con); - void handle_timeout(scc con, const asio::error_code& ec); - void drop_connection(scc con); - }; -}} + asio::io_context& ioc; + packet_factory pktfty; + std::unique_ptr acceptor; + std::array connections; + buffer_t game_init_info; + + scc make_connection(); + plr_t next_free(); + bool empty(); + void start_accept(); + void handle_accept(scc con, const asio::error_code& ec); + void start_recv(scc con); + void handle_recv(scc con, const asio::error_code& ec, size_t bytes_read); + void handle_recv_newplr(scc con, packet& pkt); + void handle_recv_packet(packet& pkt); + void send_connect(scc con); + void send_packet(packet& pkt); + void start_send(scc con, packet& pkt); + void handle_send(scc con, const asio::error_code& ec, size_t bytes_sent); + void start_timeout(scc con); + void handle_timeout(scc con, const asio::error_code& ec); + void drop_connection(scc con); +}; + +} //namespace net +} //namespace dvl diff --git a/SourceX/dvlnet/udp_p2p.cpp b/SourceX/dvlnet/udp_p2p.cpp index 0c9c7d8e2..1a143cdeb 100644 --- a/SourceX/dvlnet/udp_p2p.cpp +++ b/SourceX/dvlnet/udp_p2p.cpp @@ -2,7 +2,8 @@ #include -namespace dvl { namespace net { +namespace dvl { +namespace net { const udp_p2p::endpoint udp_p2p::none; @@ -51,7 +52,7 @@ int udp_p2p::join(std::string addrstr, std::string passwd) sock.connect(themaster); master = themaster; { // hack: try to join for 5 seconds - randombytes_buf(reinterpret_cast(&cookie_self), + randombytes_buf(reinterpret_cast(&cookie_self), sizeof(cookie_t)); auto pkt = pktfty->make_packet(PLR_BROADCAST, PLR_MASTER, cookie_self, @@ -169,4 +170,5 @@ void udp_p2p::recv_decrypted(packet& pkt, endpoint sender) recv_local(pkt); } -}} +} // namespace net +} // namespace dvl diff --git a/SourceX/dvlnet/udp_p2p.h b/SourceX/dvlnet/udp_p2p.h index 585ff485a..9f3a976b3 100644 --- a/SourceX/dvlnet/udp_p2p.h +++ b/SourceX/dvlnet/udp_p2p.h @@ -10,36 +10,40 @@ #include "dvlnet/packet.h" #include "dvlnet/base.h" -namespace dvl { namespace net { - class udp_p2p : public base { - public: - virtual int create(std::string addrstr, std::string passwd); - virtual int join(std::string addrstr, std::string passwd); - virtual void poll(); - virtual void send(packet& pkt); - - private: - typedef asio::ip::udp::endpoint endpoint; - static const endpoint none; - - unsigned short udpport_self = 0; - - static constexpr unsigned short default_port = 6112; - static constexpr unsigned short try_ports = 512; - static constexpr int ACTIVE = 60; - - asio::io_context io_context; - endpoint master; - - std::set connection_requests_pending; - std::array nexthop_table; - - asio::ip::udp::socket sock = asio::ip::udp::socket(io_context); - - void recv(); - void handle_join_request(packet& pkt, endpoint sender); - void send_internal(packet& pkt, endpoint sender = none); - std::set dests_for_addr(plr_t dest, endpoint sender); - void recv_decrypted(packet& pkt, endpoint sender); - }; -}} +namespace dvl { +namespace net { + +class udp_p2p : public base { +public: + virtual int create(std::string addrstr, std::string passwd); + virtual int join(std::string addrstr, std::string passwd); + virtual void poll(); + virtual void send(packet& pkt); + +private: + typedef asio::ip::udp::endpoint endpoint; + static const endpoint none; + + unsigned short udpport_self = 0; + + static constexpr unsigned short default_port = 6112; + static constexpr unsigned short try_ports = 512; + static constexpr int ACTIVE = 60; + + asio::io_context io_context; + endpoint master; + + std::set connection_requests_pending; + std::array nexthop_table; + + asio::ip::udp::socket sock = asio::ip::udp::socket(io_context); + + void recv(); + void handle_join_request(packet& pkt, endpoint sender); + void send_internal(packet& pkt, endpoint sender = none); + std::set dests_for_addr(plr_t dest, endpoint sender); + void recv_decrypted(packet& pkt, endpoint sender); +}; + +} // namespace net +} // namespace dvl