From 0010c9307ff3f3cd5d55d4f8fa82c0bc5bc0e6f6 Mon Sep 17 00:00:00 2001 From: Xadhoom Date: Thu, 7 Nov 2019 22:23:25 +0000 Subject: [PATCH] Fix network: creating and joining --- CMakeLists.txt | 1 + SourceX/dvlnet/abstract_net.cpp | 9 +- SourceX/dvlnet/abstract_net.h | 22 ++--- SourceX/dvlnet/base.cpp | 23 +++-- SourceX/dvlnet/base.h | 16 ++-- SourceX/dvlnet/cdwrap.cpp | 1 + SourceX/dvlnet/cdwrap.h | 155 ++++++++++++++++++++++++++++++++ SourceX/dvlnet/tcp_client.cpp | 24 +++-- SourceX/dvlnet/tcp_client.h | 5 +- SourceX/dvlnet/tcp_server.cpp | 44 ++++----- SourceX/dvlnet/tcp_server.h | 7 +- 11 files changed, 238 insertions(+), 69 deletions(-) create mode 100644 SourceX/dvlnet/cdwrap.cpp create mode 100644 SourceX/dvlnet/cdwrap.h diff --git a/CMakeLists.txt b/CMakeLists.txt index 4e7300c6b..86697a723 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -232,6 +232,7 @@ set(devilutionx_SRCS SourceX/dvlnet/packet.cpp SourceX/dvlnet/base.cpp SourceX/dvlnet/frame_queue.cpp + SourceX/dvlnet/cdwrap.cpp SourceX/DiabloUI/art_draw.cpp SourceX/DiabloUI/errorart.cpp SourceX/DiabloUI/art.cpp diff --git a/SourceX/dvlnet/abstract_net.cpp b/SourceX/dvlnet/abstract_net.cpp index cfab74ecb..f51ec2928 100644 --- a/SourceX/dvlnet/abstract_net.cpp +++ b/SourceX/dvlnet/abstract_net.cpp @@ -2,6 +2,7 @@ #include "stubs.h" #ifndef NONET +#include "dvlnet/cdwrap.h" #include "dvlnet/tcp_client.h" #include "dvlnet/udp_p2p.h" #endif @@ -10,10 +11,6 @@ namespace dvl { namespace net { -abstract_net::~abstract_net() -{ -} - std::unique_ptr abstract_net::make_net(provider_t provider) { #ifdef NONET @@ -21,10 +18,10 @@ std::unique_ptr abstract_net::make_net(provider_t provider) #else switch (provider) { case SELCONN_TCP: - return std::unique_ptr(new tcp_client); + return std::unique_ptr(new cdwrap); #ifdef BUGGY case SELCONN_UDP: - return std::unique_ptr(new udp_p2p); + return std::unique_ptr(new cdwrap); #endif case SELCONN_LOOPBACK: return std::unique_ptr(new loopback); diff --git a/SourceX/dvlnet/abstract_net.h b/SourceX/dvlnet/abstract_net.h index d8ed34477..73430c37e 100644 --- a/SourceX/dvlnet/abstract_net.h +++ b/SourceX/dvlnet/abstract_net.h @@ -20,28 +20,28 @@ public: virtual int create(std::string addrstr, std::string passwd) = 0; virtual int join(std::string addrstr, std::string passwd) = 0; virtual bool SNetReceiveMessage(int *sender, char **data, - int *size) - = 0; + int *size) + = 0; virtual bool SNetSendMessage(int dest, void *data, - unsigned int size) - = 0; + unsigned int size) + = 0; virtual bool SNetReceiveTurns(char **data, unsigned int *size, - DWORD *status) - = 0; + DWORD *status) + = 0; virtual bool SNetSendTurn(char *data, unsigned int size) = 0; virtual int SNetGetProviderCaps(struct _SNETCAPS *caps) = 0; virtual bool SNetRegisterEventHandler(event_type evtype, - SEVTHANDLER func) - = 0; + SEVTHANDLER func) + = 0; virtual bool SNetUnregisterEventHandler(event_type evtype, - SEVTHANDLER func) - = 0; + SEVTHANDLER func) + = 0; virtual bool SNetLeaveGame(int type) = 0; virtual bool SNetDropPlayer(int playerid, DWORD flags) = 0; virtual bool SNetGetOwnerTurnsWaiting(DWORD *turns) = 0; virtual bool SNetGetTurnsInTransit(int *turns) = 0; virtual void setup_gameinfo(buffer_t info) = 0; - virtual ~abstract_net(); + virtual ~abstract_net() = default; static std::unique_ptr make_net(provider_t provider); }; diff --git a/SourceX/dvlnet/base.cpp b/SourceX/dvlnet/base.cpp index 096a2b471..5da03319b 100644 --- a/SourceX/dvlnet/base.cpp +++ b/SourceX/dvlnet/base.cpp @@ -9,7 +9,6 @@ namespace net { void base::setup_gameinfo(buffer_t info) { game_init_info = std::move(info); - pktfty.reset(new packet_factory()); } void base::setup_password(std::string pw) @@ -48,11 +47,11 @@ void base::handle_accept(packet &pkt) void base::clear_msg(plr_t plr) { message_queue.erase(std::remove_if(message_queue.begin(), - message_queue.end(), - [&](message_t &msg) { - return msg.sender == plr; - }), - message_queue.end()); + message_queue.end(), + [&](message_t &msg) { + return msg.sender == plr; + }), + message_queue.end()); } void base::recv_local(packet &pkt) @@ -113,7 +112,7 @@ bool base::SNetReceiveMessage(int *sender, char **data, int *size) bool base::SNetSendMessage(int playerID, void *data, unsigned int size) { if (playerID != SNPLAYER_ALL && playerID != SNPLAYER_OTHERS - && (playerID < 0 || playerID >= MAX_PLRS)) + && (playerID < 0 || playerID >= MAX_PLRS)) abort(); auto raw_message = reinterpret_cast(data); buffer_t message(raw_message, raw_message + size); @@ -190,7 +189,7 @@ int base::SNetGetProviderCaps(struct _SNETCAPS *caps) caps->latencyms = 0; // unused caps->defaultturnssec = 10; // ? caps->defaultturnsintransit = 1; // maximum acceptable number - // of turns in queue? + // of turns in queue? return 1; } @@ -217,7 +216,7 @@ bool base::SNetRegisterEventHandler(event_type evtype, SEVTHANDLER func) bool base::SNetLeaveGame(int type) { auto pkt = pktfty->make_packet(plr_self, PLR_BROADCAST, - plr_self, type); + plr_self, type); send(*pkt); return true; } @@ -225,9 +224,9 @@ bool base::SNetLeaveGame(int type) bool base::SNetDropPlayer(int playerid, DWORD flags) { auto pkt = pktfty->make_packet(plr_self, - PLR_BROADCAST, - (plr_t)playerid, - (leaveinfo_t)flags); + PLR_BROADCAST, + (plr_t)playerid, + (leaveinfo_t)flags); send(*pkt); recv_local(*pkt); return true; diff --git a/SourceX/dvlnet/base.h b/SourceX/dvlnet/base.h index df61a686f..3157480b1 100644 --- a/SourceX/dvlnet/base.h +++ b/SourceX/dvlnet/base.h @@ -29,13 +29,13 @@ public: virtual bool SNetReceiveMessage(int *sender, char **data, int *size); virtual bool SNetSendMessage(int dest, void *data, unsigned int size); virtual bool SNetReceiveTurns(char **data, unsigned int *size, - DWORD *status); + DWORD *status); virtual bool SNetSendTurn(char *data, unsigned int size); virtual int SNetGetProviderCaps(struct _SNETCAPS *caps); virtual bool SNetRegisterEventHandler(event_type evtype, - SEVTHANDLER func); + SEVTHANDLER func); virtual bool SNetUnregisterEventHandler(event_type evtype, - SEVTHANDLER func); + SEVTHANDLER func); virtual bool SNetLeaveGame(int type); virtual bool SNetDropPlayer(int playerid, DWORD flags); virtual bool SNetGetOwnerTurnsWaiting(DWORD *turns); @@ -46,6 +46,8 @@ public: void setup_gameinfo(buffer_t info); + virtual ~base() = default; + protected: std::map registered_handlers; buffer_t game_init_info; @@ -54,13 +56,13 @@ protected: int sender; // change int to something else in devilution code later buffer_t payload; message_t() - : sender(-1) - , payload({}) + : sender(-1) + , payload({}) { } message_t(int s, buffer_t p) - : sender(s) - , payload(p) + : sender(s) + , payload(p) { } }; diff --git a/SourceX/dvlnet/cdwrap.cpp b/SourceX/dvlnet/cdwrap.cpp new file mode 100644 index 000000000..dcc6d04b3 --- /dev/null +++ b/SourceX/dvlnet/cdwrap.cpp @@ -0,0 +1 @@ +#include "dvlnet/cdwrap.h" diff --git a/SourceX/dvlnet/cdwrap.h b/SourceX/dvlnet/cdwrap.h new file mode 100644 index 000000000..0bee0aab9 --- /dev/null +++ b/SourceX/dvlnet/cdwrap.h @@ -0,0 +1,155 @@ +#pragma once + +#include +#include +#include +#include +#include + +#include "devilution.h" + +#include "dvlnet/abstract_net.h" + +namespace dvl { +namespace net { + +template +class cdwrap : public abstract_net { +private: + std::unique_ptr dvlnet_wrap; + std::map registered_handlers; + buffer_t game_init_info; + + void reset(); + +public: + virtual int create(std::string addrstr, std::string passwd); + virtual int join(std::string addrstr, std::string passwd); + virtual bool SNetReceiveMessage(int *sender, char **data, int *size); + virtual bool SNetSendMessage(int dest, void *data, + unsigned int size); + virtual bool SNetReceiveTurns(char **data, unsigned int *size, + DWORD *status); + virtual bool SNetSendTurn(char *data, unsigned int size); + virtual int SNetGetProviderCaps(struct _SNETCAPS *caps); + virtual bool SNetRegisterEventHandler(event_type evtype, + SEVTHANDLER func); + virtual bool SNetUnregisterEventHandler(event_type evtype, + SEVTHANDLER func); + virtual bool SNetLeaveGame(int type); + virtual bool SNetDropPlayer(int playerid, DWORD flags); + virtual bool SNetGetOwnerTurnsWaiting(DWORD *turns); + virtual bool SNetGetTurnsInTransit(int *turns); + virtual void setup_gameinfo(buffer_t info); + + virtual ~cdwrap() = default; +}; + +template +void cdwrap::reset() +{ + dvlnet_wrap.reset(new T); + dvlnet_wrap->setup_gameinfo(game_init_info); + + for (const auto &pair : registered_handlers) + dvlnet_wrap->SNetRegisterEventHandler(pair.first, pair.second); +} + +template +int cdwrap::create(std::string addrstr, std::string passwd) +{ + reset(); + return dvlnet_wrap->create(addrstr, passwd); +} + +template +int cdwrap::join(std::string addrstr, std::string passwd) +{ + reset(); + return dvlnet_wrap->join(addrstr, passwd); +} + +template +void cdwrap::setup_gameinfo(buffer_t info) +{ + game_init_info = std::move(info); + if (dvlnet_wrap) + dvlnet_wrap->setup_gameinfo(game_init_info); +} + +template +bool cdwrap::SNetReceiveMessage(int *sender, char **data, int *size) +{ + return dvlnet_wrap->SNetReceiveMessage(sender, data, size); +} + +template +bool cdwrap::SNetSendMessage(int playerID, void *data, unsigned int size) +{ + return dvlnet_wrap->SNetSendMessage(playerID, data, size); +} + +template +bool cdwrap::SNetReceiveTurns(char **data, unsigned int *size, DWORD *status) +{ + return dvlnet_wrap->SNetReceiveTurns(data, size, status); +} + +template +bool cdwrap::SNetSendTurn(char *data, unsigned int size) +{ + return dvlnet_wrap->SNetSendTurn(data, size); +} + +template +int cdwrap::SNetGetProviderCaps(struct _SNETCAPS *caps) +{ + return dvlnet_wrap->SNetGetProviderCaps(caps); +} + +template +bool cdwrap::SNetUnregisterEventHandler(event_type evtype, SEVTHANDLER func) +{ + registered_handlers.erase(evtype); + if (dvlnet_wrap) + return dvlnet_wrap->SNetUnregisterEventHandler(evtype, func); + else + return true; +} + +template +bool cdwrap::SNetRegisterEventHandler(event_type evtype, SEVTHANDLER func) +{ + registered_handlers[evtype] = func; + if (dvlnet_wrap) + return dvlnet_wrap->SNetRegisterEventHandler(evtype, func); + else + return true; +} + +template +bool cdwrap::SNetLeaveGame(int type) +{ + return dvlnet_wrap->SNetLeaveGame(type); +} + +template +bool cdwrap::SNetDropPlayer(int playerid, DWORD flags) +{ + return dvlnet_wrap->SNetDropPlayer(playerid, flags); +} + +template +bool cdwrap::SNetGetOwnerTurnsWaiting(DWORD *turns) +{ + return dvlnet_wrap->SNetGetOwnerTurnsWaiting(turns); +} + +template +bool cdwrap::SNetGetTurnsInTransit(int *turns) +{ + return dvlnet_wrap->SNetGetTurnsInTransit(turns); +} + +} // namespace net +} // namespace dvl diff --git a/SourceX/dvlnet/tcp_client.cpp b/SourceX/dvlnet/tcp_client.cpp index d3db57f3f..6edffee38 100644 --- a/SourceX/dvlnet/tcp_client.cpp +++ b/SourceX/dvlnet/tcp_client.cpp @@ -40,10 +40,10 @@ int tcp_client::join(std::string addrstr, std::string passwd) start_recv(); { randombytes_buf(reinterpret_cast(&cookie_self), - sizeof(cookie_t)); + sizeof(cookie_t)); auto pkt = pktfty->make_packet(PLR_BROADCAST, - PLR_MASTER, cookie_self, - game_init_info); + PLR_MASTER, cookie_self, + game_init_info); send(*pkt); for (auto i = 0; i < no_sleep; ++i) { try { @@ -88,8 +88,8 @@ void tcp_client::handle_recv(const asio::error_code &error, size_t bytes_read) void tcp_client::start_recv() { sock.async_receive(asio::buffer(recv_buffer), - std::bind(&tcp_client::handle_recv, this, - std::placeholders::_1, std::placeholders::_2)); + std::bind(&tcp_client::handle_recv, this, + std::placeholders::_1, std::placeholders::_2)); } void tcp_client::handle_send(const asio::error_code &error, size_t bytes_sent) @@ -107,10 +107,16 @@ void tcp_client::send(packet &pkt) }); } -bool tcp_client::SNetLeaveGame(int type){ - if(sock.is_open()) - sock.close(); - return true; +bool tcp_client::SNetLeaveGame(int type) +{ + auto ret = base::SNetLeaveGame(type); + poll(); + local_server.reset(); + return ret; +} + +tcp_client::~tcp_client() +{ } } // namespace net diff --git a/SourceX/dvlnet/tcp_client.h b/SourceX/dvlnet/tcp_client.h index 763457757..30329347f 100644 --- a/SourceX/dvlnet/tcp_client.h +++ b/SourceX/dvlnet/tcp_client.h @@ -19,13 +19,16 @@ class tcp_client : public base { public: int create(std::string addrstr, std::string passwd); int join(std::string addrstr, std::string passwd); - virtual bool SNetLeaveGame(int type); constexpr static unsigned short default_port = 6112; virtual void poll(); virtual void send(packet &pkt); + virtual bool SNetLeaveGame(int type); + + virtual ~tcp_client(); + private: frame_queue recv_queue; buffer_t recv_buffer = buffer_t(frame_queue::max_frame_size); diff --git a/SourceX/dvlnet/tcp_server.cpp b/SourceX/dvlnet/tcp_server.cpp index c04dab56c..630e8e71f 100644 --- a/SourceX/dvlnet/tcp_server.cpp +++ b/SourceX/dvlnet/tcp_server.cpp @@ -9,9 +9,9 @@ namespace dvl { namespace net { tcp_server::tcp_server(asio::io_context &ioc, std::string bindaddr, - unsigned short port, std::string pw) - : ioc(ioc) - , pktfty(pw) + unsigned short port, std::string pw) + : ioc(ioc) + , pktfty(pw) { auto addr = asio::ip::address::from_string(bindaddr); auto ep = asio::ip::tcp::endpoint(addr, port); @@ -57,13 +57,13 @@ bool tcp_server::empty() void tcp_server::start_recv(scc con) { con->socket.async_receive(asio::buffer(con->recv_buffer), - std::bind(&tcp_server::handle_recv, this, con, - std::placeholders::_1, - std::placeholders::_2)); + std::bind(&tcp_server::handle_recv, this, con, + std::placeholders::_1, + std::placeholders::_2)); } void tcp_server::handle_recv(scc con, const asio::error_code &ec, - size_t bytes_read) + size_t bytes_read) { if (ec || bytes_read == 0) { drop_connection(con); @@ -92,7 +92,7 @@ void tcp_server::handle_recv(scc con, const asio::error_code &ec, void tcp_server::send_connect(scc con) { auto pkt = pktfty.make_packet(PLR_MASTER, PLR_BROADCAST, - con->plr); + con->plr); send_packet(*pkt); } @@ -104,8 +104,8 @@ void tcp_server::handle_recv_newplr(scc con, packet &pkt) if (empty()) game_init_info = pkt.info(); auto reply = pktfty.make_packet(PLR_MASTER, PLR_BROADCAST, - pkt.cookie(), newplr, - game_init_info); + pkt.cookie(), newplr, + game_init_info); start_send(con, *reply); con->plr = newplr; connections[newplr] = con; @@ -137,14 +137,14 @@ void tcp_server::start_send(scc con, packet &pkt) const auto *frame = new buffer_t(frame_queue::make_frame(pkt.data())); auto buf = asio::buffer(*frame); asio::async_write(con->socket, buf, - [this, con, frame](const asio::error_code &ec, size_t bytes_sent) { - handle_send(con, ec, bytes_sent); - delete frame; - }); + [this, con, frame](const asio::error_code &ec, size_t bytes_sent) { + handle_send(con, ec, bytes_sent); + delete frame; + }); } void tcp_server::handle_send(scc con, const asio::error_code &ec, - size_t bytes_sent) + size_t bytes_sent) { // empty for now } @@ -153,9 +153,9 @@ void tcp_server::start_accept() { auto nextcon = make_connection(); acceptor->async_accept(nextcon->socket, - std::bind(&tcp_server::handle_accept, - this, nextcon, - std::placeholders::_1)); + std::bind(&tcp_server::handle_accept, + this, nextcon, + std::placeholders::_1)); } void tcp_server::handle_accept(scc con, const asio::error_code &ec) @@ -176,7 +176,7 @@ void tcp_server::start_timeout(scc con) { con->timer.expires_after(std::chrono::seconds(1)); con->timer.async_wait(std::bind(&tcp_server::handle_timeout, this, con, - std::placeholders::_1)); + std::placeholders::_1)); } void tcp_server::handle_timeout(scc con, const asio::error_code &ec) @@ -200,7 +200,7 @@ void tcp_server::drop_connection(scc con) { if (con->plr != PLR_BROADCAST) { auto pkt = pktfty.make_packet(PLR_MASTER, PLR_BROADCAST, - con->plr, LEAVE_DROP); + con->plr, LEAVE_DROP); connections[con->plr] = nullptr; send_packet(*pkt); // TODO: investigate if it is really ok for the server to @@ -210,5 +210,9 @@ void tcp_server::drop_connection(scc con) con->socket.close(); } +tcp_server::~tcp_server() +{ +} + } // namespace net } // namespace dvl diff --git a/SourceX/dvlnet/tcp_server.h b/SourceX/dvlnet/tcp_server.h index 8fec25356..c6bddfde7 100644 --- a/SourceX/dvlnet/tcp_server.h +++ b/SourceX/dvlnet/tcp_server.h @@ -21,8 +21,9 @@ class server_exception : public dvlnet_exception { class tcp_server { public: tcp_server(asio::io_context &ioc, std::string bindaddr, - unsigned short port, std::string pw); + unsigned short port, std::string pw); std::string localhost_self(); + virtual ~tcp_server(); private: static constexpr int timeout_connect = 30; @@ -36,8 +37,8 @@ private: asio::steady_timer timer; int timeout; client_connection(asio::io_context &ioc) - : socket(ioc) - , timer(ioc) + : socket(ioc) + , timer(ioc) { } };