Skip to content
50 changes: 43 additions & 7 deletions Source/loadsave.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
#include "pfile.h"
#include "plrmsg.h"
#include "qol/stash.h"
#include "spells.h"
#include "stores.h"
#include "tables/playerdat.hpp"
#include "utils/algorithm/container.hpp"
Expand Down Expand Up @@ -636,7 +637,6 @@ void LoadPlayer(LoadHelper &file, Player &player)
sgGameInitInfo.nDifficulty = static_cast<_difficulty>(file.NextLE<uint32_t>());
player.pDamAcFlags = static_cast<ItemSpecialEffectHf>(file.NextLE<uint32_t>());
file.Skip(20); // Available bytes
CalcPlrInv(player, false);

player.executedSpell = player.queuedSpell; // Ensures backwards compatibility

Expand Down Expand Up @@ -2329,23 +2329,54 @@ size_t HotkeysSize(size_t nHotkeys = NumHotkeys)
return sizeof(uint8_t) + (nHotkeys * sizeof(int32_t)) + (nHotkeys * sizeof(uint8_t)) + sizeof(int32_t) + sizeof(uint8_t);
}

size_t LegacyHotkeysSize()
{
return HotkeysSize(4) - sizeof(uint8_t);
}

void LoadHotkeys()
{
LoadHelper file(OpenSaveArchive(gSaveNumber), "hotkeys");
if (!file.IsValid())
if (MyPlayer == nullptr)
return;

Player &myPlayer = *MyPlayer;
LoadHotkeys(gSaveNumber, *MyPlayer);
}

