#include "dvlnet/tcp_client.h" #include #include #include #include #include #ifdef USE_SDL3 #include #include #else #include #endif #include #include #include "options.h" #include "utils/language.h" #include "utils/str_cat.hpp" #include "utils/str_split.hpp" namespace devilution::net { int tcp_client::create(std::string_view addrstr) { auto port = *GetOptions().Network.port; local_server = std::make_unique(ioc, std::string(addrstr), port, *pktfty); return join(local_server->LocalhostSelf()); } int tcp_client::join(std::string_view addrstr) { constexpr int MsSleep = 10; constexpr int NoSleep = 250; const char *defaultPort = "6112"; std::string_view host; std::string_view port = defaultPort; if (!addrstr.empty() && addrstr[0] == '[') { // Assume IPv6 address in square brackets, followed by port // Example: [::1]:6113 size_t pos = addrstr.find(']', 1); pos = pos != std::string::npos ? pos + 1 : addrstr.length(); host = addrstr.substr(0, pos); if (pos != addrstr.length()) { if (addrstr[pos] != ':') { SDL_SetError("Invalid hostname: expected colon after square brackets"); return -1; } if (++pos != addrstr.length()) port = addrstr.substr(pos); } } else { // Assume "hostname:port" const SplitByChar splithost(addrstr, ':'); auto it = splithost.begin(); if (it != splithost.end()) host = *it++; if (it != splithost.end()) port = *it++; // If there is more than one colon, assume it's just a plain IPv6 address if (it != splithost.end()) { host = addrstr; port = defaultPort; } } asio::error_code errorCode; const asio::ip::basic_resolver_results range = resolver.resolve(host, port, errorCode); if (errorCode) { SDL_SetError("%s", errorCode.message().c_str()); return -1; } asio::connect(sock, range, errorCode); if (errorCode) { SDL_SetError("%s", errorCode.message().c_str()); return -1; } const asio::ip::tcp::no_delay option(true); sock.set_option(option, errorCode); if (errorCode) LogError("Client error setting socket option: {}", errorCode.message()); StartReceive(); { cookie_self = packet_out::GenerateCookie(); tl::expected, PacketError> pkt = pktfty->make_packet( PLR_BROADCAST, PLR_MASTER, cookie_self, game_init_info); if (!pkt.has_value()) { const std::string_view message = pkt.error().what(); SDL_SetError("make_packet: %.*s", static_cast(message.size()), message.data()); return -1; } tl::expected sendResult = send(**pkt); if (!sendResult.has_value()) { const std::string_view message = sendResult.error().what(); SDL_SetError("send: %.*s", static_cast(message.size()), message.data()); return -1; } for (auto i = 0; i < NoSleep; ++i) { tl::expected pollResult = poll(); if (!pollResult.has_value()) { const std::string_view message = pollResult.error().what(); SDL_SetError("%.*s", static_cast(message.size()), message.data()); return -1; } if (plr_self != PLR_BROADCAST) break; // join successful SDL_Delay(MsSleep); } } if (plr_self == PLR_BROADCAST) { const std::string_view message = _("Unable to connect"); SDL_SetError("%.*s", static_cast(message.size()), message.data()); return -1; } return plr_self; } bool tcp_client::IsGameHost() { return local_server != nullptr; } tl::expected tcp_client::poll() { while (ioc.poll_one() > 0) { if (IsGameHost()) { tl::expected serverResult = local_server->CheckIoHandlerError(); if (!serverResult.has_value()) return serverResult; } if (ioHandlerResult == std::nullopt) continue; tl::expected packetError = tl::make_unexpected(*ioHandlerResult); ioHandlerResult = std::nullopt; return packetError; } return {}; } void tcp_client::HandleReceive(const asio::error_code &error, size_t bytesRead) { if (error) { const PacketError packetError = IoHandlerError(error.message()); RaiseIoHandlerError(packetError); return; } if (bytesRead == 0) { const PacketError packetError(_("error: read 0 bytes from server")); RaiseIoHandlerError(packetError); return; } recv_buffer.resize(bytesRead); recv_queue.Write(std::move(recv_buffer)); recv_buffer.resize(frame_queue::max_frame_size); while (true) { tl::expected ready = recv_queue.PacketReady(); if (!ready.has_value()) { RaiseIoHandlerError(ready.error()); return; } if (!*ready) break; if (recv_queue.ReadPacketFlags() == TcpErrorCodeFlags) { HandleTcpErrorCode(); return; } tl::expected result = recv_queue.ReadPacket() .and_then([this](buffer_t &&pktData) { return pktfty->make_packet(pktData); }) .and_then([this](std::unique_ptr &&pkt) { return RecvLocal(*pkt); }); if (!result.has_value()) { RaiseIoHandlerError(result.error()); return; } } StartReceive(); } void tcp_client::StartReceive() { sock.async_receive( asio::buffer(recv_buffer), std::bind(&tcp_client::HandleReceive, this, std::placeholders::_1, std::placeholders::_2)); } void tcp_client::HandleSend(const asio::error_code &error, size_t bytesSent) { if (error) RaiseIoHandlerError(error.message()); } void tcp_client::HandleTcpErrorCode() { tl::expected packet = recv_queue.ReadPacket(); if (!packet.has_value()) { RaiseIoHandlerError(packet.error()); return; } buffer_t pktData = *packet; if (pktData.size() != 1) { RaiseIoHandlerError(PacketError()); return; } PacketError::ErrorCode code = static_cast(pktData[0]); if (code == PacketError::ErrorCode::DecryptionFailed) RaiseIoHandlerError(_("Server failed to decrypt your packet. Check if you typed the password correctly.")); else RaiseIoHandlerError(fmt::format("Unknown error code received from server: {:#04x}", pktData[0])); } tl::expected tcp_client::send(packet &pkt) { tl::expected frame = frame_queue::MakeFrame(pkt.Data()); if (!frame.has_value()) return tl::make_unexpected(frame.error()); std::unique_ptr framePtr = std::make_unique(*frame); const asio::mutable_buffer buf = asio::buffer(*framePtr); asio::async_write(sock, buf, [this, frame = std::move(framePtr)](const asio::error_code &error, size_t bytesSent) { HandleSend(error, bytesSent); }); return {}; } void tcp_client::DisconnectNet(plr_t plr) { if (local_server != nullptr) local_server->DisconnectNet(plr); } bool tcp_client::SNetLeaveGame(net::leaveinfo_t type) { auto ret = base::SNetLeaveGame(type); process_network_packets(); if (local_server != nullptr) local_server->Close(); sock.close(); return ret; } std::string tcp_client::make_default_gamename() { return std::string(GetOptions().Network.szBindAddress); } void tcp_client::RaiseIoHandlerError(const PacketError &error) { ioHandlerResult.emplace(error); } tcp_client::~tcp_client() = default; } // namespace devilution::net