diff --git a/Source/dvlnet/base.cpp b/Source/dvlnet/base.cpp index 4e88c2dc2..ec853e478 100644 --- a/Source/dvlnet/base.cpp +++ b/Source/dvlnet/base.cpp @@ -5,6 +5,8 @@ #include #include +#include + #include "player.h" namespace devilution { @@ -37,93 +39,119 @@ void base::DisconnectNet(plr_t plr) { } -void base::SendEchoRequest(plr_t player) +tl::expected base::SendEchoRequest(plr_t player) { if (plr_self == PLR_BROADCAST) - return; + return {}; if (player == plr_self) - return; + return {}; timestamp_t now = SDL_GetTicks(); - auto echo = pktfty->make_packet(plr_self, player, now); - send(*echo); + tl::expected, PacketError> pkt + = pktfty->make_packet(plr_self, player, now); + if (!pkt.has_value()) { + return tl::make_unexpected(pkt.error()); + } + send(**pkt); + return {}; } -void base::HandleAccept(packet &pkt) +tl::expected base::HandleAccept(packet &pkt) { if (plr_self != PLR_BROADCAST) { - return; // already have player id + return {}; // already have player id } if (pkt.Cookie() == cookie_self) { - plr_self = pkt.NewPlayer(); + tl::expected newPlayerPkt = pkt.NewPlayer(); + if (!newPlayerPkt.has_value()) + return tl::make_unexpected(newPlayerPkt.error()); + plr_self = *std::move(newPlayerPkt); Connect(plr_self); } - if (game_init_info != pkt.Info()) { - if (pkt.Info().size() != sizeof(GameData)) { + tl::expected infoPkt = pkt.Info(); + if (!infoPkt.has_value()) + return tl::make_unexpected(infoPkt.error()); + const buffer_t &info = **infoPkt; + if (game_init_info != info) { + if (info.size() != sizeof(GameData)) { ABORT(); } // we joined and did not create - game_init_info = pkt.Info(); + game_init_info = info; _SNETEVENT ev; ev.eventid = EVENT_TYPE_PLAYER_CREATE_GAME; ev.playerid = plr_self; - ev.data = const_cast(pkt.Info().data()); - ev.databytes = pkt.Info().size(); + ev.data = const_cast(info.data()); + ev.databytes = info.size(); RunEventHandler(ev); } + return {}; } -void base::HandleConnect(packet &pkt) +tl::expected base::HandleConnect(packet &pkt) { - plr_t newPlayer = pkt.NewPlayer(); - Connect(newPlayer); + return pkt.NewPlayer().transform([this](plr_t &&newPlayer) { + Connect(newPlayer); + }); } -void base::HandleTurn(packet &pkt) +tl::expected base::HandleTurn(packet &pkt) { plr_t src = pkt.Source(); PlayerState &playerState = playerStateTable_[src]; std::deque &turnQueue = playerState.turnQueue; - const turn_t &turn = pkt.Turn(); - turnQueue.push_back(turn); - MakeReady(turn.SequenceNumber); -} - -void base::HandleDisconnect(packet &pkt) -{ - plr_t newPlayer = pkt.NewPlayer(); - if (newPlayer != plr_self) { - if (IsConnected(newPlayer)) { - auto leaveinfo = pkt.LeaveInfo(); - _SNETEVENT ev; - ev.eventid = EVENT_TYPE_PLAYER_LEAVE_GAME; - ev.playerid = pkt.NewPlayer(); - ev.data = reinterpret_cast(&leaveinfo); - ev.databytes = sizeof(leaveinfo_t); - RunEventHandler(ev); - DisconnectNet(newPlayer); - ClearMsg(newPlayer); - PlayerState &playerState = playerStateTable_[newPlayer]; - playerState.isConnected = false; - playerState.turnQueue.clear(); - } - } else { + return pkt.Turn().transform([&](turn_t &&turn) { + turnQueue.push_back(turn); + MakeReady(turn.SequenceNumber); + }); +} + +tl::expected base::HandleDisconnect(packet &pkt) +{ + tl::expected newPlayer = pkt.NewPlayer(); + if (!newPlayer.has_value()) + return tl::make_unexpected(newPlayer.error()); + if (*newPlayer == plr_self) { ABORT(); // we were dropped by the owner?!? } + if (IsConnected(*newPlayer)) { + tl::expected leaveinfo = pkt.LeaveInfo(); + if (!leaveinfo.has_value()) + return tl::make_unexpected(leaveinfo.error()); + _SNETEVENT ev; + ev.eventid = EVENT_TYPE_PLAYER_LEAVE_GAME; + ev.playerid = *newPlayer; + ev.data = reinterpret_cast(&*leaveinfo); + ev.databytes = sizeof(leaveinfo_t); + RunEventHandler(ev); + DisconnectNet(*newPlayer); + ClearMsg(*newPlayer); + PlayerState &playerState = playerStateTable_[*newPlayer]; + playerState.isConnected = false; + playerState.turnQueue.clear(); + } + return {}; } -void base::HandleEchoRequest(packet &pkt) +tl::expected base::HandleEchoRequest(packet &pkt) { - auto reply = pktfty->make_packet(plr_self, pkt.Source(), pkt.Time()); - send(*reply); + return pkt.Time() + .and_then([&](cookie_t &&pktTime) { + return pktfty->make_packet(plr_self, pkt.Source(), pktTime); + }) + .transform([&](std::unique_ptr &&pkt) { + send(*pkt); + }); } -void base::HandleEchoReply(packet &pkt) +tl::expected base::HandleEchoReply(packet &pkt) { - uint32_t now = SDL_GetTicks(); + const uint32_t now = SDL_GetTicks(); plr_t src = pkt.Source(); - PlayerState &playerState = playerStateTable_[src]; - playerState.roundTripLatency = now - pkt.Time(); + return pkt.Time().transform([&](cookie_t &&pktTime) { + PlayerState &playerState = playerStateTable_[src]; + playerState.roundTripLatency = now - pktTime; + }); } void base::ClearMsg(plr_t plr) @@ -136,14 +164,15 @@ void base::ClearMsg(plr_t plr) message_queue.end()); } -void base::Connect(plr_t player) +tl::expected base::Connect(plr_t player) { PlayerState &playerState = playerStateTable_[player]; bool wasConnected = playerState.isConnected; playerState.isConnected = true; if (!wasConnected) - SendFirstTurnIfReady(player); + return SendFirstTurnIfReady(player); + return {}; } bool base::IsConnected(plr_t player) const @@ -152,35 +181,33 @@ bool base::IsConnected(plr_t player) const return playerState.isConnected; } -void base::RecvLocal(packet &pkt) +tl::expected base::RecvLocal(packet &pkt) { if (pkt.Source() < MAX_PLRS) { - Connect(pkt.Source()); + if (tl::expected result = Connect(pkt.Source()); + !result.has_value()) { + return result; + } } switch (pkt.Type()) { case PT_MESSAGE: - message_queue.emplace_back(pkt.Source(), pkt.Message()); - break; + return pkt.Message().transform([&](const buffer_t *message) { + message_queue.emplace_back(pkt.Source(), *message); + }); case PT_TURN: - HandleTurn(pkt); - break; + return HandleTurn(pkt); case PT_JOIN_ACCEPT: - HandleAccept(pkt); - break; + return HandleAccept(pkt); case PT_CONNECT: - HandleConnect(pkt); - break; + return HandleConnect(pkt); case PT_DISCONNECT: - HandleDisconnect(pkt); - break; + return HandleDisconnect(pkt); case PT_ECHO_REQUEST: - HandleEchoRequest(pkt); - break; + return HandleEchoRequest(pkt); case PT_ECHO_REPLY: - HandleEchoReply(pkt); - break; + return HandleEchoReply(pkt); default: - break; + return {}; // otherwise drop } } @@ -213,8 +240,13 @@ bool base::SNetSendMessage(int playerId, void *data, unsigned int size) else dest = playerId; if (dest != plr_self) { - auto pkt = pktfty->make_packet(plr_self, dest, message); - send(*pkt); + tl::expected, PacketError> pkt + = pktfty->make_packet(plr_self, dest, message); + if (!pkt.has_value()) { + LogError("make_packet: {}", pkt.error().what()); + return false; + } + send(**pkt); } return true; } @@ -319,37 +351,47 @@ bool base::SNetSendTurn(char *data, unsigned int size) return true; } -void base::SendTurnIfReady(turn_t turn) +tl::expected base::SendTurnIfReady(turn_t turn) { if (awaitingSequenceNumber_) awaitingSequenceNumber_ = !IsGameHost(); if (!awaitingSequenceNumber_) { - auto pkt = pktfty->make_packet(plr_self, PLR_BROADCAST, turn); - send(*pkt); + tl::expected, PacketError> pkt + = pktfty->make_packet(plr_self, PLR_BROADCAST, turn); + if (!pkt.has_value()) { + return tl::make_unexpected(pkt.error()); + } + send(**pkt); } + return {}; } -void base::SendFirstTurnIfReady(plr_t player) +tl::expected base::SendFirstTurnIfReady(plr_t player) { if (awaitingSequenceNumber_) - return; + return {}; PlayerState &playerState = playerStateTable_[plr_self]; std::deque &turnQueue = playerState.turnQueue; if (turnQueue.empty()) - return; + return {}; for (turn_t turn : turnQueue) { - auto pkt = pktfty->make_packet(plr_self, player, turn); - send(*pkt); + tl::expected, PacketError> pkt + = pktfty->make_packet(plr_self, player, turn); + if (!pkt.has_value()) { + return tl::make_unexpected(pkt.error()); + } + send(**pkt); } + return {}; } -void base::MakeReady(seq_t sequenceNumber) +tl::expected base::MakeReady(seq_t sequenceNumber) { if (!awaitingSequenceNumber_) - return; + return {}; current_turn = sequenceNumber; next_turn = sequenceNumber; @@ -360,8 +402,12 @@ void base::MakeReady(seq_t sequenceNumber) for (turn_t &turn : turnQueue) { turn.SequenceNumber = next_turn; next_turn++; - SendTurnIfReady(turn); + if (tl::expected result = SendTurnIfReady(turn); + !result.has_value()) { + return result; + } } + return {}; } void base::SNetGetProviderCaps(struct _SNETCAPS *caps) @@ -400,21 +446,35 @@ bool base::SNetRegisterEventHandler(event_type evtype, SEVTHANDLER func) bool base::SNetLeaveGame(int type) { - auto pkt = pktfty->make_packet(plr_self, PLR_BROADCAST, - plr_self, type); - send(*pkt); + tl::expected, PacketError> pkt + = pktfty->make_packet( + plr_self, PLR_BROADCAST, plr_self, type); + if (!pkt.has_value()) { + LogError("make_packet: {}", pkt.error().what()); + return false; + } + send(**pkt); plr_self = PLR_BROADCAST; return true; } bool base::SNetDropPlayer(int playerid, uint32_t flags) { - auto pkt = pktfty->make_packet(plr_self, - PLR_BROADCAST, - (plr_t)playerid, - (leaveinfo_t)flags); - send(*pkt); - RecvLocal(*pkt); + tl::expected, PacketError> pkt + = pktfty->make_packet( + plr_self, + PLR_BROADCAST, + (plr_t)playerid, + (leaveinfo_t)flags); + if (!pkt.has_value()) { + LogError("make_packet: {}", pkt.error().what()); + return false; + } + send(**pkt); + if (tl::expected result = RecvLocal(**pkt); !result.has_value()) { + LogError("SNetDropPlayer: {}", result.error().what()); + return false; + } return true; } diff --git a/Source/dvlnet/base.h b/Source/dvlnet/base.h index 699e0adc8..28a85b90d 100644 --- a/Source/dvlnet/base.h +++ b/Source/dvlnet/base.h @@ -76,10 +76,10 @@ protected: std::unique_ptr pktfty; - void Connect(plr_t player); - void RecvLocal(packet &pkt); + tl::expected Connect(plr_t player); + tl::expected RecvLocal(packet &pkt); void RunEventHandler(_SNETEVENT &ev); - void SendEchoRequest(plr_t player); + tl::expected SendEchoRequest(plr_t player); [[nodiscard]] bool IsConnected(plr_t player) const; virtual bool IsGameHost() = 0; @@ -90,17 +90,17 @@ private: plr_t GetOwner(); bool AllTurnsArrived(); - void MakeReady(seq_t sequenceNumber); - void SendTurnIfReady(turn_t turn); - void SendFirstTurnIfReady(plr_t player); + tl::expected MakeReady(seq_t sequenceNumber); + tl::expected SendTurnIfReady(turn_t turn); + tl::expected SendFirstTurnIfReady(plr_t player); void ClearMsg(plr_t plr); - void HandleAccept(packet &pkt); - void HandleConnect(packet &pkt); - void HandleTurn(packet &pkt); - void HandleDisconnect(packet &pkt); - void HandleEchoRequest(packet &pkt); - void HandleEchoReply(packet &pkt); + tl::expected HandleAccept(packet &pkt); + tl::expected HandleConnect(packet &pkt); + tl::expected HandleTurn(packet &pkt); + tl::expected HandleDisconnect(packet &pkt); + tl::expected HandleEchoRequest(packet &pkt); + tl::expected HandleEchoReply(packet &pkt); }; } // namespace net diff --git a/Source/dvlnet/base_protocol.h b/Source/dvlnet/base_protocol.h index 9d0ab0bd6..9594cf8e4 100644 --- a/Source/dvlnet/base_protocol.h +++ b/Source/dvlnet/base_protocol.h @@ -10,8 +10,7 @@ #include "player.h" #include "utils/log.hpp" -namespace devilution { -namespace net { +namespace devilution::net { template class base_protocol : public base { @@ -56,18 +55,18 @@ private: bool isGameHost_; plr_t get_master(); - void InitiateHandshake(plr_t player); + tl::expected InitiateHandshake(plr_t player); void SendTo(plr_t player, packet &pkt); void DrainSendQueue(plr_t player); void recv(); - void handle_join_request(packet &pkt, endpoint_t sender); - void recv_decrypted(packet &pkt, endpoint_t sender); - void recv_ingame(packet &pkt, endpoint_t sender); + tl::expected handle_join_request(packet &pkt, endpoint_t sender); + tl::expected recv_decrypted(packet &pkt, endpoint_t sender); + tl::expected recv_ingame(packet &pkt, endpoint_t sender); bool is_recognized(endpoint_t sender); bool wait_network(); bool wait_firstpeer(); - void wait_join(); + tl::expected wait_join(); }; template @@ -121,24 +120,33 @@ bool base_protocol

