|
|
|
|
@ -3,7 +3,9 @@ package sqlite3
|
|
|
|
|
import ( |
|
|
|
|
"context" |
|
|
|
|
"io" |
|
|
|
|
"iter" |
|
|
|
|
"sync" |
|
|
|
|
"sync/atomic" |
|
|
|
|
|
|
|
|
|
"github.com/tetratelabs/wazero/api" |
|
|
|
|
|
|
|
|
|
@ -45,7 +47,7 @@ func (c Conn) AnyCollationNeeded() error {
|
|
|
|
|
// CreateCollation defines a new collating sequence.
|
|
|
|
|
//
|
|
|
|
|
// https://sqlite.org/c3ref/create_collation.html
|
|
|
|
|
func (c *Conn) CreateCollation(name string, fn func(a, b []byte) int) error { |
|
|
|
|
func (c *Conn) CreateCollation(name string, fn CollatingFunction) error { |
|
|
|
|
var funcPtr ptr_t |
|
|
|
|
defer c.arena.mark()() |
|
|
|
|
namePtr := c.arena.string(name) |
|
|
|
|
@ -57,6 +59,10 @@ func (c *Conn) CreateCollation(name string, fn func(a, b []byte) int) error {
|
|
|
|
|
return c.error(rc) |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
// Collating function is the type of a collation callback.
|
|
|
|
|
// Implementations must not retain a or b.
|
|
|
|
|
type CollatingFunction func(a, b []byte) int |
|
|
|
|
|
|
|
|
|
// CreateFunction defines a new scalar SQL function.
|
|
|
|
|
//
|
|
|
|
|
// https://sqlite.org/c3ref/create_function.html
|
|
|
|
|
@ -77,34 +83,67 @@ func (c *Conn) CreateFunction(name string, nArg int, flag FunctionFlag, fn Scala
|
|
|
|
|
// Implementations must not retain arg.
|
|
|
|
|
type ScalarFunction func(ctx Context, arg ...Value) |
|
|
|
|
|
|
|
|
|
// CreateAggregateFunction defines a new aggregate SQL function.
|
|
|
|
|
//
|
|
|
|
|
// https://sqlite.org/c3ref/create_function.html
|
|
|
|
|
func (c *Conn) CreateAggregateFunction(name string, nArg int, flag FunctionFlag, fn AggregateSeqFunction) error { |
|
|
|
|
var funcPtr ptr_t |
|
|
|
|
defer c.arena.mark()() |
|
|
|
|
namePtr := c.arena.string(name) |
|
|
|
|
if fn != nil { |
|
|
|
|
funcPtr = util.AddHandle(c.ctx, AggregateConstructor(func() AggregateFunction { |
|
|
|
|
var a aggregateFunc |
|
|
|
|
coro := func(yieldCoro func(struct{}) bool) { |
|
|
|
|
seq := func(yieldSeq func([]Value) bool) { |
|
|
|
|
for yieldSeq(a.arg) { |
|
|
|
|
if !yieldCoro(struct{}{}) { |
|
|
|
|
break |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
fn(&a.ctx, seq) |
|
|
|
|
} |
|
|
|
|
a.next, a.stop = iter.Pull(coro) |
|
|
|
|
return &a |
|
|
|
|
})) |
|
|
|
|
} |
|
|
|
|
rc := res_t(c.call("sqlite3_create_aggregate_function_go", |
|
|
|
|
stk_t(c.handle), stk_t(namePtr), stk_t(nArg), |
|
|
|
|
stk_t(flag), stk_t(funcPtr))) |
|
|
|
|
return c.error(rc) |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
// AggregateSeqFunction is the type of an aggregate SQL function.
|
|
|
|
|
// Implementations must not retain the slices yielded by seq.
|
|
|
|
|
type AggregateSeqFunction func(ctx *Context, seq iter.Seq[[]Value]) |
|
|
|
|
|
|
|
|
|
// CreateWindowFunction defines a new aggregate or aggregate window SQL function.
|
|
|
|
|
// If fn returns a [WindowFunction], then an aggregate window function is created.
|
|
|
|
|
// If fn returns a [WindowFunction], an aggregate window function is created.
|
|
|
|
|
// If fn returns an [io.Closer], it will be called to free resources.
|
|
|
|
|
//
|
|
|
|
|
// https://sqlite.org/c3ref/create_function.html
|
|
|
|
|
func (c *Conn) CreateWindowFunction(name string, nArg int, flag FunctionFlag, fn func() AggregateFunction) error { |
|
|
|
|
func (c *Conn) CreateWindowFunction(name string, nArg int, flag FunctionFlag, fn AggregateConstructor) error { |
|
|
|
|
var funcPtr ptr_t |
|
|
|
|
defer c.arena.mark()() |
|
|
|
|
namePtr := c.arena.string(name) |
|
|
|
|
call := "sqlite3_create_aggregate_function_go" |
|
|
|
|
if fn != nil { |
|
|
|
|
agg := fn() |
|
|
|
|
if c, ok := agg.(io.Closer); ok { |
|
|
|
|
if err := c.Close(); err != nil { |
|
|
|
|
return err |
|
|
|
|
funcPtr = util.AddHandle(c.ctx, AggregateConstructor(func() AggregateFunction { |
|
|
|
|
agg := fn() |
|
|
|
|
if win, ok := agg.(WindowFunction); ok { |
|
|
|
|
return win |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
if _, ok := agg.(WindowFunction); ok { |
|
|
|
|
call = "sqlite3_create_window_function_go" |
|
|
|
|
} |
|
|
|
|
funcPtr = util.AddHandle(c.ctx, fn) |
|
|
|
|
return windowFunc{agg, name} |
|
|
|
|
})) |
|
|
|
|
} |
|
|
|
|
rc := res_t(c.call(call, |
|
|
|
|
rc := res_t(c.call("sqlite3_create_window_function_go", |
|
|
|
|
stk_t(c.handle), stk_t(namePtr), stk_t(nArg), |
|
|
|
|
stk_t(flag), stk_t(funcPtr))) |
|
|
|
|
return c.error(rc) |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
// AggregateConstructor is a an [AggregateFunction] constructor.
|
|
|
|
|
type AggregateConstructor func() AggregateFunction |
|
|
|
|
|
|
|
|
|
// AggregateFunction is the interface an aggregate function should implement.
|
|
|
|
|
//
|
|
|
|
|
// https://sqlite.org/appfunc.html
|
|
|
|
|
@ -153,26 +192,24 @@ func collationCallback(ctx context.Context, mod api.Module, pArg, pDB ptr_t, eTe
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
func compareCallback(ctx context.Context, mod api.Module, pApp ptr_t, nKey1 int32, pKey1 ptr_t, nKey2 int32, pKey2 ptr_t) uint32 { |
|
|
|
|
fn := util.GetHandle(ctx, pApp).(func(a, b []byte) int) |
|
|
|
|
fn := util.GetHandle(ctx, pApp).(CollatingFunction) |
|
|
|
|
return uint32(fn(util.View(mod, pKey1, int64(nKey1)), util.View(mod, pKey2, int64(nKey2)))) |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
func funcCallback(ctx context.Context, mod api.Module, pCtx, pApp ptr_t, nArg int32, pArg ptr_t) { |
|
|
|
|
args := getFuncArgs() |
|
|
|
|
defer putFuncArgs(args) |
|
|
|
|
db := ctx.Value(connKey{}).(*Conn) |
|
|
|
|
args := callbackArgs(db, nArg, pArg) |
|
|
|
|
defer returnArgs(args) |
|
|
|
|
fn := util.GetHandle(db.ctx, pApp).(ScalarFunction) |
|
|
|
|
callbackArgs(db, args[:nArg], pArg) |
|
|
|
|
fn(Context{db, pCtx}, args[:nArg]...) |
|
|
|
|
fn(Context{db, pCtx}, *args...) |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
func stepCallback(ctx context.Context, mod api.Module, pCtx, pAgg, pApp ptr_t, nArg int32, pArg ptr_t) { |
|
|
|
|
args := getFuncArgs() |
|
|
|
|
defer putFuncArgs(args) |
|
|
|
|
db := ctx.Value(connKey{}).(*Conn) |
|
|
|
|
callbackArgs(db, args[:nArg], pArg) |
|
|
|
|
args := callbackArgs(db, nArg, pArg) |
|
|
|
|
defer returnArgs(args) |
|
|
|
|
fn, _ := callbackAggregate(db, pAgg, pApp) |
|
|
|
|
fn.Step(Context{db, pCtx}, args[:nArg]...) |
|
|
|
|
fn.Step(Context{db, pCtx}, *args...) |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
func valueCallback(ctx context.Context, mod api.Module, pCtx, pAgg, pApp ptr_t, final int32) { |
|
|
|
|
@ -196,12 +233,11 @@ func valueCallback(ctx context.Context, mod api.Module, pCtx, pAgg, pApp ptr_t,
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
func inverseCallback(ctx context.Context, mod api.Module, pCtx, pAgg ptr_t, nArg int32, pArg ptr_t) { |
|
|
|
|
args := getFuncArgs() |
|
|
|
|
defer putFuncArgs(args) |
|
|
|
|
db := ctx.Value(connKey{}).(*Conn) |
|
|
|
|
callbackArgs(db, args[:nArg], pArg) |
|
|
|
|
args := callbackArgs(db, nArg, pArg) |
|
|
|
|
defer returnArgs(args) |
|
|
|
|
fn := util.GetHandle(db.ctx, pAgg).(WindowFunction) |
|
|
|
|
fn.Inverse(Context{db, pCtx}, args[:nArg]...) |
|
|
|
|
fn.Inverse(Context{db, pCtx}, *args...) |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
func callbackAggregate(db *Conn, pAgg, pApp ptr_t) (AggregateFunction, ptr_t) { |
|
|
|
|
@ -211,7 +247,7 @@ func callbackAggregate(db *Conn, pAgg, pApp ptr_t) (AggregateFunction, ptr_t) {
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
// We need to create the aggregate.
|
|
|
|
|
fn := util.GetHandle(db.ctx, pApp).(func() AggregateFunction)() |
|
|
|
|
fn := util.GetHandle(db.ctx, pApp).(AggregateConstructor)() |
|
|
|
|
if pAgg != 0 { |
|
|
|
|
handle := util.AddHandle(db.ctx, fn) |
|
|
|
|
util.Write32(db.mod, pAgg, handle) |
|
|
|
|
@ -220,25 +256,64 @@ func callbackAggregate(db *Conn, pAgg, pApp ptr_t) (AggregateFunction, ptr_t) {
|
|
|
|
|
return fn, 0 |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
func callbackArgs(db *Conn, arg []Value, pArg ptr_t) { |
|
|
|
|
for i := range arg { |
|
|
|
|
arg[i] = Value{ |
|
|
|
|
var ( |
|
|
|
|
valueArgsPool sync.Pool |
|
|
|
|
valueArgsLen atomic.Int32 |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
func callbackArgs(db *Conn, nArg int32, pArg ptr_t) *[]Value { |
|
|
|
|
arg, ok := valueArgsPool.Get().(*[]Value) |
|
|
|
|
if !ok || cap(*arg) < int(nArg) { |
|
|
|
|
max := valueArgsLen.Or(nArg) | nArg |
|
|
|
|
lst := make([]Value, max) |
|
|
|
|
arg = &lst |
|
|
|
|
} |
|
|
|
|
lst := (*arg)[:nArg] |
|
|
|
|
for i := range lst { |
|
|
|
|
lst[i] = Value{ |
|
|
|
|
c: db, |
|
|
|
|
handle: util.Read32[ptr_t](db.mod, pArg+ptr_t(i)*ptrlen), |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
*arg = lst |
|
|
|
|
return arg |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
var funcArgsPool sync.Pool |
|
|
|
|
func returnArgs(p *[]Value) { |
|
|
|
|
valueArgsPool.Put(p) |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
func putFuncArgs(p *[_MAX_FUNCTION_ARG]Value) { |
|
|
|
|
funcArgsPool.Put(p) |
|
|
|
|
type aggregateFunc struct { |
|
|
|
|
next func() (struct{}, bool) |
|
|
|
|
stop func() |
|
|
|
|
ctx Context |
|
|
|
|
arg []Value |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
func getFuncArgs() *[_MAX_FUNCTION_ARG]Value { |
|
|
|
|
if p := funcArgsPool.Get(); p == nil { |
|
|
|
|
return new([_MAX_FUNCTION_ARG]Value) |
|
|
|
|
} else { |
|
|
|
|
return p.(*[_MAX_FUNCTION_ARG]Value) |
|
|
|
|
func (a *aggregateFunc) Step(ctx Context, arg ...Value) { |
|
|
|
|
a.ctx = ctx |
|
|
|
|
a.arg = append(a.arg[:0], arg...) |
|
|
|
|
if _, more := a.next(); !more { |
|
|
|
|
a.stop() |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
func (a *aggregateFunc) Value(ctx Context) { |
|
|
|
|
a.ctx = ctx |
|
|
|
|
a.stop() |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
func (a *aggregateFunc) Close() error { |
|
|
|
|
a.stop() |
|
|
|
|
return nil |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
type windowFunc struct { |
|
|
|
|
AggregateFunction |
|
|
|
|
name string |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
func (w windowFunc) Inverse(ctx Context, arg ...Value) { |
|
|
|
|
// Implementing inverse allows certain queries that don't really need it to succeed.
|
|
|
|
|
ctx.ResultError(util.ErrorString(w.name + ": may not be used as a window function")) |
|
|
|
|
} |
|
|
|
|
|