Browse Source

Lua: `require` supports loads from assets

Implements a `require` function that supports built-in modules like so:

```lua
local log = require('devilutionx.log')
```

It falls back to reading from assets, so this loads `lua/user.lua`:

```lua
local user = require('lua.user')
```

The bytecode for the asset scripts is cached, in case we want to later
support multiple isolated environments.

There may be a simpler or better way to do this.

It's good enough for now until someone more knowledgeable
about Lua comes along.
pull/6767/head
Gleb Mazovetskiy 2 years ago
parent
commit
5d9d5c6872
  1. 27
      Source/engine/assets.cpp
  2. 13
      Source/engine/assets.hpp
  3. 104
      Source/lua/lua.cpp
  4. 2
      Source/lua/lua.hpp
  5. 4
      Source/lua/repl.cpp

27
Source/engine/assets.cpp

@ -75,6 +75,8 @@ bool FindMpqFile(std::string_view filename, MpqArchive **archive, uint32_t *file
AssetRef FindAsset(std::string_view filename)
{
AssetRef result;
if (filename.empty() || filename.back() == '\\')
return result;
result.path[0] = '\0';
char pathBuf[AssetRef::PathBufSize];
@ -113,6 +115,9 @@ AssetRef FindAsset(std::string_view filename)
AssetRef FindAsset(std::string_view filename)
{
AssetRef result;
if (filename.empty() || filename.back() == '\\')
return result;
std::string relativePath { filename };
#ifndef _WIN32
std::replace(relativePath.begin(), relativePath.end(), '\\', '/');
@ -206,4 +211,26 @@ SDL_RWops *OpenAssetAsSdlRwOps(std::string_view filename, bool threadsafe)
#endif
}
tl::expected<AssetData, std::string> LoadAsset(std::string_view path)
{
AssetRef ref = FindAsset(path);
if (!ref.ok()) {
return tl::make_unexpected(StrCat("Asset not found: ", path));
}
const size_t size = ref.size();
std::unique_ptr<char[]> data { new char[size] };
AssetHandle handle = OpenAsset(std::move(ref));
if (!handle.ok()) {
return tl::make_unexpected(StrCat("Failed to open asset: ", path, "\n", handle.error()));
}
if (size > 0 && !handle.read(data.get(), size)) {
return tl::make_unexpected(StrCat("Read failed: ", path, "\n", handle.error()));
}
return AssetData { std::move(data), size };
}
} // namespace devilution

13
Source/engine/assets.hpp

@ -7,6 +7,7 @@
#include <string_view>
#include <SDL.h>
#include <expected.hpp>
#include "appfat.h"
#include "diablo.h"
@ -246,4 +247,16 @@ AssetHandle OpenAsset(std::string_view filename, size_t &fileSize, bool threadsa
SDL_RWops *OpenAssetAsSdlRwOps(std::string_view filename, bool threadsafe = false);
struct AssetData {
std::unique_ptr<char[]> data;
size_t size;
explicit operator std::string_view() const
{
return std::string_view(data.get(), size);
}
};
tl::expected<AssetData, std::string> LoadAsset(std::string_view path);
} // namespace devilution

104
Source/lua/lua.cpp

@ -2,6 +2,7 @@
#include <optional>
#include <string_view>
#include <unordered_map>
#include <sol/sol.hpp>
@ -17,7 +18,58 @@ namespace devilution {
namespace {
std::optional<sol::state> luaState;
struct LuaState {
sol::state sol;
std::unordered_map<std::string, sol::bytecode> compiledScripts;
};
std::optional<LuaState> CurrentLuaState;
// A Lua function that we use to generate a `require` implementation.
constexpr std::string_view RequireGenSrc = R"(
function requireGen(loaded, loadFn)
return function(packageName)
local p = loaded[packageName]
if p == nil then
local loader = loadFn(packageName)
if type(loader) == "string" then
error(loader)
end
p = loader(packageName)
loaded[packageName] = p
end
return p
end
end
)";
sol::object LuaLoadScriptFromAssets(std::string_view packageName)
{
LuaState &luaState = *CurrentLuaState;
std::string path { packageName };
std::replace(path.begin(), path.end(), '.', '\\');
path.append(".lua");
auto iter = luaState.compiledScripts.find(path);
if (iter != luaState.compiledScripts.end()) {
return luaState.sol.load(iter->second.as_string_view(), path, sol::load_mode::binary);
}
tl::expected<AssetData, std::string> assetData = LoadAsset(path);
if (!assetData.has_value()) {
sol::stack::push(luaState.sol.lua_state(), assetData.error());
return sol::stack_object(luaState.sol.lua_state(), -1);
}
sol::load_result result = luaState.sol.load(std::string_view(*assetData), path, sol::load_mode::text);
if (!result.valid()) {
sol::stack::push(luaState.sol.lua_state(),
StrCat("Lua error when loading ", path, ": ", result.get<std::string>()));
return sol::stack_object(luaState.sol.lua_state(), -1);
}
const sol::function fn = result;
luaState.compiledScripts[path] = fn.dump();
return result;
}
int LuaPrint(lua_State *state)
{
@ -50,29 +102,15 @@ bool CheckResult(sol::protected_function_result result, bool optional)
void RunScript(std::string_view path, bool optional)
{
AssetRef ref = FindAsset(path);
if (!ref.ok()) {
if (!optional)
app_fatal(StrCat("Asset not found: ", path));
return;
}
tl::expected<AssetData, std::string> assetData = LoadAsset(path);
const size_t size = ref.size();
std::unique_ptr<char[]> luaScript { new char[size] };
AssetHandle handle = OpenAsset(std::move(ref));
if (!handle.ok()) {
app_fatal(StrCat("Failed to open asset: ", path, "\n", handle.error()));
return;
}
if (size > 0 && !handle.read(luaScript.get(), size)) {
app_fatal(StrCat("Read failed: ", path, "\n", handle.error()));
if (!assetData.has_value()) {
if (!optional)
app_fatal(assetData.error());
return;
}
const std::string_view luaScriptStr(luaScript.get(), size);
CheckResult(luaState->safe_script(luaScriptStr), optional);
CheckResult(CurrentLuaState->sol.safe_script(std::string_view(*assetData)), optional);
}
void LuaPanic(sol::optional<std::string> message)
@ -95,8 +133,11 @@ void Sol2DebugPrintSection(const std::string &message, lua_State *state)
void LuaInitialize()
{
luaState.emplace(sol::c_call<decltype(&LuaPanic), &LuaPanic>);
sol::state &lua = *luaState;
CurrentLuaState.emplace(LuaState {
.sol = { sol::c_call<decltype(&LuaPanic), &LuaPanic> },
.compiledScripts = {},
});
sol::state &lua = CurrentLuaState->sol;
lua.open_libraries(
sol::lib::base,
sol::lib::package,
@ -116,11 +157,12 @@ void LuaInitialize()
"_VERSION", LUA_VERSION);
// Registering devilutionx object table
lua.create_named_table(
"devilutionx",
"log", LuaLogModule(lua),
"render", LuaRenderModule(lua),
"message", [](std::string_view text) { EventPlrMsg(text, UiFlags::ColorRed); });
CheckResult(lua.safe_script(RequireGenSrc), /*optional=*/false);
const sol::table loaded = lua.create_table_with(
"devilutionx.log", LuaLogModule(lua),
"devilutionx.render", LuaRenderModule(lua),
"devilutionx.message", [](std::string_view text) { EventPlrMsg(text, UiFlags::ColorRed); });
lua["require"] = lua["requireGen"](loaded, LuaLoadScriptFromAssets);
RunScript("lua\\init.lua", /*optional=*/false);
RunScript("lua\\user.lua", /*optional=*/true);
@ -130,12 +172,12 @@ void LuaInitialize()
void LuaShutdown()
{
luaState = std::nullopt;
CurrentLuaState = std::nullopt;
}
void LuaEvent(std::string_view name)
{
const sol::state &lua = *luaState;
const sol::state &lua = CurrentLuaState->sol;
const auto trigger = lua.traverse_get<std::optional<sol::object>>("Events", name, "Trigger");
if (!trigger.has_value() || !trigger->is<sol::protected_function>()) {
LogError("Events.{}.Trigger is not a function", name);
@ -145,9 +187,9 @@ void LuaEvent(std::string_view name)
CheckResult(fn(), /*optional=*/true);
}
sol::state &LuaState()
sol::state &GetLuaState()
{
return *luaState;
return CurrentLuaState->sol;
}
} // namespace devilution

2
Source/lua/lua.hpp

@ -13,6 +13,6 @@ namespace devilution {
void LuaInitialize();
void LuaShutdown();
void LuaEvent(std::string_view name);
sol::state &LuaState();
sol::state &GetLuaState();
} // namespace devilution

4
Source/lua/repl.cpp

@ -38,7 +38,7 @@ int LuaPrintToConsole(lua_State *state)
void CreateReplEnvironment()
{
sol::state &lua = LuaState();
sol::state &lua = GetLuaState();
replEnv.emplace(lua, sol::create, lua.globals());
replEnv->set("print", LuaPrintToConsole);
}
@ -53,7 +53,7 @@ sol::environment &ReplEnvironment()
sol::protected_function_result TryRunLuaAsExpressionThenStatement(std::string_view code)
{
// Try to compile as an expression first. This also how the `lua` repl is implemented.
sol::state &lua = LuaState();
sol::state &lua = GetLuaState();
std::string expression = StrCat("return ", code, ";");
sol::detail::typical_chunk_name_t basechunkname = {};
sol::load_status status = static_cast<sol::load_status>(

Loading…
Cancel
Save