From 4e58828335dd9d912773f5530276914aae232ff9 Mon Sep 17 00:00:00 2001 From: morfidon <57798071+morfidon@users.noreply.github.com> Date: Thu, 12 Mar 2026 11:22:35 +0100 Subject: [PATCH] Normalize spell cast state after hotkey load Introduce SyncPlayerSpellStateFromSelections and call it in hotkey load paths so queued/executed spell metadata and spellFrom are reset consistently after sanitization. This centralizes post-load spell-state normalization, removes duplicated per-call-site resync code, and adds writehero regression tests for missing hotkeys data and legacy 4-slot hotkeys format handling. --- Source/loadsave.cpp | 4 +- Source/pfile.cpp | 3 +- Source/spells.cpp | 10 +++ Source/spells.h | 1 + test/writehero_test.cpp | 185 ++++++++++++++++++++++++++++++++++++++++ 5 files changed, 199 insertions(+), 4 deletions(-) diff --git a/Source/loadsave.cpp b/Source/loadsave.cpp index 92ab7cf64..8bd6cba6d 100644 --- a/Source/loadsave.cpp +++ b/Source/loadsave.cpp @@ -2342,6 +2342,7 @@ void LoadHotkeys(uint32_t saveNum, Player &myPlayer) LoadHelper file(OpenSaveArchive(saveNum), "hotkeys"); if (!file.IsValid()) { SanitizePlayerSpellSelections(myPlayer); + SyncPlayerSpellStateFromSelections(myPlayer); return; } @@ -2379,6 +2380,7 @@ void LoadHotkeys(uint32_t saveNum, Player &myPlayer) myPlayer._pRSpell = static_cast(file.NextLE()); myPlayer._pRSplType = static_cast(file.NextLE()); SanitizePlayerSpellSelections(myPlayer); + SyncPlayerSpellStateFromSelections(myPlayer); } void SaveHotkeys(SaveWriter &saveWriter, const Player &player) @@ -2532,8 +2534,6 @@ tl::expected LoadGame(bool firstflag) ValidatePlayer(); CalcPlrInv(myPlayer, false); LoadHotkeys(gSaveNumber, myPlayer); - myPlayer.queuedSpell.spellId = myPlayer._pRSpell; - myPlayer.queuedSpell.spellType = myPlayer._pRSplType; if (sgGameInitInfo.nDifficulty < DIFF_NORMAL || sgGameInitInfo.nDifficulty > DIFF_HELL) sgGameInitInfo.nDifficulty = DIFF_NORMAL; diff --git a/Source/pfile.cpp b/Source/pfile.cpp index 4e08bd4c8..0543adb83 100644 --- a/Source/pfile.cpp +++ b/Source/pfile.cpp @@ -781,10 +781,9 @@ void pfile_read_player_from_save(uint32_t saveNum, Player &player) CalcPlrInv(player, false); if (&player == MyPlayer) { LoadHotkeys(saveNum, player); - player.queuedSpell.spellId = player._pRSpell; - player.queuedSpell.spellType = player._pRSplType; } else { SanitizePlayerSpellSelections(player); + SyncPlayerSpellStateFromSelections(player); } } diff --git a/Source/spells.cpp b/Source/spells.cpp index 438a64821..f6c69ffcf 100644 --- a/Source/spells.cpp +++ b/Source/spells.cpp @@ -110,6 +110,16 @@ void SanitizePlayerSpellSelections(Player &player) } } +void SyncPlayerSpellStateFromSelections(Player &myPlayer) +{ + myPlayer.queuedSpell.spellId = myPlayer._pRSpell; + myPlayer.queuedSpell.spellType = myPlayer._pRSplType; + myPlayer.queuedSpell.spellFrom = 0; + myPlayer.queuedSpell.spellLevel = 0; + myPlayer.executedSpell = myPlayer.queuedSpell; + myPlayer.spellFrom = 0; +} + bool IsWallSpell(SpellID spl) { return spl == SpellID::FireWall || spl == SpellID::LightningWall; diff --git a/Source/spells.h b/Source/spells.h index d7b67ed87..a315bd0b1 100644 --- a/Source/spells.h +++ b/Source/spells.h @@ -23,6 +23,7 @@ bool IsValidSpell(SpellID spl); bool IsValidSpellFrom(int spellFrom); bool IsPlayerSpellSelectionValid(const Player &player, SpellID spellId, SpellType spellType); void SanitizePlayerSpellSelections(Player &player); +void SyncPlayerSpellStateFromSelections(Player &myPlayer); bool IsWallSpell(SpellID spl); bool TargetsMonster(SpellID id); int GetManaAmount(const Player &player, SpellID sn); diff --git a/test/writehero_test.cpp b/test/writehero_test.cpp index 4a4fa2dda..4c7ae7efd 100644 --- a/test/writehero_test.cpp +++ b/test/writehero_test.cpp @@ -1,12 +1,16 @@ #include "player_test.h" +#include +#include #include #include +#include #include #include #include +#include "codec.h" #include "cursor.h" #include "engine/assets.hpp" #include "game_mode.hpp" @@ -263,6 +267,54 @@ void PackPlayerTest(PlayerPack *pPack) SwapLE(*pPack); } +void AppendLE32(std::vector &buffer, int32_t value) +{ + const uint32_t rawValue = static_cast(value); + buffer.push_back(static_cast(rawValue & 0xFF)); + buffer.push_back(static_cast((rawValue >> 8) & 0xFF)); + buffer.push_back(static_cast((rawValue >> 16) & 0xFF)); + buffer.push_back(static_cast((rawValue >> 24) & 0xFF)); +} + +void AppendU8(std::vector &buffer, uint8_t value) +{ + buffer.push_back(static_cast(value)); +} + +void WriteEncodedArchiveEntry(const std::string &savePath, const char *entryName, std::vector decodedData) +{ + std::vector encodedData(codec_get_encoded_len(decodedData.size())); + std::copy(decodedData.begin(), decodedData.end(), encodedData.begin()); + codec_encode(encodedData.data(), decodedData.size(), encodedData.size(), pfile_get_password()); + + std::string savePathCopy = savePath; + SaveWriter saveWriter(std::move(savePathCopy)); + ASSERT_TRUE(saveWriter.WriteFile(entryName, encodedData.data(), encodedData.size())); +} + +void WriteLegacyHotkeys( + const std::string &savePath, + const std::array &hotkeySpells, + const std::array &hotkeyTypes, + SpellID selectedSpell, + SpellType selectedSpellType) +{ + std::vector decodedData; + decodedData.reserve(4 * sizeof(int32_t) + 4 * sizeof(uint8_t) + sizeof(int32_t) + sizeof(uint8_t)); + + for (SpellID spellId : hotkeySpells) { + AppendLE32(decodedData, static_cast(spellId)); + } + for (SpellType spellType : hotkeyTypes) { + AppendU8(decodedData, static_cast(spellType)); + } + + AppendLE32(decodedData, static_cast(selectedSpell)); + AppendU8(decodedData, static_cast(selectedSpellType)); + + WriteEncodedArchiveEntry(savePath, "hotkeys", std::move(decodedData)); +} + void AssertPlayer(Player &player) { ASSERT_EQ(CountU8(player._pSplLvl, 64), 23); @@ -540,6 +592,139 @@ TEST(Writehero, pfile_read_player_from_save_preserves_valid_spell_selections) EXPECT_EQ(player._pRSplType, SpellType::Spell); } +TEST(Writehero, LoadHotkeysWithoutFileSanitizesAndNormalizesSpellState) +{ + LoadCoreArchives(); + LoadGameArchives(); + + if (!HaveMainData()) { + GTEST_SKIP() << "MPQ assets (spawn.mpq or DIABDAT.MPQ) not found - skipping test"; + } + + const std::string savePath = paths::BasePath() + "multi_0.sv"; + paths::SetPrefPath(paths::BasePath()); + RemoveFile(savePath.c_str()); + + gbVanilla = true; + gbIsHellfire = false; + gbIsSpawn = false; + gbIsMultiplayer = true; + gbIsHellfireSaveGame = false; + leveltype = DTYPE_TOWN; + giNumberOfLevels = 17; + + Players.resize(1); + MyPlayerId = 0; + MyPlayer = &Players[MyPlayerId]; + + LoadSpellData(); + LoadPlayerDataFiles(); + LoadMonsterData(); + LoadItemData(); + _uiheroinfo info {}; + info.heroclass = HeroClass::Rogue; + pfile_ui_save_create(&info); + + Player &player = *MyPlayer; + player._pMemSpells = GetSpellBitmask(SpellID::Healing); + player._pSplLvl[static_cast(SpellID::Healing)] = 1; + + player._pRSpell = SpellID::Apocalypse; + player._pRSplType = SpellType::Spell; + player.queuedSpell.spellId = SpellID::Healing; + player.queuedSpell.spellType = SpellType::Spell; + player.queuedSpell.spellFrom = INVITEM_BELT_FIRST; + player.queuedSpell.spellLevel = 7; + player.executedSpell.spellId = SpellID::Healing; + player.executedSpell.spellType = SpellType::Scroll; + player.executedSpell.spellFrom = INVITEM_INV_FIRST; + player.executedSpell.spellLevel = 3; + player.spellFrom = INVITEM_INV_FIRST; + + LoadHotkeys(info.saveNumber, player); + + EXPECT_EQ(player._pRSpell, SpellID::Invalid); + EXPECT_EQ(player._pRSplType, SpellType::Invalid); + EXPECT_EQ(player.queuedSpell.spellId, SpellID::Invalid); + EXPECT_EQ(player.queuedSpell.spellType, SpellType::Invalid); + EXPECT_EQ(player.queuedSpell.spellFrom, 0); + EXPECT_EQ(player.queuedSpell.spellLevel, 0); + EXPECT_EQ(player.executedSpell.spellId, SpellID::Invalid); + EXPECT_EQ(player.executedSpell.spellType, SpellType::Invalid); + EXPECT_EQ(player.executedSpell.spellFrom, 0); + EXPECT_EQ(player.executedSpell.spellLevel, 0); + EXPECT_EQ(player.spellFrom, 0); +} + +TEST(Writehero, LoadHotkeysLegacyFormatSanitizesInvalidSelections) +{ + LoadCoreArchives(); + LoadGameArchives(); + + if (!HaveMainData()) { + GTEST_SKIP() << "MPQ assets (spawn.mpq or DIABDAT.MPQ) not found - skipping test"; + } + + const std::string savePath = paths::BasePath() + "multi_0.sv"; + paths::SetPrefPath(paths::BasePath()); + RemoveFile(savePath.c_str()); + + gbVanilla = false; + gbIsHellfire = false; + gbIsSpawn = false; + gbIsMultiplayer = true; + gbIsHellfireSaveGame = false; + leveltype = DTYPE_TOWN; + giNumberOfLevels = 17; + + Players.resize(1); + MyPlayerId = 0; + MyPlayer = &Players[MyPlayerId]; + + LoadSpellData(); + LoadPlayerDataFiles(); + LoadMonsterData(); + LoadItemData(); + _uiheroinfo info {}; + info.heroclass = HeroClass::Rogue; + pfile_ui_save_create(&info); + + Player &player = *MyPlayer; + player._pMemSpells = GetSpellBitmask(SpellID::Healing); + player._pSplLvl[static_cast(SpellID::Healing)] = 1; + + WriteLegacyHotkeys( + savePath, + { SpellID::Apocalypse, SpellID::Healing, SpellID::Invalid, SpellID::Invalid }, + { SpellType::Spell, SpellType::Spell, SpellType::Invalid, SpellType::Invalid }, + SpellID::Apocalypse, + SpellType::Spell); + + player.queuedSpell.spellFrom = INVITEM_BELT_FIRST; + player.queuedSpell.spellLevel = 9; + player.executedSpell.spellFrom = INVITEM_INV_FIRST; + player.executedSpell.spellLevel = 4; + player.spellFrom = INVITEM_INV_FIRST; + + LoadHotkeys(info.saveNumber, player); + + EXPECT_EQ(player._pSplHotKey[0], SpellID::Invalid); + EXPECT_EQ(player._pSplTHotKey[0], SpellType::Invalid); + EXPECT_EQ(player._pSplHotKey[1], SpellID::Healing); + EXPECT_EQ(player._pSplTHotKey[1], SpellType::Spell); + EXPECT_EQ(player._pRSpell, SpellID::Invalid); + EXPECT_EQ(player._pRSplType, SpellType::Invalid); + EXPECT_EQ(player.queuedSpell.spellId, SpellID::Invalid); + EXPECT_EQ(player.queuedSpell.spellType, SpellType::Invalid); + EXPECT_EQ(player.queuedSpell.spellFrom, 0); + EXPECT_EQ(player.queuedSpell.spellLevel, 0); + EXPECT_EQ(player.executedSpell.spellId, SpellID::Invalid); + EXPECT_EQ(player.executedSpell.spellType, SpellType::Invalid); + EXPECT_EQ(player.executedSpell.spellFrom, 0); + EXPECT_EQ(player.executedSpell.spellLevel, 0); + EXPECT_EQ(player.spellFrom, 0); +} + TEST(Writehero, DiabloRewritePersistsSanitizedSpellSelectionsFromHellfireSave) { LoadCoreArchives();