diff --git a/Source/dvlnet/frame_queue.cpp b/Source/dvlnet/frame_queue.cpp index a5ca70c2b..7966dd214 100644 --- a/Source/dvlnet/frame_queue.cpp +++ b/Source/dvlnet/frame_queue.cpp @@ -68,19 +68,26 @@ tl::expected frame_queue::PacketReady() if (nextsize == 0) return tl::make_unexpected(FrameQueueError()); } - return Size() >= nextsize; + return Size() >= (nextsize & frame_size_mask); +} + +uint16_t frame_queue::ReadPacketFlags() +{ + static_assert(sizeof(nextsize) == 4, "framesize_t is not 4 bytes"); + return static_cast(nextsize >> 16); } tl::expected frame_queue::ReadPacket() { - if (nextsize == 0 || Size() < nextsize) + const framesize_t packetSize = nextsize & frame_size_mask; + if (nextsize == 0 || Size() < packetSize) return tl::make_unexpected(FrameQueueError()); - tl::expected ret = Read(nextsize); + tl::expected ret = Read(packetSize); nextsize = 0; return ret; } -tl::expected frame_queue::MakeFrame(buffer_t packetbuf) +tl::expected frame_queue::MakeFrame(buffer_t packetbuf, uint16_t flags) { buffer_t ret; const framesize_t size = static_cast(packetbuf.size()); @@ -88,7 +95,7 @@ tl::expected frame_queue::MakeFrame(buffer_t packetbuf) return tl::make_unexpected("Buffer exceeds maximum frame size"); static_assert(sizeof(size) == 4, "framesize_t is not 4 bytes"); unsigned char sizeBuf[4]; - WriteLE32(sizeBuf, size); + WriteLE32(sizeBuf, size | (static_cast(flags) << 16)); ret.insert(ret.end(), sizeBuf, sizeBuf + 4); ret.insert(ret.end(), packetbuf.begin(), packetbuf.end()); return ret; diff --git a/Source/dvlnet/frame_queue.h b/Source/dvlnet/frame_queue.h index eadf82b2d..89c6e9c3c 100644 --- a/Source/dvlnet/frame_queue.h +++ b/Source/dvlnet/frame_queue.h @@ -17,6 +17,7 @@ typedef uint32_t framesize_t; class frame_queue { public: + constexpr static framesize_t frame_size_mask = 0xFFFF; constexpr static framesize_t max_frame_size = 0xFFFF; private: @@ -29,10 +30,11 @@ private: public: tl::expected PacketReady(); + uint16_t ReadPacketFlags(); tl::expected ReadPacket(); void Write(buffer_t buf); - static tl::expected MakeFrame(buffer_t packetbuf); + static tl::expected MakeFrame(buffer_t packetbuf, uint16_t flags = 0); }; } // namespace net diff --git a/Source/dvlnet/packet.cpp b/Source/dvlnet/packet.cpp index 947952ea7..10b3261d8 100644 --- a/Source/dvlnet/packet.cpp +++ b/Source/dvlnet/packet.cpp @@ -237,8 +237,11 @@ tl::expected packet_in::Decrypt(buffer_t buf) encrypted_buffer.size() - crypto_secretbox_NONCEBYTES, encrypted_buffer.data(), key.data()); - if (status != 0) - return tl::make_unexpected(PacketError()); + if (status != 0) { + auto code = PacketError::ErrorCode::DecryptionFailed; + std::string_view message = "Failed to decrypt packet data"; + return tl::make_unexpected(PacketError(code, message)); + } have_decrypted = true; return {}; @@ -246,12 +249,12 @@ tl::expected packet_in::Decrypt(buffer_t buf) #endif #ifdef PACKET_ENCRYPTION -void packet_out::Encrypt() +tl::expected packet_out::Encrypt() { assert(have_decrypted); if (have_encrypted) - return; + return {}; auto lenCleartext = decrypted_buffer.size(); encrypted_buffer.insert(encrypted_buffer.begin(), @@ -265,10 +268,14 @@ void packet_out::Encrypt() lenCleartext, encrypted_buffer.data(), key.data()); - if (status != 0) - ABORT(); + if (status != 0) { + auto code = PacketError::ErrorCode::EncryptionFailed; + std::string_view message = "Failed to encrypt packet data"; + return tl::make_unexpected(PacketError(code, message)); + } have_encrypted = true; + return {}; } #endif diff --git a/Source/dvlnet/packet.h b/Source/dvlnet/packet.h index 7945ad530..2a3bbd7b0 100644 --- a/Source/dvlnet/packet.h +++ b/Source/dvlnet/packet.h @@ -64,33 +64,63 @@ static constexpr plr_t PLR_BROADCAST = 0xFF; class PacketError { public: + enum class ErrorCode : uint8_t { + None, + EncryptionFailed, + DecryptionFailed + }; + PacketError() - : message_(std::string_view("Incorrect package size")) + : message_(std::string_view("Incorrect packet size")) + , code_(ErrorCode::None) { } PacketError(const char message[]) : message_(std::string_view(message)) + , code_(ErrorCode::None) { } PacketError(std::string &&message) : message_(std::move(message)) + , code_(ErrorCode::None) { } PacketError(std::string_view message) : message_(message) + , code_(ErrorCode::None) + { + } + + PacketError(ErrorCode code, const char message[]) + : message_(std::string_view(message)) + , code_(code) + { + } + + PacketError(ErrorCode code, std::string &&message) + : message_(std::move(message)) + , code_(code) + { + } + + PacketError(ErrorCode code, std::string_view message) + : message_(message) + , code_(code) { } PacketError(const PacketError &error) : message_(std::string(error.message_)) + , code_(error.code_) { } PacketError(PacketError &&error) : message_(std::move(error.message_)) + , code_(error.code_) { } @@ -99,8 +129,14 @@ public: return message_; } + ErrorCode code() const + { + return code_; + } + private: StringOrView message_; + ErrorCode code_; }; inline PacketError IoHandlerError(std::string message) @@ -176,7 +212,7 @@ public: template tl::expected process_element(const T &x); static cookie_t GenerateCookie(); - void Encrypt(); + tl::expected Encrypt(); }; template @@ -442,13 +478,15 @@ inline tl::expected, PacketError> packet_factory::make_p { auto ret = std::make_unique(key); #ifndef PACKET_ENCRYPTION - ret->Create(std::move(buf)); + tl::expected isCreated = ret->Create(std::move(buf)); #else - if (!secure) - ret->Create(std::move(buf)); - else - ret->Decrypt(std::move(buf)); + tl::expected isCreated = !secure + ? ret->Create(std::move(buf)) + : ret->Decrypt(std::move(buf)); #endif + if (!isCreated.has_value()) { + return tl::make_unexpected(isCreated.error()); + } if (const tl::expected result = ret->process_data(); !result.has_value()) { return tl::make_unexpected(result.error()); } @@ -464,8 +502,12 @@ tl::expected, PacketError> packet_factory::make_packet(A return tl::make_unexpected(result.error()); } #ifdef PACKET_ENCRYPTION - if (secure) - ret->Encrypt(); + if (secure) { + tl::expected isEncrypted = ret->Encrypt(); + if (!isEncrypted.has_value()) { + return tl::make_unexpected(isEncrypted.error()); + } + } #endif return ret; } diff --git a/Source/dvlnet/tcp_client.cpp b/Source/dvlnet/tcp_client.cpp index 01857a812..e583538ca 100644 --- a/Source/dvlnet/tcp_client.cpp +++ b/Source/dvlnet/tcp_client.cpp @@ -168,6 +168,10 @@ void tcp_client::HandleReceive(const asio::error_code &error, size_t bytesRead) } if (!*ready) break; + if (recv_queue.ReadPacketFlags() == TcpErrorCodeFlags) { + HandleTcpErrorCode(); + return; + } tl::expected result = recv_queue.ReadPacket() .and_then([this](buffer_t &&pktData) { return pktfty->make_packet(pktData); }) @@ -193,6 +197,27 @@ void tcp_client::HandleSend(const asio::error_code &error, size_t bytesSent) RaiseIoHandlerError(error.message()); } +void tcp_client::HandleTcpErrorCode() +{ + tl::expected packet = recv_queue.ReadPacket(); + if (!packet.has_value()) { + RaiseIoHandlerError(packet.error()); + return; + } + + buffer_t pktData = *packet; + if (pktData.size() != 1) { + RaiseIoHandlerError(PacketError()); + return; + } + + PacketError::ErrorCode code = static_cast(pktData[0]); + if (code == PacketError::ErrorCode::DecryptionFailed) + RaiseIoHandlerError(_("Server failed to decrypt your packet. Check if you typed the password correctly.")); + else + RaiseIoHandlerError(fmt::format("Unknown error code received from server: {:#04x}", pktData[0])); +} + tl::expected tcp_client::send(packet &pkt) { tl::expected frame = frame_queue::MakeFrame(pkt.Data()); diff --git a/Source/dvlnet/tcp_client.h b/Source/dvlnet/tcp_client.h index 2775a3a61..4196a7753 100644 --- a/Source/dvlnet/tcp_client.h +++ b/Source/dvlnet/tcp_client.h @@ -57,6 +57,7 @@ private: void HandleReceive(const asio::error_code &error, size_t bytesRead); void StartReceive(); void HandleSend(const asio::error_code &error, size_t bytesSent); + void HandleTcpErrorCode(); void RaiseIoHandlerError(const PacketError &error); }; diff --git a/Source/dvlnet/tcp_server.cpp b/Source/dvlnet/tcp_server.cpp index 145b9d7be..72d627736 100644 --- a/Source/dvlnet/tcp_server.cpp +++ b/Source/dvlnet/tcp_server.cpp @@ -95,6 +95,8 @@ void tcp_server::HandleReceive(const scc &con, const asio::error_code &ec, tl::expected, PacketError> pkt = pktfty.make_packet(*pktData); if (!pkt.has_value()) { Log("make_packet: {}", pkt.error().what()); + if (pkt.error().code() == PacketError::ErrorCode::DecryptionFailed) + StartSend(con, pkt.error().code()); DropConnection(con); return; } @@ -181,7 +183,19 @@ tl::expected tcp_server::SendPacket(packet &pkt) tl::expected tcp_server::StartSend(const scc &con, packet &pkt) { - tl::expected frame = frame_queue::MakeFrame(pkt.Data()); + return StartSend(con, pkt.Data(), 0); +} + +tl::expected tcp_server::StartSend(const scc &con, PacketError::ErrorCode errorCode) +{ + buffer_t pktData; + pktData.push_back(static_cast(errorCode)); + return StartSend(con, pktData, TcpErrorCodeFlags); +} + +tl::expected tcp_server::StartSend(const scc &con, buffer_t pktData, uint16_t flags) +{ + tl::expected frame = frame_queue::MakeFrame(pktData, flags); if (!frame.has_value()) return tl::make_unexpected(frame.error()); std::unique_ptr framePtr = std::make_unique(*frame); diff --git a/Source/dvlnet/tcp_server.h b/Source/dvlnet/tcp_server.h index 4827cbadb..5d33ddc60 100644 --- a/Source/dvlnet/tcp_server.h +++ b/Source/dvlnet/tcp_server.h @@ -28,6 +28,8 @@ namespace devilution::net { +constexpr uint16_t TcpErrorCodeFlags = 0x8000; + inline PacketError ServerError() { return PacketError("Invalid player ID"); @@ -82,6 +84,8 @@ private: tl::expected HandleReceivePacket(packet &pkt); tl::expected SendPacket(packet &pkt); tl::expected StartSend(const scc &con, packet &pkt); + tl::expected StartSend(const scc &con, PacketError::ErrorCode errorCode); + tl::expected StartSend(const scc &con, buffer_t pktData, uint16_t flags); void HandleSend(const scc &con, const asio::error_code &ec, size_t bytesSent); void StartTimeout(const scc &con); void HandleTimeout(const scc &con, const asio::error_code &ec);