Browse Source

Show an appropriate error message when wrong password is entered (TCP)

pull/8291/head
staphen 4 months ago committed by Gleb Mazovetskiy
parent
commit
cf942fa54a
  1. 17
      Source/dvlnet/frame_queue.cpp
  2. 4
      Source/dvlnet/frame_queue.h
  3. 19
      Source/dvlnet/packet.cpp
  4. 60
      Source/dvlnet/packet.h
  5. 25
      Source/dvlnet/tcp_client.cpp
  6. 1
      Source/dvlnet/tcp_client.h
  7. 16
      Source/dvlnet/tcp_server.cpp
  8. 4
      Source/dvlnet/tcp_server.h

17
Source/dvlnet/frame_queue.cpp

@ -68,19 +68,26 @@ tl::expected<bool, PacketError> frame_queue::PacketReady()
if (nextsize == 0)
return tl::make_unexpected(FrameQueueError());
}
return Size() >= nextsize;
return Size() >= (nextsize & frame_size_mask);
}
uint16_t frame_queue::ReadPacketFlags()
{
static_assert(sizeof(nextsize) == 4, "framesize_t is not 4 bytes");
return static_cast<uint16_t>(nextsize >> 16);
}
tl::expected<buffer_t, PacketError> frame_queue::ReadPacket()
{
if (nextsize == 0 || Size() < nextsize)
const framesize_t packetSize = nextsize & frame_size_mask;
if (nextsize == 0 || Size() < packetSize)
return tl::make_unexpected(FrameQueueError());
tl::expected<buffer_t, PacketError> ret = Read(nextsize);
tl::expected<buffer_t, PacketError> ret = Read(packetSize);
nextsize = 0;
return ret;
}
tl::expected<buffer_t, PacketError> frame_queue::MakeFrame(buffer_t packetbuf)
tl::expected<buffer_t, PacketError> frame_queue::MakeFrame(buffer_t packetbuf, uint16_t flags)
{
buffer_t ret;
const framesize_t size = static_cast<framesize_t>(packetbuf.size());
@ -88,7 +95,7 @@ tl::expected<buffer_t, PacketError> frame_queue::MakeFrame(buffer_t packetbuf)
return tl::make_unexpected("Buffer exceeds maximum frame size");
static_assert(sizeof(size) == 4, "framesize_t is not 4 bytes");
unsigned char sizeBuf[4];
WriteLE32(sizeBuf, size);
WriteLE32(sizeBuf, size | (static_cast<framesize_t>(flags) << 16));
ret.insert(ret.end(), sizeBuf, sizeBuf + 4);
ret.insert(ret.end(), packetbuf.begin(), packetbuf.end());
return ret;

4
Source/dvlnet/frame_queue.h

@ -17,6 +17,7 @@ typedef uint32_t framesize_t;
class frame_queue {
public:
constexpr static framesize_t frame_size_mask = 0xFFFF;
constexpr static framesize_t max_frame_size = 0xFFFF;
private:
@ -29,10 +30,11 @@ private:
public:
tl::expected<bool, PacketError> PacketReady();
uint16_t ReadPacketFlags();
tl::expected<buffer_t, PacketError> ReadPacket();
void Write(buffer_t buf);
static tl::expected<buffer_t, PacketError> MakeFrame(buffer_t packetbuf);
static tl::expected<buffer_t, PacketError> MakeFrame(buffer_t packetbuf, uint16_t flags = 0);
};
} // namespace net

19
Source/dvlnet/packet.cpp

