From cec1f1f37cb0ac7f29f62b6dedea6728a38817d9 Mon Sep 17 00:00:00 2001 From: Gleb Mazovetskiy Date: Sat, 4 Nov 2023 21:02:02 +0000 Subject: [PATCH] Lua: Implement sandbox 1. Every script runs in its own sandbox. 2. `Events` is the only global shared state. 3. `os` datetime functions are now available. --- Packaging/resources/assets/lua/init.lua | 2 - Source/lua/lua.cpp | 121 ++++++++++++++++-------- Source/lua/lua.hpp | 1 + Source/lua/repl.cpp | 7 +- 4 files changed, 89 insertions(+), 42 deletions(-) diff --git a/Packaging/resources/assets/lua/init.lua b/Packaging/resources/assets/lua/init.lua index ee06d9c9b..f59bd0171 100644 --- a/Packaging/resources/assets/lua/init.lua +++ b/Packaging/resources/assets/lua/init.lua @@ -1,5 +1,3 @@ -Events = {} - function Events:RegisterEvent(eventName) self[eventName] = { Functions = {}, diff --git a/Source/lua/lua.cpp b/Source/lua/lua.cpp index 015502ff4..c002b4532 100644 --- a/Source/lua/lua.cpp +++ b/Source/lua/lua.cpp @@ -27,19 +27,22 @@ namespace devilution { namespace { struct LuaState { - sol::state sol; - std::unordered_map compiledScripts; + sol::state sol = {}; + sol::table commonPackages = {}; + std::unordered_map compiledScripts = {}; + sol::environment sandbox = {}; }; std::optional CurrentLuaState; // A Lua function that we use to generate a `require` implementation. -constexpr std::string_view RequireGenSrc = R"( -function requireGen(loaded, loadFn) +constexpr std::string_view RequireGenSrc = R"lua( +function requireGen(env, loaded, loadFn) return function(packageName) local p = loaded[packageName] if p == nil then local loader = loadFn(packageName) + setEnvironment(loader, env) if type(loader) == "string" then error(loader) end @@ -49,7 +52,7 @@ function requireGen(loaded, loadFn) return p end end -)"; +)lua"; sol::object LuaLoadScriptFromAssets(std::string_view packageName) { @@ -107,7 +110,7 @@ void LuaWarn(void *userData, const char *message, int continued) warnBuffer.clear(); } -bool CheckResult(sol::protected_function_result result, bool optional) +sol::object CheckResult(sol::protected_function_result result, bool optional) { const bool valid = result.valid(); if (!valid) { @@ -118,25 +121,29 @@ bool CheckResult(sol::protected_function_result result, bool optional) app_fatal(error); LogError(error); } - return valid; + return result; } -void RunScript(std::string_view path, bool optional) +sol::object RunScript(std::optional env, std::string_view packageName, bool optional) { - tl::expected assetData = LoadAsset(path); - - if (!assetData.has_value()) { + sol::object result = LuaLoadScriptFromAssets(packageName); + // We return a string on error: + if (result.get_type() == sol::type::string) { if (!optional) - app_fatal(assetData.error()); - return; + app_fatal(result.as()); + LogError("{}", result.as()); + return sol::lua_nil; } - - CheckResult(CurrentLuaState->sol.safe_script(std::string_view(*assetData)), optional); + auto fn = result.as(); + if (env.has_value()) { + sol::set_environment(*env, fn); + } + return CheckResult(fn(), optional); } void LuaPanic(sol::optional message) { - LogError("Lua is in a panic state and will now abort() the application:\n", + LogError("Lua is in a panic state and will now abort() the application:\n{}", message.value_or("unknown error")); } @@ -152,12 +159,58 @@ void Sol2DebugPrintSection(const std::string &message, lua_State *state) LogDebug("-- {} -- [ {} ]", message, sol::detail::debug::dump_types(state)); } +sol::environment CreateLuaSandbox() +{ + sol::state &lua = CurrentLuaState->sol; + sol::environment sandbox(CurrentLuaState->sol, sol::create); + + // Registering globals + sandbox.set( + "print", LuaPrint, + "_DEBUG", +#ifdef _DEBUG + true, +#else + false, +#endif + "_VERSION", LUA_VERSION); + + // Register safe built-in globals. + for (const std::string_view global : { + // DevilutionX + "Events", + // Built-ins: + "assert", "warn", "error", "ipairs", "next", "pairs", "pcall", + "select", "tonumber", "tostring", "type", "xpcall", + "rawequal", "rawget", "rawset", "setmetatable", + // Built-in packages: + +#ifdef _DEBUG + "debug", +#endif + "base", "coroutine", "table", "string", "math", "utf8" }) { + const sol::object obj = lua[global]; + if (obj.get_type() == sol::type::lua_nil) { + app_fatal(StrCat("Missing Lua global [", global, "]")); + } + sandbox[global] = obj; + } + + // We only allow datetime-related functions from `os`: + const sol::table os = lua["os"]; + sandbox.create_named("os", + "date", os["date"], + "difftime", os["difftime"], + "time", os["time"]); + + sandbox["require"] = lua["requireGen"](sandbox, CurrentLuaState->commonPackages, LuaLoadScriptFromAssets); + + return sandbox; +} + void LuaInitialize() { - CurrentLuaState.emplace(LuaState { - .sol = { sol::c_call }, - .compiledScripts = {}, - }); + CurrentLuaState.emplace(LuaState { .sol = { sol::c_call } }); sol::state &lua = CurrentLuaState->sol; lua_setwarnf(lua.lua_state(), LuaWarn, /*ud=*/nullptr); lua.open_libraries( @@ -167,26 +220,16 @@ void LuaInitialize() sol::lib::table, sol::lib::string, sol::lib::math, - sol::lib::utf8); + sol::lib::utf8, + sol::lib::os); #ifdef _DEBUG lua.open_libraries(sol::lib::debug); #endif - // Registering globals - lua.set( - "print", LuaPrint, - "_DEBUG", -#ifdef _DEBUG - true, -#else - false, -#endif - "_VERSION", LUA_VERSION); - // Registering devilutionx object table CheckResult(lua.safe_script(RequireGenSrc), /*optional=*/false); - const sol::table loaded = lua.create_table_with( + CurrentLuaState->commonPackages = lua.create_table_with( #ifdef _DEBUG "devilutionx.dev", LuaDevModule(lua), #endif @@ -194,11 +237,15 @@ void LuaInitialize() "devilutionx.log", LuaLogModule(lua), "devilutionx.audio", LuaAudioModule(lua), "devilutionx.render", LuaRenderModule(lua), - "devilutionx.message", [](std::string_view text) { EventPlrMsg(text, UiFlags::ColorRed); }); - lua["require"] = lua["requireGen"](loaded, LuaLoadScriptFromAssets); + "devilutionx.message", [](std::string_view text) { EventPlrMsg(text, UiFlags::ColorRed); }, + // Load the "inspect" package without the sandbox. + "inspect", RunScript(/*env=*/std::nullopt, "inspect", /*optional=*/false)); + + // This table is set up by the init script. + lua.create_named_table("Events"); - RunScript("lua\\init.lua", /*optional=*/false); - RunScript("lua\\user.lua", /*optional=*/true); + RunScript(CreateLuaSandbox(), "init", /*optional=*/false); + RunScript(CreateLuaSandbox(), "user", /*optional=*/true); LuaEvent("OnGameBoot"); } diff --git a/Source/lua/lua.hpp b/Source/lua/lua.hpp index 70191f211..f75c3f426 100644 --- a/Source/lua/lua.hpp +++ b/Source/lua/lua.hpp @@ -11,5 +11,6 @@ void LuaInitialize(); void LuaShutdown(); void LuaEvent(std::string_view name); sol::state &GetLuaState(); +sol::environment CreateLuaSandbox(); } // namespace devilution diff --git a/Source/lua/repl.cpp b/Source/lua/repl.cpp index 26c39575e..5c0240880 100644 --- a/Source/lua/repl.cpp +++ b/Source/lua/repl.cpp @@ -49,9 +49,10 @@ int LuaPrintToConsole(lua_State *state) void CreateReplEnvironment() { sol::state &lua = GetLuaState(); - replEnv.emplace(lua, sol::create, lua.globals()); - replEnv->set("print", LuaPrintToConsole); - lua_setwarnf(replEnv->lua_state(), LuaConsoleWarn, /*ud=*/nullptr); + sol::environment env = CreateLuaSandbox(); + env["print"] = LuaPrintToConsole; + lua_setwarnf(env.lua_state(), LuaConsoleWarn, /*ud=*/nullptr); + replEnv.emplace(env); } sol::protected_function_result TryRunLuaAsExpressionThenStatement(std::string_view code)