::send_info_request() { if (!proto.network_online()) return false; - auto pkt = pktfty->make_packet(PLR_BROADCAST, PLR_MASTER); - proto.send_oob_mc(pkt->Data()); + tl::expected, PacketError> pkt + = pktfty->make_packet(PLR_BROADCAST, PLR_MASTER); + if (!pkt.has_value()) { + LogError("make_packet: {}", pkt.error().what()); + return false; + } + proto.send_oob_mc((*pkt)->Data()); return true; } template -void base_protocol

::wait_join() +tl::expected base_protocol

::wait_join() { cookie_self = packet_out::GenerateCookie(); - auto pkt = pktfty->make_packet(PLR_BROADCAST, - PLR_MASTER, cookie_self, game_init_info); - proto.send(firstpeer, pkt->Data()); + tl::expected, PacketError> pkt + = pktfty->make_packet(PLR_BROADCAST, PLR_MASTER, cookie_self, game_init_info); + if (!pkt.has_value()) { + return tl::unexpected(pkt.error()); + } + proto.send(firstpeer, (*pkt)->Data()); for (auto i = 0; i < 500; ++i) { recv(); if (plr_self != PLR_BROADCAST) break; // join successful SDL_Delay(10); } + return {}; } template @@ -149,7 +157,11 @@ int base_protocol

::create(std::string addrstr) if (wait_network()) { plr_self = 0; - Connect(plr_self); + if (tl::expected result = Connect(plr_self); + !result.has_value()) { + LogError("Connect: {}", result.error().what()); + return -1; + } } return (plr_self == PLR_BROADCAST ? -1 : plr_self); } @@ -161,8 +173,12 @@ int base_protocol

::join(std::string addrstr) isGameHost_ = false; if (wait_network()) { - if (wait_firstpeer()) - wait_join(); + if (wait_firstpeer()) { + if (tl::expected result = wait_join(); !result.has_value()) { + LogError("wait_join: {}", result.error().what()); + return -1; + } + } } return (plr_self == PLR_BROADCAST ? -1 : plr_self); } @@ -180,7 +196,7 @@ void base_protocol