void LoadHotkeys(uint32_t saveNum, Player &myPlayer)
{
LoadHelper file(OpenSaveArchive(saveNum), "hotkeys");
if (!file.IsValid()) {
SanitizePlayerSpellSelections(myPlayer);
SyncPlayerSpellStateFromSelections(myPlayer);
return;
}

size_t nHotkeys = 4; // Defaults to old save format number

// Refill the spell arrays with no selection
std::fill(myPlayer._pSplHotKey, myPlayer._pSplHotKey + NumHotkeys, SpellID::Invalid);
std::fill(myPlayer._pSplTHotKey, myPlayer._pSplTHotKey + NumHotkeys, SpellType::Invalid);

// Checking if the save file has the old format with only 4 hotkeys and no header
if (file.IsValid(HotkeysSize(nHotkeys))) {
// The file contains a header byte and at least 4 entries, so we can assume it's a new format save
const size_t fileSize = file.Size();

if (fileSize == LegacyHotkeysSize()) {
// Legacy format: exactly 4 hotkeys, no leading count byte.
} else {
if (!file.IsValid(sizeof(uint8_t))) {
SanitizePlayerSpellSelections(myPlayer);
SyncPlayerSpellStateFromSelections(myPlayer);
return;
}

nHotkeys = file.NextLE<uint8_t>();

const size_t payloadSize = (nHotkeys * sizeof(int32_t)) + (nHotkeys * sizeof(uint8_t)) + sizeof(int32_t) + sizeof(uint8_t);

if (!file.IsValid(payloadSize)) {
SanitizePlayerSpellSelections(myPlayer);
SyncPlayerSpellStateFromSelections(myPlayer);
return;
}
}

// Read all hotkeys in the file
Expand All @@ -2369,6 +2400,8 @@ void LoadHotkeys()
// Load the selected spell last
myPlayer._pRSpell = static_cast<SpellID>(file.NextLE<int32_t>());
myPlayer._pRSplType = static_cast<SpellType>(file.NextLE<uint8_t>());
SanitizePlayerSpellSelections(myPlayer);
SyncPlayerSpellStateFromSelections(myPlayer);
}

void SaveHotkeys(SaveWriter &saveWriter, const Player &player)
Expand Down Expand Up @@ -2519,6 +2552,9 @@ tl::expected<void, std::string> LoadGame(bool firstflag)
Player &myPlayer = *MyPlayer;

LoadPlayer(file, myPlayer);
ValidatePlayer();
CalcPlrInv(myPlayer, false);
LoadHotkeys(gSaveNumber, myPlayer);

if (sgGameInitInfo.nDifficulty < DIFF_NORMAL || sgGameInitInfo.nDifficulty > DIFF_HELL)
sgGameInitInfo.nDifficulty = DIFF_NORMAL;
Expand Down
1 change: 1 addition & 0 deletions Source/loadsave.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ _item_indexes RemapItemIdxFromSpawn(_item_indexes i);
_item_indexes RemapItemIdxToSpawn(_item_indexes i);
bool IsHeaderValid(uint32_t magicNumber);
void LoadHotkeys();
void LoadHotkeys(uint32_t saveNum, Player &myPlayer);
void LoadHeroItems(Player &player);
/**
* @brief Remove invalid inventory items from the inventory grid
Expand Down
8 changes: 8 additions & 0 deletions Source/pfile.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include "mpq/mpq_common.hpp"
#include "pack.h"
#include "qol/stash.h"
#include "spells.h"
#include "tables/playerdat.hpp"
#include "utils/endian_read.hpp"
#include "utils/endian_swap.hpp"
Expand Down Expand Up @@ -690,6 +691,7 @@ bool pfile_ui_set_hero_infos(bool (*uiAddHeroInfo)(_uiheroinfo *))
LoadHeroItems(player);
RemoveAllInvalidItems(player);
CalcPlrInv(player, false);
SanitizePlayerSpellSelections(player);

Game2UiPlayer(player, &uihero, hasSaveGame);
uiAddHeroInfo(&uihero);
Expand Down Expand Up @@ -777,6 +779,12 @@ void pfile_read_player_from_save(uint32_t saveNum, Player &player)
LoadHeroItems(player);
RemoveAllInvalidItems(player);
CalcPlrInv(player, false);
if (&player == MyPlayer) {
LoadHotkeys(saveNum, player);
} else {
SanitizePlayerSpellSelections(player);
SyncPlayerSpellStateFromSelections(player);
}
}

void pfile_save_level()
Expand Down
6 changes: 4 additions & 2 deletions Source/player.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1410,6 +1410,8 @@ bool PlrDeathModeOK(Player &player)
return false;
}

} // namespace

void ValidatePlayer()
{
assert(MyPlayer != nullptr);
Expand Down Expand Up @@ -1467,6 +1469,8 @@ void ValidatePlayer()
myPlayer._pInfraFlag = false;
}

namespace {

HeroClass GetPlayerSpriteClass(HeroClass cls)
{
if (cls == HeroClass::Bard && !HaveBardAssets())
Expand Down Expand Up @@ -2483,8 +2487,6 @@ void InitPlayer(Player &player, bool firstTime)
if (firstTime) {
player._pRSplType = SpellType::Invalid;
player._pRSpell = SpellID::Invalid;
if (&player == MyPlayer)
LoadHotkeys();
player._pSBkSpell = SpellID::Invalid;
player.queuedSpell.spellId = player._pRSpell;
player.queuedSpell.spellType = player._pRSplType;
Expand Down
1 change: 1 addition & 0 deletions Source/player.h
Original file line number Diff line number Diff line change
Expand Up @@ -985,6 +985,7 @@ void CheckPlrSpell(bool isShiftHeld, SpellID spellID = MyPlayer->_pRSpell, Spell
void SyncPlrAnim(Player &player);
void SyncInitPlrPos(Player &player);
void SyncInitPlr(Player &player);
void ValidatePlayer();
void CheckStats(Player &player);
void ModifyPlrStr(Player &player, int l);
void ModifyPlrMag(Player &player, int l);
Expand Down
65 changes: 50 additions & 15 deletions Source/spells.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,21 +32,7 @@ namespace {
*/
bool IsReadiedSpellValid(const Player &player)
{
switch (player._pRSplType) {
case SpellType::Skill:
case SpellType::Spell:
case SpellType::Invalid:
return true;

case SpellType::Charges:
return (player._pISpells & GetSpellBitmask(player._pRSpell)) != 0;

case SpellType::Scroll:
return (player._pScrlSpells & GetSpellBitmask(player._pRSpell)) != 0;

default:
return false;
}
return IsPlayerSpellSelectionValid(player, player._pRSpell, player._pRSplType);
}

/**
Expand Down Expand Up @@ -85,6 +71,55 @@ bool IsValidSpellFrom(int spellFrom)
return false;
}

bool IsPlayerSpellSelectionValid(const Player &player, SpellID spellId, SpellType spellType)
{
if (spellType == SpellType::Invalid) {
return spellId == SpellID::Invalid;
}

if (!IsValidSpell(spellId)) {
return false;
}

switch (spellType) {
case SpellType::Skill:
return (player._pAblSpells & GetSpellBitmask(spellId)) != 0;
case SpellType::Spell:
return (player._pMemSpells & GetSpellBitmask(spellId)) != 0 && player.GetSpellLevel(spellId) > 0;
case SpellType::Scroll:
return (player._pScrlSpells & GetSpellBitmask(spellId)) != 0;
case SpellType::Charges:
return (player._pISpells & GetSpellBitmask(spellId)) != 0;
default:
return false;
}
}

void SanitizePlayerSpellSelections(Player &player)
{
for (size_t i = 0; i < NumHotkeys; ++i) {
if (!IsPlayerSpellSelectionValid(player, player._pSplHotKey[i], player._pSplTHotKey[i])) {
player._pSplHotKey[i] = SpellID::Invalid;
player._pSplTHotKey[i] = SpellType::Invalid;
}
}

if (!IsPlayerSpellSelectionValid(player, player._pRSpell, player._pRSplType)) {
player._pRSpell = SpellID::Invalid;
player._pRSplType = SpellType::Invalid;
}
}

void SyncPlayerSpellStateFromSelections(Player &myPlayer)
{
myPlayer.queuedSpell.spellId = myPlayer._pRSpell;
myPlayer.queuedSpell.spellType = myPlayer._pRSplType;
myPlayer.queuedSpell.spellFrom = 0;
myPlayer.queuedSpell.spellLevel = 0;
myPlayer.executedSpell = myPlayer.queuedSpell;
myPlayer.spellFrom = 0;
}

bool IsWallSpell(SpellID spl)
{
return spl == SpellID::FireWall || spl == SpellID::LightningWall;
Expand Down
3 changes: 3 additions & 0 deletions Source/spells.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ enum class SpellCheckResult : uint8_t {

bool IsValidSpell(SpellID spl);
bool IsValidSpellFrom(int spellFrom);
bool IsPlayerSpellSelectionValid(const Player &player, SpellID spellId, SpellType spellType);
void SanitizePlayerSpellSelections(Player &player);
void SyncPlayerSpellStateFromSelections(Player &myPlayer);
bool IsWallSpell(SpellID spl);
bool TargetsMonster(SpellID id);
int GetManaAmount(const Player &player, SpellID sn);
Expand Down
50 changes: 50 additions & 0 deletions test/player_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include "cursor.h"
#include "engine/assets.hpp"
#include "init.hpp"
#include "spells.h"
#include "tables/playerdat.hpp"

using namespace devilution;
Expand Down Expand Up @@ -204,3 +205,52 @@ TEST(Player, CreatePlayer)
CreatePlayer(Players[0], HeroClass::Rogue);
AssertPlayer(Players[0]);
}

TEST(Player, IsPlayerSpellSelectionValidChecksSpellSources)
{
LoadCoreArchives();
LoadGameArchives();
if (!HaveMainData()) {
GTEST_SKIP() << "MPQ assets (spawn.mpq or DIABDAT.MPQ) not found - skipping test";
}
LoadSpellData();

const SpellID spell = SpellID::Healing;
const uint64_t mask = GetSpellBitmask(spell);
Player player {};

EXPECT_FALSE(IsPlayerSpellSelectionValid(player, spell, SpellType::Spell));
player._pMemSpells = mask;
EXPECT_FALSE(IsPlayerSpellSelectionValid(player, spell, SpellType::Spell));
player._pSplLvl[static_cast<size_t>(spell)] = 1;
EXPECT_TRUE(IsPlayerSpellSelectionValid(player, spell, SpellType::Spell));

EXPECT_FALSE(IsPlayerSpellSelectionValid(player, spell, SpellType::Scroll));
player._pScrlSpells = mask;
EXPECT_TRUE(IsPlayerSpellSelectionValid(player, spell, SpellType::Scroll));

EXPECT_FALSE(IsPlayerSpellSelectionValid(player, spell, SpellType::Charges));
player._pISpells = mask;
EXPECT_TRUE(IsPlayerSpellSelectionValid(player, spell, SpellType::Charges));

EXPECT_FALSE(IsPlayerSpellSelectionValid(player, spell, SpellType::Skill));
player._pAblSpells = mask;
EXPECT_TRUE(IsPlayerSpellSelectionValid(player, spell, SpellType::Skill));
}

TEST(Player, IsPlayerSpellSelectionValidRejectsInvalidSelections)
{
LoadCoreArchives();
LoadGameArchives();
if (!HaveMainData()) {
GTEST_SKIP() << "MPQ assets (spawn.mpq or DIABDAT.MPQ) not found - skipping test";
}
LoadSpellData();

Player player {};

EXPECT_TRUE(IsPlayerSpellSelectionValid(player, SpellID::Invalid, SpellType::Invalid));
EXPECT_FALSE(IsPlayerSpellSelectionValid(player, SpellID::Healing, SpellType::Invalid));
EXPECT_FALSE(IsPlayerSpellSelectionValid(player, SpellID::Invalid, SpellType::Spell));
EXPECT_FALSE(IsPlayerSpellSelectionValid(player, SpellID::Null, SpellType::Spell));
}
Loading