Browse Source

Normalize spell cast state after hotkey load

Introduce SyncPlayerSpellStateFromSelections and call it in hotkey load paths so queued/executed spell metadata and spellFrom are reset consistently after sanitization.

This centralizes post-load spell-state normalization, removes duplicated per-call-site resync code, and adds writehero regression tests for missing hotkeys data and legacy 4-slot hotkeys format handling.
pull/8507/head
morfidon 5 days ago
parent
commit
4e58828335
  1. 4
      Source/loadsave.cpp
  2. 3
      Source/pfile.cpp
  3. 10
      Source/spells.cpp
  4. 1
      Source/spells.h
  5. 185
      test/writehero_test.cpp

4
Source/loadsave.cpp

@ -2342,6 +2342,7 @@ void LoadHotkeys(uint32_t saveNum, Player &myPlayer)
LoadHelper file(OpenSaveArchive(saveNum), "hotkeys");
if (!file.IsValid()) {
SanitizePlayerSpellSelections(myPlayer);
SyncPlayerSpellStateFromSelections(myPlayer);
return;
}
@ -2379,6 +2380,7 @@ void LoadHotkeys(uint32_t saveNum, Player &myPlayer)
myPlayer._pRSpell = static_cast<SpellID>(file.NextLE<int32_t>());
myPlayer._pRSplType = static_cast<SpellType>(file.NextLE<uint8_t>());
SanitizePlayerSpellSelections(myPlayer);
SyncPlayerSpellStateFromSelections(myPlayer);
}
void SaveHotkeys(SaveWriter &saveWriter, const Player &player)
@ -2532,8 +2534,6 @@ tl::expected<void, std::string> LoadGame(bool firstflag)
ValidatePlayer();
CalcPlrInv(myPlayer, false);
LoadHotkeys(gSaveNumber, myPlayer);
myPlayer.queuedSpell.spellId = myPlayer._pRSpell;
myPlayer.queuedSpell.spellType = myPlayer._pRSplType;
if (sgGameInitInfo.nDifficulty < DIFF_NORMAL || sgGameInitInfo.nDifficulty > DIFF_HELL)
sgGameInitInfo.nDifficulty = DIFF_NORMAL;

3
Source/pfile.cpp

@ -781,10 +781,9 @@ void pfile_read_player_from_save(uint32_t saveNum, Player &player)
CalcPlrInv(player, false);
if (&player == MyPlayer) {
LoadHotkeys(saveNum, player);
player.queuedSpell.spellId = player._pRSpell;
player.queuedSpell.spellType = player._pRSplType;
} else {
SanitizePlayerSpellSelections(player);
SyncPlayerSpellStateFromSelections(player);
}
}

10
Source/spells.cpp

