Browse Source

Lua: Implement sandbox

1. Every script runs in its own sandbox.
2. `Events` is the only global shared state.
3. `os` datetime functions are now available.
pull/6779/head
Gleb Mazovetskiy 2 years ago
parent
commit
cec1f1f37c
  1. 2
      Packaging/resources/assets/lua/init.lua
  2. 121
      Source/lua/lua.cpp
  3. 1
      Source/lua/lua.hpp
  4. 7
      Source/lua/repl.cpp

2
Packaging/resources/assets/lua/init.lua

@ -1,5 +1,3 @@
Events = {}
function Events:RegisterEvent(eventName)
self[eventName] = {
Functions = {},

121
Source/lua/lua.cpp

@ -27,19 +27,22 @@ namespace devilution {
namespace {
struct LuaState {
sol::state sol;
std::unordered_map<std::string, sol::bytecode> compiledScripts;
sol::state sol = {};
sol::table commonPackages = {};
std::unordered_map<std::string, sol::bytecode> compiledScripts = {};
sol::environment sandbox = {};
};
std::optional<LuaState> CurrentLuaState;
// A Lua function that we use to generate a `require` implementation.
constexpr std::string_view RequireGenSrc = R"(
function requireGen(loaded, loadFn)
constexpr std::string_view RequireGenSrc = R"lua(
function requireGen(env, loaded, loadFn)
return function(packageName)
local p = loaded[packageName]
if p == nil then
local loader = loadFn(packageName)
setEnvironment(loader, env)
if type(loader) == "string" then
error(loader)
end
@ -49,7 +52,7 @@ function requireGen(loaded, loadFn)
return p
end
end
)";
)lua";
sol::object LuaLoadScriptFromAssets(std::string_view packageName)
{
@ -107,7 +110,7 @@ void LuaWarn(void *userData, const char *message, int continued)
warnBuffer.clear();
}
bool CheckResult(sol::protected_function_result result, bool optional)
sol::object CheckResult(sol::protected_function_result result, bool optional)
{
const bool valid = result.valid();
if (!valid) {
@ -118,25 +121,29 @@ bool CheckResult(sol::protected_function_result result, bool optional)
app_fatal(error);
LogError(error);
}
return valid;
return result;
}
void RunScript(std::string_view path, bool optional)
sol::object RunScript(std::optional<sol::environment> env, std::string_view packageName, bool optional)
{
tl::expected<AssetData, std::string> assetData = LoadAsset(path);
if (!assetData.has_value()) {
sol::object result = LuaLoadScriptFromAssets(packageName);
// We return a string on error:
if (result.get_type() == sol::type::string) {
if (!optional)
app_fatal(assetData.error());
return;
app_fatal(result.as<std::string>());
LogError("{}", result.as<std::string>());
return sol::lua_nil;
}
CheckResult(CurrentLuaState->sol.safe_script(std::string_view(*assetData)), optional);
auto fn = result.as<sol::protected_function>();
if (env.has_value()) {
sol::set_environment(*env, fn);
}
return CheckResult(fn(), optional);
}
void LuaPanic(sol::optional<std::string> message)
{
LogError("Lua is in a panic state and will now abort() the application:\n",
LogError("Lua is in a panic state and will now abort() the application:\n{}",
message.value_or("unknown error"));
}
@ -152,12 +159,58 @@ void Sol2DebugPrintSection(const std::string &message, lua_State *state)
LogDebug("-- {} -- [ {} ]", message, sol::detail::debug::dump_types(state));
}
sol::environment CreateLuaSandbox()
{
sol::state &lua = CurrentLuaState->sol;
sol::environment sandbox(CurrentLuaState->sol, sol::create);
// Registering globals
sandbox.set(
"print", LuaPrint,
"_DEBUG",
#ifdef _DEBUG
true,
#else
false,
#endif
"_VERSION", LUA_VERSION);
// Register safe built-in globals.
for (const std::string_view global : {
// DevilutionX
"Events",
// Built-ins:
"assert", "warn", "error", "ipairs", "next", "pairs", "pcall",
"select", "tonumber", "tostring", "type", "xpcall",
"rawequal", "rawget", "rawset", "setmetatable",
// Built-in packages:
#ifdef _DEBUG
"debug",
#endif
"base", "coroutine", "table", "string", "math", "utf8" }) {
const sol::object obj = lua[global];
if (obj.get_type() == sol::type::lua_nil) {
app_fatal(StrCat("Missing Lua global [", global, "]"));
}
sandbox[global] = obj;
}
// We only allow datetime-related functions from `os`:
const sol::table os = lua["os"];
sandbox.create_named("os",
"date", os["date"],
"difftime", os["difftime"],
"time", os["time"]);
sandbox["require"] = lua["requireGen"](sandbox, CurrentLuaState->commonPackages, LuaLoadScriptFromAssets);
return sandbox;
}
void LuaInitialize()
{
CurrentLuaState.emplace(LuaState {
.sol = { sol::c_call<decltype(&LuaPanic), &LuaPanic> },
.compiledScripts = {},
});
CurrentLuaState.emplace(LuaState { .sol = { sol::c_call<decltype(&LuaPanic), &LuaPanic> } });
sol::state &lua = CurrentLuaState->sol;
lua_setwarnf(lua.lua_state(), LuaWarn, /*ud=*/nullptr);
lua.open_libraries(
@ -167,26 +220,16 @@ void LuaInitialize()
sol::lib::table,
sol::lib::string,
sol::lib::math,
sol::lib::utf8);
sol::lib::utf8,
sol::lib::os);
#ifdef _DEBUG
lua.open_libraries(sol::lib::debug);
#endif
// Registering globals
lua.set(
"print", LuaPrint,
"_DEBUG",
#ifdef _DEBUG
true,
#else
false,
#endif
"_VERSION", LUA_VERSION);
// Registering devilutionx object table
CheckResult(lua.safe_script(RequireGenSrc), /*optional=*/false);
const sol::table loaded = lua.create_table_with(
CurrentLuaState->commonPackages = lua.create_table_with(
#ifdef _DEBUG
"devilutionx.dev", LuaDevModule(lua),
#endif
@ -194,11 +237,15 @@ void LuaInitialize()
"devilutionx.log", LuaLogModule(lua),
"devilutionx.audio", LuaAudioModule(lua),
"devilutionx.render", LuaRenderModule(lua),
"devilutionx.message", [](std::string_view text) { EventPlrMsg(text, UiFlags::ColorRed); });
lua["require"] = lua["requireGen"](loaded, LuaLoadScriptFromAssets);
"devilutionx.message", [](std::string_view text) { EventPlrMsg(text, UiFlags::ColorRed); },
// Load the "inspect" package without the sandbox.
"inspect", RunScript(/*env=*/std::nullopt, "inspect", /*optional=*/false));
// This table is set up by the init script.
lua.create_named_table("Events");
RunScript("lua\\init.lua", /*optional=*/false);
RunScript("lua\\user.lua", /*optional=*/true);
RunScript(CreateLuaSandbox(), "init", /*optional=*/false);
RunScript(CreateLuaSandbox(), "user", /*optional=*/true);
LuaEvent("OnGameBoot");
}

1
Source/lua/lua.hpp

@ -11,5 +11,6 @@ void LuaInitialize();
void LuaShutdown();
void LuaEvent(std::string_view name);
sol::state &GetLuaState();
sol::environment CreateLuaSandbox();
} // namespace devilution

7
Source/lua/repl.cpp

@ -49,9 +49,10 @@ int LuaPrintToConsole(lua_State *state)
void CreateReplEnvironment()
{
sol::state &lua = GetLuaState();
replEnv.emplace(lua, sol::create, lua.globals());
replEnv->set("print", LuaPrintToConsole);
lua_setwarnf(replEnv->lua_state(), LuaConsoleWarn, /*ud=*/nullptr);
sol::environment env = CreateLuaSandbox();
env["print"] = LuaPrintToConsole;
lua_setwarnf(env.lua_state(), LuaConsoleWarn, /*ud=*/nullptr);
replEnv.emplace(env);
}
sol::protected_function_result TryRunLuaAsExpressionThenStatement(std::string_view code)

Loading…
Cancel
Save