#pragma once #include #include #include #include #include #include #include #ifdef PACKET_ENCRYPTION #include #endif #include "appfat.h" #include "dvlnet/abstract_net.h" #include "utils/attributes.h" #include "utils/endian_read.hpp" #include "utils/endian_write.hpp" #include "utils/str_cat.hpp" #include "utils/stubs.h" namespace devilution { namespace net { enum packet_type : uint8_t { // clang-format off PT_MESSAGE = 0x01, PT_TURN = 0x02, PT_JOIN_REQUEST = 0x11, PT_JOIN_ACCEPT = 0x12, PT_CONNECT = 0x13, PT_DISCONNECT = 0x14, PT_INFO_REQUEST = 0x21, PT_INFO_REPLY = 0x22, PT_ECHO_REQUEST = 0x31, PT_ECHO_REPLY = 0x32, // clang-format on }; // Returns NULL for an invalid packet type. const char *packet_type_to_string(uint8_t packetType); typedef uint8_t plr_t; typedef uint8_t seq_t; typedef uint32_t cookie_t; typedef uint32_t timestamp_t; typedef uint32_t leaveinfo_t; #ifdef PACKET_ENCRYPTION typedef std::array key_t; #else // Stub out the key_t definition as we're not doing any encryption. using key_t = uint8_t; #endif struct turn_t { seq_t SequenceNumber; int32_t Value; }; static constexpr plr_t PLR_MASTER = 0xFE; static constexpr plr_t PLR_BROADCAST = 0xFF; class PacketError { public: PacketError() : message_(std::string_view("Incorrect package size")) { } PacketError(const char message[]) : message_(std::string_view(message)) { } PacketError(std::string &&message) : message_(std::move(message)) { } PacketError(std::string_view message) : message_(message) { } PacketError(const PacketError &error) : message_(std::string(error.message_)) { } PacketError(PacketError &&error) : message_(std::move(error.message_)) { } std::string_view what() const { return message_; } private: StringOrView message_; }; inline PacketError IoHandlerError(std::string message) { return PacketError(std::move(message)); } PacketError PacketTypeError(std::uint8_t unknownPacketType); PacketError PacketTypeError(std::initializer_list expectedTypes, std::uint8_t actual); class packet { protected: packet_type m_type; plr_t m_src; plr_t m_dest; buffer_t m_message; turn_t m_turn; cookie_t m_cookie; plr_t m_newplr; timestamp_t m_time; buffer_t m_info; leaveinfo_t m_leaveinfo; const key_t &key; bool have_encrypted = false; bool have_decrypted = false; buffer_t encrypted_buffer; buffer_t decrypted_buffer; public: packet(const key_t &k) : key(k) {}; const buffer_t &Data(); packet_type Type(); plr_t Source() const; plr_t Destination() const; tl::expected Message(); tl::expected Turn(); tl::expected Cookie(); tl::expected NewPlayer(); tl::expected Time(); tl::expected Info(); tl::expected LeaveInfo(); }; template class packet_proc : public packet { public: using packet::packet; tl::expected process_data(); }; class packet_in : public packet_proc { public: using packet_proc::packet_proc; tl::expected Create(buffer_t buf); tl::expected process_element(buffer_t &x); template tl::expected process_element(T &x); tl::expected Decrypt(buffer_t buf); }; class packet_out : public packet_proc { public: using packet_proc::packet_proc; template void create(Args... args); tl::expected process_element(buffer_t &x); template tl::expected process_element(const T &x); static cookie_t GenerateCookie(); void Encrypt(); }; template tl::expected packet_proc