@ -237,8 +237,11 @@ tl::expected<void, PacketError> packet_in::Decrypt(buffer_t buf)
encrypted_buffer.size() - crypto_secretbox_NONCEBYTES,
encrypted_buffer.data(),
key.data());
if (status != 0)
return tl::make_unexpected(PacketError());
if (status != 0) {
auto code = PacketError::ErrorCode::DecryptionFailed;
std::string_view message = "Failed to decrypt packet data";
return tl::make_unexpected(PacketError(code, message));
}
have_decrypted = true;
return {};
@ -246,12 +249,12 @@ tl::expected<void, PacketError> packet_in::Decrypt(buffer_t buf)
#endif
#ifdef PACKET_ENCRYPTION
void packet_out::Encrypt()
tl::expected<void, PacketError> packet_out::Encrypt()
{
assert(have_decrypted);
if (have_encrypted)
return;
return {};
auto lenCleartext = decrypted_buffer.size();
encrypted_buffer.insert(encrypted_buffer.begin(),
@ -265,10 +268,14 @@ void packet_out::Encrypt()
lenCleartext,
encrypted_buffer.data(),
key.data());
if (status != 0)
ABORT();
if (status != 0) {
auto code = PacketError::ErrorCode::EncryptionFailed;
std::string_view message = "Failed to encrypt packet data";
return tl::make_unexpected(PacketError(code, message));
}
have_encrypted = true;
return {};
}
#endif

60
Source/dvlnet/packet.h

@ -64,33 +64,63 @@ static constexpr plr_t PLR_BROADCAST = 0xFF;
class PacketError {
public:
enum class ErrorCode : uint8_t {
None,
EncryptionFailed,
DecryptionFailed
};
PacketError()
: message_(std::string_view("Incorrect package size"))
: message_(std::string_view("Incorrect packet size"))
, code_(ErrorCode::None)
{
}
PacketError(const char message[])
: message_(std::string_view(message))
, code_(ErrorCode::None)
{
}
PacketError(std::string &&message)
: message_(std::move(message))
, code_(ErrorCode::None)
{
}
PacketError(std::string_view message)
: message_(message)
, code_(ErrorCode::None)
{
}
PacketError(ErrorCode code, const char message[])
: message_(std::string_view(message))
, code_(code)
{
}
PacketError(ErrorCode code, std::string &&message)
: message_(std::move(message))
, code_(code)
{
}
PacketError(ErrorCode code, std::string_view message)
: message_(message)
, code_(code)
{
}
PacketError(const PacketError &error)
: message_(std::string(error.message_))
, code_(error.code_)
{
}
PacketError(PacketError &&error)
: message_(std::move(error.message_))
, code_(error.code_)
{
}
@ -99,8 +129,14 @@ public:
return message_;
}
ErrorCode code() const
{
return code_;
}
private:
StringOrView message_;
ErrorCode code_;
};
inline PacketError IoHandlerError(std::string message)
@ -176,7 +212,7 @@ public:
template <class T>
tl::expected<void, PacketError> process_element(const T &x);
static cookie_t GenerateCookie();
void Encrypt();
tl::expected<void, PacketError> Encrypt();
};
template <class P>
@ -442,13 +478,15 @@ inline tl::expected<std::unique_ptr<packet>, PacketError> packet_factory::make_p
{
auto ret = std::make_unique<packet_in>(key);
#ifndef PACKET_ENCRYPTION
ret->Create(std::move(buf));
tl::expected<void, PacketError> isCreated = ret->Create(std::move(buf));
#else
if (!secure)
ret->Create(std::move(buf));
else
ret->Decrypt(std::move(buf));
tl::expected<void, PacketError> isCreated = !secure
? ret->Create(std::move(buf))
: ret->Decrypt(std::move(buf));
#endif
if (!isCreated.has_value()) {
return tl::make_unexpected(isCreated.error());
}
if (const tl::expected<void, PacketError> result = ret->process_data(); !result.has_value()) {
return tl::make_unexpected(result.error());
}
@ -464,8 +502,12 @@ tl::expected<std::unique_ptr<packet>, PacketError> packet_factory::make_packet(A
return tl::make_unexpected(result.error());
}
#ifdef PACKET_ENCRYPTION
if (secure)
ret->Encrypt();
if (secure) {
tl::expected<void, PacketError> isEncrypted = ret->Encrypt();
if (!isEncrypted.has_value()) {
return tl::make_unexpected(isEncrypted.error());
}
}
#endif
return ret;
}

25
Source/dvlnet/tcp_client.cpp

