From 12d03388626048c4e8136b2770ac3288b7eefc25 Mon Sep 17 00:00:00 2001 From: Gleb Mazovetskiy Date: Sat, 15 Feb 2020 16:28:42 +0000 Subject: [PATCH] dvlnet: Better error on packet type mismatch --- SourceX/dvlnet/packet.cpp | 70 +++++++++++++++++++++++++++++++------- SourceX/dvlnet/packet.h | 16 +++++++++ SourceX/dvlnet/udp_p2p.cpp | 6 +++- 3 files changed, 78 insertions(+), 14 deletions(-) diff --git a/SourceX/dvlnet/packet.cpp b/SourceX/dvlnet/packet.cpp index 8c822879c..5165b3268 100644 --- a/SourceX/dvlnet/packet.cpp +++ b/SourceX/dvlnet/packet.cpp @@ -7,6 +7,57 @@ namespace net { static constexpr bool disable_encryption = false; #endif +const char *packet_type_to_string(uint8_t packet_type) +{ + switch (packet_type) { + case PT_MESSAGE: + return "PT_MESSAGE"; + case PT_TURN: + return "PT_TURN"; + case PT_JOIN_REQUEST: + return "PT_JOIN_REQUEST"; + case PT_JOIN_ACCEPT: + return "PT_JOIN_ACCEPT"; + case PT_CONNECT: + return "PT_CONNECT"; + case PT_DISCONNECT: + return "PT_DISCONNECT"; + default: + return nullptr; + } +} + +wrong_packet_type_exception::wrong_packet_type_exception(std::initializer_list expected_types, std::uint8_t actual) +{ + message_ = "Expected packet of type "; + const auto append_packet_type = [this](std::uint8_t t) { + const char *type_str = packet_type_to_string(t); + if (type_str != nullptr) + message_.append(type_str); + else + message_.append(std::to_string(t)); + }; + + constexpr char kJoinTypes[] = " or "; + for (const packet_type t : expected_types) { + append_packet_type(t); + message_.append(kJoinTypes); + } + message_.resize(message_.size() - (sizeof(kJoinTypes) - 1)); + message_.append(", got"); + append_packet_type(actual); +} + +namespace { + +void CheckPacketTypeOneOf(std::initializer_list expected_types, std::uint8_t actual_type) { + for (std::uint8_t packet_type : expected_types) + if (actual_type == packet_type) return; + throw wrong_packet_type_exception(std::move(expected_types), actual_type); +} + +} // namespace + const buffer_t &packet::data() { if (!have_decrypted || !have_encrypted) @@ -39,8 +90,7 @@ const buffer_t &packet::message() { if (!have_decrypted) ABORT(); - if (m_type != PT_MESSAGE) - throw packet_exception(); + CheckPacketTypeOneOf({PT_MESSAGE}, m_type); return m_message; } @@ -48,8 +98,7 @@ turn_t packet::turn() { if (!have_decrypted) ABORT(); - if (m_type != PT_TURN) - throw packet_exception(); + CheckPacketTypeOneOf({PT_TURN}, m_type); return m_turn; } @@ -57,8 +106,7 @@ cookie_t packet::cookie() { if (!have_decrypted) ABORT(); - if (m_type != PT_JOIN_REQUEST && m_type != PT_JOIN_ACCEPT) - throw packet_exception(); + CheckPacketTypeOneOf({PT_JOIN_REQUEST, PT_JOIN_ACCEPT}, m_type); return m_cookie; } @@ -66,9 +114,7 @@ plr_t packet::newplr() { if (!have_decrypted) ABORT(); - if (m_type != PT_JOIN_ACCEPT && m_type != PT_CONNECT - && m_type != PT_DISCONNECT) - throw packet_exception(); + CheckPacketTypeOneOf({PT_JOIN_ACCEPT, PT_CONNECT, PT_DISCONNECT}, m_type); return m_newplr; } @@ -76,8 +122,7 @@ const buffer_t &packet::info() { if (!have_decrypted) ABORT(); - if (m_type != PT_JOIN_REQUEST && m_type != PT_JOIN_ACCEPT) - throw packet_exception(); + CheckPacketTypeOneOf({PT_JOIN_REQUEST, PT_JOIN_ACCEPT}, m_type); return m_info; } @@ -85,8 +130,7 @@ leaveinfo_t packet::leaveinfo() { if (!have_decrypted) ABORT(); - if (m_type != PT_DISCONNECT) - throw packet_exception(); + CheckPacketTypeOneOf({PT_DISCONNECT}, m_type); return m_leaveinfo; } diff --git a/SourceX/dvlnet/packet.h b/SourceX/dvlnet/packet.h index d5584e605..4f1e1e804 100644 --- a/SourceX/dvlnet/packet.h +++ b/SourceX/dvlnet/packet.h @@ -23,6 +23,9 @@ enum packet_type : uint8_t { PT_DISCONNECT = 0x14, }; +// Returns nullptr for an invalid packet type. +const char *packet_type_to_string(uint8_t packet_type); + typedef uint8_t plr_t; typedef uint32_t cookie_t; typedef int turn_t; // change int to something else in devilution code later @@ -42,6 +45,19 @@ public: } }; +class wrong_packet_type_exception : public packet_exception { +public: + wrong_packet_type_exception(std::initializer_list expected_types, std::uint8_t actual); + + const char *what() const throw() override + { + return message_.c_str(); + } + +private: + std::string message_; +}; + class packet { protected: packet_type m_type; diff --git a/SourceX/dvlnet/udp_p2p.cpp b/SourceX/dvlnet/udp_p2p.cpp index 0a065704e..4a9d15fec 100644 --- a/SourceX/dvlnet/udp_p2p.cpp +++ b/SourceX/dvlnet/udp_p2p.cpp @@ -2,6 +2,10 @@ #include +#ifdef USE_SDL1 +#include "sdl2_to_1_2_backports.h" +#endif + namespace dvl { namespace net { @@ -92,7 +96,7 @@ void udp_p2p::recv() auto pkt = pktfty->make_packet(pkt_buf); recv_decrypted(*pkt, sender); } catch (packet_exception &e) { - SDL_Log("Incorrect package size"); + SDL_Log(e.what()); // drop packet } }