From 9d34cb2795ca27e8f49dab5351c7964b38e6a3b8 Mon Sep 17 00:00:00 2001 From: staphen Date: Sun, 24 Sep 2023 14:41:20 -0400 Subject: [PATCH] Remove frame queue exceptions from the network layer --- 3rdParty/libzt/CMakeLists.txt | 2 +- Source/dvlnet/base.cpp | 37 +++++++--- Source/dvlnet/base.h | 2 +- Source/dvlnet/base_protocol.h | 87 +++++++++++++----------- Source/dvlnet/frame_queue.cpp | 37 +++++----- Source/dvlnet/frame_queue.h | 21 +++--- Source/dvlnet/protocol_zt.cpp | 27 ++++++-- Source/dvlnet/protocol_zt.h | 2 +- Source/dvlnet/tcp_client.cpp | 39 +++++++---- Source/dvlnet/tcp_client.h | 2 +- Source/dvlnet/tcp_server.cpp | 124 ++++++++++++++++++---------------- Source/dvlnet/tcp_server.h | 2 +- 12 files changed, 218 insertions(+), 164 deletions(-) diff --git a/3rdParty/libzt/CMakeLists.txt b/3rdParty/libzt/CMakeLists.txt index 76219cefa..42ba3ff7b 100644 --- a/3rdParty/libzt/CMakeLists.txt +++ b/3rdParty/libzt/CMakeLists.txt @@ -5,7 +5,7 @@ set(BUILD_HOST_SELFTEST OFF) include(FetchContent) FetchContent_Declare(libzt GIT_REPOSITORY https://github.com/diasurgical/libzt.git - GIT_TAG def49803307920da0ab5b9e9b70b399fdc2943dc) + GIT_TAG 8c31715bc48bade2097d66ced54db07598268710) FetchContent_MakeAvailableExcludeFromAll(libzt) if(NOT ANDROID) diff --git a/Source/dvlnet/base.cpp b/Source/dvlnet/base.cpp index ec853e478..e0249eaa2 100644 --- a/Source/dvlnet/base.cpp +++ b/Source/dvlnet/base.cpp @@ -52,8 +52,7 @@ tl::expected base::SendEchoRequest(plr_t player) if (!pkt.has_value()) { return tl::make_unexpected(pkt.error()); } - send(**pkt); - return {}; + return send(**pkt); } tl::expected base::HandleAccept(packet &pkt) @@ -139,8 +138,8 @@ tl::expected base::HandleEchoRequest(packet &pkt) .and_then([&](cookie_t &&pktTime) { return pktfty->make_packet(plr_self, pkt.Source(), pktTime); }) - .transform([&](std::unique_ptr &&pkt) { - send(*pkt); + .and_then([&](std::unique_ptr &&pkt) { + return send(*pkt); }); } @@ -246,7 +245,11 @@ bool base::SNetSendMessage(int playerId, void *data, unsigned int size) LogError("make_packet: {}", pkt.error().what()); return false; } - send(**pkt); + tl::expected result = send(**pkt); + if (!result.has_value()) { + LogError("send: {}", result.error().what()); + return false; + } } return true; } @@ -362,7 +365,7 @@ tl::expected base::SendTurnIfReady(turn_t turn) if (!pkt.has_value()) { return tl::make_unexpected(pkt.error()); } - send(**pkt); + return send(**pkt); } return {}; } @@ -383,7 +386,10 @@ tl::expected base::SendFirstTurnIfReady(plr_t player) if (!pkt.has_value()) { return tl::make_unexpected(pkt.error()); } - send(**pkt); + tl::expected result = send(**pkt); + if (!result.has_value()) { + return result; + } } return {}; } @@ -453,7 +459,11 @@ bool base::SNetLeaveGame(int type) LogError("make_packet: {}", pkt.error().what()); return false; } - send(**pkt); + tl::expected result = send(**pkt); + if (!result.has_value()) { + LogError("send: {}", result.error().what()); + return false; + } plr_self = PLR_BROADCAST; return true; } @@ -470,9 +480,14 @@ bool base::SNetDropPlayer(int playerid, uint32_t flags) LogError("make_packet: {}", pkt.error().what()); return false; } - send(**pkt); - if (tl::expected result = RecvLocal(**pkt); !result.has_value()) { - LogError("SNetDropPlayer: {}", result.error().what()); + tl::expected sendResult = send(**pkt); + if (!sendResult.has_value()) { + LogError("send: {}", sendResult.error().what()); + return false; + } + tl::expected receiveResult = RecvLocal(**pkt); + if (!receiveResult.has_value()) { + LogError("SNetDropPlayer: {}", receiveResult.error().what()); return false; } return true; diff --git a/Source/dvlnet/base.h b/Source/dvlnet/base.h index 27abdb327..7963c8062 100644 --- a/Source/dvlnet/base.h +++ b/Source/dvlnet/base.h @@ -30,7 +30,7 @@ public: bool SNetGetTurnsInTransit(uint32_t *turns) override; virtual tl::expected poll() = 0; - virtual void send(packet &pkt) = 0; + virtual tl::expected send(packet &pkt) = 0; virtual void DisconnectNet(plr_t plr); void setup_gameinfo(buffer_t info); diff --git a/Source/dvlnet/base_protocol.h b/Source/dvlnet/base_protocol.h index b3c525f3b..7ea9fe298 100644 --- a/Source/dvlnet/base_protocol.h +++ b/Source/dvlnet/base_protocol.h @@ -18,7 +18,7 @@ public: int create(std::string addrstr) override; int join(std::string addrstr) override; tl::expected poll() override; - void send(packet &pkt) override; + tl::expected send(packet &pkt) override; void DisconnectNet(plr_t plr) override; bool SNetLeaveGame(int type) override; @@ -56,7 +56,7 @@ private: plr_t get_master(); tl::expected InitiateHandshake(plr_t player); - void SendTo(plr_t player, packet &pkt); + tl::expected SendTo(plr_t player, packet &pkt); void DrainSendQueue(plr_t player); void recv(); tl::expected handle_join_request(packet &pkt, endpoint_t sender); @@ -137,16 +137,19 @@ tl::expected base_protocol

::wait_join() tl::expected, PacketError> pkt = pktfty->make_packet(PLR_BROADCAST, PLR_MASTER, cookie_self, game_init_info); if (!pkt.has_value()) { - return tl::unexpected(pkt.error()); + return tl::make_unexpected(pkt.error()); + } + tl::expected result = proto.send(firstpeer, (*pkt)->Data()); + if (!result.has_value()) { + return result; } - proto.send(firstpeer, (*pkt)->Data()); for (auto i = 0; i < 500; ++i) { recv(); if (plr_self != PLR_BROADCAST) - break; // join successful + return {}; // join successful SDL_Delay(10); } - return {}; + return tl::make_unexpected("Timeout waiting to join game"); } template @@ -174,8 +177,10 @@ int base_protocol

::join(std::string addrstr) if (wait_network()) { if (wait_firstpeer()) { - if (tl::expected result = wait_join(); !result.has_value()) { - LogError("wait_join: {}", result.error().what()); + tl::expected result = wait_join(); + if (!result.has_value()) { + const std::string_view message = result.error().what(); + SDL_SetError("wait_join: %.*s", static_cast(message.size()), message.data()); return -1; } } @@ -212,36 +217,39 @@ tl::expected base_protocol

::InitiateHandshake(plr_t player } template -void base_protocol

::send(packet &pkt) +tl::expected base_protocol

::send(packet &pkt) { plr_t destination = pkt.Destination(); - if (destination < MAX_PLRS) { - if (destination == MyPlayerId) - return; - SendTo(destination, pkt); - } else if (destination == PLR_BROADCAST) { - for (plr_t player = 0; player < Players.size(); player++) - SendTo(player, pkt); - } else if (destination == PLR_MASTER) { - throw dvlnet_exception(); - } else { - throw dvlnet_exception(); + if (destination == PLR_BROADCAST) { + for (plr_t player = 0; player < Players.size(); player++) { + tl::expected result = SendTo(player, pkt); + if (!result.has_value()) + LogError("Failed to send packet {} to player {}: {}", static_cast(pkt.Type()), player, result.error().what()); + } + return {}; } + if (destination >= MAX_PLRS) + return tl::make_unexpected("Invalid player ID"); + if (destination == MyPlayerId) + return {}; + return SendTo(destination, pkt); } template -void base_protocol

::SendTo(plr_t player, packet &pkt) +tl::expected base_protocol

::SendTo(plr_t player, packet &pkt) { Peer &peer = peers[player]; if (!peer.endpoint) - return; + return {}; // The handshake uses echo packets so clients know // when they can safely drain their send queues - if (peer.sendQueue && !IsAnyOf(pkt.Type(), PT_ECHO_REQUEST, PT_ECHO_REPLY)) + if (peer.sendQueue && !IsAnyOf(pkt.Type(), PT_ECHO_REQUEST, PT_ECHO_REPLY)) { peer.sendQueue->push_back(pkt); - else - proto.send(peer.endpoint, pkt.Data()); + return {}; + } + + return proto.send(peer.endpoint, pkt.Data()); } template @@ -301,20 +309,13 @@ tl::expected base_protocol

::handle_join_request(packet &in for (plr_t j = 0; j < Players.size(); ++j) { endpoint_t peer = peers[j].endpoint; if ((j != plr_self) && (j != i) && peer) { - { - tl::expected, PacketError> pkt - = pktfty->make_packet(PLR_MASTER, PLR_BROADCAST, i, senderinfo); - if (!pkt.has_value()) - return tl::make_unexpected(pkt.error()); - proto.send(peer, (*pkt)->Data()); - } - { - tl::expected, PacketError> pkt - = pktfty->make_packet(PLR_MASTER, PLR_BROADCAST, j, peer.serialize()); - if (!pkt.has_value()) - return tl::make_unexpected(pkt.error()); - proto.send(sender, (*pkt)->Data()); - } + tl::expected result + = pktfty->make_packet(PLR_MASTER, PLR_BROADCAST, i, senderinfo) + .and_then([&](std::unique_ptr &&pkt) { return proto.send(peer, pkt->Data()); }) + .and_then([&]() { return pktfty->make_packet(PLR_MASTER, PLR_BROADCAST, j, peer.serialize()); }) + .and_then([&](std::unique_ptr &&pkt) { return proto.send(sender, pkt->Data()); }); + if (!result.has_value()) + return result; } } @@ -327,7 +328,9 @@ tl::expected base_protocol

::handle_join_request(packet &in = pktfty->make_packet(plr_self, PLR_BROADCAST, *cookie, i, game_init_info); if (!pkt.has_value()) return tl::make_unexpected(pkt.error()); - proto.send(sender, (*pkt)->Data()); + tl::expected result = proto.send(sender, (*pkt)->Data()); + if (!result.has_value()) + return result; DrainSendQueue(i); return {}; } @@ -473,7 +476,9 @@ void base_protocol

::DrainSendQueue(plr_t player) std::deque &sendQueue = *srcPeer.sendQueue; while (!sendQueue.empty()) { packet &pkt = sendQueue.front(); - proto.send(srcPeer.endpoint, pkt.Data()); + tl::expected result = proto.send(srcPeer.endpoint, pkt.Data()); + if (!result.has_value()) + LogError("DrainSendQueue failed to send packet: {}", result.error().what()); sendQueue.pop_front(); } diff --git a/Source/dvlnet/frame_queue.cpp b/Source/dvlnet/frame_queue.cpp index 73e321c01..fb9518e86 100644 --- a/Source/dvlnet/frame_queue.cpp +++ b/Source/dvlnet/frame_queue.cpp @@ -9,21 +9,24 @@ namespace devilution { namespace net { -#if DVL_EXCEPTIONS -#define FRAME_QUEUE_ERROR throw frame_queue_exception() -#else -#define FRAME_QUEUE_ERROR app_fatal("frame queue error") -#endif +namespace { + +PacketError FrameQueueError() +{ + return PacketError("Incorrect frame size"); +} + +} // namespace framesize_t frame_queue::Size() const { return current_size; } -buffer_t frame_queue::Read(framesize_t s) +tl::expected frame_queue::Read(framesize_t s) { if (current_size < s) - FRAME_QUEUE_ERROR; + return tl::make_unexpected(FrameQueueError()); buffer_t ret; while (s > 0 && s >= buffer_deque.front().size()) { s -= buffer_deque.front().size(); @@ -50,33 +53,35 @@ void frame_queue::Write(buffer_t buf) buffer_deque.push_back(std::move(buf)); } -bool frame_queue::PacketReady() +tl::expected frame_queue::PacketReady() { if (nextsize == 0) { if (Size() < sizeof(framesize_t)) return false; - auto szbuf = Read(sizeof(framesize_t)); - std::memcpy(&nextsize, &szbuf[0], sizeof(framesize_t)); + tl::expected szbuf = Read(sizeof(framesize_t)); + if (!szbuf.has_value()) + return tl::make_unexpected(szbuf.error()); + std::memcpy(&nextsize, &(*szbuf)[0], sizeof(framesize_t)); if (nextsize == 0) - FRAME_QUEUE_ERROR; + return tl::make_unexpected(FrameQueueError()); } return Size() >= nextsize; } -buffer_t frame_queue::ReadPacket() +tl::expected frame_queue::ReadPacket() { if (nextsize == 0 || Size() < nextsize) - FRAME_QUEUE_ERROR; - auto ret = Read(nextsize); + return tl::make_unexpected(FrameQueueError()); + tl::expected ret = Read(nextsize); nextsize = 0; return ret; } -buffer_t frame_queue::MakeFrame(buffer_t packetbuf) +tl::expected frame_queue::MakeFrame(buffer_t packetbuf) { buffer_t ret; if (packetbuf.size() > max_frame_size) - ABORT(); + return tl::make_unexpected("Buffer exceeds maximum frame size"); framesize_t size = packetbuf.size(); ret.insert(ret.end(), packet_out::begin(size), packet_out::end(size)); ret.insert(ret.end(), packetbuf.begin(), packetbuf.end()); diff --git a/Source/dvlnet/frame_queue.h b/Source/dvlnet/frame_queue.h index e2673fa41..eadf82b2d 100644 --- a/Source/dvlnet/frame_queue.h +++ b/Source/dvlnet/frame_queue.h @@ -5,19 +5,14 @@ #include #include +#include + +#include "dvlnet/packet.h" + namespace devilution { namespace net { typedef std::vector buffer_t; - -class frame_queue_exception : public std::exception { -public: - const char *what() const throw() override - { - return "Incorrect frame size"; - } -}; - typedef uint32_t framesize_t; class frame_queue { @@ -30,14 +25,14 @@ private: framesize_t nextsize = 0; framesize_t Size() const; - buffer_t Read(framesize_t s); + tl::expected Read(framesize_t s); public: - bool PacketReady(); - buffer_t ReadPacket(); + tl::expected PacketReady(); + tl::expected ReadPacket(); void Write(buffer_t buf); - static buffer_t MakeFrame(buffer_t packetbuf); + static tl::expected MakeFrame(buffer_t packetbuf); }; } // namespace net diff --git a/Source/dvlnet/protocol_zt.cpp b/Source/dvlnet/protocol_zt.cpp index 04e17bd87..eabbddc43 100644 --- a/Source/dvlnet/protocol_zt.cpp +++ b/Source/dvlnet/protocol_zt.cpp @@ -90,10 +90,13 @@ bool protocol_zt::network_online() return true; } -bool protocol_zt::send(const endpoint &peer, const buffer_t &data) +tl::expected protocol_zt::send(const endpoint &peer, const buffer_t &data) { - peer_list[peer].send_queue.push_back(frame_queue::MakeFrame(data)); - return true; + tl::expected frame = frame_queue::MakeFrame(data); + if (!frame.has_value()) + return tl::make_unexpected(frame.error()); + peer_list[peer].send_queue.push_back(*frame); + return {}; } bool protocol_zt::send_oob(const endpoint &peer, const buffer_t &data) const @@ -238,11 +241,21 @@ bool protocol_zt::recv(endpoint &peer, buffer_t &data) } for (auto &p : peer_list) { - if (p.second.recv_queue.PacketReady()) { - peer = p.first; - data = p.second.recv_queue.ReadPacket(); - return true; + tl::expected ready = p.second.recv_queue.PacketReady(); + if (!ready.has_value()) { + LogError("PacketReady: {}", ready.error().what()); + continue; } + if (!*ready) + continue; + tl::expected packet = p.second.recv_queue.ReadPacket(); + if (!packet.has_value()) { + LogError("Failed reading packet data from peer: {}", packet.error().what()); + continue; + } + peer = p.first; + data = *packet; + return true; } return false; } diff --git a/Source/dvlnet/protocol_zt.h b/Source/dvlnet/protocol_zt.h index ade771ea1..214a2d797 100644 --- a/Source/dvlnet/protocol_zt.h +++ b/Source/dvlnet/protocol_zt.h @@ -68,7 +68,7 @@ public: protocol_zt(); ~protocol_zt(); void disconnect(const endpoint &peer); - bool send(const endpoint &peer, const buffer_t &data); + tl::expected send(const endpoint &peer, const buffer_t &data); bool send_oob(const endpoint &peer, const buffer_t &data) const; bool send_oob_mc(const buffer_t &data) const; bool recv(endpoint &peer, buffer_t &data); diff --git a/Source/dvlnet/tcp_client.cpp b/Source/dvlnet/tcp_client.cpp index 5761a725a..b4821147b 100644 --- a/Source/dvlnet/tcp_client.cpp +++ b/Source/dvlnet/tcp_client.cpp @@ -59,10 +59,16 @@ int tcp_client::join(std::string addrstr) SDL_SetError("make_packet: %.*s", static_cast(message.size()), message.data()); return -1; } - send(**pkt); + tl::expected sendResult = send(**pkt); + if (!sendResult.has_value()) { + const std::string_view message = sendResult.error().what(); + SDL_SetError("send: %.*s", static_cast(message.size()), message.data()); + return -1; + } for (auto i = 0; i < NoSleep; ++i) { - if (tl::expected result = poll(); !result.has_value()) { - const std::string_view message = result.error().what(); + tl::expected pollResult = poll(); + if (!pollResult.has_value()) { + const std::string_view message = pollResult.error().what(); SDL_SetError("%.*s", static_cast(message.size()), message.data()); return -1; } @@ -117,13 +123,19 @@ void tcp_client::HandleReceive(const asio::error_code &error, size_t bytesRead) recv_buffer.resize(bytesRead); recv_queue.Write(std::move(recv_buffer)); recv_buffer.resize(frame_queue::max_frame_size); - while (recv_queue.PacketReady()) { - tl::expected, PacketError> pkt = pktfty->make_packet(recv_queue.ReadPacket()); - if (!pkt.has_value()) { - RaiseIoHandlerError(pkt.error()); + while (true) { + tl::expected ready = recv_queue.PacketReady(); + if (!ready.has_value()) { + RaiseIoHandlerError(ready.error()); return; } - if (tl::expected result = RecvLocal(**pkt); !result.has_value()) { + if (!*ready) + break; + tl::expected result + = recv_queue.ReadPacket() + .and_then([this](buffer_t &&pktData) { return pktfty->make_packet(pktData); }) + .and_then([this](std::unique_ptr &&pkt) { return RecvLocal(*pkt); }); + if (!result.has_value()) { RaiseIoHandlerError(result.error()); return; } @@ -144,11 +156,14 @@ void tcp_client::HandleSend(const asio::error_code &error, size_t bytesSent) RaiseIoHandlerError(error.message()); } -void tcp_client::send(packet &pkt) +tl::expected tcp_client::send(packet &pkt) { - auto frame = std::make_unique(frame_queue::MakeFrame(pkt.Data())); - auto buf = asio::buffer(*frame); - asio::async_write(sock, buf, [this, frame = std::move(frame)](const asio::error_code &error, size_t bytesSent) { + tl::expected frame = frame_queue::MakeFrame(pkt.Data()); + if (!frame.has_value()) + return tl::make_unexpected(frame.error()); + std::unique_ptr framePtr = std::make_unique(*frame); + asio::mutable_buffer buf = asio::buffer(*framePtr); + asio::async_write(sock, buf, [this, frame = std::move(framePtr)](const asio::error_code &error, size_t bytesSent) { HandleSend(error, bytesSent); }); } diff --git a/Source/dvlnet/tcp_client.h b/Source/dvlnet/tcp_client.h index 2baedc919..c4029afd3 100644 --- a/Source/dvlnet/tcp_client.h +++ b/Source/dvlnet/tcp_client.h @@ -31,7 +31,7 @@ public: int join(std::string addrstr) override; tl::expected poll() override; - void send(packet &pkt) override; + tl::expected send(packet &pkt) override; bool SNetLeaveGame(int type) override; diff --git a/Source/dvlnet/tcp_server.cpp b/Source/dvlnet/tcp_server.cpp index f968e8766..c796bcfba 100644 --- a/Source/dvlnet/tcp_server.cpp +++ b/Source/dvlnet/tcp_server.cpp @@ -77,33 +77,43 @@ void tcp_server::HandleReceive(const scc &con, const asio::error_code &ec, con->recv_buffer.resize(bytesRead); con->recv_queue.Write(std::move(con->recv_buffer)); con->recv_buffer.resize(frame_queue::max_frame_size); - try { - while (con->recv_queue.PacketReady()) { - tl::expected, PacketError> pkt = pktfty.make_packet(con->recv_queue.ReadPacket()); - if (!pkt.has_value()) { - Log("make_packet: {}", pkt.error().what()); + while (true) { + tl::expected ready = con->recv_queue.PacketReady(); + if (!ready.has_value()) { + Log("PacketReady: {}", ready.error().what()); + DropConnection(con); + return; + } + if (!*ready) + break; + tl::expected pktData = con->recv_queue.ReadPacket(); + if (!pktData.has_value()) { + Log("ReadPacket: {}", pktData.error().what()); + DropConnection(con); + return; + } + tl::expected, PacketError> pkt = pktfty.make_packet(*pktData); + if (!pkt.has_value()) { + Log("make_packet: {}", pkt.error().what()); + DropConnection(con); + return; + } + if (con->plr == PLR_BROADCAST) { + tl::expected result = HandleReceiveNewPlayer(con, **pkt); + if (!result.has_value()) { + Log("HandleReceiveNewPlayer: {}", result.error().what()); DropConnection(con); return; } - if (con->plr == PLR_BROADCAST) { - if (tl::expected result = HandleReceiveNewPlayer(con, **pkt); !result.has_value()) { - Log("HandleReceiveNewPlayer: {}", result.error().what()); - DropConnection(con); - return; - } - } else { - con->timeout = timeout_active; - if (tl::expected result = HandleReceivePacket(**pkt); !result.has_value()) { - Log("Network error: {}", result.error().what()); - DropConnection(con); - return; - } + } else { + con->timeout = timeout_active; + tl::expected result = HandleReceivePacket(**pkt); + if (!result.has_value()) { + Log("Network error: {}", result.error().what()); + DropConnection(con); + return; } } - } catch (frame_queue_exception &e) { - Log("Invalid packet: {}", e.what()); - DropConnection(con); - return; } StartReceive(con); } @@ -123,33 +133,22 @@ tl::expected tcp_server::HandleReceiveNewPlayer(const scc &co for (plr_t player = 0; player < Players.size(); player++) { if (connections[player]) { - { - tl::expected, PacketError> pkt - = pktfty.make_packet(PLR_MASTER, PLR_BROADCAST, newplr); - if (!pkt.has_value()) - return tl::make_unexpected(pkt.error()); - StartSend(connections[player], **pkt); - } - - { - tl::expected, PacketError> pkt - = pktfty.make_packet(PLR_MASTER, PLR_BROADCAST, player); - if (!pkt.has_value()) - return tl::make_unexpected(pkt.error()); - StartSend(con, **pkt); - } + tl::expected result + = pktfty.make_packet(PLR_MASTER, PLR_BROADCAST, newplr) + .and_then([&](std::unique_ptr &&pkt) { return StartSend(connections[player], *pkt); }) + .and_then([&]() { return pktfty.make_packet(PLR_MASTER, PLR_BROADCAST, player); }) + .and_then([&](std::unique_ptr &&pkt) { return StartSend(con, *pkt); }); + if (!result.has_value()) + return result; } } - tl::expected cookie = inPkt.Cookie(); - if (!cookie.has_value()) - return tl::make_unexpected(cookie.error()); - tl::expected, PacketError> pkt - = pktfty.make_packet(PLR_MASTER, PLR_BROADCAST, - *cookie, newplr, game_init_info); - if (!pkt.has_value()) - return tl::make_unexpected(pkt.error()); - StartSend(con, **pkt); + tl::expected result + = inPkt.Cookie() + .and_then([&](cookie_t &&cookie) { return pktfty.make_packet(PLR_MASTER, PLR_BROADCAST, cookie, newplr, game_init_info); }) + .and_then([&](std::unique_ptr &&pkt) { return StartSend(con, *pkt); }); + if (!result.has_value()) + return result; con->plr = newplr; connections[newplr] = con; con->timeout = timeout_active; @@ -164,24 +163,31 @@ tl::expected tcp_server::HandleReceivePacket(packet &pkt) tl::expected tcp_server::SendPacket(packet &pkt) { if (pkt.Destination() == PLR_BROADCAST) { - for (size_t i = 0; i < Players.size(); ++i) - if (i != pkt.Source() && connections[i]) - StartSend(connections[i], pkt); - } else { - if (pkt.Destination() >= MAX_PLRS) - return tl::make_unexpected(ServerError()); - if ((pkt.Destination() != pkt.Source()) && connections[pkt.Destination()]) - StartSend(connections[pkt.Destination()], pkt); + for (size_t i = 0; i < Players.size(); ++i) { + if (i == pkt.Source() || !connections[i]) + continue; + tl::expected result = StartSend(connections[i], pkt); + if (!result.has_value()) + LogError("Failed to send packet {} to player {}: {}", static_cast(pkt.Type()), i, result.error().what()); + } + return {}; } - return {}; + if (pkt.Destination() >= MAX_PLRS) + return tl::make_unexpected(ServerError()); + if (pkt.Destination() == pkt.Source() || !connections[pkt.Destination()]) + return {}; + return StartSend(connections[pkt.Destination()], pkt); } -void tcp_server::StartSend(const scc &con, packet &pkt) +tl::expected tcp_server::StartSend(const scc &con, packet &pkt) { - auto frame = std::make_unique(frame_queue::MakeFrame(pkt.Data())); - auto buf = asio::buffer(*frame); + tl::expected frame = frame_queue::MakeFrame(pkt.Data()); + if (!frame.has_value()) + return tl::make_unexpected(frame.error()); + std::unique_ptr framePtr = std::make_unique(*frame); + asio::mutable_buffer buf = asio::buffer(*framePtr); asio::async_write(con->socket, buf, - [this, con, frame = std::move(frame)](const asio::error_code &ec, size_t bytesSent) { + [this, con, frame = std::move(framePtr)](const asio::error_code &ec, size_t bytesSent) { HandleSend(con, ec, bytesSent); }); } diff --git a/Source/dvlnet/tcp_server.h b/Source/dvlnet/tcp_server.h index 07a1365b3..c9907d72c 100644 --- a/Source/dvlnet/tcp_server.h +++ b/Source/dvlnet/tcp_server.h @@ -80,7 +80,7 @@ private: tl::expected HandleReceiveNewPlayer(const scc &con, packet &pkt); tl::expected HandleReceivePacket(packet &pkt); tl::expected SendPacket(packet &pkt); - void StartSend(const scc &con, packet &pkt); + tl::expected StartSend(const scc &con, packet &pkt); void HandleSend(const scc &con, const asio::error_code &ec, size_t bytesSent); void StartTimeout(const scc &con); void HandleTimeout(const scc &con, const asio::error_code &ec);