@ -110,6 +110,16 @@ void SanitizePlayerSpellSelections(Player &player)
}
}
void SyncPlayerSpellStateFromSelections(Player &myPlayer)
{
myPlayer.queuedSpell.spellId = myPlayer._pRSpell;
myPlayer.queuedSpell.spellType = myPlayer._pRSplType;
myPlayer.queuedSpell.spellFrom = 0;
myPlayer.queuedSpell.spellLevel = 0;
myPlayer.executedSpell = myPlayer.queuedSpell;
myPlayer.spellFrom = 0;
}
bool IsWallSpell(SpellID spl)
{
return spl == SpellID::FireWall || spl == SpellID::LightningWall;

1
Source/spells.h

@ -23,6 +23,7 @@ bool IsValidSpell(SpellID spl);
bool IsValidSpellFrom(int spellFrom);
bool IsPlayerSpellSelectionValid(const Player &player, SpellID spellId, SpellType spellType);
void SanitizePlayerSpellSelections(Player &player);
void SyncPlayerSpellStateFromSelections(Player &myPlayer);
bool IsWallSpell(SpellID spl);
bool TargetsMonster(SpellID id);
int GetManaAmount(const Player &player, SpellID sn);

185
test/writehero_test.cpp

@ -1,12 +1,16 @@
#include "player_test.h"
#include <algorithm>
#include <array>
#include <cstdint>
#include <cstdio>
#include <cstring>
#include <vector>
#include <gtest/gtest.h>
#include <picosha2.h>
#include "codec.h"
#include "cursor.h"
#include "engine/assets.hpp"
#include "game_mode.hpp"
@ -263,6 +267,54 @@ void PackPlayerTest(PlayerPack *pPack)
SwapLE(*pPack);
}
void AppendLE32(std::vector<std::byte> &buffer, int32_t value)
{
const uint32_t rawValue = static_cast<uint32_t>(value);
buffer.push_back(static_cast<std::byte>(rawValue & 0xFF));
buffer.push_back(static_cast<std::byte>((rawValue >> 8) & 0xFF));
buffer.push_back(static_cast<std::byte>((rawValue >> 16) & 0xFF));
buffer.push_back(static_cast<std::byte>((rawValue >> 24) & 0xFF));
}
void AppendU8(std::vector<std::byte> &buffer, uint8_t value)
{
buffer.push_back(static_cast<std::byte>(value));
}
void WriteEncodedArchiveEntry(const std::string &savePath, const char *entryName, std::vector<std::byte> decodedData)
{
std::vector<std::byte> encodedData(codec_get_encoded_len(decodedData.size()));
std::copy(decodedData.begin(), decodedData.end(), encodedData.begin());
codec_encode(encodedData.data(), decodedData.size(), encodedData.size(), pfile_get_password());
std::string savePathCopy = savePath;
SaveWriter saveWriter(std::move(savePathCopy));
ASSERT_TRUE(saveWriter.WriteFile(entryName, encodedData.data(), encodedData.size()));
}
void WriteLegacyHotkeys(
const std::string &savePath,
const std::array<SpellID, 4> &hotkeySpells,
const std::array<SpellType, 4> &hotkeyTypes,
SpellID selectedSpell,
SpellType selectedSpellType)
{
std::vector<std::byte> decodedData;
decodedData.reserve(4 * sizeof(int32_t) + 4 * sizeof(uint8_t) + sizeof(int32_t) + sizeof(uint8_t));
for (SpellID spellId : hotkeySpells) {
AppendLE32(decodedData, static_cast<int8_t>(spellId));
}
for (SpellType spellType : hotkeyTypes) {
AppendU8(decodedData, static_cast<uint8_t>(spellType));
}
AppendLE32(decodedData, static_cast<int8_t>(selectedSpell));
AppendU8(decodedData, static_cast<uint8_t>(selectedSpellType));
WriteEncodedArchiveEntry(savePath, "hotkeys", std::move(decodedData));
}
void AssertPlayer(Player &player)
{
ASSERT_EQ(CountU8(player._pSplLvl, 64), 23);
@ -540,6 +592,139 @@ TEST(Writehero, pfile_read_player_from_save_preserves_valid_spell_selections)
EXPECT_EQ(player._pRSplType, SpellType::Spell);
}
TEST(Writehero, LoadHotkeysWithoutFileSanitizesAndNormalizesSpellState)
{
LoadCoreArchives();
LoadGameArchives();
if (!HaveMainData()) {
GTEST_SKIP() << "MPQ assets (spawn.mpq or DIABDAT.MPQ) not found - skipping test";
}
const std::string savePath = paths::BasePath() + "multi_0.sv";
paths::SetPrefPath(paths::BasePath());
RemoveFile(savePath.c_str());
gbVanilla = true;
gbIsHellfire = false;
gbIsSpawn = false;
gbIsMultiplayer = true;
gbIsHellfireSaveGame = false;
leveltype = DTYPE_TOWN;
giNumberOfLevels = 17;
Players.resize(1);
MyPlayerId = 0;
MyPlayer = &Players[MyPlayerId];
LoadSpellData();
LoadPlayerDataFiles();
LoadMonsterData();
LoadItemData();
_uiheroinfo info {};
info.heroclass = HeroClass::Rogue;
pfile_ui_save_create(&info);
Player &player = *MyPlayer;
player._pMemSpells = GetSpellBitmask(SpellID::Healing);
player._pSplLvl[static_cast<size_t>(SpellID::Healing)] = 1;
player._pRSpell = SpellID::Apocalypse;
player._pRSplType = SpellType::Spell;
player.queuedSpell.spellId = SpellID::Healing;
player.queuedSpell.spellType = SpellType::Spell;
player.queuedSpell.spellFrom = INVITEM_BELT_FIRST;
player.queuedSpell.spellLevel = 7;
player.executedSpell.spellId = SpellID::Healing;
player.executedSpell.spellType = SpellType::Scroll;
player.executedSpell.spellFrom = INVITEM_INV_FIRST;
player.executedSpell.spellLevel = 3;
player.spellFrom = INVITEM_INV_FIRST;
LoadHotkeys(info.saveNumber, player);
EXPECT_EQ(player._pRSpell, SpellID::Invalid);
EXPECT_EQ(player._pRSplType, SpellType::Invalid);
EXPECT_EQ(player.queuedSpell.spellId, SpellID::Invalid);
EXPECT_EQ(player.queuedSpell.spellType, SpellType::Invalid);
EXPECT_EQ(player.queuedSpell.spellFrom, 0);
EXPECT_EQ(player.queuedSpell.spellLevel, 0);
EXPECT_EQ(player.executedSpell.spellId, SpellID::Invalid);
EXPECT_EQ(player.executedSpell.spellType, SpellType::Invalid);
EXPECT_EQ(player.executedSpell.spellFrom, 0);
EXPECT_EQ(player.executedSpell.spellLevel, 0);
EXPECT_EQ(player.spellFrom, 0);
}
TEST(Writehero, LoadHotkeysLegacyFormatSanitizesInvalidSelections)
{
LoadCoreArchives();
LoadGameArchives();
if (!HaveMainData()) {
GTEST_SKIP() << "MPQ assets (spawn.mpq or DIABDAT.MPQ) not found - skipping test";
}
const std::string savePath = paths::BasePath() + "multi_0.sv";
paths::SetPrefPath(paths::BasePath());
RemoveFile(savePath.c_str());
gbVanilla = false;
gbIsHellfire = false;
gbIsSpawn = false;
gbIsMultiplayer = true;
gbIsHellfireSaveGame = false;
leveltype = DTYPE_TOWN;
giNumberOfLevels = 17;
Players.resize(1);
MyPlayerId = 0;
MyPlayer = &Players[MyPlayerId];
LoadSpellData();
LoadPlayerDataFiles();
LoadMonsterData();
LoadItemData();
_uiheroinfo info {};
info.heroclass = HeroClass::Rogue;
pfile_ui_save_create(&info);
Player &player = *MyPlayer;
player._pMemSpells = GetSpellBitmask(SpellID::Healing);
player._pSplLvl[static_cast<size_t>(SpellID::Healing)] = 1;
WriteLegacyHotkeys(
savePath,
{ SpellID::Apocalypse, SpellID::Healing, SpellID::Invalid, SpellID::Invalid },
{ SpellType::Spell, SpellType::Spell, SpellType::Invalid, SpellType::Invalid },
SpellID::Apocalypse,
SpellType::Spell);
player.queuedSpell.spellFrom = INVITEM_BELT_FIRST;
player.queuedSpell.spellLevel = 9;
player.executedSpell.spellFrom = INVITEM_INV_FIRST;
player.executedSpell.spellLevel = 4;
player.spellFrom = INVITEM_INV_FIRST;
LoadHotkeys(info.saveNumber, player);
EXPECT_EQ(player._pSplHotKey[0], SpellID::Invalid);
EXPECT_EQ(player._pSplTHotKey[0], SpellType::Invalid);
EXPECT_EQ(player._pSplHotKey[1], SpellID::Healing);
EXPECT_EQ(player._pSplTHotKey[1], SpellType::Spell);
EXPECT_EQ(player._pRSpell, SpellID::Invalid);
EXPECT_EQ(player._pRSplType, SpellType::Invalid);
EXPECT_EQ(player.queuedSpell.spellId, SpellID::Invalid);
EXPECT_EQ(player.queuedSpell.spellType, SpellType::Invalid);
EXPECT_EQ(player.queuedSpell.spellFrom, 0);
EXPECT_EQ(player.queuedSpell.spellLevel, 0);
EXPECT_EQ(player.executedSpell.spellId, SpellID::Invalid);
EXPECT_EQ(player.executedSpell.spellType, SpellType::Invalid);
EXPECT_EQ(player.executedSpell.spellFrom, 0);
EXPECT_EQ(player.executedSpell.spellLevel, 0);
EXPECT_EQ(player.spellFrom, 0);
}
TEST(Writehero, DiabloRewritePersistsSanitizedSpellSelectionsFromHellfireSave)
{
LoadCoreArchives();

Loading…
Cancel
Save