::process_data() { P &self = static_cast

(*this); { tl::expected result = self.process_element(m_type) .and_then([&]() { return self.process_element(m_src); }) .and_then([&]() { return self.process_element(m_dest); }); if (!result.has_value()) return result; } switch (m_type) { case PT_MESSAGE: return self.process_element(m_message); case PT_TURN: return self.process_element(m_turn.SequenceNumber) .and_then([&]() { return self.process_element(m_turn.Value); }); case PT_JOIN_REQUEST: return self.process_element(m_cookie) .and_then([&]() { return self.process_element(m_info); }); case PT_JOIN_ACCEPT: return self.process_element(m_cookie) .and_then([&]() { return self.process_element(m_newplr); }) .and_then([&]() { return self.process_element(m_info); }); case PT_CONNECT: return self.process_element(m_newplr) .and_then([&]() { return self.process_element(m_info); }); case PT_DISCONNECT: return self.process_element(m_newplr) .and_then([&]() { return self.process_element(m_leaveinfo); }); case PT_INFO_REPLY: return self.process_element(m_info); case PT_INFO_REQUEST: return {}; case PT_ECHO_REQUEST: case PT_ECHO_REPLY: return self.process_element(m_time); } return tl::make_unexpected(PacketTypeError(m_type)); } inline tl::expected packet_in::process_element(buffer_t &x) { x.insert(x.begin(), decrypted_buffer.begin(), decrypted_buffer.end()); decrypted_buffer.resize(0); return {}; } template tl::expected packet_in::process_element(T &x) { static_assert(std::is_integral::value || std::is_enum::value, "Unsupported T"); static_assert(sizeof(T) == 4 || sizeof(T) == 2 || sizeof(T) == 1, "Unsupported T"); if (decrypted_buffer.size() < sizeof(T)) { return tl::make_unexpected(PacketError()); } if (sizeof(T) == 4) { x = static_cast(LoadLE32(decrypted_buffer.data())); } else if (sizeof(T) == 2) { x = static_cast(LoadLE16(decrypted_buffer.data())); } else if (sizeof(T) == 1) { std::memcpy(&x, decrypted_buffer.data(), sizeof(T)); } decrypted_buffer.erase(decrypted_buffer.begin(), decrypted_buffer.begin() + sizeof(T)); return {}; } template <> inline void packet_out::create(plr_t s, plr_t d) { if (have_encrypted || have_decrypted) ABORT(); have_decrypted = true; m_type = PT_INFO_REQUEST; m_src = s; m_dest = d; } template <> inline void packet_out::create(plr_t s, plr_t d, buffer_t i) { if (have_encrypted || have_decrypted) ABORT(); have_decrypted = true; m_type = PT_INFO_REPLY; m_src = s; m_dest = d; m_info = std::move(i); } template <> inline void packet_out::create(plr_t s, plr_t d, buffer_t m) { if (have_encrypted || have_decrypted) ABORT(); have_decrypted = true; m_type = PT_MESSAGE; m_src = s; m_dest = d; m_message = std::move(m); } template <> inline void packet_out::create(plr_t s, plr_t d, turn_t u) { if (have_encrypted || have_decrypted) ABORT(); have_decrypted = true; m_type = PT_TURN; m_src = s; m_dest = d; m_turn = u; } template <> inline void packet_out::create(plr_t s, plr_t d, cookie_t c, buffer_t i) { if (have_encrypted || have_decrypted) ABORT(); have_decrypted = true; m_type = PT_JOIN_REQUEST; m_src = s; m_dest = d; m_cookie = c; m_info = i; } template <> inline void packet_out::create(plr_t s, plr_t d, cookie_t c, plr_t n, buffer_t i) { if (have_encrypted || have_decrypted) ABORT(); have_decrypted = true; m_type = PT_JOIN_ACCEPT; m_src = s; m_dest = d; m_cookie = c; m_newplr = n; m_info = i; } template <> inline void packet_out::create(plr_t s, plr_t d, plr_t n, buffer_t i) { if (have_encrypted || have_decrypted) ABORT(); have_decrypted = true; m_type = PT_CONNECT; m_src = s; m_dest = d; m_newplr = n; m_info = i; } template <> inline void packet_out::create(plr_t s, plr_t d, plr_t n) { if (have_encrypted || have_decrypted) ABORT(); have_decrypted = true; m_type = PT_CONNECT; m_src = s; m_dest = d; m_newplr = n; } template <> inline void packet_out::create(plr_t s, plr_t d, plr_t n, leaveinfo_t l) { if (have_encrypted || have_decrypted) ABORT(); have_decrypted = true; m_type = PT_DISCONNECT; m_src = s; m_dest = d; m_newplr = n; m_leaveinfo = l; } template <> inline void packet_out::create(plr_t s, plr_t d, timestamp_t t) { if (have_encrypted || have_decrypted) ABORT(); have_decrypted = true; m_type = PT_ECHO_REQUEST; m_src = s; m_dest = d; m_time = t; } template <> inline void packet_out::create(plr_t s, plr_t d, timestamp_t t) { if (have_encrypted || have_decrypted) ABORT(); have_decrypted = true; m_type = PT_ECHO_REPLY; m_src = s; m_dest = d; m_time = t; } inline tl::expected packet_out::process_element(buffer_t &x) { decrypted_buffer.insert(decrypted_buffer.end(), x.begin(), x.end()); return {}; } template tl::expected packet_out::process_element(const T &x) { static_assert(std::is_integral::value || std::is_enum::value, "Unsupported T"); static_assert(sizeof(T) == 4 || sizeof(T) == 2 || sizeof(T) == 1, "Unsupported T"); if (sizeof(T) == 4) { unsigned char buf[4]; WriteLE32(buf, x); decrypted_buffer.insert(decrypted_buffer.end(), buf, buf + 4); } else if (sizeof(T) == 2) { unsigned char buf[2]; WriteLE16(buf, x); decrypted_buffer.insert(decrypted_buffer.end(), buf, buf + 2); } else if (sizeof(T) == 1) { decrypted_buffer.push_back(static_cast(x)); } return {}; } class packet_factory { key_t key = {}; bool secure; public: static constexpr unsigned short max_packet_size = 0xFFFF; packet_factory(); packet_factory(std::string pw); tl::expected, PacketError> make_packet(buffer_t buf); template tl::expected, PacketError> make_packet(Args... args); }; inline tl::expected, PacketError> packet_factory::make_packet(buffer_t buf) { auto ret = std::make_unique(key); #ifndef PACKET_ENCRYPTION ret->Create(std::move(buf)); #else if (!secure) ret->Create(std::move(buf)); else ret->Decrypt(std::move(buf)); #endif if (const tl::expected result = ret->process_data(); !result.has_value()) { return tl::make_unexpected(result.error()); } return ret; } template tl::expected, PacketError> packet_factory::make_packet(Args... args) { auto ret = std::make_unique(key); ret->create(args...); if (const tl::expected result = ret->process_data(); !result.has_value()) { return tl::make_unexpected(result.error()); } #ifdef PACKET_ENCRYPTION if (secure) ret->Encrypt(); #endif return ret; } } // namespace net } // namespace devilution