Browse Source

feat(cel): implement CEL compiler with expression evaluation and cost estimation

Signed-off-by: maksim.nabokikh <max.nabokih@gmail.com>
pull/4607/head
maksim.nabokikh 2 weeks ago
parent
commit
4d9fb62ac2
  1. 5
      go.mod
  2. 11
      go.sum
  3. 223
      pkg/cel/cel.go
  4. 270
      pkg/cel/cel_test.go
  5. 98
      pkg/cel/cost.go
  6. 136
      pkg/cel/cost_test.go
  7. 5
      pkg/cel/doc.go
  8. 4
      pkg/cel/library/doc.go
  9. 73
      pkg/cel/library/email.go
  10. 108
      pkg/cel/library/email_test.go
  11. 119
      pkg/cel/library/groups.go
  12. 130
      pkg/cel/library/groups_test.go
  13. 80
      pkg/cel/types.go
  14. 3
      pkg/featureflags/doc.go
  15. 2
      pkg/groups/doc.go
  16. 1
      pkg/groups/groups.go
  17. 3
      pkg/httpclient/doc.go

5
go.mod

@ -16,6 +16,7 @@ require (
github.com/go-jose/go-jose/v4 v4.1.3
github.com/go-ldap/ldap/v3 v3.4.12
github.com/go-sql-driver/mysql v1.9.3
github.com/google/cel-go v0.27.0
github.com/google/uuid v1.6.0
github.com/gorilla/handlers v1.5.2
github.com/gorilla/mux v1.8.1
@ -34,7 +35,7 @@ require (
go.etcd.io/etcd/client/pkg/v3 v3.6.8
go.etcd.io/etcd/client/v3 v3.6.8
golang.org/x/crypto v0.49.0
golang.org/x/exp v0.0.0-20221004215720-b9f4876ce741
golang.org/x/exp v0.0.0-20240823005443-9b4947da3948
golang.org/x/net v0.52.0
golang.org/x/oauth2 v0.36.0
google.golang.org/api v0.271.0
@ -44,6 +45,7 @@ require (
require (
ariga.io/atlas v0.32.1-0.20250325101103-175b25e1c1b9 // indirect
cel.dev/expr v0.25.1 // indirect
cloud.google.com/go/auth v0.18.2 // indirect
cloud.google.com/go/auth/oauth2adapt v0.2.8 // indirect
dario.cat/mergo v1.0.1 // indirect
@ -52,6 +54,7 @@ require (
github.com/Masterminds/goutils v1.1.1 // indirect
github.com/Masterminds/semver/v3 v3.3.0 // indirect
github.com/agext/levenshtein v1.2.3 // indirect
github.com/antlr4-go/antlr/v4 v4.13.1 // indirect
github.com/apparentlymart/go-textseg/v15 v15.0.0 // indirect
github.com/beorn7/perks v1.0.1 // indirect
github.com/bmatcuk/doublestar v1.3.4 // indirect

11
go.sum

@ -1,5 +1,7 @@
ariga.io/atlas v0.32.1-0.20250325101103-175b25e1c1b9 h1:E0wvcUXTkgyN4wy4LGtNzMNGMytJN8afmIWXJVMi4cc=
ariga.io/atlas v0.32.1-0.20250325101103-175b25e1c1b9/go.mod h1:Oe1xWPuu5q9LzyrWfbZmEZxFYeu4BHTyzfjeW2aZp/w=
cel.dev/expr v0.25.1 h1:1KrZg61W6TWSxuNZ37Xy49ps13NUovb66QLprthtwi4=
cel.dev/expr v0.25.1/go.mod h1:hrXvqGP6G6gyx8UAHSHJ5RGk//1Oj5nXQ2NI02Nrsg4=
cloud.google.com/go/auth v0.18.2 h1:+Nbt5Ev0xEqxlNjd6c+yYUeosQ5TtEUaNcN/3FozlaM=
cloud.google.com/go/auth v0.18.2/go.mod h1:xD+oY7gcahcu7G2SG2DsBerfFxgPAJz17zz2joOFF3M=
cloud.google.com/go/auth/oauth2adapt v0.2.8 h1:keo8NaayQZ6wimpNSmW5OPc283g65QNIiLpZnkHRbnc=
@ -30,6 +32,8 @@ github.com/agext/levenshtein v1.2.3 h1:YB2fHEn0UJagG8T1rrWknE3ZQzWM06O8AMAatNn7l
github.com/agext/levenshtein v1.2.3/go.mod h1:JEDfjyjHDjOF/1e4FlBE/PkbqA9OfWu2ki2W0IB5558=
github.com/alexbrainman/sspi v0.0.0-20250919150558-7d374ff0d59e h1:4dAU9FXIyQktpoUAgOJK3OTFc/xug0PCXYCqU0FgDKI=
github.com/alexbrainman/sspi v0.0.0-20250919150558-7d374ff0d59e/go.mod h1:cEWa1LVoE5KvSD9ONXsZrj0z6KqySlCCNKHlLzbqAt4=
github.com/antlr4-go/antlr/v4 v4.13.1 h1:SqQKkuVZ+zWkMMNkjy5FZe5mr5WURWnlpmOuzYWrPrQ=
github.com/antlr4-go/antlr/v4 v4.13.1/go.mod h1:GKmUxMtwp6ZgGwZSva4eWPC5mS6vUAmOABFgjdkM7Nw=
github.com/apparentlymart/go-textseg/v15 v15.0.0 h1:uYvfpb3DyLSCGWnctWKGj857c6ew1u1fNQOlOtuGxQY=
github.com/apparentlymart/go-textseg/v15 v15.0.0/go.mod h1:K8XmNZdhEBkdlyDdvbmmsvpAG721bKi0joRfFdHIWJ4=
github.com/beevik/etree v1.6.0 h1:u8Kwy8pp9D9XeITj2Z0XtA5qqZEmtJtuXZRQi+j03eE=
@ -89,6 +93,8 @@ github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q=
github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q=
github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek=
github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps=
github.com/google/cel-go v0.27.0 h1:e7ih85+4qVrBuqQWTW4FKSqZYokVuc3HnhH5keboFTo=
github.com/google/cel-go v0.27.0/go.mod h1:tTJ11FWqnhw5KKpnWpvW9CJC3Y9GK4EIS0WXnBbebzw=
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
github.com/google/s2a-go v0.1.9 h1:LGD7gtMgezd8a/Xak7mEWL0PjoTQFvpRudN895yqKW0=
@ -264,14 +270,15 @@ go.uber.org/zap v1.27.0 h1:aJMhYGrd5QSmlpLMr2MftRKl7t8J8PTZPA732ud/XR8=
go.uber.org/zap v1.27.0/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E=
go.yaml.in/yaml/v2 v2.4.2 h1:DzmwEr2rDGHl7lsFgAHxmNz/1NlQ7xLIrlN2h5d1eGI=
go.yaml.in/yaml/v2 v2.4.2/go.mod h1:081UH+NErpNdqlCXm3TtEran0rJZGxAYx9hb/ELlsPU=
go.yaml.in/yaml/v3 v3.0.4 h1:tfq32ie2Jv2UxXFdLJdh3jXuOzWiL1fo0bu/FbuKpbc=
go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.49.0 h1:+Ng2ULVvLHnJ/ZFEq4KdcDd/cfjrrjjNSXNzxg0Y4U4=
golang.org/x/crypto v0.49.0/go.mod h1:ErX4dUh2UM+CFYiXZRTcMpEcN8b/1gxEuv3nODoYtCA=
golang.org/x/exp v0.0.0-20221004215720-b9f4876ce741 h1:fGZugkZk2UgYBxtpKmvub51Yno1LJDeEsRp2xGD+0gY=
golang.org/x/exp v0.0.0-20221004215720-b9f4876ce741/go.mod h1:cyybsKvd6eL0RnXn6p/Grxp8F5bW7iYuBgsNCOHpMYE=
golang.org/x/exp v0.0.0-20240823005443-9b4947da3948 h1:kx6Ds3MlpiUHKj7syVnbp57++8WpuKPcR5yjLBjvLEA=
golang.org/x/exp v0.0.0-20240823005443-9b4947da3948/go.mod h1:akd2r19cwCdwSwWeIdzYQGa/EZZyqcOdwWiwj5L5eKQ=
golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.33.0 h1:tHFzIWbBifEmbwtGz65eaWyGiGZatSrT9prnU8DbVL8=

223
pkg/cel/cel.go

@ -0,0 +1,223 @@
package cel
import (
"context"
"fmt"
"github.com/google/cel-go/cel"
"github.com/google/cel-go/checker"
"github.com/google/cel-go/common/types/ref"
"github.com/google/cel-go/ext"
"github.com/dexidp/dex/pkg/cel/library"
)
// EnvironmentVersion represents the version of the CEL environment.
// New variables, functions, or libraries are introduced in new versions.
type EnvironmentVersion uint32
const (
// EnvironmentV1 is the initial CEL environment.
EnvironmentV1 EnvironmentVersion = 1
)
// CompilationResult holds a compiled CEL program ready for evaluation.
type CompilationResult struct {
Program cel.Program
OutputType *cel.Type
Expression string
ast *cel.Ast
}
// CompilerOption configures a Compiler.
type CompilerOption func(*compilerConfig)
type compilerConfig struct {
costBudget uint64
version EnvironmentVersion
}
func defaultCompilerConfig() *compilerConfig {
return &compilerConfig{
costBudget: DefaultCostBudget,
version: EnvironmentV1,
}
}
// WithCostBudget sets a custom cost budget for expression evaluation.
func WithCostBudget(budget uint64) CompilerOption {
return func(cfg *compilerConfig) {
cfg.costBudget = budget
}
}
// WithVersion sets the target environment version for the compiler.
// Defaults to the latest version. Specifying an older version ensures
// that only functions/types available at that version are used.
func WithVersion(v EnvironmentVersion) CompilerOption {
return func(cfg *compilerConfig) {
cfg.version = v
}
}
// Compiler compiles CEL expressions against a specific environment.
type Compiler struct {
env *cel.Env
cfg *compilerConfig
}
// NewCompiler creates a new CEL compiler with the specified variable
// declarations and options.
//
// All custom Dex libraries are automatically included.
// The environment is configured with cost limits and safe defaults.
func NewCompiler(variables []VariableDeclaration, opts ...CompilerOption) (*Compiler, error) {
cfg := defaultCompilerConfig()
for _, opt := range opts {
opt(cfg)
}
envOpts := make([]cel.EnvOption, 0, 8+len(variables))
envOpts = append(envOpts,
cel.DefaultUTCTimeZone(true),
// Standard extension libraries (same set as Kubernetes)
ext.Strings(),
ext.Encoders(),
ext.Lists(),
ext.Sets(),
ext.Math(),
// Custom Dex libraries
cel.Lib(&library.Email{}),
cel.Lib(&library.Groups{}),
// Presence tests like has(field) and 'key' in map are O(1) hash
// lookups on map(string, dyn) variables, so they should not count
// toward the cost budget. Without this, expressions with multiple
// 'in' checks (e.g. "'admin' in identity.groups") would accumulate
// inflated cost estimates. This matches Kubernetes CEL behavior
// where presence tests are free for CRD validation rules.
cel.CostEstimatorOptions(
checker.PresenceTestHasCost(false),
),
)
for _, v := range variables {
envOpts = append(envOpts, cel.Variable(v.Name, v.Type))
}
env, err := cel.NewEnv(envOpts...)
if err != nil {
return nil, fmt.Errorf("failed to create CEL environment: %w", err)
}
return &Compiler{env: env, cfg: cfg}, nil
}
// CompileBool compiles a CEL expression that must evaluate to bool.
func (c *Compiler) CompileBool(expression string) (*CompilationResult, error) {
return c.compile(expression, cel.BoolType)
}
// CompileString compiles a CEL expression that must evaluate to string.
func (c *Compiler) CompileString(expression string) (*CompilationResult, error) {
return c.compile(expression, cel.StringType)
}
// CompileStringList compiles a CEL expression that must evaluate to list(string).
func (c *Compiler) CompileStringList(expression string) (*CompilationResult, error) {
return c.compile(expression, cel.ListType(cel.StringType))
}
// Compile compiles a CEL expression with any output type.
func (c *Compiler) Compile(expression string) (*CompilationResult, error) {
return c.compile(expression, nil)
}
func (c *Compiler) compile(expression string, expectedType *cel.Type) (*CompilationResult, error) {
if len(expression) > MaxExpressionLength {
return nil, fmt.Errorf("expression exceeds maximum length of %d characters", MaxExpressionLength)
}
ast, issues := c.env.Compile(expression)
if issues != nil && issues.Err() != nil {
return nil, fmt.Errorf("CEL compilation failed: %w", issues.Err())
}
if expectedType != nil && !ast.OutputType().IsEquivalentType(expectedType) {
return nil, fmt.Errorf(
"expected expression output type %s, got %s",
expectedType, ast.OutputType(),
)
}
// Estimate cost at compile time and reject expressions that are too expensive.
costEst, err := c.env.EstimateCost(ast, &defaultCostEstimator{})
if err != nil {
return nil, fmt.Errorf("CEL cost estimation failed: %w", err)
}
if costEst.Max > c.cfg.costBudget {
return nil, fmt.Errorf(
"CEL expression estimated cost %d exceeds budget %d",
costEst.Max, c.cfg.costBudget,
)
}
prog, err := c.env.Program(ast,
cel.EvalOptions(cel.OptOptimize),
cel.CostLimit(c.cfg.costBudget),
)
if err != nil {
return nil, fmt.Errorf("CEL program creation failed: %w", err)
}
return &CompilationResult{
Program: prog,
OutputType: ast.OutputType(),
Expression: expression,
ast: ast,
}, nil
}
// Eval evaluates a compiled program against the given variables.
func Eval(ctx context.Context, result *CompilationResult, variables map[string]any) (ref.Val, error) {
out, _, err := result.Program.ContextEval(ctx, variables)
if err != nil {
return nil, fmt.Errorf("CEL evaluation failed: %w", err)
}
return out, nil
}
// EvalBool is a convenience function that evaluates and asserts bool output.
func EvalBool(ctx context.Context, result *CompilationResult, variables map[string]any) (bool, error) {
out, err := Eval(ctx, result, variables)
if err != nil {
return false, err
}
v, ok := out.Value().(bool)
if !ok {
return false, fmt.Errorf("expected bool result, got %T", out.Value())
}
return v, nil
}
// EvalString is a convenience function that evaluates and asserts string output.
func EvalString(ctx context.Context, result *CompilationResult, variables map[string]any) (string, error) {
out, err := Eval(ctx, result, variables)
if err != nil {
return "", err
}
v, ok := out.Value().(string)
if !ok {
return "", fmt.Errorf("expected string result, got %T", out.Value())
}
return v, nil
}

270
pkg/cel/cel_test.go

@ -0,0 +1,270 @@
package cel_test
import (
"context"
"strings"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/dexidp/dex/connector"
dexcel "github.com/dexidp/dex/pkg/cel"
)
func TestCompileBool(t *testing.T) {
compiler, err := dexcel.NewCompiler(nil)
require.NoError(t, err)
tests := map[string]struct {
expr string
wantErr bool
}{
"true literal": {
expr: "true",
},
"comparison": {
expr: "1 == 1",
},
"string type mismatch": {
expr: "'hello'",
wantErr: true,
},
"int type mismatch": {
expr: "42",
wantErr: true,
},
}
for name, tc := range tests {
t.Run(name, func(t *testing.T) {
result, err := compiler.CompileBool(tc.expr)
if tc.wantErr {
assert.Error(t, err)
assert.Nil(t, result)
} else {
assert.NoError(t, err)
assert.NotNil(t, result)
}
})
}
}
func TestCompileString(t *testing.T) {
compiler, err := dexcel.NewCompiler(nil)
require.NoError(t, err)
tests := map[string]struct {
expr string
wantErr bool
}{
"string literal": {
expr: "'hello'",
},
"string concatenation": {
expr: "'hello' + ' ' + 'world'",
},
"bool type mismatch": {
expr: "true",
wantErr: true,
},
}
for name, tc := range tests {
t.Run(name, func(t *testing.T) {
result, err := compiler.CompileString(tc.expr)
if tc.wantErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
assert.NotNil(t, result)
}
})
}
}
func TestCompileStringList(t *testing.T) {
compiler, err := dexcel.NewCompiler(nil)
require.NoError(t, err)
result, err := compiler.CompileStringList("['a', 'b', 'c']")
assert.NoError(t, err)
assert.NotNil(t, result)
_, err = compiler.CompileStringList("'not a list'")
assert.Error(t, err)
}
func TestCompile(t *testing.T) {
compiler, err := dexcel.NewCompiler(nil)
require.NoError(t, err)
// Compile accepts any type
result, err := compiler.Compile("true")
assert.NoError(t, err)
assert.NotNil(t, result)
result, err = compiler.Compile("'hello'")
assert.NoError(t, err)
assert.NotNil(t, result)
result, err = compiler.Compile("42")
assert.NoError(t, err)
assert.NotNil(t, result)
}
func TestCompileErrors(t *testing.T) {
compiler, err := dexcel.NewCompiler(nil)
require.NoError(t, err)
tests := map[string]struct {
expr string
}{
"syntax error": {
expr: "1 +",
},
"undefined variable": {
expr: "undefined_var",
},
"undefined function": {
expr: "undefinedFunc()",
},
}
for name, tc := range tests {
t.Run(name, func(t *testing.T) {
_, err := compiler.Compile(tc.expr)
assert.Error(t, err)
})
}
}
func TestMaxExpressionLength(t *testing.T) {
compiler, err := dexcel.NewCompiler(nil)
require.NoError(t, err)
longExpr := "'" + strings.Repeat("a", dexcel.MaxExpressionLength) + "'"
_, err = compiler.Compile(longExpr)
assert.Error(t, err)
assert.Contains(t, err.Error(), "maximum length")
}
func TestEvalBool(t *testing.T) {
vars := dexcel.IdentityVariables()
compiler, err := dexcel.NewCompiler(vars)
require.NoError(t, err)
tests := map[string]struct {
expr string
identity map[string]any
want bool
}{
"email endsWith": {
expr: "identity.email.endsWith('@example.com')",
identity: map[string]any{
"email": "user@example.com",
},
want: true,
},
"email endsWith false": {
expr: "identity.email.endsWith('@example.com')",
identity: map[string]any{
"email": "user@other.com",
},
want: false,
},
"email_verified": {
expr: "identity.email_verified == true",
identity: map[string]any{
"email_verified": true,
},
want: true,
},
"group membership": {
expr: "identity.groups.exists(g, g == 'admin')",
identity: map[string]any{
"groups": []string{"admin", "dev"},
},
want: true,
},
}
for name, tc := range tests {
t.Run(name, func(t *testing.T) {
prog, err := compiler.CompileBool(tc.expr)
require.NoError(t, err)
result, err := dexcel.EvalBool(context.Background(), prog, map[string]any{
"identity": tc.identity,
})
require.NoError(t, err)
assert.Equal(t, tc.want, result)
})
}
}
func TestEvalString(t *testing.T) {
vars := dexcel.IdentityVariables()
compiler, err := dexcel.NewCompiler(vars)
require.NoError(t, err)
// identity.email returns dyn from map access, use Compile (not CompileString)
prog, err := compiler.Compile("identity.email")
require.NoError(t, err)
result, err := dexcel.EvalString(context.Background(), prog, map[string]any{
"identity": map[string]any{
"email": "user@example.com",
},
})
require.NoError(t, err)
assert.Equal(t, "user@example.com", result)
}
func TestEvalWithIdentityAndRequest(t *testing.T) {
vars := append(dexcel.IdentityVariables(), dexcel.RequestVariables()...)
compiler, err := dexcel.NewCompiler(vars)
require.NoError(t, err)
prog, err := compiler.CompileBool(
`identity.email.endsWith('@example.com') && 'admin' in identity.groups && request.connector_id == 'okta'`,
)
require.NoError(t, err)
identity := dexcel.IdentityFromConnector(connector.Identity{
UserID: "123",
Username: "john",
Email: "john@example.com",
Groups: []string{"admin", "dev"},
})
request := dexcel.RequestFromContext(dexcel.RequestContext{
ClientID: "my-app",
ConnectorID: "okta",
Scopes: []string{"openid", "email"},
})
result, err := dexcel.EvalBool(context.Background(), prog, map[string]any{
"identity": identity,
"request": request,
})
require.NoError(t, err)
assert.True(t, result)
}
func TestNewCompilerWithVariables(t *testing.T) {
// Claims variable
compiler, err := dexcel.NewCompiler(dexcel.ClaimsVariable())
require.NoError(t, err)
// claims.email returns dyn from map access, use Compile (not CompileString)
prog, err := compiler.Compile("claims.email")
require.NoError(t, err)
result, err := dexcel.EvalString(context.Background(), prog, map[string]any{
"claims": map[string]any{
"email": "test@example.com",
},
})
require.NoError(t, err)
assert.Equal(t, "test@example.com", result)
}

98
pkg/cel/cost.go

@ -0,0 +1,98 @@
package cel
import (
"github.com/google/cel-go/checker"
)
// DefaultCostBudget is the default cost budget for a single expression
// evaluation. Aligned with Kubernetes defaults: enough for typical identity
// operations but prevents runaway expressions.
const DefaultCostBudget uint64 = 10_000_000
// MaxExpressionLength is the maximum length of a CEL expression string.
const MaxExpressionLength = 10_240
// DefaultStringMaxLength is the estimated max length of string values
// (emails, usernames, group names, etc.) used for compile-time cost estimation.
const DefaultStringMaxLength = 256
// DefaultListMaxLength is the estimated max length of list values
// (groups, scopes) used for compile-time cost estimation.
const DefaultListMaxLength = 100
// CostEstimate holds the estimated cost range for a compiled expression.
type CostEstimate struct {
Min uint64
Max uint64
}
// EstimateCost returns the estimated cost range for a compiled expression.
// This is computed statically at compile time without evaluating the expression.
func (c *Compiler) EstimateCost(result *CompilationResult) CostEstimate {
costEst, err := c.env.EstimateCost(result.ast, &defaultCostEstimator{})
if err != nil {
return CostEstimate{}
}
return CostEstimate{Min: costEst.Min, Max: costEst.Max}
}
// defaultCostEstimator provides size hints for compile-time cost estimation.
// Without these hints, the CEL cost estimator assumes unbounded sizes for
// variables, leading to wildly overestimated max costs.
type defaultCostEstimator struct{}
func (defaultCostEstimator) EstimateSize(element checker.AstNode) *checker.SizeEstimate {
// Provide size hints for map(string, dyn) variables: identity, request, claims.
// Without these, the estimator assumes lists/strings can be infinitely large.
if element.Path() == nil {
return nil
}
path := element.Path()
if len(path) == 0 {
return nil
}
root := path[0]
switch root {
case "identity", "request", "claims":
// Nested field access (e.g. identity.email, identity.groups)
if len(path) >= 2 {
field := path[1]
switch field {
case "groups", "scopes":
return &checker.SizeEstimate{Min: 0, Max: DefaultListMaxLength}
default:
return &checker.SizeEstimate{Min: 0, Max: DefaultStringMaxLength}
}
}
// The map itself: number of keys
return &checker.SizeEstimate{Min: 0, Max: 20}
}
return nil
}
func (defaultCostEstimator) EstimateCallCost(function, overloadID string, target *checker.AstNode, args []checker.AstNode) *checker.CallEstimate {
switch function {
case "dex.emailDomain", "dex.emailLocalPart":
// Simple string split — O(n) where n is string length, bounded.
return &checker.CallEstimate{
CostEstimate: checker.CostEstimate{Min: 1, Max: 2},
}
case "dex.groupMatches":
// Iterates over groups list and matches each against a pattern.
return &checker.CallEstimate{
CostEstimate: checker.CostEstimate{Min: 1, Max: DefaultListMaxLength},
}
case "dex.groupFilter":
// Builds a set from allowed list, then iterates groups.
return &checker.CallEstimate{
CostEstimate: checker.CostEstimate{Min: 1, Max: 2 * DefaultListMaxLength},
}
}
return nil
}

136
pkg/cel/cost_test.go

@ -0,0 +1,136 @@
package cel_test
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
dexcel "github.com/dexidp/dex/pkg/cel"
)
func TestEstimateCost(t *testing.T) {
vars := dexcel.IdentityVariables()
compiler, err := dexcel.NewCompiler(vars)
require.NoError(t, err)
tests := map[string]struct {
expr string
}{
"simple bool": {
expr: "true",
},
"string comparison": {
expr: "identity.email == 'test@example.com'",
},
"group membership": {
expr: "identity.groups.exists(g, g == 'admin')",
},
}
for name, tc := range tests {
t.Run(name, func(t *testing.T) {
prog, err := compiler.Compile(tc.expr)
require.NoError(t, err)
est := compiler.EstimateCost(prog)
assert.True(t, est.Max >= est.Min, "max cost should be >= min cost")
assert.True(t, est.Max <= dexcel.DefaultCostBudget,
"estimated max cost %d should be within default budget %d", est.Max, dexcel.DefaultCostBudget)
})
}
}
func TestCompileTimeCostAcceptsSimpleExpressions(t *testing.T) {
vars := append(dexcel.IdentityVariables(), dexcel.RequestVariables()...)
compiler, err := dexcel.NewCompiler(vars)
require.NoError(t, err)
tests := map[string]string{
"literal": "true",
"email endsWith": "identity.email.endsWith('@example.com')",
"group check": "'admin' in identity.groups",
"emailDomain": `dex.emailDomain(identity.email)`,
"groupMatches": `dex.groupMatches(identity.groups, "team:*")`,
"groupFilter": `dex.groupFilter(identity.groups, ["admin", "dev"])`,
"combined policy": `identity.email.endsWith('@example.com') && 'admin' in identity.groups`,
"complex policy": `identity.email.endsWith('@example.com') &&
identity.groups.exists(g, g == 'admin') &&
request.connector_id == 'okta' &&
request.scopes.exists(s, s == 'openid')`,
"filter+map chain": `identity.groups
.filter(g, g.startsWith('team:'))
.map(g, g.replace('team:', ''))
.size() > 0`,
}
for name, expr := range tests {
t.Run(name, func(t *testing.T) {
_, err := compiler.Compile(expr)
assert.NoError(t, err, "expression should compile within default budget")
})
}
}
func TestCompileTimeCostRejection(t *testing.T) {
vars := append(dexcel.IdentityVariables(), dexcel.RequestVariables()...)
tests := map[string]struct {
budget uint64
expr string
}{
"simple exists exceeds tiny budget": {
budget: 1,
expr: "identity.groups.exists(g, g == 'admin')",
},
"endsWith exceeds tiny budget": {
budget: 2,
expr: "identity.email.endsWith('@example.com')",
},
"nested comprehension over groups exceeds moderate budget": {
// Two nested iterations over groups: O(n^2) where n=100 → ~280K
budget: 10_000,
expr: `identity.groups.exists(g1,
identity.groups.exists(g2,
g1 != g2 && g1.startsWith(g2)
)
)`,
},
"cross-variable comprehension exceeds moderate budget": {
// filter groups then check each against scopes: O(n*m) → ~162K
budget: 10_000,
expr: `identity.groups
.filter(g, g.startsWith('team:'))
.exists(g, request.scopes.exists(s, s == g))`,
},
"chained filter+map+filter+map exceeds small budget": {
budget: 1000,
expr: `identity.groups
.filter(g, g.startsWith('team:'))
.map(g, g.replace('team:', ''))
.filter(g, g.size() > 3)
.map(g, g.upperAscii())
.size() > 0`,
},
"many independent exists exceeds small budget": {
budget: 5000,
expr: `identity.groups.exists(g, g.contains('a')) &&
identity.groups.exists(g, g.contains('b')) &&
identity.groups.exists(g, g.contains('c')) &&
identity.groups.exists(g, g.contains('d')) &&
identity.groups.exists(g, g.contains('e'))`,
},
}
for name, tc := range tests {
t.Run(name, func(t *testing.T) {
compiler, err := dexcel.NewCompiler(vars, dexcel.WithCostBudget(tc.budget))
require.NoError(t, err)
_, err = compiler.Compile(tc.expr)
assert.Error(t, err)
assert.Contains(t, err.Error(), "estimated cost")
assert.Contains(t, err.Error(), "exceeds budget")
})
}
}

5
pkg/cel/doc.go

@ -0,0 +1,5 @@
// Package cel provides a safe, sandboxed CEL (Common Expression Language)
// environment for policy evaluation, claim mapping, and token customization
// in Dex. It includes cost budgets, Kubernetes-grade compatibility guarantees,
// and a curated set of extension libraries.
package cel

4
pkg/cel/library/doc.go

@ -0,0 +1,4 @@
// Package library provides custom CEL function libraries for Dex.
// Each library implements the cel.Library interface and can be registered
// in a CEL environment.
package library

73
pkg/cel/library/email.go

@ -0,0 +1,73 @@
package library
import (
"strings"
"github.com/google/cel-go/cel"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
)
// Email provides email-related CEL functions.
//
// Functions (V1):
//
// dex.emailDomain(email: string) -> string
// Returns the domain portion of an email address.
// Example: dex.emailDomain("user@example.com") == "example.com"
//
// dex.emailLocalPart(email: string) -> string
// Returns the local part of an email address.
// Example: dex.emailLocalPart("user@example.com") == "user"
type Email struct{}
func (Email) CompileOptions() []cel.EnvOption {
return []cel.EnvOption{
cel.Function("dex.emailDomain",
cel.Overload("dex_email_domain_string",
[]*cel.Type{cel.StringType},
cel.StringType,
cel.UnaryBinding(emailDomainImpl),
),
),
cel.Function("dex.emailLocalPart",
cel.Overload("dex_email_local_part_string",
[]*cel.Type{cel.StringType},
cel.StringType,
cel.UnaryBinding(emailLocalPartImpl),
),
),
}
}
func (Email) ProgramOptions() []cel.ProgramOption {
return nil
}
func emailDomainImpl(arg ref.Val) ref.Val {
email, ok := arg.Value().(string)
if !ok {
return types.NewErr("dex.emailDomain: expected string argument")
}
_, domain, found := strings.Cut(email, "@")
if !found {
return types.String("")
}
return types.String(domain)
}
func emailLocalPartImpl(arg ref.Val) ref.Val {
email, ok := arg.Value().(string)
if !ok {
return types.NewErr("dex.emailLocalPart: expected string argument")
}
localPart, _, found := strings.Cut(email, "@")
if !found {
return types.String(email)
}
return types.String(localPart)
}

108
pkg/cel/library/email_test.go

@ -0,0 +1,108 @@
package library_test
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
dexcel "github.com/dexidp/dex/pkg/cel"
)
func TestEmailDomain(t *testing.T) {
compiler, err := dexcel.NewCompiler(nil)
require.NoError(t, err)
tests := map[string]struct {
expr string
want string
}{
"standard email": {
expr: `dex.emailDomain("user@example.com")`,
want: "example.com",
},
"subdomain": {
expr: `dex.emailDomain("admin@sub.domain.org")`,
want: "sub.domain.org",
},
"no at sign": {
expr: `dex.emailDomain("nodomain")`,
want: "",
},
"empty string": {
expr: `dex.emailDomain("")`,
want: "",
},
"multiple at signs": {
expr: `dex.emailDomain("user@name@example.com")`,
want: "name@example.com",
},
}
for name, tc := range tests {
t.Run(name, func(t *testing.T) {
prog, err := compiler.CompileString(tc.expr)
require.NoError(t, err)
result, err := dexcel.EvalString(context.Background(), prog, map[string]any{})
require.NoError(t, err)
assert.Equal(t, tc.want, result)
})
}
}
func TestEmailLocalPart(t *testing.T) {
compiler, err := dexcel.NewCompiler(nil)
require.NoError(t, err)
tests := map[string]struct {
expr string
want string
}{
"standard email": {
expr: `dex.emailLocalPart("user@example.com")`,
want: "user",
},
"no at sign": {
expr: `dex.emailLocalPart("justuser")`,
want: "justuser",
},
"empty string": {
expr: `dex.emailLocalPart("")`,
want: "",
},
"multiple at signs": {
expr: `dex.emailLocalPart("user@name@example.com")`,
want: "user",
},
}
for name, tc := range tests {
t.Run(name, func(t *testing.T) {
prog, err := compiler.CompileString(tc.expr)
require.NoError(t, err)
result, err := dexcel.EvalString(context.Background(), prog, map[string]any{})
require.NoError(t, err)
assert.Equal(t, tc.want, result)
})
}
}
func TestEmailDomainWithIdentityVariable(t *testing.T) {
vars := dexcel.IdentityVariables()
compiler, err := dexcel.NewCompiler(vars)
require.NoError(t, err)
prog, err := compiler.CompileString(`dex.emailDomain(identity.email)`)
require.NoError(t, err)
result, err := dexcel.EvalString(context.Background(), prog, map[string]any{
"identity": map[string]any{
"email": "admin@corp.example.com",
},
})
require.NoError(t, err)
assert.Equal(t, "corp.example.com", result)
}

119
pkg/cel/library/groups.go

@ -0,0 +1,119 @@
package library
import (
"path/filepath"
"github.com/google/cel-go/cel"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
"github.com/google/cel-go/common/types/traits"
)
// Groups provides group-related CEL functions.
//
// Functions (V1):
//
// dex.groupMatches(groups: list(string), pattern: string) -> list(string)
// Returns groups matching a glob pattern.
// Example: dex.groupMatches(["team:dev", "team:ops", "admin"], "team:*")
//
// dex.groupFilter(groups: list(string), allowed: list(string)) -> list(string)
// Returns only groups present in the allowed list.
// Example: dex.groupFilter(["admin", "dev", "ops"], ["admin", "ops"])
type Groups struct{}
func (Groups) CompileOptions() []cel.EnvOption {
return []cel.EnvOption{
cel.Function("dex.groupMatches",
cel.Overload("dex_group_matches_list_string",
[]*cel.Type{cel.ListType(cel.StringType), cel.StringType},
cel.ListType(cel.StringType),
cel.BinaryBinding(groupMatchesImpl),
),
),
cel.Function("dex.groupFilter",
cel.Overload("dex_group_filter_list_list",
[]*cel.Type{cel.ListType(cel.StringType), cel.ListType(cel.StringType)},
cel.ListType(cel.StringType),
cel.BinaryBinding(groupFilterImpl),
),
),
}
}
func (Groups) ProgramOptions() []cel.ProgramOption {
return nil
}
func groupMatchesImpl(lhs, rhs ref.Val) ref.Val {
groupList, ok := lhs.(traits.Lister)
if !ok {
return types.NewErr("dex.groupMatches: expected list(string) as first argument")
}
pattern, ok := rhs.Value().(string)
if !ok {
return types.NewErr("dex.groupMatches: expected string pattern as second argument")
}
iter := groupList.Iterator()
var matched []ref.Val
for iter.HasNext() == types.True {
item := iter.Next()
group, ok := item.Value().(string)
if !ok {
continue
}
if ok, _ := filepath.Match(pattern, group); ok {
matched = append(matched, types.String(group))
}
}
return types.NewRefValList(types.DefaultTypeAdapter, matched)
}
func groupFilterImpl(lhs, rhs ref.Val) ref.Val {
groupList, ok := lhs.(traits.Lister)
if !ok {
return types.NewErr("dex.groupFilter: expected list(string) as first argument")
}
allowedList, ok := rhs.(traits.Lister)
if !ok {
return types.NewErr("dex.groupFilter: expected list(string) as second argument")
}
allowed := make(map[string]struct{})
iter := allowedList.Iterator()
for iter.HasNext() == types.True {
item := iter.Next()
s, ok := item.Value().(string)
if !ok {
continue
}
allowed[s] = struct{}{}
}
var filtered []ref.Val
iter = groupList.Iterator()
for iter.HasNext() == types.True {
item := iter.Next()
group, ok := item.Value().(string)
if !ok {
continue
}
if _, exists := allowed[group]; exists {
filtered = append(filtered, types.String(group))
}
}
return types.NewRefValList(types.DefaultTypeAdapter, filtered)
}

130
pkg/cel/library/groups_test.go

@ -0,0 +1,130 @@
package library_test
import (
"context"
"reflect"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
dexcel "github.com/dexidp/dex/pkg/cel"
)
func TestGroupMatches(t *testing.T) {
vars := dexcel.IdentityVariables()
compiler, err := dexcel.NewCompiler(vars)
require.NoError(t, err)
tests := map[string]struct {
expr string
groups []string
want []string
}{
"wildcard pattern": {
expr: `dex.groupMatches(identity.groups, "team:*")`,
groups: []string{"team:dev", "team:ops", "admin"},
want: []string{"team:dev", "team:ops"},
},
"exact match": {
expr: `dex.groupMatches(identity.groups, "admin")`,
groups: []string{"team:dev", "admin", "user"},
want: []string{"admin"},
},
"no matches": {
expr: `dex.groupMatches(identity.groups, "nonexistent")`,
groups: []string{"team:dev", "admin"},
want: []string{},
},
"question mark pattern": {
expr: `dex.groupMatches(identity.groups, "team?")`,
groups: []string{"teamA", "teamB", "teams-long"},
want: []string{"teamA", "teamB"},
},
"match all": {
expr: `dex.groupMatches(identity.groups, "*")`,
groups: []string{"a", "b", "c"},
want: []string{"a", "b", "c"},
},
}
for name, tc := range tests {
t.Run(name, func(t *testing.T) {
prog, err := compiler.CompileStringList(tc.expr)
require.NoError(t, err)
out, err := dexcel.Eval(context.Background(), prog, map[string]any{
"identity": map[string]any{
"groups": tc.groups,
},
})
require.NoError(t, err)
nativeVal, err := out.ConvertToNative(reflect.TypeOf([]string{}))
require.NoError(t, err)
got, ok := nativeVal.([]string)
require.True(t, ok, "expected []string, got %T", nativeVal)
assert.Equal(t, tc.want, got)
})
}
}
func TestGroupFilter(t *testing.T) {
vars := dexcel.IdentityVariables()
compiler, err := dexcel.NewCompiler(vars)
require.NoError(t, err)
tests := map[string]struct {
expr string
groups []string
want []string
}{
"filter to allowed": {
expr: `dex.groupFilter(identity.groups, ["admin", "ops"])`,
groups: []string{"admin", "dev", "ops"},
want: []string{"admin", "ops"},
},
"no overlap": {
expr: `dex.groupFilter(identity.groups, ["marketing"])`,
groups: []string{"admin", "dev"},
want: []string{},
},
"all allowed": {
expr: `dex.groupFilter(identity.groups, ["a", "b", "c"])`,
groups: []string{"a", "b", "c"},
want: []string{"a", "b", "c"},
},
"empty allowed list": {
expr: `dex.groupFilter(identity.groups, [])`,
groups: []string{"admin", "dev"},
want: []string{},
},
"preserves order": {
expr: `dex.groupFilter(identity.groups, ["z", "a"])`,
groups: []string{"a", "b", "z"},
want: []string{"a", "z"},
},
}
for name, tc := range tests {
t.Run(name, func(t *testing.T) {
prog, err := compiler.CompileStringList(tc.expr)
require.NoError(t, err)
out, err := dexcel.Eval(context.Background(), prog, map[string]any{
"identity": map[string]any{
"groups": tc.groups,
},
})
require.NoError(t, err)
nativeVal, err := out.ConvertToNative(reflect.TypeOf([]string{}))
require.NoError(t, err)
got, ok := nativeVal.([]string)
require.True(t, ok, "expected []string, got %T", nativeVal)
assert.Equal(t, tc.want, got)
})
}
}

80
pkg/cel/types.go

@ -0,0 +1,80 @@
package cel
import (
"github.com/google/cel-go/cel"
"github.com/dexidp/dex/connector"
)
// VariableDeclaration declares a named variable and its CEL type
// that will be available in expressions.
type VariableDeclaration struct {
Name string
Type *cel.Type
}
// IdentityVariables provides the 'identity' variable with user claims.
//
// identity.user_id — string
// identity.username — string
// identity.preferred_username — string
// identity.email — string
// identity.email_verified — bool
// identity.groups — list(string)
func IdentityVariables() []VariableDeclaration {
return []VariableDeclaration{
{Name: "identity", Type: cel.MapType(cel.StringType, cel.DynType)},
}
}
// RequestVariables provides the 'request' variable with request context.
//
// request.client_id — string
// request.connector_id — string
// request.scopes — list(string)
// request.redirect_uri — string
func RequestVariables() []VariableDeclaration {
return []VariableDeclaration{
{Name: "request", Type: cel.MapType(cel.StringType, cel.DynType)},
}
}
// ClaimsVariable provides a 'claims' map for raw upstream claims.
//
// claims — map(string, dyn)
func ClaimsVariable() []VariableDeclaration {
return []VariableDeclaration{
{Name: "claims", Type: cel.MapType(cel.StringType, cel.DynType)},
}
}
// IdentityFromConnector converts a connector.Identity to a CEL-compatible map.
func IdentityFromConnector(id connector.Identity) map[string]any {
return map[string]any{
"user_id": id.UserID,
"username": id.Username,
"preferred_username": id.PreferredUsername,
"email": id.Email,
"email_verified": id.EmailVerified,
"groups": id.Groups,
}
}
// RequestContext represents the authentication/token request context
// available as the 'request' variable in CEL expressions.
type RequestContext struct {
ClientID string
ConnectorID string
Scopes []string
RedirectURI string
}
// RequestFromContext converts a RequestContext to a CEL-compatible map.
func RequestFromContext(rc RequestContext) map[string]any {
return map[string]any{
"client_id": rc.ClientID,
"connector_id": rc.ConnectorID,
"scopes": rc.Scopes,
"redirect_uri": rc.RedirectURI,
}
}

3
pkg/featureflags/doc.go

@ -0,0 +1,3 @@
// Package featureflags provides a mechanism for toggling experimental or
// optional Dex features via environment variables (DEX_<FLAG_NAME>).
package featureflags

2
pkg/groups/doc.go

@ -0,0 +1,2 @@
// Package groups contains helper functions related to groups.
package groups

1
pkg/groups/groups.go

@ -1,4 +1,3 @@
// Package groups contains helper functions related to groups
package groups
// Filter filters out any groups of given that are not in required. Thus it may

3
pkg/httpclient/doc.go

@ -0,0 +1,3 @@
// Package httpclient provides a configurable HTTP client constructor with
// support for custom CA certificates, root CAs, and TLS settings.
package httpclient
Loading…
Cancel
Save