diff --git a/Source/dvlnet/abstract_net.h b/Source/dvlnet/abstract_net.h index 832c50cba..e6deeaac4 100644 --- a/Source/dvlnet/abstract_net.h +++ b/Source/dvlnet/abstract_net.h @@ -16,8 +16,8 @@ using provider_t = unsigned long; class abstract_net { public: - virtual int create(std::string addrstr) = 0; - virtual int join(std::string addrstr) = 0; + virtual int create(std::string_view addrstr) = 0; + virtual int join(std::string_view addrstr) = 0; virtual bool SNetReceiveMessage(uint8_t *sender, void **data, size_t *size) = 0; virtual bool SNetSendMessage(uint8_t dest, void *data, size_t size) = 0; virtual bool SNetReceiveTurns(char **data, size_t *size, uint32_t *status) = 0; diff --git a/Source/dvlnet/base_protocol.h b/Source/dvlnet/base_protocol.h index cf0145066..274c326d9 100644 --- a/Source/dvlnet/base_protocol.h +++ b/Source/dvlnet/base_protocol.h @@ -15,8 +15,8 @@ namespace devilution::net { template class base_protocol : public base { public: - int create(std::string addrstr) override; - int join(std::string addrstr) override; + int create(std::string_view addrstr) override; + int join(std::string_view addrstr) override; tl::expected poll() override; tl::expected send(packet &pkt) override; void DisconnectNet(plr_t plr) override; @@ -161,7 +161,7 @@ tl::expected base_protocol

::wait_join() } template -int base_protocol

::create(std::string addrstr) +int base_protocol

::create(std::string_view addrstr) { gamename = addrstr; isGameHost_ = true; @@ -183,7 +183,7 @@ int base_protocol

::create(std::string addrstr) } template -int base_protocol

::join(std::string addrstr) +int base_protocol

::join(std::string_view addrstr) { gamename = addrstr; isGameHost_ = false; diff --git a/Source/dvlnet/cdwrap.cpp b/Source/dvlnet/cdwrap.cpp index 371d472ae..3d6864d05 100644 --- a/Source/dvlnet/cdwrap.cpp +++ b/Source/dvlnet/cdwrap.cpp @@ -17,13 +17,13 @@ void cdwrap::reset() dvlnet_wrap->SNetRegisterEventHandler(eventType, eventHandler); } -int cdwrap::create(std::string addrstr) +int cdwrap::create(std::string_view addrstr) { reset(); return dvlnet_wrap->create(addrstr); } -int cdwrap::join(std::string addrstr) +int cdwrap::join(std::string_view addrstr) { game_init_info = buffer_t(); reset(); diff --git a/Source/dvlnet/cdwrap.h b/Source/dvlnet/cdwrap.h index 339477d29..55ba3523f 100644 --- a/Source/dvlnet/cdwrap.h +++ b/Source/dvlnet/cdwrap.h @@ -32,8 +32,8 @@ public: reset(); } - int create(std::string addrstr) override; - int join(std::string addrstr) override; + int create(std::string_view addrstr) override; + int join(std::string_view addrstr) override; bool SNetReceiveMessage(uint8_t *sender, void **data, size_t *size) override; bool SNetSendMessage(uint8_t dest, void *data, size_t size) override; bool SNetReceiveTurns(char **data, size_t *size, uint32_t *status) override; diff --git a/Source/dvlnet/loopback.cpp b/Source/dvlnet/loopback.cpp index 21ca4e12c..3d1401e8e 100644 --- a/Source/dvlnet/loopback.cpp +++ b/Source/dvlnet/loopback.cpp @@ -9,13 +9,13 @@ namespace devilution::net { -int loopback::create(std::string /*addrstr*/) +int loopback::create(std::string_view /*addrstr*/) { IsLoopback = true; return plr_single; } -int loopback::join(std::string /*addrstr*/) +int loopback::join(std::string_view /*addrstr*/) { ABORT(); } diff --git a/Source/dvlnet/loopback.h b/Source/dvlnet/loopback.h index cbb295b5e..4bf5d57dc 100644 --- a/Source/dvlnet/loopback.h +++ b/Source/dvlnet/loopback.h @@ -17,8 +17,8 @@ private: public: loopback() = default; - int create(std::string addrstr) override; - int join(std::string addrstr) override; + int create(std::string_view addrstr) override; + int join(std::string_view addrstr) override; bool SNetReceiveMessage(uint8_t *sender, void **data, size_t *size) override; bool SNetSendMessage(uint8_t dest, void *data, size_t size) override; bool SNetReceiveTurns(char **data, size_t *size, uint32_t *status) override; diff --git a/Source/dvlnet/tcp_client.cpp b/Source/dvlnet/tcp_client.cpp index 25e15ae50..df40ce6d8 100644 --- a/Source/dvlnet/tcp_client.cpp +++ b/Source/dvlnet/tcp_client.cpp @@ -13,25 +13,56 @@ #include "options.h" #include "utils/language.h" #include "utils/str_cat.hpp" +#include "utils/str_split.hpp" namespace devilution::net { -int tcp_client::create(std::string addrstr) +int tcp_client::create(std::string_view addrstr) { auto port = *sgOptions.Network.port; - local_server = std::make_unique(ioc, addrstr, port, *pktfty); + local_server = std::make_unique(ioc, std::string(addrstr), port, *pktfty); return join(local_server->LocalhostSelf()); } -int tcp_client::join(std::string addrstr) +int tcp_client::join(std::string_view addrstr) { constexpr int MsSleep = 10; constexpr int NoSleep = 250; - std::string port = StrCat(*sgOptions.Network.port); + const char *defaultPort = "6112"; + std::string_view host; + std::string_view port = defaultPort; + if (!addrstr.empty() && addrstr[0] == '[') { + // Assume IPv6 address in square brackets, followed by port + // Example: [::1]:6113 + size_t pos = addrstr.find(']', 1); + pos = pos != std::string::npos ? pos + 1 : addrstr.length(); + host = addrstr.substr(0, pos); + + if (pos != addrstr.length()) { + if (addrstr[pos] != ':') { + SDL_SetError("Invalid hostname: expected colon after square brackets"); + return -1; + } + if (++pos != addrstr.length()) + port = addrstr.substr(pos); + } + } else { + // Assume "hostname:port" + SplitByChar splithost(addrstr, ':'); + auto it = splithost.begin(); + if (it != splithost.end()) host = *it++; + if (it != splithost.end()) port = *it++; + + // If there is more than one colon, assume it's just a plain IPv6 address + if (it != splithost.end()) { + host = addrstr; + port = defaultPort; + } + } asio::error_code errorCode; - asio::ip::basic_resolver_results range = resolver.resolve(addrstr, port, errorCode); + asio::ip::basic_resolver_results range = resolver.resolve(host, port, errorCode); if (errorCode) { SDL_SetError("%s", errorCode.message().c_str()); return -1; diff --git a/Source/dvlnet/tcp_client.h b/Source/dvlnet/tcp_client.h index 7a68072b7..3d959bdd1 100644 --- a/Source/dvlnet/tcp_client.h +++ b/Source/dvlnet/tcp_client.h @@ -27,8 +27,8 @@ namespace devilution::net { class tcp_client : public base { public: - int create(std::string addrstr) override; - int join(std::string addrstr) override; + int create(std::string_view addrstr) override; + int join(std::string_view addrstr) override; tl::expected poll() override; tl::expected send(packet &pkt) override;