diff --git a/Source/loadsave.cpp b/Source/loadsave.cpp index 010658713..78bdd9d9d 100644 --- a/Source/loadsave.cpp +++ b/Source/loadsave.cpp @@ -35,6 +35,7 @@ #include "pfile.h" #include "plrmsg.h" #include "qol/stash.h" +#include "spells.h" #include "stores.h" #include "tables/playerdat.hpp" #include "utils/algorithm/container.hpp" @@ -636,7 +637,6 @@ void LoadPlayer(LoadHelper &file, Player &player) sgGameInitInfo.nDifficulty = static_cast<_difficulty>(file.NextLE()); player.pDamAcFlags = static_cast(file.NextLE()); file.Skip(20); // Available bytes - CalcPlrInv(player, false); player.executedSpell = player.queuedSpell; // Ensures backwards compatibility @@ -2329,23 +2329,54 @@ size_t HotkeysSize(size_t nHotkeys = NumHotkeys) return sizeof(uint8_t) + (nHotkeys * sizeof(int32_t)) + (nHotkeys * sizeof(uint8_t)) + sizeof(int32_t) + sizeof(uint8_t); } +size_t LegacyHotkeysSize() +{ + return HotkeysSize(4) - sizeof(uint8_t); +} + void LoadHotkeys() { - LoadHelper file(OpenSaveArchive(gSaveNumber), "hotkeys"); - if (!file.IsValid()) + if (MyPlayer == nullptr) return; - Player &myPlayer = *MyPlayer; + LoadHotkeys(gSaveNumber, *MyPlayer); +} + +void LoadHotkeys(uint32_t saveNum, Player &myPlayer) +{ + LoadHelper file(OpenSaveArchive(saveNum), "hotkeys"); + if (!file.IsValid()) { + SanitizePlayerSpellSelections(myPlayer); + SyncPlayerSpellStateFromSelections(myPlayer); + return; + } + size_t nHotkeys = 4; // Defaults to old save format number // Refill the spell arrays with no selection std::fill(myPlayer._pSplHotKey, myPlayer._pSplHotKey + NumHotkeys, SpellID::Invalid); std::fill(myPlayer._pSplTHotKey, myPlayer._pSplTHotKey + NumHotkeys, SpellType::Invalid); - // Checking if the save file has the old format with only 4 hotkeys and no header - if (file.IsValid(HotkeysSize(nHotkeys))) { - // The file contains a header byte and at least 4 entries, so we can assume it's a new format save + const size_t fileSize = file.Size(); + + if (fileSize == LegacyHotkeysSize()) { + // Legacy format: exactly 4 hotkeys, no leading count byte. + } else { + if (!file.IsValid(sizeof(uint8_t))) { + SanitizePlayerSpellSelections(myPlayer); + SyncPlayerSpellStateFromSelections(myPlayer); + return; + } + nHotkeys = file.NextLE(); + + const size_t payloadSize = (nHotkeys * sizeof(int32_t)) + (nHotkeys * sizeof(uint8_t)) + sizeof(int32_t) + sizeof(uint8_t); + + if (!file.IsValid(payloadSize)) { + SanitizePlayerSpellSelections(myPlayer); + SyncPlayerSpellStateFromSelections(myPlayer); + return; + } } // Read all hotkeys in the file @@ -2369,6 +2400,8 @@ void LoadHotkeys() // Load the selected spell last myPlayer._pRSpell = static_cast(file.NextLE()); myPlayer._pRSplType = static_cast(file.NextLE()); + SanitizePlayerSpellSelections(myPlayer); + SyncPlayerSpellStateFromSelections(myPlayer); } void SaveHotkeys(SaveWriter &saveWriter, const Player &player) @@ -2519,6 +2552,9 @@ tl::expected LoadGame(bool firstflag) Player &myPlayer = *MyPlayer; LoadPlayer(file, myPlayer); + ValidatePlayer(); + CalcPlrInv(myPlayer, false); + LoadHotkeys(gSaveNumber, myPlayer); if (sgGameInitInfo.nDifficulty < DIFF_NORMAL || sgGameInitInfo.nDifficulty > DIFF_HELL) sgGameInitInfo.nDifficulty = DIFF_NORMAL; diff --git a/Source/loadsave.h b/Source/loadsave.h index 0d139e143..309f6887b 100644 --- a/Source/loadsave.h +++ b/Source/loadsave.h @@ -25,6 +25,7 @@ _item_indexes RemapItemIdxFromSpawn(_item_indexes i); _item_indexes RemapItemIdxToSpawn(_item_indexes i); bool IsHeaderValid(uint32_t magicNumber); void LoadHotkeys(); +void LoadHotkeys(uint32_t saveNum, Player &myPlayer); void LoadHeroItems(Player &player); /** * @brief Remove invalid inventory items from the inventory grid diff --git a/Source/pfile.cpp b/Source/pfile.cpp index ec96c6608..0543adb83 100644 --- a/Source/pfile.cpp +++ b/Source/pfile.cpp @@ -28,6 +28,7 @@ #include "mpq/mpq_common.hpp" #include "pack.h" #include "qol/stash.h" +#include "spells.h" #include "tables/playerdat.hpp" #include "utils/endian_read.hpp" #include "utils/endian_swap.hpp" @@ -690,6 +691,7 @@ bool pfile_ui_set_hero_infos(bool (*uiAddHeroInfo)(_uiheroinfo *)) LoadHeroItems(player); RemoveAllInvalidItems(player); CalcPlrInv(player, false); + SanitizePlayerSpellSelections(player); Game2UiPlayer(player, &uihero, hasSaveGame); uiAddHeroInfo(&uihero); @@ -777,6 +779,12 @@ void pfile_read_player_from_save(uint32_t saveNum, Player &player) LoadHeroItems(player); RemoveAllInvalidItems(player); CalcPlrInv(player, false); + if (&player == MyPlayer) { + LoadHotkeys(saveNum, player); + } else { + SanitizePlayerSpellSelections(player); + SyncPlayerSpellStateFromSelections(player); + } } void pfile_save_level() diff --git a/Source/player.cpp b/Source/player.cpp index 1b6f14096..1a0e2251d 100644 --- a/Source/player.cpp +++ b/Source/player.cpp @@ -1410,6 +1410,8 @@ bool PlrDeathModeOK(Player &player) return false; } +} // namespace + void ValidatePlayer() { assert(MyPlayer != nullptr); @@ -1467,6 +1469,8 @@ void ValidatePlayer() myPlayer._pInfraFlag = false; } +namespace { + HeroClass GetPlayerSpriteClass(HeroClass cls) { if (cls == HeroClass::Bard && !HaveBardAssets()) @@ -2483,8 +2487,6 @@ void InitPlayer(Player &player, bool firstTime) if (firstTime) { player._pRSplType = SpellType::Invalid; player._pRSpell = SpellID::Invalid; - if (&player == MyPlayer) - LoadHotkeys(); player._pSBkSpell = SpellID::Invalid; player.queuedSpell.spellId = player._pRSpell; player.queuedSpell.spellType = player._pRSplType; diff --git a/Source/player.h b/Source/player.h index 5cc99ca19..d84bdf94d 100644 --- a/Source/player.h +++ b/Source/player.h @@ -985,6 +985,7 @@ void CheckPlrSpell(bool isShiftHeld, SpellID spellID = MyPlayer->_pRSpell, Spell void SyncPlrAnim(Player &player); void SyncInitPlrPos(Player &player); void SyncInitPlr(Player &player); +void ValidatePlayer(); void CheckStats(Player &player); void ModifyPlrStr(Player &player, int l); void ModifyPlrMag(Player &player, int l); diff --git a/Source/spells.cpp b/Source/spells.cpp index 476e2013b..f6c69ffcf 100644 --- a/Source/spells.cpp +++ b/Source/spells.cpp @@ -32,21 +32,7 @@ namespace { */ bool IsReadiedSpellValid(const Player &player) { - switch (player._pRSplType) { - case SpellType::Skill: - case SpellType::Spell: - case SpellType::Invalid: - return true; - - case SpellType::Charges: - return (player._pISpells & GetSpellBitmask(player._pRSpell)) != 0; - - case SpellType::Scroll: - return (player._pScrlSpells & GetSpellBitmask(player._pRSpell)) != 0; - - default: - return false; - } + return IsPlayerSpellSelectionValid(player, player._pRSpell, player._pRSplType); } /** @@ -85,6 +71,55 @@ bool IsValidSpellFrom(int spellFrom) return false; } +bool IsPlayerSpellSelectionValid(const Player &player, SpellID spellId, SpellType spellType) +{ + if (spellType == SpellType::Invalid) { + return spellId == SpellID::Invalid; + } + + if (!IsValidSpell(spellId)) { + return false; + } + + switch (spellType) { + case SpellType::Skill: + return (player._pAblSpells & GetSpellBitmask(spellId)) != 0; + case SpellType::Spell: + return (player._pMemSpells & GetSpellBitmask(spellId)) != 0 && player.GetSpellLevel(spellId) > 0; + case SpellType::Scroll: + return (player._pScrlSpells & GetSpellBitmask(spellId)) != 0; + case SpellType::Charges: + return (player._pISpells & GetSpellBitmask(spellId)) != 0; + default: + return false; + } +} + +void SanitizePlayerSpellSelections(Player &player) +{ + for (size_t i = 0; i < NumHotkeys; ++i) { + if (!IsPlayerSpellSelectionValid(player, player._pSplHotKey[i], player._pSplTHotKey[i])) { + player._pSplHotKey[i] = SpellID::Invalid; + player._pSplTHotKey[i] = SpellType::Invalid; + } + } + + if (!IsPlayerSpellSelectionValid(player, player._pRSpell, player._pRSplType)) { + player._pRSpell = SpellID::Invalid; + player._pRSplType = SpellType::Invalid; + } +} + +void SyncPlayerSpellStateFromSelections(Player &myPlayer) +{ + myPlayer.queuedSpell.spellId = myPlayer._pRSpell; + myPlayer.queuedSpell.spellType = myPlayer._pRSplType; + myPlayer.queuedSpell.spellFrom = 0; + myPlayer.queuedSpell.spellLevel = 0; + myPlayer.executedSpell = myPlayer.queuedSpell; + myPlayer.spellFrom = 0; +} + bool IsWallSpell(SpellID spl) { return spl == SpellID::FireWall || spl == SpellID::LightningWall; diff --git a/Source/spells.h b/Source/spells.h index 01b72e8b4..a315bd0b1 100644 --- a/Source/spells.h +++ b/Source/spells.h @@ -21,6 +21,9 @@ enum class SpellCheckResult : uint8_t { bool IsValidSpell(SpellID spl); bool IsValidSpellFrom(int spellFrom); +bool IsPlayerSpellSelectionValid(const Player &player, SpellID spellId, SpellType spellType); +void SanitizePlayerSpellSelections(Player &player); +void SyncPlayerSpellStateFromSelections(Player &myPlayer); bool IsWallSpell(SpellID spl); bool TargetsMonster(SpellID id); int GetManaAmount(const Player &player, SpellID sn); diff --git a/test/player_test.cpp b/test/player_test.cpp index 4d4108e7f..32a9f9d4b 100644 --- a/test/player_test.cpp +++ b/test/player_test.cpp @@ -5,6 +5,7 @@ #include "cursor.h" #include "engine/assets.hpp" #include "init.hpp" +#include "spells.h" #include "tables/playerdat.hpp" using namespace devilution; @@ -204,3 +205,52 @@ TEST(Player, CreatePlayer) CreatePlayer(Players[0], HeroClass::Rogue); AssertPlayer(Players[0]); } + +TEST(Player, IsPlayerSpellSelectionValidChecksSpellSources) +{ + LoadCoreArchives(); + LoadGameArchives(); + if (!HaveMainData()) { + GTEST_SKIP() << "MPQ assets (spawn.mpq or DIABDAT.MPQ) not found - skipping test"; + } + LoadSpellData(); + + const SpellID spell = SpellID::Healing; + const uint64_t mask = GetSpellBitmask(spell); + Player player {}; + + EXPECT_FALSE(IsPlayerSpellSelectionValid(player, spell, SpellType::Spell)); + player._pMemSpells = mask; + EXPECT_FALSE(IsPlayerSpellSelectionValid(player, spell, SpellType::Spell)); + player._pSplLvl[static_cast(spell)] = 1; + EXPECT_TRUE(IsPlayerSpellSelectionValid(player, spell, SpellType::Spell)); + + EXPECT_FALSE(IsPlayerSpellSelectionValid(player, spell, SpellType::Scroll)); + player._pScrlSpells = mask; + EXPECT_TRUE(IsPlayerSpellSelectionValid(player, spell, SpellType::Scroll)); + + EXPECT_FALSE(IsPlayerSpellSelectionValid(player, spell, SpellType::Charges)); + player._pISpells = mask; + EXPECT_TRUE(IsPlayerSpellSelectionValid(player, spell, SpellType::Charges)); + + EXPECT_FALSE(IsPlayerSpellSelectionValid(player, spell, SpellType::Skill)); + player._pAblSpells = mask; + EXPECT_TRUE(IsPlayerSpellSelectionValid(player, spell, SpellType::Skill)); +} + +TEST(Player, IsPlayerSpellSelectionValidRejectsInvalidSelections) +{ + LoadCoreArchives(); + LoadGameArchives(); + if (!HaveMainData()) { + GTEST_SKIP() << "MPQ assets (spawn.mpq or DIABDAT.MPQ) not found - skipping test"; + } + LoadSpellData(); + + Player player {}; + + EXPECT_TRUE(IsPlayerSpellSelectionValid(player, SpellID::Invalid, SpellType::Invalid)); + EXPECT_FALSE(IsPlayerSpellSelectionValid(player, SpellID::Healing, SpellType::Invalid)); + EXPECT_FALSE(IsPlayerSpellSelectionValid(player, SpellID::Invalid, SpellType::Spell)); + EXPECT_FALSE(IsPlayerSpellSelectionValid(player, SpellID::Null, SpellType::Spell)); +} diff --git a/test/writehero_test.cpp b/test/writehero_test.cpp index 892f7e18d..10fcac3cd 100644 --- a/test/writehero_test.cpp +++ b/test/writehero_test.cpp @@ -1,25 +1,34 @@ #include "player_test.h" +#include +#include #include #include +#include #include #include #include +#include "codec.h" #include "cursor.h" #include "engine/assets.hpp" #include "game_mode.hpp" #include "init.hpp" #include "loadsave.h" +#include "menu.h" #include "pack.h" #include "pfile.h" +#include "spells.h" #include "tables/playerdat.hpp" #include "utils/endian_swap.hpp" #include "utils/file_util.h" #include "utils/paths.h" namespace devilution { + +uint32_t gSaveNumber = 0; + namespace { constexpr int SpellDatVanilla[] = { @@ -258,6 +267,54 @@ void PackPlayerTest(PlayerPack *pPack) SwapLE(*pPack); } +void AppendLE32(std::vector &buffer, int32_t value) +{ + const uint32_t rawValue = static_cast(value); + buffer.push_back(static_cast(rawValue & 0xFF)); + buffer.push_back(static_cast((rawValue >> 8) & 0xFF)); + buffer.push_back(static_cast((rawValue >> 16) & 0xFF)); + buffer.push_back(static_cast((rawValue >> 24) & 0xFF)); +} + +void AppendU8(std::vector &buffer, uint8_t value) +{ + buffer.push_back(static_cast(value)); +} + +void WriteEncodedArchiveEntry(const std::string &savePath, const char *entryName, std::vector decodedData) +{ + std::vector encodedData(codec_get_encoded_len(decodedData.size())); + std::copy(decodedData.begin(), decodedData.end(), encodedData.begin()); + codec_encode(encodedData.data(), decodedData.size(), encodedData.size(), pfile_get_password()); + + std::string savePathCopy = savePath; + SaveWriter saveWriter(std::move(savePathCopy)); + ASSERT_TRUE(saveWriter.WriteFile(entryName, encodedData.data(), encodedData.size())); +} + +void WriteLegacyHotkeys( + const std::string &savePath, + const std::array &hotkeySpells, + const std::array &hotkeyTypes, + SpellID selectedSpell, + SpellType selectedSpellType) +{ + std::vector decodedData; + decodedData.reserve(4 * sizeof(int32_t) + 4 * sizeof(uint8_t) + sizeof(int32_t) + sizeof(uint8_t)); + + for (SpellID spellId : hotkeySpells) { + AppendLE32(decodedData, static_cast(spellId)); + } + for (SpellType spellType : hotkeyTypes) { + AppendU8(decodedData, static_cast(spellType)); + } + + AppendLE32(decodedData, static_cast(selectedSpell)); + AppendU8(decodedData, static_cast(selectedSpellType)); + + WriteEncodedArchiveEntry(savePath, "hotkeys", std::move(decodedData)); +} + void AssertPlayer(Player &player) { ASSERT_EQ(CountU8(player._pSplLvl, 64), 23); @@ -416,5 +473,343 @@ TEST(Writehero, pfile_write_hero) "a79367caae6192d54703168d82e0316aa289b2a33251255fad8abe34889c1d3a"); } +TEST(Writehero, pfile_read_player_from_save_clears_invalid_spell_selections) +{ + LoadCoreArchives(); + LoadGameArchives(); + + if (!HaveMainData()) { + GTEST_SKIP() << "MPQ assets (spawn.mpq or DIABDAT.MPQ) not found - skipping test"; + } + + const std::string savePath = paths::BasePath() + "multi_0.sv"; + paths::SetPrefPath(paths::BasePath()); + RemoveFile(savePath.c_str()); + + gbVanilla = false; + gbIsHellfire = false; + gbIsSpawn = false; + gbIsMultiplayer = true; + gbIsHellfireSaveGame = false; + leveltype = DTYPE_TOWN; + giNumberOfLevels = 17; + + Players.resize(1); + MyPlayerId = 0; + MyPlayer = &Players[MyPlayerId]; + + LoadSpellData(); + LoadPlayerDataFiles(); + LoadMonsterData(); + LoadItemData(); + _uiheroinfo info {}; + info.heroclass = HeroClass::Rogue; + pfile_ui_save_create(&info); + + Player &player = *MyPlayer; + player._pMemSpells = GetSpellBitmask(SpellID::Healing); + player._pSplLvl[static_cast(SpellID::Healing)] = 1; + player._pMemSpells |= GetSpellBitmask(SpellID::Apocalypse); + player._pSplHotKey[0] = SpellID::Apocalypse; + player._pSplTHotKey[0] = SpellType::Spell; + player._pSplHotKey[1] = SpellID::Healing; + player._pSplTHotKey[1] = SpellType::Spell; + player._pRSpell = SpellID::Apocalypse; + player._pRSplType = SpellType::Spell; + + pfile_write_hero(); + + player._pSplHotKey[0] = SpellID::Invalid; + player._pSplTHotKey[0] = SpellType::Invalid; + player._pSplHotKey[1] = SpellID::Invalid; + player._pSplTHotKey[1] = SpellType::Invalid; + player._pRSpell = SpellID::Invalid; + player._pRSplType = SpellType::Invalid; + + pfile_read_player_from_save(info.saveNumber, player); + + EXPECT_EQ(player._pSplHotKey[0], SpellID::Invalid); + EXPECT_EQ(player._pSplTHotKey[0], SpellType::Invalid); + EXPECT_EQ(player._pSplHotKey[1], SpellID::Healing); + EXPECT_EQ(player._pSplTHotKey[1], SpellType::Spell); + EXPECT_EQ(player._pRSpell, SpellID::Invalid); + EXPECT_EQ(player._pRSplType, SpellType::Invalid); +} + +TEST(Writehero, LoadHotkeysWithoutFileSanitizesAndNormalizesSpellState) +{ + LoadCoreArchives(); + LoadGameArchives(); + + if (!HaveMainData()) { + GTEST_SKIP() << "MPQ assets (spawn.mpq or DIABDAT.MPQ) not found - skipping test"; + } + + const std::string savePath = paths::BasePath() + "multi_0.sv"; + paths::SetPrefPath(paths::BasePath()); + RemoveFile(savePath.c_str()); + + gbVanilla = true; + gbIsHellfire = false; + gbIsSpawn = false; + gbIsMultiplayer = true; + gbIsHellfireSaveGame = false; + leveltype = DTYPE_TOWN; + giNumberOfLevels = 17; + + Players.resize(1); + MyPlayerId = 0; + MyPlayer = &Players[MyPlayerId]; + + LoadSpellData(); + LoadPlayerDataFiles(); + LoadMonsterData(); + LoadItemData(); + _uiheroinfo info {}; + info.heroclass = HeroClass::Rogue; + pfile_ui_save_create(&info); + + Player &player = *MyPlayer; + player._pMemSpells = GetSpellBitmask(SpellID::Healing); + player._pSplLvl[static_cast(SpellID::Healing)] = 1; + + player._pRSpell = SpellID::Apocalypse; + player._pRSplType = SpellType::Spell; + player.queuedSpell.spellId = SpellID::Healing; + player.queuedSpell.spellType = SpellType::Spell; + player.queuedSpell.spellFrom = INVITEM_BELT_FIRST; + player.queuedSpell.spellLevel = 7; + player.executedSpell.spellId = SpellID::Healing; + player.executedSpell.spellType = SpellType::Scroll; + player.executedSpell.spellFrom = INVITEM_INV_FIRST; + player.executedSpell.spellLevel = 3; + player.spellFrom = INVITEM_INV_FIRST; + + LoadHotkeys(info.saveNumber, player); + + EXPECT_EQ(player._pRSpell, SpellID::Invalid); + EXPECT_EQ(player._pRSplType, SpellType::Invalid); + EXPECT_EQ(player.queuedSpell.spellId, SpellID::Invalid); + EXPECT_EQ(player.queuedSpell.spellType, SpellType::Invalid); + EXPECT_EQ(player.queuedSpell.spellFrom, 0); + EXPECT_EQ(player.queuedSpell.spellLevel, 0); + EXPECT_EQ(player.executedSpell.spellId, SpellID::Invalid); + EXPECT_EQ(player.executedSpell.spellType, SpellType::Invalid); + EXPECT_EQ(player.executedSpell.spellFrom, 0); + EXPECT_EQ(player.executedSpell.spellLevel, 0); + EXPECT_EQ(player.spellFrom, 0); +} + +TEST(Writehero, LoadHotkeysLegacyFormatSanitizesInvalidSelections) +{ + LoadCoreArchives(); + LoadGameArchives(); + + if (!HaveMainData()) { + GTEST_SKIP() << "MPQ assets (spawn.mpq or DIABDAT.MPQ) not found - skipping test"; + } + + const std::string savePath = paths::BasePath() + "multi_0.sv"; + paths::SetPrefPath(paths::BasePath()); + RemoveFile(savePath.c_str()); + + gbVanilla = false; + gbIsHellfire = false; + gbIsSpawn = false; + gbIsMultiplayer = true; + gbIsHellfireSaveGame = false; + leveltype = DTYPE_TOWN; + giNumberOfLevels = 17; + + Players.resize(1); + MyPlayerId = 0; + MyPlayer = &Players[MyPlayerId]; + + LoadSpellData(); + LoadPlayerDataFiles(); + LoadMonsterData(); + LoadItemData(); + _uiheroinfo info {}; + info.heroclass = HeroClass::Rogue; + pfile_ui_save_create(&info); + + Player &player = *MyPlayer; + player._pMemSpells = GetSpellBitmask(SpellID::Healing); + player._pSplLvl[static_cast(SpellID::Healing)] = 1; + + WriteLegacyHotkeys( + savePath, + { SpellID::Apocalypse, SpellID::Healing, SpellID::Invalid, SpellID::Invalid }, + { SpellType::Spell, SpellType::Spell, SpellType::Invalid, SpellType::Invalid }, + SpellID::Apocalypse, + SpellType::Spell); + + player.queuedSpell.spellFrom = INVITEM_BELT_FIRST; + player.queuedSpell.spellLevel = 9; + player.executedSpell.spellFrom = INVITEM_INV_FIRST; + player.executedSpell.spellLevel = 4; + player.spellFrom = INVITEM_INV_FIRST; + + LoadHotkeys(info.saveNumber, player); + + EXPECT_EQ(player._pSplHotKey[0], SpellID::Invalid); + EXPECT_EQ(player._pSplTHotKey[0], SpellType::Invalid); + EXPECT_EQ(player._pSplHotKey[1], SpellID::Healing); + EXPECT_EQ(player._pSplTHotKey[1], SpellType::Spell); + EXPECT_EQ(player._pRSpell, SpellID::Invalid); + EXPECT_EQ(player._pRSplType, SpellType::Invalid); + EXPECT_EQ(player.queuedSpell.spellId, SpellID::Invalid); + EXPECT_EQ(player.queuedSpell.spellType, SpellType::Invalid); + EXPECT_EQ(player.queuedSpell.spellFrom, 0); + EXPECT_EQ(player.queuedSpell.spellLevel, 0); + EXPECT_EQ(player.executedSpell.spellId, SpellID::Invalid); + EXPECT_EQ(player.executedSpell.spellType, SpellType::Invalid); + EXPECT_EQ(player.executedSpell.spellFrom, 0); + EXPECT_EQ(player.executedSpell.spellLevel, 0); + EXPECT_EQ(player.spellFrom, 0); +} + +TEST(Writehero, LoadHotkeysLegacyFormatPreservesValidScrollSelection) +{ + LoadCoreArchives(); + LoadGameArchives(); + + if (!HaveMainData()) { + GTEST_SKIP() << "MPQ assets (spawn.mpq or DIABDAT.MPQ) not found - skipping test"; + } + + const std::string savePath = paths::BasePath() + "multi_0.sv"; + paths::SetPrefPath(paths::BasePath()); + RemoveFile(savePath.c_str()); + + gbVanilla = false; + gbIsHellfire = false; + gbIsSpawn = false; + gbIsMultiplayer = true; + gbIsHellfireSaveGame = false; + leveltype = DTYPE_TOWN; + giNumberOfLevels = 17; + + Players.resize(1); + MyPlayerId = 0; + MyPlayer = &Players[MyPlayerId]; + + LoadSpellData(); + LoadPlayerDataFiles(); + LoadMonsterData(); + LoadItemData(); + _uiheroinfo info {}; + info.heroclass = HeroClass::Rogue; + pfile_ui_save_create(&info); + + Player &player = *MyPlayer; + player._pNumInv = 1; + InitializeItem(player.InvList[0], ItemMiscIdIdx(IMISC_SCROLL)); + player.InvList[0]._iSpell = SpellID::Healing; + player.InvList[0]._iStatFlag = true; + player._pScrlSpells = GetSpellBitmask(SpellID::Healing); + ASSERT_TRUE((player._pScrlSpells & GetSpellBitmask(SpellID::Healing)) != 0); + ASSERT_TRUE(IsPlayerSpellSelectionValid(player, SpellID::Healing, SpellType::Scroll)); + + WriteLegacyHotkeys( + savePath, + { SpellID::Healing, SpellID::Invalid, SpellID::Invalid, SpellID::Invalid }, + { SpellType::Scroll, SpellType::Invalid, SpellType::Invalid, SpellType::Invalid }, + SpellID::Healing, + SpellType::Scroll); + + LoadHotkeys(info.saveNumber, player); + + EXPECT_EQ(player._pRSpell, SpellID::Healing); + EXPECT_EQ(player._pRSplType, SpellType::Scroll); + EXPECT_EQ(player.queuedSpell.spellId, SpellID::Healing); + EXPECT_EQ(player.queuedSpell.spellType, SpellType::Scroll); + EXPECT_EQ(player.queuedSpell.spellFrom, 0); + leveltype = DTYPE_CATHEDRAL; + EXPECT_TRUE(CanUseScroll(player, SpellID::Healing)); + + player.InvList[0].clear(); + player._pNumInv = 0; + player._pScrlSpells = 0; + + EXPECT_FALSE(CanUseScroll(player, SpellID::Healing)); +} + +TEST(Writehero, DiabloRewritePersistsSanitizedSpellSelectionsFromHellfireSave) +{ + LoadCoreArchives(); + LoadGameArchives(); + + if (!HaveMainData()) { + GTEST_SKIP() << "MPQ assets (spawn.mpq or DIABDAT.MPQ) not found - skipping test"; + } + + const std::string savePath = paths::BasePath() + "multi_0.sv"; + const std::string hellfireSavePath = paths::BasePath() + "multi_0.hsv"; + paths::SetPrefPath(paths::BasePath()); + RemoveFile(savePath.c_str()); + RemoveFile(hellfireSavePath.c_str()); + + gbVanilla = false; + gbIsSpawn = false; + gbIsMultiplayer = true; + leveltype = DTYPE_TOWN; + currlevel = 0; + ViewPosition = {}; + giNumberOfLevels = 25; + gbIsHellfire = true; + gbIsHellfireSaveGame = true; + + Players.resize(1); + MyPlayerId = 0; + MyPlayer = &Players[MyPlayerId]; + + LoadSpellData(); + LoadPlayerDataFiles(); + LoadMonsterData(); + LoadItemData(); + _uiheroinfo info {}; + info.heroclass = HeroClass::Rogue; + pfile_ui_save_create(&info); + gSaveNumber = info.saveNumber; + + Player &player = *MyPlayer; + player._pMemSpells = GetSpellBitmask(SpellID::Healing) | GetSpellBitmask(SpellID::Apocalypse); + player._pSplLvl[static_cast(SpellID::Healing)] = 1; + player._pSplLvl[static_cast(SpellID::Apocalypse)] = 1; + player._pSplHotKey[0] = SpellID::Apocalypse; + player._pSplTHotKey[0] = SpellType::Spell; + player._pSplHotKey[1] = SpellID::Healing; + player._pSplTHotKey[1] = SpellType::Spell; + player._pRSpell = SpellID::Apocalypse; + player._pRSplType = SpellType::Spell; + + pfile_write_hero(/*writeGameData=*/true); + RenameFile(hellfireSavePath.c_str(), savePath.c_str()); + + gbIsHellfire = false; + gbIsHellfireSaveGame = false; + giNumberOfLevels = 17; + + pfile_read_player_from_save(info.saveNumber, player); + pfile_write_hero(); + + player._pSplHotKey[0] = SpellID::Apocalypse; + player._pSplTHotKey[0] = SpellType::Spell; + player._pSplHotKey[1] = SpellID::Invalid; + player._pSplTHotKey[1] = SpellType::Invalid; + player._pRSpell = SpellID::Apocalypse; + player._pRSplType = SpellType::Spell; + + LoadHotkeys(); + + EXPECT_EQ(player._pSplHotKey[0], SpellID::Invalid); + EXPECT_EQ(player._pSplTHotKey[0], SpellType::Invalid); + EXPECT_EQ(player._pSplHotKey[1], SpellID::Healing); + EXPECT_EQ(player._pSplTHotKey[1], SpellType::Spell); + EXPECT_EQ(player._pRSpell, SpellID::Invalid); + EXPECT_EQ(player._pRSplType, SpellType::Invalid); +} + } // namespace } // namespace devilution