diff --git a/Source/msg.cpp b/Source/msg.cpp index dc9aa5d95..2676cb7a6 100644 --- a/Source/msg.cpp +++ b/Source/msg.cpp @@ -34,6 +34,7 @@ #include "nthread.h" #include "objects.h" #include "options.h" +#include "pack.h" #include "pfile.h" #include "plrmsg.h" #include "spells.h" @@ -45,8 +46,44 @@ #include "utils/str_cat.hpp" #include "utils/utf8.hpp" +#define ValidateField(logValue, condition) \ + do { \ + if (!(condition)) { \ + LogFailedPacket(#condition, #logValue, logValue); \ + EventFailedPacket(player._pName); \ + return false; \ + } \ + } while (0) + +#define ValidateFields(logValue1, logValue2, condition) \ + do { \ + if (!(condition)) { \ + LogFailedPacket(#condition, #logValue1, logValue1, #logValue2, logValue2); \ + EventFailedPacket(player._pName); \ + return false; \ + } \ + } while (0) + namespace devilution { +void EventFailedPacket(const char *playerName) +{ + std::string message = fmt::format("Player '{}' sent an invalid packet.", playerName); + EventPlrMsg(message); +} + +template +void LogFailedPacket(const char *condition, const char *name, T value) +{ + LogDebug("Remote player packet validation failed: ValidateField({}: {}, {})", name, value, condition); +} + +template +void LogFailedPacket(const char *condition, const char *name1, T1 value1, const char *name2, T2 value2) +{ + LogDebug("Remote player packet validation failed: ValidateFields({}: {}, {}: {}, {})", name1, value1, name2, value2, condition); +} + // #define LOG_RECEIVED_MESSAGES uint8_t gbBufferMsgs; @@ -964,14 +1001,30 @@ bool IsGItemValid(const TCmdGItem &message) return IsItemAvailable(static_cast<_item_indexes>(SDL_SwapLE16(message.def.wIndx))); } -bool IsPItemValid(const TCmdPItem &message) +bool IsPItemValid(const TCmdPItem &message, const Player &player) { const Point position { message.x, message.y }; if (!InDungeonBounds(position)) return false; - return IsItemAvailable(static_cast<_item_indexes>(SDL_SwapLE16(message.def.wIndx))); + auto idx = static_cast<_item_indexes>(SDL_SwapLE16(message.def.wIndx)); + + if (idx != IDI_EAR) { + uint16_t creationFlags = SDL_SwapLE16(message.item.wCI); + uint32_t dwBuff = SDL_SwapLE16(message.item.dwBuff); + + if (idx != IDI_GOLD) + ValidateField(creationFlags, IsCreationFlagComboValid(creationFlags)); + if ((creationFlags & CF_TOWN) != 0) + ValidateField(creationFlags, IsTownItemValid(creationFlags, player)); + else if ((creationFlags & CF_USEFUL) == CF_UPER15) + ValidateFields(creationFlags, dwBuff, IsUniqueMonsterItemValid(creationFlags, dwBuff)); + else + ValidateFields(creationFlags, dwBuff, IsDungeonItemValid(creationFlags, dwBuff)); + } + + return IsItemAvailable(idx); } void PrepareItemForNetwork(const Item &item, TCmdGItem &message) @@ -1254,7 +1307,7 @@ size_t OnPutItem(const TCmd *pCmd, Player &player) if (gbBufferMsgs == 1) { SendPacket(player, &message, sizeof(message)); - } else if (IsPItemValid(message)) { + } else if (IsPItemValid(message, player)) { const Point position { message.x, message.y }; bool isSelf = &player == MyPlayer; const int32_t dwSeed = SDL_SwapLE32(message.def.dwSeed); @@ -1294,7 +1347,7 @@ size_t OnSyncPutItem(const TCmd *pCmd, Player &player) if (gbBufferMsgs == 1) SendPacket(player, &message, sizeof(message)); - else if (IsPItemValid(message)) { + else if (IsPItemValid(message, player)) { const int32_t dwSeed = SDL_SwapLE32(message.def.dwSeed); const uint16_t wCI = SDL_SwapLE16(message.def.wCI); const _item_indexes wIndx = static_cast<_item_indexes>(SDL_SwapLE16(message.def.wIndx)); @@ -1950,7 +2003,7 @@ size_t OnDropItem(const TCmd *pCmd, Player &player) if (gbBufferMsgs == 1) { SendPacket(player, &message, sizeof(message)); - } else if (IsPItemValid(message)) { + } else if (IsPItemValid(message, player)) { DeltaPutItem(message, { message.x, message.y }, player); } @@ -1963,7 +2016,7 @@ size_t OnSpawnItem(const TCmd *pCmd, Player &player) if (gbBufferMsgs == 1) { SendPacket(player, &message, sizeof(message)); - } else if (IsPItemValid(message)) { + } else if (IsPItemValid(message, player)) { if (player.isOnActiveLevel() && &player != MyPlayer) { SyncDropItem(message); } diff --git a/Source/pack.cpp b/Source/pack.cpp index 7fe7c8ab9..e9039b8fd 100644 --- a/Source/pack.cpp +++ b/Source/pack.cpp @@ -80,6 +80,8 @@ bool hasMultipleFlags(uint16_t flags) return (flags & (flags - 1)) > 0; } +} // namespace + bool IsCreationFlagComboValid(uint16_t iCreateInfo) { iCreateInfo = iCreateInfo & ~CF_LEVEL; @@ -198,8 +200,6 @@ bool RecreateHellfireSpellBook(const Player &player, const ItemNetPack &packedIt return true; } -} // namespace - void PackItem(ItemPack &packedItem, const Item &item, bool isHellfire) { packedItem = {}; @@ -562,7 +562,7 @@ bool UnPackNetPlayer(const PlayerNetPack &packed, Player &player) ValidateFields(packed.pClass, packed.pBaseDex, packed.pBaseDex <= player.GetMaximumAttributeValue(CharacterAttribute::Dexterity)); ValidateFields(packed.pClass, packed.pBaseVit, packed.pBaseVit <= player.GetMaximumAttributeValue(CharacterAttribute::Vitality)); - ValidateField(packed._pNumInv, packed._pNumInv < InventoryGridCells); + ValidateField(packed._pNumInv, packed._pNumInv <= InventoryGridCells); player.setCharacterLevel(packed.pLevel); player.position.tile = position; diff --git a/Source/pack.h b/Source/pack.h index 9cefabbff..5ff70a1c5 100644 --- a/Source/pack.h +++ b/Source/pack.h @@ -142,6 +142,10 @@ struct PlayerNetPack { }; #pragma pack(pop) +bool IsCreationFlagComboValid(uint16_t iCreateInfo); +bool IsTownItemValid(uint16_t iCreateInfo, const Player &player); +bool IsUniqueMonsterItemValid(uint16_t iCreateInfo, uint32_t dwBuff); +bool IsDungeonItemValid(uint16_t iCreateInfo, uint32_t dwBuff); void PackPlayer(PlayerPack &pPack, const Player &player); void UnPackPlayer(const PlayerPack &pPack, Player &player); void PackNetPlayer(PlayerNetPack &packed, const Player &player); diff --git a/test/pack_test.cpp b/test/pack_test.cpp index 7e45a927c..6d706b487 100644 --- a/test/pack_test.cpp +++ b/test/pack_test.cpp @@ -1038,7 +1038,7 @@ TEST_F(NetPackTest, UnPackNetPlayer_invalid_baseVit) TEST_F(NetPackTest, UnPackNetPlayer_invalid_numInv) { - MyPlayer->_pNumInv = InventoryGridCells; + MyPlayer->_pNumInv = InventoryGridCells + 1; ASSERT_FALSE(TestNetPackValidation()); }