@ -168,6 +168,10 @@ void tcp_client::HandleReceive(const asio::error_code &error, size_t bytesRead)
}
if (!*ready)
break;
if (recv_queue.ReadPacketFlags() == TcpErrorCodeFlags) {
HandleTcpErrorCode();
return;
}
tl::expected<void, PacketError> result
= recv_queue.ReadPacket()
.and_then([this](buffer_t &&pktData) { return pktfty->make_packet(pktData); })
@ -193,6 +197,27 @@ void tcp_client::HandleSend(const asio::error_code &error, size_t bytesSent)
RaiseIoHandlerError(error.message());
}
void tcp_client::HandleTcpErrorCode()
{
tl::expected<buffer_t, PacketError> packet = recv_queue.ReadPacket();
if (!packet.has_value()) {
RaiseIoHandlerError(packet.error());
return;
}
buffer_t pktData = *packet;
if (pktData.size() != 1) {
RaiseIoHandlerError(PacketError());
return;
}
PacketError::ErrorCode code = static_cast<PacketError::ErrorCode>(pktData[0]);
if (code == PacketError::ErrorCode::DecryptionFailed)
RaiseIoHandlerError(_("Server failed to decrypt your packet. Check if you typed the password correctly."));
else
RaiseIoHandlerError(fmt::format("Unknown error code received from server: {:#04x}", pktData[0]));
}
tl::expected<void, PacketError> tcp_client::send(packet &pkt)
{
tl::expected<buffer_t, PacketError> frame = frame_queue::MakeFrame(pkt.Data());

1
Source/dvlnet/tcp_client.h

@ -57,6 +57,7 @@ private:
void HandleReceive(const asio::error_code &error, size_t bytesRead);
void StartReceive();
void HandleSend(const asio::error_code &error, size_t bytesSent);
void HandleTcpErrorCode();
void RaiseIoHandlerError(const PacketError &error);
};

16
Source/dvlnet/tcp_server.cpp

@ -95,6 +95,8 @@ void tcp_server::HandleReceive(const scc &con, const asio::error_code &ec,
tl::expected<std::unique_ptr<packet>, PacketError> pkt = pktfty.make_packet(*pktData);
if (!pkt.has_value()) {
Log("make_packet: {}", pkt.error().what());
if (pkt.error().code() == PacketError::ErrorCode::DecryptionFailed)
StartSend(con, pkt.error().code());
DropConnection(con);
return;
}
@ -181,7 +183,19 @@ tl::expected<void, PacketError> tcp_server::SendPacket(packet &pkt)
tl::expected<void, PacketError> tcp_server::StartSend(const scc &con, packet &pkt)
{
tl::expected<buffer_t, PacketError> frame = frame_queue::MakeFrame(pkt.Data());
return StartSend(con, pkt.Data(), 0);
}
tl::expected<void, PacketError> tcp_server::StartSend(const scc &con, PacketError::ErrorCode errorCode)
{
buffer_t pktData;
pktData.push_back(static_cast<unsigned char>(errorCode));
return StartSend(con, pktData, TcpErrorCodeFlags);
}
tl::expected<void, PacketError> tcp_server::StartSend(const scc &con, buffer_t pktData, uint16_t flags)
{
tl::expected<buffer_t, PacketError> frame = frame_queue::MakeFrame(pktData, flags);
if (!frame.has_value())
return tl::make_unexpected(frame.error());
std::unique_ptr<buffer_t> framePtr = std::make_unique<buffer_t>(*frame);

4
Source/dvlnet/tcp_server.h

@ -28,6 +28,8 @@
namespace devilution::net {
constexpr uint16_t TcpErrorCodeFlags = 0x8000;
inline PacketError ServerError()
{
return PacketError("Invalid player ID");
@ -82,6 +84,8 @@ private:
tl::expected<void, PacketError> HandleReceivePacket(packet &pkt);
tl::expected<void, PacketError> SendPacket(packet &pkt);
tl::expected<void, PacketError> StartSend(const scc &con, packet &pkt);
tl::expected<void, PacketError> StartSend(const scc &con, PacketError::ErrorCode errorCode);
tl::expected<void, PacketError> StartSend(const scc &con, buffer_t pktData, uint16_t flags);
void HandleSend(const scc &con, const asio::error_code &ec, size_t bytesSent);
void StartTimeout(const scc &con);
void HandleTimeout(const scc &con, const asio::error_code &ec);

Loading…
Cancel
Save