diff --git a/Source/DiabloUI/multi/selgame.cpp b/Source/DiabloUI/multi/selgame.cpp index dc8af704d..7c28260ea 100644 --- a/Source/DiabloUI/multi/selgame.cpp +++ b/Source/DiabloUI/multi/selgame.cpp @@ -94,12 +94,14 @@ bool IsGameCompatible(const GameData &data) return (data.versionMajor == PROJECT_VERSION_MAJOR && data.versionMinor == PROJECT_VERSION_MINOR && data.versionPatch == PROJECT_VERSION_PATCH - && data.programid == GAME_ID); - return false; + && data.programid == GAME_ID + && data.modHash == sgGameInitInfo.modHash); } static std::string GetErrorMessageIncompatibility(const GameData &data) { + if (data.modHash != sgGameInitInfo.modHash) + return std::string(_("The host is using a different set of mods.")); if (data.programid != GAME_ID) { std::string_view gameMode; switch (data.programid) { diff --git a/Source/dvlnet/base_protocol.h b/Source/dvlnet/base_protocol.h index 3967186a4..62317698d 100644 --- a/Source/dvlnet/base_protocol.h +++ b/Source/dvlnet/base_protocol.h @@ -1,5 +1,6 @@ #pragma once +#include #include #include #include @@ -318,6 +319,12 @@ void base_protocol

::recv() template tl::expected base_protocol

::handle_join_request(packet &inPkt, endpoint_t sender) { + tl::expected pktInfo = inPkt.Info(); + if (pktInfo.has_value() && (*pktInfo)->size() == sizeof(GameData) && game_init_info.size() == sizeof(GameData)) { + constexpr size_t ModHashOffset = offsetof(GameData, modHash); + if (LoadLE32((*pktInfo)->data() + ModHashOffset) != LoadLE32(game_init_info.data() + ModHashOffset)) + return {}; + } plr_t i; for (i = 0; i < Players.size(); ++i) { Peer &peer = peers[i]; diff --git a/Source/dvlnet/packet.h b/Source/dvlnet/packet.h index 2a3bbd7b0..25ac981d2 100644 --- a/Source/dvlnet/packet.h +++ b/Source/dvlnet/packet.h @@ -67,7 +67,8 @@ public: enum class ErrorCode : uint8_t { None, EncryptionFailed, - DecryptionFailed + DecryptionFailed, + ModMismatch }; PacketError() diff --git a/Source/dvlnet/tcp_client.cpp b/Source/dvlnet/tcp_client.cpp index 0e1a18e9b..deab56e82 100644 --- a/Source/dvlnet/tcp_client.cpp +++ b/Source/dvlnet/tcp_client.cpp @@ -214,6 +214,8 @@ void tcp_client::HandleTcpErrorCode() PacketError::ErrorCode code = static_cast(pktData[0]); if (code == PacketError::ErrorCode::DecryptionFailed) RaiseIoHandlerError(_("Server failed to decrypt your packet. Check if you typed the password correctly.")); + else if (code == PacketError::ErrorCode::ModMismatch) + RaiseIoHandlerError(_("The host is using a different set of mods.")); else RaiseIoHandlerError(fmt::format("Unknown error code received from server: {:#04x}", pktData[0])); } diff --git a/Source/dvlnet/tcp_server.cpp b/Source/dvlnet/tcp_server.cpp index c0fb9bc0d..95fa9e90e 100644 --- a/Source/dvlnet/tcp_server.cpp +++ b/Source/dvlnet/tcp_server.cpp @@ -1,6 +1,7 @@ #include "dvlnet/tcp_server.h" #include +#include #include #include #include @@ -126,11 +127,20 @@ tl::expected tcp_server::HandleReceiveNewPlayer(const scc &co if (newplr == PLR_BROADCAST) return tl::make_unexpected(ServerError()); + tl::expected pktInfo = inPkt.Info(); + if (!pktInfo.has_value()) + return tl::make_unexpected(pktInfo.error()); + const buffer_t &joinerInfo = **pktInfo; + if (Empty()) { - tl::expected pktInfo = inPkt.Info(); - if (!pktInfo.has_value()) - return tl::make_unexpected(pktInfo.error()); - game_init_info = **pktInfo; + game_init_info = joinerInfo; + } else if (joinerInfo.size() == sizeof(GameData) && game_init_info.size() == sizeof(GameData)) { + constexpr size_t ModHashOffset = offsetof(GameData, modHash); + if (LoadLE32(joinerInfo.data() + ModHashOffset) != LoadLE32(game_init_info.data() + ModHashOffset)) { + StartSend(con, PacketError::ErrorCode::ModMismatch); + DropConnection(con); + return {}; + } } for (plr_t player = 0; player < Players.size(); player++) { diff --git a/Source/multi.cpp b/Source/multi.cpp index 85afa1539..db5f3ec09 100644 --- a/Source/multi.cpp +++ b/Source/multi.cpp @@ -8,6 +8,7 @@ #include #include #include +#include #include #ifdef USE_SDL3 @@ -95,6 +96,7 @@ void GameData::swapLE() gameSeed[1] = Swap32LE(gameSeed[1]); gameSeed[2] = Swap32LE(gameSeed[2]); gameSeed[3] = Swap32LE(gameSeed[3]); + modHash = Swap32LE(modHash); } namespace { @@ -554,6 +556,22 @@ std::string FormatGameSeed(const uint32_t gameSeed[4]) gameSeed[0], gameSeed[1], gameSeed[2], gameSeed[3]); } +uint32_t ComputeModListHash(std::span mods) +{ + constexpr uint32_t FnvPrime = 16777619U; + constexpr uint32_t FnvOffsetBasis = 2166136261U; + uint32_t result = 0; + for (const std::string_view mod : mods) { + uint32_t hash = FnvOffsetBasis; + for (const char c : mod) { + hash ^= static_cast(c); + hash *= FnvPrime; + } + result ^= hash; + } + return result; +} + void InitGameInfo() { const xoshiro128plusplus gameGenerator = ReserveSeedSequence(); @@ -571,6 +589,8 @@ void InitGameInfo() sgGameInitInfo.bCowQuest = *options.Gameplay.cowQuest ? 1 : 0; sgGameInitInfo.bFriendlyFire = *options.Gameplay.friendlyFire ? 1 : 0; sgGameInitInfo.fullQuests = (!gbIsMultiplayer || *options.Gameplay.multiplayerFullQuests) ? 1 : 0; + const std::vector activeMods = GetOptions().Mods.GetActiveModList(); + sgGameInitInfo.modHash = ComputeModListHash(activeMods); } void NetSendLoPri(uint8_t playerId, const std::byte *data, size_t size) diff --git a/Source/multi.h b/Source/multi.h index 43ebdf3e1..9cd11ab95 100644 --- a/Source/multi.h +++ b/Source/multi.h @@ -6,7 +6,9 @@ #pragma once #include +#include #include +#include #include #include "dvlnet/leaveinfo.hpp" @@ -39,6 +41,8 @@ struct GameData { uint8_t fullQuests; /** Used to initialise the seed table for dungeon levels so players in multiplayer games generate the same layout */ uint32_t gameSeed[4]; + /** FNV-1a hash of active mod list for multiplayer compatibility check */ + uint32_t modHash; void swapLE(); }; @@ -68,6 +72,7 @@ extern bool IsLoopback; DVL_API_FOR_TEST std::string DescribeLeaveReason(leaveinfo_t leaveReason); std::string FormatGameSeed(const uint32_t gameSeed[4]); +uint32_t ComputeModListHash(std::span mods); void InitGameInfo(); void NetSendLoPri(uint8_t playerId, const std::byte *data, size_t size); diff --git a/test/multi_logging_test.cpp b/test/multi_logging_test.cpp index 0911ad5e3..7205cc633 100644 --- a/test/multi_logging_test.cpp +++ b/test/multi_logging_test.cpp @@ -1,9 +1,53 @@ +#include +#include +#include + #include #include "multi.h" namespace devilution { +TEST(ComputeModListHash, EmptyListProducesZero) +{ + // An empty mod list produces zero (XOR identity with no contributors). + EXPECT_EQ(ComputeModListHash({}), 0U); +} + +TEST(ComputeModListHash, Deterministic) +{ + const std::array mods = { "mod-a", "mod-b" }; + EXPECT_EQ(ComputeModListHash(mods), ComputeModListHash(mods)); +} + +TEST(ComputeModListHash, DifferentModsProduceDifferentHashes) +{ + const std::array modsA = { "mod-a" }; + const std::array modsB = { "mod-b" }; + EXPECT_NE(ComputeModListHash(modsA), ComputeModListHash(modsB)); +} + +TEST(ComputeModListHash, OrderDoesNotMatter) +{ + const std::array ab = { "mod-a", "mod-b" }; + const std::array ba = { "mod-b", "mod-a" }; + EXPECT_EQ(ComputeModListHash(ab), ComputeModListHash(ba)); +} + +TEST(ComputeModListHash, DifferentModNamesDifferentHashes) +{ + // ["ab", "c"] must not collide with ["a", "bc"]. + const std::array splitFirst = { "ab", "c" }; + const std::array splitSecond = { "a", "bc" }; + EXPECT_NE(ComputeModListHash(splitFirst), ComputeModListHash(splitSecond)); +} + +TEST(ComputeModListHash, NoModsDifferFromSomeMods) +{ + const std::array oneMod = { "any-mod" }; + EXPECT_NE(ComputeModListHash({}), ComputeModListHash(oneMod)); +} + TEST(MultiplayerLogging, NormalExitReason) { EXPECT_EQ("normal exit", DescribeLeaveReason(net::leaveinfo_t::LEAVE_EXIT));