@ -2,6 +2,7 @@
# include <optional>
# include <string_view>
# include <unordered_map>
# include <sol/sol.hpp>
@ -17,7 +18,58 @@ namespace devilution {
namespace {
std : : optional < sol : : state > luaState ;
struct LuaState {
sol : : state sol ;
std : : unordered_map < std : : string , sol : : bytecode > compiledScripts ;
} ;
std : : optional < LuaState > CurrentLuaState ;
// A Lua function that we use to generate a `require` implementation.
constexpr std : : string_view RequireGenSrc = R " (
function requireGen ( loaded , loadFn )
return function ( packageName )
local p = loaded [ packageName ]
if p = = nil then
local loader = loadFn ( packageName )
if type ( loader ) = = " string " then
error ( loader )
end
p = loader ( packageName )
loaded [ packageName ] = p
end
return p
end
end
) " ;
sol : : object LuaLoadScriptFromAssets ( std : : string_view packageName )
{
LuaState & luaState = * CurrentLuaState ;
std : : string path { packageName } ;
std : : replace ( path . begin ( ) , path . end ( ) , ' . ' , ' \\ ' ) ;
path . append ( " .lua " ) ;
auto iter = luaState . compiledScripts . find ( path ) ;
if ( iter ! = luaState . compiledScripts . end ( ) ) {
return luaState . sol . load ( iter - > second . as_string_view ( ) , path , sol : : load_mode : : binary ) ;
}
tl : : expected < AssetData , std : : string > assetData = LoadAsset ( path ) ;
if ( ! assetData . has_value ( ) ) {
sol : : stack : : push ( luaState . sol . lua_state ( ) , assetData . error ( ) ) ;
return sol : : stack_object ( luaState . sol . lua_state ( ) , - 1 ) ;
}
sol : : load_result result = luaState . sol . load ( std : : string_view ( * assetData ) , path , sol : : load_mode : : text ) ;
if ( ! result . valid ( ) ) {
sol : : stack : : push ( luaState . sol . lua_state ( ) ,
StrCat ( " Lua error when loading " , path , " : " , result . get < std : : string > ( ) ) ) ;
return sol : : stack_object ( luaState . sol . lua_state ( ) , - 1 ) ;
}
const sol : : function fn = result ;
luaState . compiledScripts [ path ] = fn . dump ( ) ;
return result ;
}
int LuaPrint ( lua_State * state )
{
@ -50,29 +102,15 @@ bool CheckResult(sol::protected_function_result result, bool optional)
void RunScript ( std : : string_view path , bool optional )
{
AssetRef ref = FindAsset ( path ) ;
if ( ! ref . ok ( ) ) {
if ( ! optional )
app_fatal ( StrCat ( " Asset not found: " , path ) ) ;
return ;
}
tl : : expected < AssetData , std : : string > assetData = LoadAsset ( path ) ;
const size_t size = ref . size ( ) ;
std : : unique_ptr < char [ ] > luaScript { new char [ size ] } ;
AssetHandle handle = OpenAsset ( std : : move ( ref ) ) ;
if ( ! handle . ok ( ) ) {
app_fatal ( StrCat ( " Failed to open asset: " , path , " \n " , handle . error ( ) ) ) ;
return ;
}
if ( size > 0 & & ! handle . read ( luaScript . get ( ) , size ) ) {
app_fatal ( StrCat ( " Read failed: " , path , " \n " , handle . error ( ) ) ) ;
if ( ! assetData . has_value ( ) ) {
if ( ! optional )
app_fatal ( assetData . error ( ) ) ;
return ;
}
const std : : string_view luaScriptStr ( luaScript . get ( ) , size ) ;
CheckResult ( luaState - > safe_script ( luaScriptStr ) , optional ) ;
CheckResult ( CurrentLuaState - > sol . safe_script ( std : : string_view ( * assetData ) ) , optional ) ;
}
void LuaPanic ( sol : : optional < std : : string > message )
@ -95,8 +133,11 @@ void Sol2DebugPrintSection(const std::string &message, lua_State *state)
void LuaInitialize ( )
{
luaState . emplace ( sol : : c_call < decltype ( & LuaPanic ) , & LuaPanic > ) ;
sol : : state & lua = * luaState ;
CurrentLuaState . emplace ( LuaState {
. sol = { sol : : c_call < decltype ( & LuaPanic ) , & LuaPanic > } ,
. compiledScripts = { } ,
} ) ;
sol : : state & lua = CurrentLuaState - > sol ;
lua . open_libraries (
sol : : lib : : base ,
sol : : lib : : package ,
@ -116,11 +157,12 @@ void LuaInitialize()
" _VERSION " , LUA_VERSION ) ;
// Registering devilutionx object table
lua . create_named_table (
" devilutionx " ,
" log " , LuaLogModule ( lua ) ,
" render " , LuaRenderModule ( lua ) ,
" message " , [ ] ( std : : string_view text ) { EventPlrMsg ( text , UiFlags : : ColorRed ) ; } ) ;
CheckResult ( lua . safe_script ( RequireGenSrc ) , /*optional=*/ false ) ;
const sol : : table loaded = lua . create_table_with (
" devilutionx.log " , LuaLogModule ( lua ) ,
" devilutionx.render " , LuaRenderModule ( lua ) ,
" devilutionx.message " , [ ] ( std : : string_view text ) { EventPlrMsg ( text , UiFlags : : ColorRed ) ; } ) ;
lua [ " require " ] = lua [ " requireGen " ] ( loaded , LuaLoadScriptFromAssets ) ;
RunScript ( " lua \\ init.lua " , /*optional=*/ false ) ;
RunScript ( " lua \\ user.lua " , /*optional=*/ true ) ;
@ -130,12 +172,12 @@ void LuaInitialize()
void LuaShutdown ( )
{
l uaState = std : : nullopt ;
CurrentL uaState = std : : nullopt ;
}
void LuaEvent ( std : : string_view name )
{
const sol : : state & lua = * luaState ;
const sol : : state & lua = CurrentLuaState - > sol ;
const auto trigger = lua . traverse_get < std : : optional < sol : : object > > ( " Events " , name , " Trigger " ) ;
if ( ! trigger . has_value ( ) | | ! trigger - > is < sol : : protected_function > ( ) ) {
LogError ( " Events.{}.Trigger is not a function " , name ) ;
@ -145,9 +187,9 @@ void LuaEvent(std::string_view name)
CheckResult ( fn ( ) , /*optional=*/ true ) ;
}
sol : : state & LuaState ( )
sol : : state & Get LuaState( )
{
return * luaState ;
return CurrentLuaState - > sol ;
}
} // namespace devilution