From 94b1a8677aaa5d4f6d80772182d877fbe6198aee Mon Sep 17 00:00:00 2001 From: morfidon <57798071+morfidon@users.noreply.github.com> Date: Wed, 11 Mar 2026 17:48:54 +0100 Subject: [PATCH] Sanitize player spell selections on load to prevent invalid spell hotkeys --- Source/loadsave.cpp | 7 ++- Source/pfile.cpp | 3 + Source/player.cpp | 15 +++-- Source/spells.cpp | 55 +++++++++++++----- Source/spells.h | 2 + test/player_test.cpp | 50 +++++++++++++++++ test/writehero_test.cpp | 120 ++++++++++++++++++++++++++++++++++++++++ 7 files changed, 231 insertions(+), 21 deletions(-) diff --git a/Source/loadsave.cpp b/Source/loadsave.cpp index 010658713..634c5b7e3 100644 --- a/Source/loadsave.cpp +++ b/Source/loadsave.cpp @@ -35,6 +35,7 @@ #include "pfile.h" #include "plrmsg.h" #include "qol/stash.h" +#include "spells.h" #include "stores.h" #include "tables/playerdat.hpp" #include "utils/algorithm/container.hpp" @@ -49,6 +50,8 @@ namespace devilution { bool gbIsHellfireSaveGame; uint8_t giNumberOfLevels; +void ValidatePlayerForLoad(); + namespace { constexpr size_t MaxMissilesForSaveGame = 125; @@ -636,7 +639,6 @@ void LoadPlayer(LoadHelper &file, Player &player) sgGameInitInfo.nDifficulty = static_cast<_difficulty>(file.NextLE()); player.pDamAcFlags = static_cast(file.NextLE()); file.Skip(20); // Available bytes - CalcPlrInv(player, false); player.executedSpell = player.queuedSpell; // Ensures backwards compatibility @@ -2519,6 +2521,9 @@ tl::expected LoadGame(bool firstflag) Player &myPlayer = *MyPlayer; LoadPlayer(file, myPlayer); + ValidatePlayerForLoad(); + CalcPlrInv(myPlayer, false); + SanitizePlayerSpellSelections(myPlayer); if (sgGameInitInfo.nDifficulty < DIFF_NORMAL || sgGameInitInfo.nDifficulty > DIFF_HELL) sgGameInitInfo.nDifficulty = DIFF_NORMAL; diff --git a/Source/pfile.cpp b/Source/pfile.cpp index ec96c6608..2580cee36 100644 --- a/Source/pfile.cpp +++ b/Source/pfile.cpp @@ -28,6 +28,7 @@ #include "mpq/mpq_common.hpp" #include "pack.h" #include "qol/stash.h" +#include "spells.h" #include "tables/playerdat.hpp" #include "utils/endian_read.hpp" #include "utils/endian_swap.hpp" @@ -690,6 +691,7 @@ bool pfile_ui_set_hero_infos(bool (*uiAddHeroInfo)(_uiheroinfo *)) LoadHeroItems(player); RemoveAllInvalidItems(player); CalcPlrInv(player, false); + SanitizePlayerSpellSelections(player); Game2UiPlayer(player, &uihero, hasSaveGame); uiAddHeroInfo(&uihero); @@ -777,6 +779,7 @@ void pfile_read_player_from_save(uint32_t saveNum, Player &player) LoadHeroItems(player); RemoveAllInvalidItems(player); CalcPlrInv(player, false); + SanitizePlayerSpellSelections(player); } void pfile_save_level() diff --git a/Source/player.cpp b/Source/player.cpp index 1b6f14096..ca3226d5c 100644 --- a/Source/player.cpp +++ b/Source/player.cpp @@ -1528,11 +1528,16 @@ void GetPlayerGraphicsPath(std::string_view path, std::string_view prefix, std:: *BufCopy(out, "plrgfx\\", path, "\\", prefix, "\\", prefix, type) = '\0'; } -} // namespace - -void Player::CalcScrolls() -{ - _pScrlSpells = 0; +} // namespace + +void ValidatePlayerForLoad() +{ + ValidatePlayer(); +} + +void Player::CalcScrolls() +{ + _pScrlSpells = 0; for (const Item &item : InventoryAndBeltPlayerItemsRange { *this }) { if (item.isScroll() && item._iStatFlag) { _pScrlSpells |= GetSpellBitmask(item._iSpell); diff --git a/Source/spells.cpp b/Source/spells.cpp index 476e2013b..438a64821 100644 --- a/Source/spells.cpp +++ b/Source/spells.cpp @@ -32,21 +32,7 @@ namespace { */ bool IsReadiedSpellValid(const Player &player) { - switch (player._pRSplType) { - case SpellType::Skill: - case SpellType::Spell: - case SpellType::Invalid: - return true; - - case SpellType::Charges: - return (player._pISpells & GetSpellBitmask(player._pRSpell)) != 0; - - case SpellType::Scroll: - return (player._pScrlSpells & GetSpellBitmask(player._pRSpell)) != 0; - - default: - return false; - } + return IsPlayerSpellSelectionValid(player, player._pRSpell, player._pRSplType); } /** @@ -85,6 +71,45 @@ bool IsValidSpellFrom(int spellFrom) return false; } +bool IsPlayerSpellSelectionValid(const Player &player, SpellID spellId, SpellType spellType) +{ + if (spellType == SpellType::Invalid) { + return spellId == SpellID::Invalid; + } + + if (!IsValidSpell(spellId)) { + return false; + } + + switch (spellType) { + case SpellType::Skill: + return (player._pAblSpells & GetSpellBitmask(spellId)) != 0; + case SpellType::Spell: + return (player._pMemSpells & GetSpellBitmask(spellId)) != 0 && player.GetSpellLevel(spellId) > 0; + case SpellType::Scroll: + return (player._pScrlSpells & GetSpellBitmask(spellId)) != 0; + case SpellType::Charges: + return (player._pISpells & GetSpellBitmask(spellId)) != 0; + default: + return false; + } +} + +void SanitizePlayerSpellSelections(Player &player) +{ + for (size_t i = 0; i < NumHotkeys; ++i) { + if (!IsPlayerSpellSelectionValid(player, player._pSplHotKey[i], player._pSplTHotKey[i])) { + player._pSplHotKey[i] = SpellID::Invalid; + player._pSplTHotKey[i] = SpellType::Invalid; + } + } + + if (!IsPlayerSpellSelectionValid(player, player._pRSpell, player._pRSplType)) { + player._pRSpell = SpellID::Invalid; + player._pRSplType = SpellType::Invalid; + } +} + bool IsWallSpell(SpellID spl) { return spl == SpellID::FireWall || spl == SpellID::LightningWall; diff --git a/Source/spells.h b/Source/spells.h index 01b72e8b4..d7b67ed87 100644 --- a/Source/spells.h +++ b/Source/spells.h @@ -21,6 +21,8 @@ enum class SpellCheckResult : uint8_t { bool IsValidSpell(SpellID spl); bool IsValidSpellFrom(int spellFrom); +bool IsPlayerSpellSelectionValid(const Player &player, SpellID spellId, SpellType spellType); +void SanitizePlayerSpellSelections(Player &player); bool IsWallSpell(SpellID spl); bool TargetsMonster(SpellID id); int GetManaAmount(const Player &player, SpellID sn); diff --git a/test/player_test.cpp b/test/player_test.cpp index 4d4108e7f..32a9f9d4b 100644 --- a/test/player_test.cpp +++ b/test/player_test.cpp @@ -5,6 +5,7 @@ #include "cursor.h" #include "engine/assets.hpp" #include "init.hpp" +#include "spells.h" #include "tables/playerdat.hpp" using namespace devilution; @@ -204,3 +205,52 @@ TEST(Player, CreatePlayer) CreatePlayer(Players[0], HeroClass::Rogue); AssertPlayer(Players[0]); } + +TEST(Player, IsPlayerSpellSelectionValidChecksSpellSources) +{ + LoadCoreArchives(); + LoadGameArchives(); + if (!HaveMainData()) { + GTEST_SKIP() << "MPQ assets (spawn.mpq or DIABDAT.MPQ) not found - skipping test"; + } + LoadSpellData(); + + const SpellID spell = SpellID::Healing; + const uint64_t mask = GetSpellBitmask(spell); + Player player {}; + + EXPECT_FALSE(IsPlayerSpellSelectionValid(player, spell, SpellType::Spell)); + player._pMemSpells = mask; + EXPECT_FALSE(IsPlayerSpellSelectionValid(player, spell, SpellType::Spell)); + player._pSplLvl[static_cast(spell)] = 1; + EXPECT_TRUE(IsPlayerSpellSelectionValid(player, spell, SpellType::Spell)); + + EXPECT_FALSE(IsPlayerSpellSelectionValid(player, spell, SpellType::Scroll)); + player._pScrlSpells = mask; + EXPECT_TRUE(IsPlayerSpellSelectionValid(player, spell, SpellType::Scroll)); + + EXPECT_FALSE(IsPlayerSpellSelectionValid(player, spell, SpellType::Charges)); + player._pISpells = mask; + EXPECT_TRUE(IsPlayerSpellSelectionValid(player, spell, SpellType::Charges)); + + EXPECT_FALSE(IsPlayerSpellSelectionValid(player, spell, SpellType::Skill)); + player._pAblSpells = mask; + EXPECT_TRUE(IsPlayerSpellSelectionValid(player, spell, SpellType::Skill)); +} + +TEST(Player, IsPlayerSpellSelectionValidRejectsInvalidSelections) +{ + LoadCoreArchives(); + LoadGameArchives(); + if (!HaveMainData()) { + GTEST_SKIP() << "MPQ assets (spawn.mpq or DIABDAT.MPQ) not found - skipping test"; + } + LoadSpellData(); + + Player player {}; + + EXPECT_TRUE(IsPlayerSpellSelectionValid(player, SpellID::Invalid, SpellType::Invalid)); + EXPECT_FALSE(IsPlayerSpellSelectionValid(player, SpellID::Healing, SpellType::Invalid)); + EXPECT_FALSE(IsPlayerSpellSelectionValid(player, SpellID::Invalid, SpellType::Spell)); + EXPECT_FALSE(IsPlayerSpellSelectionValid(player, SpellID::Null, SpellType::Spell)); +} diff --git a/test/writehero_test.cpp b/test/writehero_test.cpp index 892f7e18d..cd9f4f1f7 100644 --- a/test/writehero_test.cpp +++ b/test/writehero_test.cpp @@ -14,6 +14,7 @@ #include "loadsave.h" #include "pack.h" #include "pfile.h" +#include "spells.h" #include "tables/playerdat.hpp" #include "utils/endian_swap.hpp" #include "utils/file_util.h" @@ -416,5 +417,124 @@ TEST(Writehero, pfile_write_hero) "a79367caae6192d54703168d82e0316aa289b2a33251255fad8abe34889c1d3a"); } +TEST(Writehero, pfile_read_player_from_save_clears_invalid_spell_selections) +{ + LoadCoreArchives(); + LoadGameArchives(); + + if (!HaveMainData()) { + GTEST_SKIP() << "MPQ assets (spawn.mpq or DIABDAT.MPQ) not found - skipping test"; + } + + const std::string savePath = paths::BasePath() + "multi_0.sv"; + paths::SetPrefPath(paths::BasePath()); + RemoveFile(savePath.c_str()); + + gbVanilla = false; + gbIsHellfire = false; + gbIsSpawn = false; + gbIsMultiplayer = true; + gbIsHellfireSaveGame = false; + leveltype = DTYPE_TOWN; + giNumberOfLevels = 17; + + Players.resize(1); + MyPlayerId = 0; + MyPlayer = &Players[MyPlayerId]; + + LoadSpellData(); + LoadPlayerDataFiles(); + LoadMonsterData(); + LoadItemData(); + _uiheroinfo info {}; + info.heroclass = HeroClass::Rogue; + pfile_ui_save_create(&info); + + Player &player = *MyPlayer; + player._pMemSpells = GetSpellBitmask(SpellID::Healing); + player._pSplLvl[static_cast(SpellID::Healing)] = 1; + player._pMemSpells |= GetSpellBitmask(SpellID::Apocalypse); + player._pSplHotKey[0] = SpellID::Apocalypse; + player._pSplTHotKey[0] = SpellType::Spell; + player._pSplHotKey[1] = SpellID::Healing; + player._pSplTHotKey[1] = SpellType::Spell; + player._pRSpell = SpellID::Apocalypse; + player._pRSplType = SpellType::Spell; + + pfile_write_hero(); + + player._pSplHotKey[0] = SpellID::Invalid; + player._pSplTHotKey[0] = SpellType::Invalid; + player._pSplHotKey[1] = SpellID::Invalid; + player._pSplTHotKey[1] = SpellType::Invalid; + player._pRSpell = SpellID::Invalid; + player._pRSplType = SpellType::Invalid; + + pfile_read_player_from_save(info.saveNumber, player); + + EXPECT_EQ(player._pSplHotKey[0], SpellID::Invalid); + EXPECT_EQ(player._pSplTHotKey[0], SpellType::Invalid); + EXPECT_EQ(player._pSplHotKey[1], SpellID::Healing); + EXPECT_EQ(player._pSplTHotKey[1], SpellType::Spell); + EXPECT_EQ(player._pRSpell, SpellID::Invalid); + EXPECT_EQ(player._pRSplType, SpellType::Invalid); +} + +TEST(Writehero, pfile_read_player_from_save_preserves_valid_spell_selections) +{ + LoadCoreArchives(); + LoadGameArchives(); + + if (!HaveMainData()) { + GTEST_SKIP() << "MPQ assets (spawn.mpq or DIABDAT.MPQ) not found - skipping test"; + } + + const std::string savePath = paths::BasePath() + "multi_0.sv"; + paths::SetPrefPath(paths::BasePath()); + RemoveFile(savePath.c_str()); + + gbVanilla = false; + gbIsHellfire = false; + gbIsSpawn = false; + gbIsMultiplayer = true; + gbIsHellfireSaveGame = false; + leveltype = DTYPE_TOWN; + giNumberOfLevels = 17; + + Players.resize(1); + MyPlayerId = 0; + MyPlayer = &Players[MyPlayerId]; + + LoadSpellData(); + LoadPlayerDataFiles(); + LoadMonsterData(); + LoadItemData(); + _uiheroinfo info {}; + info.heroclass = HeroClass::Rogue; + pfile_ui_save_create(&info); + + Player &player = *MyPlayer; + player._pMemSpells = GetSpellBitmask(SpellID::Healing); + player._pSplLvl[static_cast(SpellID::Healing)] = 1; + player._pSplHotKey[0] = SpellID::Healing; + player._pSplTHotKey[0] = SpellType::Spell; + player._pRSpell = SpellID::Healing; + player._pRSplType = SpellType::Spell; + + pfile_write_hero(); + + player._pSplHotKey[0] = SpellID::Invalid; + player._pSplTHotKey[0] = SpellType::Invalid; + player._pRSpell = SpellID::Invalid; + player._pRSplType = SpellType::Invalid; + + pfile_read_player_from_save(info.saveNumber, player); + + EXPECT_EQ(player._pSplHotKey[0], SpellID::Healing); + EXPECT_EQ(player._pSplTHotKey[0], SpellType::Spell); + EXPECT_EQ(player._pRSpell, SpellID::Healing); + EXPECT_EQ(player._pRSplType, SpellType::Spell); +} + } // namespace } // namespace devilution