::poll() } template -void base_protocol

::InitiateHandshake(plr_t player) +tl::expected base_protocol

::InitiateHandshake(plr_t player) { Peer &peer = peers[player]; @@ -189,7 +205,9 @@ void base_protocol

::InitiateHandshake(plr_t player) // If the connection is already open, it should be safe to initiate from either end. // If not, only the player with the smaller player number should initiate the handshake. if (plr_self < player || proto.is_peer_connected(peer.endpoint)) - SendEchoRequest(player); + return SendEchoRequest(player); + + return {}; } template @@ -232,13 +250,15 @@ void base_protocol

::recv() buffer_t pkt_buf; endpoint_t sender; while (proto.recv(sender, pkt_buf)) { // read until kernel buffer is empty? - try { - auto pkt = pktfty->make_packet(pkt_buf); - recv_decrypted(*pkt, sender); - } catch (packet_exception &e) { + tl::expected result + = pktfty->make_packet(pkt_buf) + .and_then([&](std::unique_ptr &&pkt) { + return recv_decrypted(*pkt, sender); + }); + if (!result.has_value()) { // drop packet proto.disconnect(sender); - Log("{}", e.what()); + Log("{}", result.error().what()); } } while (proto.get_disconnected(sender)) { @@ -256,7 +276,7 @@ void base_protocol

::recv() } template -void base_protocol

::handle_join_request(packet &pkt, endpoint_t sender) +tl::expected base_protocol

::handle_join_request(packet &inPkt, endpoint_t sender) { plr_t i; for (i = 0; i < Players.size(); ++i) { @@ -264,70 +284,94 @@ void base_protocol

::handle_join_request(packet &pkt, endpoint_t sender) if (i != plr_self && !peer.endpoint) { peer.endpoint = sender; peer.sendQueue = std::make_unique>(); - Connect(i); + if (tl::expected result = Connect(i); + !result.has_value()) { + return result; + } break; } } if (i >= MAX_PLRS) { // already full - return; + return {}; } auto senderinfo = sender.serialize(); for (plr_t j = 0; j < Players.size(); ++j) { endpoint_t peer = peers[j].endpoint; if ((j != plr_self) && (j != i) && peer) { - auto peerpkt = pktfty->make_packet(PLR_MASTER, PLR_BROADCAST, i, senderinfo); - proto.send(peer, peerpkt->Data()); - - auto infopkt = pktfty->make_packet(PLR_MASTER, PLR_BROADCAST, j, peer.serialize()); - proto.send(sender, infopkt->Data()); + { + tl::expected, PacketError> pkt + = pktfty->make_packet(PLR_MASTER, PLR_BROADCAST, i, senderinfo); + if (!pkt.has_value()) + return tl::make_unexpected(pkt.error()); + proto.send(peer, (*pkt)->Data()); + } + { + tl::expected, PacketError> pkt + = pktfty->make_packet(PLR_MASTER, PLR_BROADCAST, j, peer.serialize()); + if (!pkt.has_value()) + return tl::make_unexpected(pkt.error()); + proto.send(sender, (*pkt)->Data()); + } } } // PT_JOIN_ACCEPT must be sent after all PT_CONNECT packets so the new player does // not resume game logic until after having been notified of all existing players - auto reply = pktfty->make_packet(plr_self, PLR_BROADCAST, - pkt.Cookie(), i, - game_init_info); - proto.send(sender, reply->Data()); + tl::expected cookie = inPkt.Cookie(); + if (!cookie.has_value()) + return tl::make_unexpected(cookie.error()); + tl::expected, PacketError> pkt + = pktfty->make_packet(plr_self, PLR_BROADCAST, *cookie, i, game_init_info); + if (!pkt.has_value()) + return tl::make_unexpected(pkt.error()); + proto.send(sender, (*pkt)->Data()); DrainSendQueue(i); + return {}; } template -void base_protocol

::recv_decrypted(packet &pkt, endpoint_t sender) +tl::expected base_protocol

::recv_decrypted(packet &pkt, endpoint_t sender) { if (pkt.Source() == PLR_BROADCAST && pkt.Destination() == PLR_MASTER && pkt.Type() == PT_INFO_REPLY) { size_t neededSize = sizeof(GameData) + (PlayerNameLength * MAX_PLRS); - if (pkt.Info().size() < neededSize) - return; - const GameData *gameData = (const GameData *)pkt.Info().data(); + const tl::expected pktInfo = pkt.Info(); + if (!pktInfo.has_value()) + return tl::make_unexpected(pktInfo.error()); + const buffer_t &infoBuffer = **pktInfo; + if (infoBuffer.size() < neededSize) + return {}; + const GameData *gameData = reinterpret_cast(infoBuffer.data()); if (gameData->size != sizeof(GameData)) - return; + return {}; std::vector playerNames; for (size_t i = 0; i < Players.size(); i++) { std::string playerName; - const char *playerNamePointer = (const char *)(pkt.Info().data() + sizeof(GameData) + (i * PlayerNameLength)); + const char *playerNamePointer = reinterpret_cast(infoBuffer.data() + sizeof(GameData) + (i * PlayerNameLength)); playerName.append(playerNamePointer, strnlen(playerNamePointer, PlayerNameLength)); if (!playerName.empty()) playerNames.push_back(playerName); } std::string gameName; - size_t gameNameSize = pkt.Info().size() - neededSize; + size_t gameNameSize = infoBuffer.size() - neededSize; gameName.resize(gameNameSize); - std::memcpy(&gameName[0], pkt.Info().data() + neededSize, gameNameSize); + std::memcpy(&gameName[0], infoBuffer.data() + neededSize, gameNameSize); game_list[gameName] = GameListValue { *gameData, std::move(playerNames), sender }; - return; + return {}; } - recv_ingame(pkt, sender); + return recv_ingame(pkt, sender); } template -void base_protocol

::recv_ingame(packet &pkt, endpoint_t sender) +tl::expected base_protocol

::recv_ingame(packet &pkt, endpoint_t sender) { if (pkt.Source() == PLR_BROADCAST && pkt.Destination() == PLR_MASTER) { if (pkt.Type() == PT_JOIN_REQUEST) { - handle_join_request(pkt, sender); + if (tl::expected result = handle_join_request(pkt, sender); + !result.has_value()) { + return result; + } } else if (pkt.Type() == PT_INFO_REQUEST) { if ((plr_self != PLR_BROADCAST) && (get_master() == plr_self)) { buffer_t buf; @@ -341,58 +385,81 @@ void base_protocol

::recv_ingame(packet &pkt, endpoint_t sender) } } std::memcpy(buf.data() + game_init_info.size() + (PlayerNameLength * MAX_PLRS), &gamename[0], gamename.size()); - auto reply = pktfty->make_packet(PLR_BROADCAST, - PLR_MASTER, - buf); - proto.send_oob(sender, reply->Data()); + tl::expected, PacketError> reply + = pktfty->make_packet(PLR_BROADCAST, PLR_MASTER, buf); + if (!reply.has_value()) { + return tl::make_unexpected(reply.error()); + } + proto.send_oob(sender, (*reply)->Data()); } } - return; - } else if (pkt.Source() == PLR_MASTER && pkt.Type() == PT_CONNECT) { + return {}; + } + if (pkt.Source() == PLR_MASTER && pkt.Type() == PT_CONNECT) { if (!is_recognized(sender)) { LogDebug("Invalid packet: PT_CONNECT received from unrecognized endpoint"); - return; + return {}; } // addrinfo packets - plr_t newPlayer = pkt.NewPlayer(); - Peer &peer = peers[newPlayer]; - peer.endpoint.unserialize(pkt.Info()); + tl::expected newPlayer = pkt.NewPlayer(); + if (!newPlayer.has_value()) + return tl::make_unexpected(newPlayer.error()); + Peer &peer = peers[*newPlayer]; + tl::expected pktInfo = pkt.Info(); + if (!pktInfo.has_value()) + return tl::make_unexpected(pktInfo.error()); + peer.endpoint.unserialize(**pktInfo); peer.sendQueue = std::make_unique>(); - Connect(newPlayer); + if (tl::expected result = Connect(*newPlayer); + !result.has_value()) { + return result; + } if (plr_self != PLR_BROADCAST) - InitiateHandshake(newPlayer); - - return; - } else if (pkt.Source() >= MAX_PLRS) { + return InitiateHandshake(*newPlayer); + return {}; + } + if (pkt.Source() >= MAX_PLRS) { // normal packets LogDebug("Invalid packet: packet source ({}) >= MAX_PLRS", pkt.Source()); - return; - } else if (sender == firstpeer && pkt.Type() == PT_JOIN_ACCEPT) { + return {}; + } + if (sender == firstpeer && pkt.Type() == PT_JOIN_ACCEPT) { plr_t src = pkt.Source(); peers[src].endpoint = sender; - Connect(src); + if (tl::expected result = Connect(src); + !result.has_value()) { + return result; + } firstpeer = {}; } else if (sender != peers[pkt.Source()].endpoint) { LogDebug("Invalid packet: packet source ({}) received from unrecognized endpoint", pkt.Source()); - return; + return {}; } if (pkt.Destination() != plr_self && pkt.Destination() != PLR_BROADCAST) - return; // packet not for us, drop + return {}; // packet not for us, drop bool wasBroadcast = plr_self == PLR_BROADCAST; - RecvLocal(pkt); + if (tl::expected result = RecvLocal(pkt); + !result.has_value()) { + return result; + } if (plr_self != PLR_BROADCAST) { if (wasBroadcast) { // Send a handshake to everyone just after PT_JOIN_ACCEPT - for (plr_t player = 0; player < Players.size(); player++) - InitiateHandshake(player); + for (plr_t player = 0; player < Players.size(); player++) { + if (tl::expected result = InitiateHandshake(player); + !result.has_value()) { + return result; + } + } } DrainSendQueue(pkt.Source()); } + return {}; } template @@ -462,5 +529,4 @@ std::string base_protocol

::make_default_gamename() return proto.make_default_gamename(); } -} // namespace net -} // namespace devilution +} // namespace devilution::net diff --git a/Source/dvlnet/packet.cpp b/Source/dvlnet/packet.cpp index 08771561a..715c6cc3c 100644 --- a/Source/dvlnet/packet.cpp +++ b/Source/dvlnet/packet.cpp @@ -10,8 +10,12 @@ #include #endif -namespace devilution { -namespace net { +#include + +#include "utils/algorithm/container.hpp" +#include "utils/str_cat.hpp" + +namespace devilution::net { #ifdef PACKET_ENCRYPTION @@ -80,7 +84,12 @@ const char *packet_type_to_string(uint8_t packetType) } } -wrong_packet_type_exception::wrong_packet_type_exception(std::initializer_list expectedTypes, std::uint8_t actual) +PacketTypeError::PacketTypeError(std::uint8_t unknownPacketType) + : message_(StrCat("Unknown packet type ", unknownPacketType)) +{ +} + +PacketTypeError::PacketTypeError(std::initializer_list expectedTypes, std::uint8_t actual) { message_ = "Expected packet of type "; const auto appendPacketType = [this](std::uint8_t t) { @@ -88,7 +97,7 @@ wrong_packet_type_exception::wrong_packet_type_exception(std::initializer_list

expectedTypes, std::uint8_t actualType) +tl::expected CheckPacketTypeOneOf(std::initializer_list expectedTypes, std::uint8_t actualType) { - for (std::uint8_t packetType : expectedTypes) - if (actualType == packetType) - return; -#if DVL_EXCEPTIONS - throw wrong_packet_type_exception(expectedTypes, actualType); -#else - app_fatal("wrong packet type"); -#endif + if (c_none_of(expectedTypes, + [actualType](uint8_t type) { return type == actualType; })) { + return tl::make_unexpected(PacketTypeError(expectedTypes, actualType)); + } + return {}; } } // namespace @@ -143,64 +149,60 @@ plr_t packet::Destination() const return m_dest; } -const buffer_t &packet::Message() +tl::expected packet::Message() { assert(have_decrypted); - CheckPacketTypeOneOf({ PT_MESSAGE }, m_type); - return m_message; + return CheckPacketTypeOneOf({ PT_MESSAGE }, m_type) + .transform([this]() { return &m_message; }); } -turn_t packet::Turn() +tl::expected packet::Turn() { assert(have_decrypted); - CheckPacketTypeOneOf({ PT_TURN }, m_type); - return m_turn; + return CheckPacketTypeOneOf({ PT_TURN }, m_type) + .transform([this]() { return m_turn; }); } -cookie_t packet::Cookie() +tl::expected packet::Cookie() { assert(have_decrypted); - CheckPacketTypeOneOf({ PT_JOIN_REQUEST, PT_JOIN_ACCEPT }, m_type); - return m_cookie; + return CheckPacketTypeOneOf({ PT_JOIN_REQUEST, PT_JOIN_ACCEPT }, m_type) + .transform([this]() { return m_cookie; }); } -plr_t packet::NewPlayer() +tl::expected packet::NewPlayer() { assert(have_decrypted); - CheckPacketTypeOneOf({ PT_JOIN_ACCEPT, PT_CONNECT, PT_DISCONNECT }, m_type); - return m_newplr; + return CheckPacketTypeOneOf({ PT_JOIN_ACCEPT, PT_CONNECT, PT_DISCONNECT }, m_type) + .transform([this]() { return m_newplr; }); } -timestamp_t packet::Time() +tl::expected packet::Time() { assert(have_decrypted); - CheckPacketTypeOneOf({ PT_ECHO_REQUEST, PT_ECHO_REPLY }, m_type); - return m_time; + return CheckPacketTypeOneOf({ PT_ECHO_REQUEST, PT_ECHO_REPLY }, m_type) + .transform([this]() { return m_time; }); } -const buffer_t &packet::Info() +tl::expected packet::Info() { assert(have_decrypted); - CheckPacketTypeOneOf({ PT_JOIN_REQUEST, PT_JOIN_ACCEPT, PT_CONNECT, PT_INFO_REPLY }, m_type); - return m_info; + return CheckPacketTypeOneOf({ PT_JOIN_REQUEST, PT_JOIN_ACCEPT, PT_CONNECT, PT_INFO_REPLY }, m_type) + .transform([this]() { return &m_info; }); } -leaveinfo_t packet::LeaveInfo() +tl::expected packet::LeaveInfo() { assert(have_decrypted); - CheckPacketTypeOneOf({ PT_DISCONNECT }, m_type); - return m_leaveinfo; + return CheckPacketTypeOneOf({ PT_DISCONNECT }, m_type) + .transform([this]() { return m_leaveinfo; }); } -void packet_in::Create(buffer_t buf) +tl::expected packet_in::Create(buffer_t buf) { assert(!have_encrypted && !have_decrypted); if (buf.size() < sizeof(packet_type) + 2 * sizeof(plr_t)) -#if DVL_EXCEPTIONS - throw packet_exception(); -#else - app_fatal("invalid packet"); -#endif + return tl::make_unexpected(PacketError()); decrypted_buffer = std::move(buf); have_decrypted = true; @@ -210,10 +212,11 @@ void packet_in::Create(buffer_t buf) // we save a copy in encrypted_buffer anyway encrypted_buffer = decrypted_buffer; have_encrypted = true; + return {}; } #ifdef PACKET_ENCRYPTION -void packet_in::Decrypt(buffer_t buf) +tl::expected packet_in::Decrypt(buffer_t buf) { assert(!have_encrypted && !have_decrypted); encrypted_buffer = std::move(buf); @@ -222,7 +225,7 @@ void packet_in::Decrypt(buffer_t buf) if (encrypted_buffer.size() < crypto_secretbox_NONCEBYTES + crypto_secretbox_MACBYTES + sizeof(packet_type) + 2 * sizeof(plr_t)) - throw packet_exception(); + return tl::make_unexpected(PacketError()); auto pktlen = (encrypted_buffer.size() - crypto_secretbox_NONCEBYTES - crypto_secretbox_MACBYTES); @@ -234,9 +237,10 @@ void packet_in::Decrypt(buffer_t buf) encrypted_buffer.data(), key.data()); if (status != 0) - throw packet_exception(); + return tl::make_unexpected(PacketError()); have_decrypted = true; + return {}; } #endif @@ -298,5 +302,4 @@ packet_factory::packet_factory(std::string pw) #endif } -} // namespace net -} // namespace devilution +} // namespace devilution::net diff --git a/Source/dvlnet/packet.h b/Source/dvlnet/packet.h index a306a422f..a6edffa94 100644 --- a/Source/dvlnet/packet.h +++ b/Source/dvlnet/packet.h @@ -6,6 +6,8 @@ #include #include +#include + #ifdef PACKET_ENCRYPTION #include #endif @@ -13,6 +15,7 @@ #include "appfat.h" #include "dvlnet/abstract_net.h" #include "utils/attributes.h" +#include "utils/str_cat.hpp" #include "utils/stubs.h" namespace devilution { @@ -56,19 +59,20 @@ struct turn_t { static constexpr plr_t PLR_MASTER = 0xFE; static constexpr plr_t PLR_BROADCAST = 0xFF; -class packet_exception : public dvlnet_exception { +class PacketError { public: - const char *what() const throw() override + virtual const char *what() const { return "Incorrect package size"; } }; -class wrong_packet_type_exception : public packet_exception { +class PacketTypeError : public PacketError { public: - wrong_packet_type_exception(std::initializer_list expectedTypes, std::uint8_t actual); + explicit PacketTypeError(std::uint8_t unknownPacketType); + PacketTypeError(std::initializer_list expectedTypes, std::uint8_t actual); - const char *what() const throw() override + const char *what() const override { return message_.c_str(); } @@ -105,30 +109,30 @@ public: packet_type Type(); plr_t Source() const; plr_t Destination() const; - const buffer_t &Message(); - turn_t Turn(); - cookie_t Cookie(); - plr_t NewPlayer(); - timestamp_t Time(); - const buffer_t &Info(); - leaveinfo_t LeaveInfo(); + tl::expected Message(); + tl::expected Turn(); + tl::expected Cookie(); + tl::expected NewPlayer(); + tl::expected Time(); + tl::expected Info(); + tl::expected LeaveInfo(); }; template class packet_proc : public packet { public: using packet::packet; - void process_data(); + tl::expected process_data(); }; class packet_in : public packet_proc { public: using packet_proc::packet_proc; - void Create(buffer_t buf); - void process_element(buffer_t &x); + tl::expected Create(buffer_t buf); + tl::expected process_element(buffer_t &x); template - void process_element(T &x); - void Decrypt(buffer_t buf); + tl::expected process_element(T &x); + tl::expected Decrypt(buffer_t buf); }; class packet_out : public packet_proc { @@ -138,9 +142,9 @@ public: template void create(Args... args); - void process_element(buffer_t &x); + tl::expected process_element(buffer_t &x); template - void process_element(T &x); + tl::expected process_element(T &x); template static const unsigned char *begin(const T &x); template @@ -150,67 +154,68 @@ public: }; template -void packet_proc

::process_data() +tl::expected packet_proc

::process_data() { P &self = static_cast

(*this); - self.process_element(m_type); - self.process_element(m_src); - self.process_element(m_dest); + { + tl::expected result + = self.process_element(m_type) + .and_then([&]() { + return self.process_element(m_src); + }) + .and_then([&]() { + return self.process_element(m_dest); + }); + if (!result.has_value()) + return result; + } switch (m_type) { case PT_MESSAGE: - self.process_element(m_message); - break; + return self.process_element(m_message); case PT_TURN: - self.process_element(m_turn.SequenceNumber); - self.process_element(m_turn.Value); - break; + return self.process_element(m_turn.SequenceNumber) + .and_then([&]() { return self.process_element(m_turn.Value); }); case PT_JOIN_REQUEST: - self.process_element(m_cookie); - self.process_element(m_info); - break; + return self.process_element(m_cookie) + .and_then([&]() { return self.process_element(m_info); }); case PT_JOIN_ACCEPT: - self.process_element(m_cookie); - self.process_element(m_newplr); - self.process_element(m_info); - break; + return self.process_element(m_cookie) + .and_then([&]() { return self.process_element(m_newplr); }) + .and_then([&]() { return self.process_element(m_info); }); case PT_CONNECT: - self.process_element(m_newplr); - self.process_element(m_info); - break; + return self.process_element(m_newplr) + .and_then([&]() { return self.process_element(m_info); }); case PT_DISCONNECT: - self.process_element(m_newplr); - self.process_element(m_leaveinfo); - break; + return self.process_element(m_newplr) + .and_then([&]() { return self.process_element(m_leaveinfo); }); case PT_INFO_REPLY: - self.process_element(m_info); - break; + return self.process_element(m_info); case PT_INFO_REQUEST: - break; + return {}; case PT_ECHO_REQUEST: case PT_ECHO_REPLY: - self.process_element(m_time); - break; + return self.process_element(m_time); } + return tl::make_unexpected(PacketTypeError(m_type)); } -inline void packet_in::process_element(buffer_t &x) +inline tl::expected packet_in::process_element(buffer_t &x) { x.insert(x.begin(), decrypted_buffer.begin(), decrypted_buffer.end()); decrypted_buffer.resize(0); + return {}; } template -void packet_in::process_element(T &x) +tl::expected packet_in::process_element(T &x) { - if (decrypted_buffer.size() < sizeof(T)) -#if DVL_EXCEPTIONS - throw packet_exception(); -#else - app_fatal("invalid packet"); -#endif + if (decrypted_buffer.size() < sizeof(T)) { + return tl::make_unexpected(PacketError()); + } std::memcpy(&x, decrypted_buffer.data(), sizeof(T)); decrypted_buffer.erase(decrypted_buffer.begin(), decrypted_buffer.begin() + sizeof(T)); + return {}; } template <> @@ -352,15 +357,17 @@ inline void packet_out::create(plr_t s, plr_t d, timestamp_t t) m_time = t; } -inline void packet_out::process_element(buffer_t &x) +inline tl::expected packet_out::process_element(buffer_t &x) { decrypted_buffer.insert(decrypted_buffer.end(), x.begin(), x.end()); + return {}; } template -void packet_out::process_element(T &x) +tl::expected packet_out::process_element(T &x) { decrypted_buffer.insert(decrypted_buffer.end(), begin(x), end(x)); + return {}; } template @@ -384,12 +391,12 @@ public: packet_factory(); packet_factory(std::string pw); - std::unique_ptr make_packet(buffer_t buf); + tl::expected, PacketError> make_packet(buffer_t buf); template - std::unique_ptr make_packet(Args... args); + tl::expected, PacketError> make_packet(Args... args); }; -inline std::unique_ptr packet_factory::make_packet(buffer_t buf) +inline tl::expected, PacketError> packet_factory::make_packet(buffer_t buf) { auto ret = std::make_unique(key); #ifndef PACKET_ENCRYPTION @@ -400,16 +407,20 @@ inline std::unique_ptr packet_factory::make_packet(buffer_t buf) else ret->Decrypt(std::move(buf)); #endif - ret->process_data(); + if (const tl::expected result = ret->process_data(); !result.has_value()) { + return tl::make_unexpected(result.error()); + } return ret; } template -std::unique_ptr packet_factory::make_packet(Args... args) +tl::expected, PacketError> packet_factory::make_packet(Args... args) { auto ret = std::make_unique(key); ret->create(args...); - ret->process_data(); + if (const tl::expected result = ret->process_data(); !result.has_value()) { + return tl::make_unexpected(result.error()); + } #ifdef PACKET_ENCRYPTION if (secure) ret->Encrypt(); diff --git a/Source/dvlnet/tcp_client.cpp b/Source/dvlnet/tcp_client.cpp index 2df717086..3b7dcd862 100644 --- a/Source/dvlnet/tcp_client.cpp +++ b/Source/dvlnet/tcp_client.cpp @@ -1,15 +1,17 @@ #include "dvlnet/tcp_client.h" -#include "options.h" -#include "utils/language.h" -#include #include #include #include #include #include +#include #include +#include + +#include "options.h" +#include "utils/language.h" namespace devilution::net { @@ -43,10 +45,14 @@ int tcp_client::join(std::string addrstr) StartReceive(); { cookie_self = packet_out::GenerateCookie(); - auto pkt = pktfty->make_packet(PLR_BROADCAST, - PLR_MASTER, cookie_self, - game_init_info); - send(*pkt); + tl::expected, PacketError> pkt + = pktfty->make_packet( + PLR_BROADCAST, PLR_MASTER, cookie_self, game_init_info); + if (!pkt.has_value()) { + SDL_SetError("make_packet: %s", pkt.error().what()); + return -1; + } + send(**pkt); for (auto i = 0; i < NoSleep; ++i) { try { poll(); @@ -95,8 +101,15 @@ void tcp_client::HandleReceive(const asio::error_code &error, size_t bytesRead) recv_queue.Write(std::move(recv_buffer)); recv_buffer.resize(frame_queue::max_frame_size); while (recv_queue.PacketReady()) { - auto pkt = pktfty->make_packet(recv_queue.ReadPacket()); - RecvLocal(*pkt); + tl::expected, PacketError> pkt = pktfty->make_packet(recv_queue.ReadPacket()); + if (!pkt.has_value()) { + LogError("make_packet: {}", pkt.error().what()); + return; + } + if (tl::expected result = RecvLocal(**pkt); !result.has_value()) { + LogError("RecvLocal: {}", result.error().what()); + return; + } } StartReceive(); } diff --git a/Source/dvlnet/tcp_server.cpp b/Source/dvlnet/tcp_server.cpp index 6cdb8b1ce..7f30cf8e6 100644 --- a/Source/dvlnet/tcp_server.cpp +++ b/Source/dvlnet/tcp_server.cpp @@ -5,12 +5,13 @@ #include #include +#include + #include "dvlnet/base.h" #include "player.h" #include "utils/log.hpp" -namespace devilution { -namespace net { +namespace devilution::net { tcp_server::tcp_server(asio::io_context &ioc, const std::string &bindaddr, unsigned short port, packet_factory &pktfty) @@ -79,12 +80,21 @@ void tcp_server::HandleReceive(const scc &con, const asio::error_code &ec, try { while (con->recv_queue.PacketReady()) { try { - auto pkt = pktfty.make_packet(con->recv_queue.ReadPacket()); + tl::expected, PacketError> pkt = pktfty.make_packet(con->recv_queue.ReadPacket()); + if (!pkt.has_value()) { + Log("make_packet: {}", pkt.error().what()); + DropConnection(con); + return; + } if (con->plr == PLR_BROADCAST) { - HandleReceiveNewPlayer(con, *pkt); + if (tl::expected result = HandleReceiveNewPlayer(con, **pkt); !result.has_value()) { + Log("HandleReceiveNewPlayer: {}", result.error().what()); + DropConnection(con); + return; + } } else { con->timeout = timeout_active; - HandleReceivePacket(*pkt); + HandleReceivePacket(**pkt); } } catch (dvlnet_exception &e) { Log("Network error: {}", e.what()); @@ -100,32 +110,52 @@ void tcp_server::HandleReceive(const scc &con, const asio::error_code &ec, StartReceive(con); } -void tcp_server::HandleReceiveNewPlayer(const scc &con, packet &pkt) +tl::expected tcp_server::HandleReceiveNewPlayer(const scc &con, packet &inPkt) { auto newplr = NextFree(); if (newplr == PLR_BROADCAST) throw server_exception(); - if (Empty()) - game_init_info = pkt.Info(); + if (Empty()) { + tl::expected pktInfo = inPkt.Info(); + if (!pktInfo.has_value()) + return tl::make_unexpected(pktInfo.error()); + game_init_info = **pktInfo; + } for (plr_t player = 0; player < Players.size(); player++) { if (connections[player]) { - auto playerPacket = pktfty.make_packet(PLR_MASTER, PLR_BROADCAST, newplr); - StartSend(connections[player], *playerPacket); + { + tl::expected, PacketError> pkt + = pktfty.make_packet(PLR_MASTER, PLR_BROADCAST, newplr); + if (!pkt.has_value()) + return tl::make_unexpected(pkt.error()); + StartSend(connections[player], **pkt); + } - auto newplrPacket = pktfty.make_packet(PLR_MASTER, PLR_BROADCAST, player); - StartSend(con, *newplrPacket); + { + tl::expected, PacketError> pkt + = pktfty.make_packet(PLR_MASTER, PLR_BROADCAST, player); + if (!pkt.has_value()) + return tl::make_unexpected(pkt.error()); + StartSend(con, **pkt); + } } } - auto reply = pktfty.make_packet(PLR_MASTER, PLR_BROADCAST, - pkt.Cookie(), newplr, - game_init_info); - StartSend(con, *reply); + tl::expected cookie = inPkt.Cookie(); + if (!cookie.has_value()) + return tl::make_unexpected(cookie.error()); + tl::expected, PacketError> pkt + = pktfty.make_packet(PLR_MASTER, PLR_BROADCAST, + *cookie, newplr, game_init_info); + if (!pkt.has_value()) + return tl::make_unexpected(pkt.error()); + StartSend(con, **pkt); con->plr = newplr; connections[newplr] = con; con->timeout = timeout_active; + return {}; } void tcp_server::HandleReceivePacket(packet &pkt) @@ -213,10 +243,15 @@ void tcp_server::HandleTimeout(const scc &con, const asio::error_code &ec) void tcp_server::DropConnection(const scc &con) { if (con->plr != PLR_BROADCAST) { - auto pkt = pktfty.make_packet(PLR_MASTER, PLR_BROADCAST, - con->plr, LEAVE_DROP); + tl::expected, PacketError> pkt + = pktfty.make_packet(PLR_MASTER, PLR_BROADCAST, + con->plr, LEAVE_DROP); connections[con->plr] = nullptr; - SendPacket(*pkt); + if (pkt.has_value()) { + SendPacket(**pkt); + } else { + LogError("make_packet: {}", pkt.error().what()); + } // TODO: investigate if it is really ok for the server to // drop a client directly. } @@ -232,5 +267,4 @@ void tcp_server::Close() tcp_server::~tcp_server() = default; -} // namespace net -} // namespace devilution +} // namespace devilution::net diff --git a/Source/dvlnet/tcp_server.h b/Source/dvlnet/tcp_server.h index c397081d6..36b69c82a 100644 --- a/Source/dvlnet/tcp_server.h +++ b/Source/dvlnet/tcp_server.h @@ -13,6 +13,8 @@ // the 3DS SDK. #include +#include + #include #include #include @@ -23,8 +25,7 @@ #include "dvlnet/packet.h" #include "multi.h" -namespace devilution { -namespace net { +namespace devilution::net { class server_exception : public dvlnet_exception { public: @@ -75,7 +76,7 @@ private: void HandleAccept(const scc &con, const asio::error_code &ec); void StartReceive(const scc &con); void HandleReceive(const scc &con, const asio::error_code &ec, size_t bytesRead); - void HandleReceiveNewPlayer(const scc &con, packet &pkt); + tl::expected HandleReceiveNewPlayer(const scc &con, packet &pkt); void HandleReceivePacket(packet &pkt); void SendPacket(packet &pkt); void StartSend(const scc &con, packet &pkt); @@ -85,5 +86,4 @@ private: void DropConnection(const scc &con); }; -} // namespace net -} // namespace devilution +} // namespace devilution::net diff --git a/Source/init.cpp b/Source/init.cpp index 654a96991..160a59ba6 100644 --- a/Source/init.cpp +++ b/Source/init.cpp @@ -405,7 +405,7 @@ void MainWndProc(const SDL_Event &event) diablo_focus_unpause(); break; default: - LogVerbose("Unhandled SDL_WINDOWEVENT event: ", event.window.event); + LogVerbose("Unhandled SDL_WINDOWEVENT event: {:d}", event.window.event); break; } #else