Browse Source

Add CEL integration with cost estimation and error handling

Signed-off-by: maksim.nabokikh <max.nabokih@gmail.com>
pull/4607/head
maksim.nabokikh 2 weeks ago
parent
commit
4cc4a491e2
  1. 8
      pkg/cel/cost.go
  2. 3
      pkg/cel/cost_test.go
  3. 8
      pkg/cel/library/groups.go
  4. 17
      pkg/cel/library/groups_test.go

8
pkg/cel/cost.go

@ -1,6 +1,8 @@
package cel
import (
"fmt"
"github.com/google/cel-go/checker"
)
@ -28,13 +30,13 @@ type CostEstimate struct {
// EstimateCost returns the estimated cost range for a compiled expression.
// This is computed statically at compile time without evaluating the expression.
func (c *Compiler) EstimateCost(result *CompilationResult) CostEstimate {
func (c *Compiler) EstimateCost(result *CompilationResult) (CostEstimate, error) {
costEst, err := c.env.EstimateCost(result.ast, &defaultCostEstimator{})
if err != nil {
return CostEstimate{}
return CostEstimate{}, fmt.Errorf("CEL cost estimation failed: %w", err)
}
return CostEstimate{Min: costEst.Min, Max: costEst.Max}
return CostEstimate{Min: costEst.Min, Max: costEst.Max}, nil
}
// defaultCostEstimator provides size hints for compile-time cost estimation.

3
pkg/cel/cost_test.go

@ -33,7 +33,8 @@ func TestEstimateCost(t *testing.T) {
prog, err := compiler.Compile(tc.expr)
require.NoError(t, err)
est := compiler.EstimateCost(prog)
est, err := compiler.EstimateCost(prog)
require.NoError(t, err)
assert.True(t, est.Max >= est.Min, "max cost should be >= min cost")
assert.True(t, est.Max <= dexcel.DefaultCostBudget,
"estimated max cost %d should be within default budget %d", est.Max, dexcel.DefaultCostBudget)

8
pkg/cel/library/groups.go

@ -1,7 +1,7 @@
package library
import (
"path/filepath"
"path"
"github.com/google/cel-go/cel"
"github.com/google/cel-go/common/types"
@ -67,7 +67,11 @@ func groupMatchesImpl(lhs, rhs ref.Val) ref.Val {
continue
}
if ok, _ := filepath.Match(pattern, group); ok {
ok, err := path.Match(pattern, group)
if err != nil {
return types.NewErr("dex.groupMatches: invalid pattern %q: %v", pattern, err)
}
if ok {
matched = append(matched, types.String(group))
}
}

17
pkg/cel/library/groups_test.go

@ -70,6 +70,23 @@ func TestGroupMatches(t *testing.T) {
}
}
func TestGroupMatchesInvalidPattern(t *testing.T) {
vars := dexcel.IdentityVariables()
compiler, err := dexcel.NewCompiler(vars)
require.NoError(t, err)
prog, err := compiler.CompileStringList(`dex.groupMatches(identity.groups, "[invalid")`)
require.NoError(t, err)
_, err = dexcel.Eval(context.Background(), prog, map[string]any{
"identity": map[string]any{
"groups": []string{"admin"},
},
})
require.Error(t, err)
assert.Contains(t, err.Error(), "invalid pattern")
}
func TestGroupFilter(t *testing.T) {
vars := dexcel.IdentityVariables()
compiler, err := dexcel.NewCompiler(vars)

Loading…
Cancel
Save