diff --git a/Source/engine/assets.cpp b/Source/engine/assets.cpp index 2af12afee..d69026b6f 100644 --- a/Source/engine/assets.cpp +++ b/Source/engine/assets.cpp @@ -75,6 +75,8 @@ bool FindMpqFile(std::string_view filename, MpqArchive **archive, uint32_t *file AssetRef FindAsset(std::string_view filename) { AssetRef result; + if (filename.empty() || filename.back() == '\\') + return result; result.path[0] = '\0'; char pathBuf[AssetRef::PathBufSize]; @@ -113,6 +115,9 @@ AssetRef FindAsset(std::string_view filename) AssetRef FindAsset(std::string_view filename) { AssetRef result; + if (filename.empty() || filename.back() == '\\') + return result; + std::string relativePath { filename }; #ifndef _WIN32 std::replace(relativePath.begin(), relativePath.end(), '\\', '/'); @@ -206,4 +211,26 @@ SDL_RWops *OpenAssetAsSdlRwOps(std::string_view filename, bool threadsafe) #endif } +tl::expected LoadAsset(std::string_view path) +{ + AssetRef ref = FindAsset(path); + if (!ref.ok()) { + return tl::make_unexpected(StrCat("Asset not found: ", path)); + } + + const size_t size = ref.size(); + std::unique_ptr data { new char[size] }; + + AssetHandle handle = OpenAsset(std::move(ref)); + if (!handle.ok()) { + return tl::make_unexpected(StrCat("Failed to open asset: ", path, "\n", handle.error())); + } + + if (size > 0 && !handle.read(data.get(), size)) { + return tl::make_unexpected(StrCat("Read failed: ", path, "\n", handle.error())); + } + + return AssetData { std::move(data), size }; +} + } // namespace devilution diff --git a/Source/engine/assets.hpp b/Source/engine/assets.hpp index 329225f69..65ebee51a 100644 --- a/Source/engine/assets.hpp +++ b/Source/engine/assets.hpp @@ -7,6 +7,7 @@ #include #include +#include #include "appfat.h" #include "diablo.h" @@ -246,4 +247,16 @@ AssetHandle OpenAsset(std::string_view filename, size_t &fileSize, bool threadsa SDL_RWops *OpenAssetAsSdlRwOps(std::string_view filename, bool threadsafe = false); +struct AssetData { + std::unique_ptr data; + size_t size; + + explicit operator std::string_view() const + { + return std::string_view(data.get(), size); + } +}; + +tl::expected LoadAsset(std::string_view path); + } // namespace devilution diff --git a/Source/lua/lua.cpp b/Source/lua/lua.cpp index 307eb0a18..28acb5184 100644 --- a/Source/lua/lua.cpp +++ b/Source/lua/lua.cpp @@ -2,6 +2,7 @@ #include #include +#include #include @@ -17,7 +18,58 @@ namespace devilution { namespace { -std::optional luaState; +struct LuaState { + sol::state sol; + std::unordered_map compiledScripts; +}; + +std::optional CurrentLuaState; + +// A Lua function that we use to generate a `require` implementation. +constexpr std::string_view RequireGenSrc = R"( +function requireGen(loaded, loadFn) + return function(packageName) + local p = loaded[packageName] + if p == nil then + local loader = loadFn(packageName) + if type(loader) == "string" then + error(loader) + end + p = loader(packageName) + loaded[packageName] = p + end + return p + end +end +)"; + +sol::object LuaLoadScriptFromAssets(std::string_view packageName) +{ + LuaState &luaState = *CurrentLuaState; + std::string path { packageName }; + std::replace(path.begin(), path.end(), '.', '\\'); + path.append(".lua"); + + auto iter = luaState.compiledScripts.find(path); + if (iter != luaState.compiledScripts.end()) { + return luaState.sol.load(iter->second.as_string_view(), path, sol::load_mode::binary); + } + + tl::expected assetData = LoadAsset(path); + if (!assetData.has_value()) { + sol::stack::push(luaState.sol.lua_state(), assetData.error()); + return sol::stack_object(luaState.sol.lua_state(), -1); + } + sol::load_result result = luaState.sol.load(std::string_view(*assetData), path, sol::load_mode::text); + if (!result.valid()) { + sol::stack::push(luaState.sol.lua_state(), + StrCat("Lua error when loading ", path, ": ", result.get())); + return sol::stack_object(luaState.sol.lua_state(), -1); + } + const sol::function fn = result; + luaState.compiledScripts[path] = fn.dump(); + return result; +} int LuaPrint(lua_State *state) { @@ -50,29 +102,15 @@ bool CheckResult(sol::protected_function_result result, bool optional) void RunScript(std::string_view path, bool optional) { - AssetRef ref = FindAsset(path); - if (!ref.ok()) { - if (!optional) - app_fatal(StrCat("Asset not found: ", path)); - return; - } + tl::expected assetData = LoadAsset(path); - const size_t size = ref.size(); - std::unique_ptr luaScript { new char[size] }; - - AssetHandle handle = OpenAsset(std::move(ref)); - if (!handle.ok()) { - app_fatal(StrCat("Failed to open asset: ", path, "\n", handle.error())); - return; - } - - if (size > 0 && !handle.read(luaScript.get(), size)) { - app_fatal(StrCat("Read failed: ", path, "\n", handle.error())); + if (!assetData.has_value()) { + if (!optional) + app_fatal(assetData.error()); return; } - const std::string_view luaScriptStr(luaScript.get(), size); - CheckResult(luaState->safe_script(luaScriptStr), optional); + CheckResult(CurrentLuaState->sol.safe_script(std::string_view(*assetData)), optional); } void LuaPanic(sol::optional message) @@ -95,8 +133,11 @@ void Sol2DebugPrintSection(const std::string &message, lua_State *state) void LuaInitialize() { - luaState.emplace(sol::c_call); - sol::state &lua = *luaState; + CurrentLuaState.emplace(LuaState { + .sol = { sol::c_call }, + .compiledScripts = {}, + }); + sol::state &lua = CurrentLuaState->sol; lua.open_libraries( sol::lib::base, sol::lib::package, @@ -116,11 +157,12 @@ void LuaInitialize() "_VERSION", LUA_VERSION); // Registering devilutionx object table - lua.create_named_table( - "devilutionx", - "log", LuaLogModule(lua), - "render", LuaRenderModule(lua), - "message", [](std::string_view text) { EventPlrMsg(text, UiFlags::ColorRed); }); + CheckResult(lua.safe_script(RequireGenSrc), /*optional=*/false); + const sol::table loaded = lua.create_table_with( + "devilutionx.log", LuaLogModule(lua), + "devilutionx.render", LuaRenderModule(lua), + "devilutionx.message", [](std::string_view text) { EventPlrMsg(text, UiFlags::ColorRed); }); + lua["require"] = lua["requireGen"](loaded, LuaLoadScriptFromAssets); RunScript("lua\\init.lua", /*optional=*/false); RunScript("lua\\user.lua", /*optional=*/true); @@ -130,12 +172,12 @@ void LuaInitialize() void LuaShutdown() { - luaState = std::nullopt; + CurrentLuaState = std::nullopt; } void LuaEvent(std::string_view name) { - const sol::state &lua = *luaState; + const sol::state &lua = CurrentLuaState->sol; const auto trigger = lua.traverse_get>("Events", name, "Trigger"); if (!trigger.has_value() || !trigger->is()) { LogError("Events.{}.Trigger is not a function", name); @@ -145,9 +187,9 @@ void LuaEvent(std::string_view name) CheckResult(fn(), /*optional=*/true); } -sol::state &LuaState() +sol::state &GetLuaState() { - return *luaState; + return CurrentLuaState->sol; } } // namespace devilution diff --git a/Source/lua/lua.hpp b/Source/lua/lua.hpp index fed528763..28c976016 100644 --- a/Source/lua/lua.hpp +++ b/Source/lua/lua.hpp @@ -13,6 +13,6 @@ namespace devilution { void LuaInitialize(); void LuaShutdown(); void LuaEvent(std::string_view name); -sol::state &LuaState(); +sol::state &GetLuaState(); } // namespace devilution diff --git a/Source/lua/repl.cpp b/Source/lua/repl.cpp index af6600f9b..dd578fbd4 100644 --- a/Source/lua/repl.cpp +++ b/Source/lua/repl.cpp @@ -38,7 +38,7 @@ int LuaPrintToConsole(lua_State *state) void CreateReplEnvironment() { - sol::state &lua = LuaState(); + sol::state &lua = GetLuaState(); replEnv.emplace(lua, sol::create, lua.globals()); replEnv->set("print", LuaPrintToConsole); } @@ -53,7 +53,7 @@ sol::environment &ReplEnvironment() sol::protected_function_result TryRunLuaAsExpressionThenStatement(std::string_view code) { // Try to compile as an expression first. This also how the `lua` repl is implemented. - sol::state &lua = LuaState(); + sol::state &lua = GetLuaState(); std::string expression = StrCat("return ", code, ";"); sol::detail::typical_chunk_name_t basechunkname = {}; sol::load_status status = static_cast(