Browse Source

Sanitize player spell selections on load to prevent invalid spell hotkeys

pull/8507/head
morfidon 5 days ago
parent
commit
94b1a8677a
  1. 7
      Source/loadsave.cpp
  2. 3
      Source/pfile.cpp
  3. 15
      Source/player.cpp
  4. 55
      Source/spells.cpp
  5. 2
      Source/spells.h
  6. 50
      test/player_test.cpp
  7. 120
      test/writehero_test.cpp

7
Source/loadsave.cpp

@ -35,6 +35,7 @@
#include "pfile.h"
#include "plrmsg.h"
#include "qol/stash.h"
#include "spells.h"
#include "stores.h"
#include "tables/playerdat.hpp"
#include "utils/algorithm/container.hpp"
@ -49,6 +50,8 @@ namespace devilution {
bool gbIsHellfireSaveGame;
uint8_t giNumberOfLevels;
void ValidatePlayerForLoad();
namespace {
constexpr size_t MaxMissilesForSaveGame = 125;
@ -636,7 +639,6 @@ void LoadPlayer(LoadHelper &file, Player &player)
sgGameInitInfo.nDifficulty = static_cast<_difficulty>(file.NextLE<uint32_t>());
player.pDamAcFlags = static_cast<ItemSpecialEffectHf>(file.NextLE<uint32_t>());
file.Skip(20); // Available bytes
CalcPlrInv(player, false);
player.executedSpell = player.queuedSpell; // Ensures backwards compatibility
@ -2519,6 +2521,9 @@ tl::expected<void, std::string> LoadGame(bool firstflag)
Player &myPlayer = *MyPlayer;
LoadPlayer(file, myPlayer);
ValidatePlayerForLoad();
CalcPlrInv(myPlayer, false);
SanitizePlayerSpellSelections(myPlayer);
if (sgGameInitInfo.nDifficulty < DIFF_NORMAL || sgGameInitInfo.nDifficulty > DIFF_HELL)
sgGameInitInfo.nDifficulty = DIFF_NORMAL;

3
Source/pfile.cpp

@ -28,6 +28,7 @@
#include "mpq/mpq_common.hpp"
#include "pack.h"
#include "qol/stash.h"
#include "spells.h"
#include "tables/playerdat.hpp"
#include "utils/endian_read.hpp"
#include "utils/endian_swap.hpp"
@ -690,6 +691,7 @@ bool pfile_ui_set_hero_infos(bool (*uiAddHeroInfo)(_uiheroinfo *))
LoadHeroItems(player);
RemoveAllInvalidItems(player);
CalcPlrInv(player, false);
SanitizePlayerSpellSelections(player);
Game2UiPlayer(player, &uihero, hasSaveGame);
uiAddHeroInfo(&uihero);
@ -777,6 +779,7 @@ void pfile_read_player_from_save(uint32_t saveNum, Player &player)
LoadHeroItems(player);
RemoveAllInvalidItems(player);
CalcPlrInv(player, false);
SanitizePlayerSpellSelections(player);
}
void pfile_save_level()

15
Source/player.cpp

@ -1528,11 +1528,16 @@ void GetPlayerGraphicsPath(std::string_view path, std::string_view prefix, std::
*BufCopy(out, "plrgfx\\", path, "\\", prefix, "\\", prefix, type) = '\0';
}
} // namespace
void Player::CalcScrolls()
{
_pScrlSpells = 0;
} // namespace
void ValidatePlayerForLoad()
{
ValidatePlayer();
}
void Player::CalcScrolls()
{
_pScrlSpells = 0;
for (const Item &item : InventoryAndBeltPlayerItemsRange { *this }) {
if (item.isScroll() && item._iStatFlag) {
_pScrlSpells |= GetSpellBitmask(item._iSpell);

55
Source/spells.cpp

@ -32,21 +32,7 @@ namespace {
*/
bool IsReadiedSpellValid(const Player &player)
{
switch (player._pRSplType) {
case SpellType::Skill:
case SpellType::Spell:
case SpellType::Invalid:
return true;
case SpellType::Charges:
return (player._pISpells & GetSpellBitmask(player._pRSpell)) != 0;
case SpellType::Scroll:
return (player._pScrlSpells & GetSpellBitmask(player._pRSpell)) != 0;
default:
return false;
}
return IsPlayerSpellSelectionValid(player, player._pRSpell, player._pRSplType);
}
/**
@ -85,6 +71,45 @@ bool IsValidSpellFrom(int spellFrom)
return false;
}
bool IsPlayerSpellSelectionValid(const Player &player, SpellID spellId, SpellType spellType)
{
if (spellType == SpellType::Invalid) {
return spellId == SpellID::Invalid;
}
if (!IsValidSpell(spellId)) {
return false;
}
switch (spellType) {
case SpellType::Skill:
return (player._pAblSpells & GetSpellBitmask(spellId)) != 0;
case SpellType::Spell:
return (player._pMemSpells & GetSpellBitmask(spellId)) != 0 && player.GetSpellLevel(spellId) > 0;
case SpellType::Scroll:
return (player._pScrlSpells & GetSpellBitmask(spellId)) != 0;
case SpellType::Charges:
return (player._pISpells & GetSpellBitmask(spellId)) != 0;
default:
return false;
}
}
void SanitizePlayerSpellSelections(Player &player)
{
for (size_t i = 0; i < NumHotkeys; ++i) {
if (!IsPlayerSpellSelectionValid(player, player._pSplHotKey[i], player._pSplTHotKey[i])) {
player._pSplHotKey[i] = SpellID::Invalid;
player._pSplTHotKey[i] = SpellType::Invalid;
}
}
if (!IsPlayerSpellSelectionValid(player, player._pRSpell, player._pRSplType)) {
player._pRSpell = SpellID::Invalid;
player._pRSplType = SpellType::Invalid;
}
}
bool IsWallSpell(SpellID spl)
{
return spl == SpellID::FireWall || spl == SpellID::LightningWall;

2
Source/spells.h

@ -21,6 +21,8 @@ enum class SpellCheckResult : uint8_t {
bool IsValidSpell(SpellID spl);
bool IsValidSpellFrom(int spellFrom);
bool IsPlayerSpellSelectionValid(const Player &player, SpellID spellId, SpellType spellType);
void SanitizePlayerSpellSelections(Player &player);
bool IsWallSpell(SpellID spl);
bool TargetsMonster(SpellID id);
int GetManaAmount(const Player &player, SpellID sn);

50
test/player_test.cpp

@ -5,6 +5,7 @@
#include "cursor.h"
#include "engine/assets.hpp"
#include "init.hpp"
#include "spells.h"
#include "tables/playerdat.hpp"
using namespace devilution;
@ -204,3 +205,52 @@ TEST(Player, CreatePlayer)
CreatePlayer(Players[0], HeroClass::Rogue);
AssertPlayer(Players[0]);
}
TEST(Player, IsPlayerSpellSelectionValidChecksSpellSources)
{
LoadCoreArchives();
LoadGameArchives();
if (!HaveMainData()) {
GTEST_SKIP() << "MPQ assets (spawn.mpq or DIABDAT.MPQ) not found - skipping test";
}
LoadSpellData();
const SpellID spell = SpellID::Healing;
const uint64_t mask = GetSpellBitmask(spell);
Player player {};
EXPECT_FALSE(IsPlayerSpellSelectionValid(player, spell, SpellType::Spell));
player._pMemSpells = mask;
EXPECT_FALSE(IsPlayerSpellSelectionValid(player, spell, SpellType::Spell));
player._pSplLvl[static_cast<size_t>(spell)] = 1;
EXPECT_TRUE(IsPlayerSpellSelectionValid(player, spell, SpellType::Spell));
EXPECT_FALSE(IsPlayerSpellSelectionValid(player, spell, SpellType::Scroll));
player._pScrlSpells = mask;
EXPECT_TRUE(IsPlayerSpellSelectionValid(player, spell, SpellType::Scroll));
EXPECT_FALSE(IsPlayerSpellSelectionValid(player, spell, SpellType::Charges));
player._pISpells = mask;
EXPECT_TRUE(IsPlayerSpellSelectionValid(player, spell, SpellType::Charges));
EXPECT_FALSE(IsPlayerSpellSelectionValid(player, spell, SpellType::Skill));
player._pAblSpells = mask;
EXPECT_TRUE(IsPlayerSpellSelectionValid(player, spell, SpellType::Skill));
}
TEST(Player, IsPlayerSpellSelectionValidRejectsInvalidSelections)
{
LoadCoreArchives();
LoadGameArchives();
if (!HaveMainData()) {
GTEST_SKIP() << "MPQ assets (spawn.mpq or DIABDAT.MPQ) not found - skipping test";
}
LoadSpellData();
Player player {};
EXPECT_TRUE(IsPlayerSpellSelectionValid(player, SpellID::Invalid, SpellType::Invalid));
EXPECT_FALSE(IsPlayerSpellSelectionValid(player, SpellID::Healing, SpellType::Invalid));
EXPECT_FALSE(IsPlayerSpellSelectionValid(player, SpellID::Invalid, SpellType::Spell));
EXPECT_FALSE(IsPlayerSpellSelectionValid(player, SpellID::Null, SpellType::Spell));
}

120
test/writehero_test.cpp

@ -14,6 +14,7 @@
#include "loadsave.h"
#include "pack.h"
#include "pfile.h"
#include "spells.h"
#include "tables/playerdat.hpp"
#include "utils/endian_swap.hpp"
#include "utils/file_util.h"
@ -416,5 +417,124 @@ TEST(Writehero, pfile_write_hero)
"a79367caae6192d54703168d82e0316aa289b2a33251255fad8abe34889c1d3a");
}
TEST(Writehero, pfile_read_player_from_save_clears_invalid_spell_selections)
{
LoadCoreArchives();
LoadGameArchives();
if (!HaveMainData()) {
GTEST_SKIP() << "MPQ assets (spawn.mpq or DIABDAT.MPQ) not found - skipping test";
}
const std::string savePath = paths::BasePath() + "multi_0.sv";
paths::SetPrefPath(paths::BasePath());
RemoveFile(savePath.c_str());
gbVanilla = false;
gbIsHellfire = false;
gbIsSpawn = false;
gbIsMultiplayer = true;
gbIsHellfireSaveGame = false;
leveltype = DTYPE_TOWN;
giNumberOfLevels = 17;
Players.resize(1);
MyPlayerId = 0;
MyPlayer = &Players[MyPlayerId];
LoadSpellData();
LoadPlayerDataFiles();
LoadMonsterData();
LoadItemData();
_uiheroinfo info {};
info.heroclass = HeroClass::Rogue;
pfile_ui_save_create(&info);
Player &player = *MyPlayer;
player._pMemSpells = GetSpellBitmask(SpellID::Healing);
player._pSplLvl[static_cast<size_t>(SpellID::Healing)] = 1;
player._pMemSpells |= GetSpellBitmask(SpellID::Apocalypse);
player._pSplHotKey[0] = SpellID::Apocalypse;
player._pSplTHotKey[0] = SpellType::Spell;
player._pSplHotKey[1] = SpellID::Healing;
player._pSplTHotKey[1] = SpellType::Spell;
player._pRSpell = SpellID::Apocalypse;
player._pRSplType = SpellType::Spell;
pfile_write_hero();
player._pSplHotKey[0] = SpellID::Invalid;
player._pSplTHotKey[0] = SpellType::Invalid;
player._pSplHotKey[1] = SpellID::Invalid;
player._pSplTHotKey[1] = SpellType::Invalid;
player._pRSpell = SpellID::Invalid;
player._pRSplType = SpellType::Invalid;
pfile_read_player_from_save(info.saveNumber, player);
EXPECT_EQ(player._pSplHotKey[0], SpellID::Invalid);
EXPECT_EQ(player._pSplTHotKey[0], SpellType::Invalid);
EXPECT_EQ(player._pSplHotKey[1], SpellID::Healing);
EXPECT_EQ(player._pSplTHotKey[1], SpellType::Spell);
EXPECT_EQ(player._pRSpell, SpellID::Invalid);
EXPECT_EQ(player._pRSplType, SpellType::Invalid);
}
TEST(Writehero, pfile_read_player_from_save_preserves_valid_spell_selections)
{
LoadCoreArchives();
LoadGameArchives();
if (!HaveMainData()) {
GTEST_SKIP() << "MPQ assets (spawn.mpq or DIABDAT.MPQ) not found - skipping test";
}
const std::string savePath = paths::BasePath() + "multi_0.sv";
paths::SetPrefPath(paths::BasePath());
RemoveFile(savePath.c_str());
gbVanilla = false;
gbIsHellfire = false;
gbIsSpawn = false;
gbIsMultiplayer = true;
gbIsHellfireSaveGame = false;
leveltype = DTYPE_TOWN;
giNumberOfLevels = 17;
Players.resize(1);
MyPlayerId = 0;
MyPlayer = &Players[MyPlayerId];
LoadSpellData();
LoadPlayerDataFiles();
LoadMonsterData();
LoadItemData();
_uiheroinfo info {};
info.heroclass = HeroClass::Rogue;
pfile_ui_save_create(&info);
Player &player = *MyPlayer;
player._pMemSpells = GetSpellBitmask(SpellID::Healing);
player._pSplLvl[static_cast<size_t>(SpellID::Healing)] = 1;
player._pSplHotKey[0] = SpellID::Healing;
player._pSplTHotKey[0] = SpellType::Spell;
player._pRSpell = SpellID::Healing;
player._pRSplType = SpellType::Spell;
pfile_write_hero();
player._pSplHotKey[0] = SpellID::Invalid;
player._pSplTHotKey[0] = SpellType::Invalid;
player._pRSpell = SpellID::Invalid;
player._pRSplType = SpellType::Invalid;
pfile_read_player_from_save(info.saveNumber, player);
EXPECT_EQ(player._pSplHotKey[0], SpellID::Healing);
EXPECT_EQ(player._pSplTHotKey[0], SpellType::Spell);
EXPECT_EQ(player._pRSpell, SpellID::Healing);
EXPECT_EQ(player._pRSplType, SpellType::Spell);
}
} // namespace
} // namespace devilution

Loading…
Cancel
Save