mirror of https://github.com/dexidp/dex.git
17 changed files with 1267 additions and 4 deletions
@ -0,0 +1,223 @@ |
|||||||
|
package cel |
||||||
|
|
||||||
|
import ( |
||||||
|
"context" |
||||||
|
"fmt" |
||||||
|
|
||||||
|
"github.com/google/cel-go/cel" |
||||||
|
"github.com/google/cel-go/checker" |
||||||
|
"github.com/google/cel-go/common/types/ref" |
||||||
|
"github.com/google/cel-go/ext" |
||||||
|
|
||||||
|
"github.com/dexidp/dex/pkg/cel/library" |
||||||
|
) |
||||||
|
|
||||||
|
// EnvironmentVersion represents the version of the CEL environment.
|
||||||
|
// New variables, functions, or libraries are introduced in new versions.
|
||||||
|
type EnvironmentVersion uint32 |
||||||
|
|
||||||
|
const ( |
||||||
|
// EnvironmentV1 is the initial CEL environment.
|
||||||
|
EnvironmentV1 EnvironmentVersion = 1 |
||||||
|
) |
||||||
|
|
||||||
|
// CompilationResult holds a compiled CEL program ready for evaluation.
|
||||||
|
type CompilationResult struct { |
||||||
|
Program cel.Program |
||||||
|
OutputType *cel.Type |
||||||
|
Expression string |
||||||
|
|
||||||
|
ast *cel.Ast |
||||||
|
} |
||||||
|
|
||||||
|
// CompilerOption configures a Compiler.
|
||||||
|
type CompilerOption func(*compilerConfig) |
||||||
|
|
||||||
|
type compilerConfig struct { |
||||||
|
costBudget uint64 |
||||||
|
version EnvironmentVersion |
||||||
|
} |
||||||
|
|
||||||
|
func defaultCompilerConfig() *compilerConfig { |
||||||
|
return &compilerConfig{ |
||||||
|
costBudget: DefaultCostBudget, |
||||||
|
version: EnvironmentV1, |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
// WithCostBudget sets a custom cost budget for expression evaluation.
|
||||||
|
func WithCostBudget(budget uint64) CompilerOption { |
||||||
|
return func(cfg *compilerConfig) { |
||||||
|
cfg.costBudget = budget |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
// WithVersion sets the target environment version for the compiler.
|
||||||
|
// Defaults to the latest version. Specifying an older version ensures
|
||||||
|
// that only functions/types available at that version are used.
|
||||||
|
func WithVersion(v EnvironmentVersion) CompilerOption { |
||||||
|
return func(cfg *compilerConfig) { |
||||||
|
cfg.version = v |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
// Compiler compiles CEL expressions against a specific environment.
|
||||||
|
type Compiler struct { |
||||||
|
env *cel.Env |
||||||
|
cfg *compilerConfig |
||||||
|
} |
||||||
|
|
||||||
|
// NewCompiler creates a new CEL compiler with the specified variable
|
||||||
|
// declarations and options.
|
||||||
|
//
|
||||||
|
// All custom Dex libraries are automatically included.
|
||||||
|
// The environment is configured with cost limits and safe defaults.
|
||||||
|
func NewCompiler(variables []VariableDeclaration, opts ...CompilerOption) (*Compiler, error) { |
||||||
|
cfg := defaultCompilerConfig() |
||||||
|
for _, opt := range opts { |
||||||
|
opt(cfg) |
||||||
|
} |
||||||
|
|
||||||
|
envOpts := make([]cel.EnvOption, 0, 8+len(variables)) |
||||||
|
envOpts = append(envOpts, |
||||||
|
cel.DefaultUTCTimeZone(true), |
||||||
|
|
||||||
|
// Standard extension libraries (same set as Kubernetes)
|
||||||
|
ext.Strings(), |
||||||
|
ext.Encoders(), |
||||||
|
ext.Lists(), |
||||||
|
ext.Sets(), |
||||||
|
ext.Math(), |
||||||
|
|
||||||
|
// Custom Dex libraries
|
||||||
|
cel.Lib(&library.Email{}), |
||||||
|
cel.Lib(&library.Groups{}), |
||||||
|
|
||||||
|
// Presence tests like has(field) and 'key' in map are O(1) hash
|
||||||
|
// lookups on map(string, dyn) variables, so they should not count
|
||||||
|
// toward the cost budget. Without this, expressions with multiple
|
||||||
|
// 'in' checks (e.g. "'admin' in identity.groups") would accumulate
|
||||||
|
// inflated cost estimates. This matches Kubernetes CEL behavior
|
||||||
|
// where presence tests are free for CRD validation rules.
|
||||||
|
cel.CostEstimatorOptions( |
||||||
|
checker.PresenceTestHasCost(false), |
||||||
|
), |
||||||
|
) |
||||||
|
|
||||||
|
for _, v := range variables { |
||||||
|
envOpts = append(envOpts, cel.Variable(v.Name, v.Type)) |
||||||
|
} |
||||||
|
|
||||||
|
env, err := cel.NewEnv(envOpts...) |
||||||
|
if err != nil { |
||||||
|
return nil, fmt.Errorf("failed to create CEL environment: %w", err) |
||||||
|
} |
||||||
|
|
||||||
|
return &Compiler{env: env, cfg: cfg}, nil |
||||||
|
} |
||||||
|
|
||||||
|
// CompileBool compiles a CEL expression that must evaluate to bool.
|
||||||
|
func (c *Compiler) CompileBool(expression string) (*CompilationResult, error) { |
||||||
|
return c.compile(expression, cel.BoolType) |
||||||
|
} |
||||||
|
|
||||||
|
// CompileString compiles a CEL expression that must evaluate to string.
|
||||||
|
func (c *Compiler) CompileString(expression string) (*CompilationResult, error) { |
||||||
|
return c.compile(expression, cel.StringType) |
||||||
|
} |
||||||
|
|
||||||
|
// CompileStringList compiles a CEL expression that must evaluate to list(string).
|
||||||
|
func (c *Compiler) CompileStringList(expression string) (*CompilationResult, error) { |
||||||
|
return c.compile(expression, cel.ListType(cel.StringType)) |
||||||
|
} |
||||||
|
|
||||||
|
// Compile compiles a CEL expression with any output type.
|
||||||
|
func (c *Compiler) Compile(expression string) (*CompilationResult, error) { |
||||||
|
return c.compile(expression, nil) |
||||||
|
} |
||||||
|
|
||||||
|
func (c *Compiler) compile(expression string, expectedType *cel.Type) (*CompilationResult, error) { |
||||||
|
if len(expression) > MaxExpressionLength { |
||||||
|
return nil, fmt.Errorf("expression exceeds maximum length of %d characters", MaxExpressionLength) |
||||||
|
} |
||||||
|
|
||||||
|
ast, issues := c.env.Compile(expression) |
||||||
|
if issues != nil && issues.Err() != nil { |
||||||
|
return nil, fmt.Errorf("CEL compilation failed: %w", issues.Err()) |
||||||
|
} |
||||||
|
|
||||||
|
if expectedType != nil && !ast.OutputType().IsEquivalentType(expectedType) { |
||||||
|
return nil, fmt.Errorf( |
||||||
|
"expected expression output type %s, got %s", |
||||||
|
expectedType, ast.OutputType(), |
||||||
|
) |
||||||
|
} |
||||||
|
|
||||||
|
// Estimate cost at compile time and reject expressions that are too expensive.
|
||||||
|
costEst, err := c.env.EstimateCost(ast, &defaultCostEstimator{}) |
||||||
|
if err != nil { |
||||||
|
return nil, fmt.Errorf("CEL cost estimation failed: %w", err) |
||||||
|
} |
||||||
|
|
||||||
|
if costEst.Max > c.cfg.costBudget { |
||||||
|
return nil, fmt.Errorf( |
||||||
|
"CEL expression estimated cost %d exceeds budget %d", |
||||||
|
costEst.Max, c.cfg.costBudget, |
||||||
|
) |
||||||
|
} |
||||||
|
|
||||||
|
prog, err := c.env.Program(ast, |
||||||
|
cel.EvalOptions(cel.OptOptimize), |
||||||
|
cel.CostLimit(c.cfg.costBudget), |
||||||
|
) |
||||||
|
if err != nil { |
||||||
|
return nil, fmt.Errorf("CEL program creation failed: %w", err) |
||||||
|
} |
||||||
|
|
||||||
|
return &CompilationResult{ |
||||||
|
Program: prog, |
||||||
|
OutputType: ast.OutputType(), |
||||||
|
Expression: expression, |
||||||
|
ast: ast, |
||||||
|
}, nil |
||||||
|
} |
||||||
|
|
||||||
|
// Eval evaluates a compiled program against the given variables.
|
||||||
|
func Eval(ctx context.Context, result *CompilationResult, variables map[string]any) (ref.Val, error) { |
||||||
|
out, _, err := result.Program.ContextEval(ctx, variables) |
||||||
|
if err != nil { |
||||||
|
return nil, fmt.Errorf("CEL evaluation failed: %w", err) |
||||||
|
} |
||||||
|
|
||||||
|
return out, nil |
||||||
|
} |
||||||
|
|
||||||
|
// EvalBool is a convenience function that evaluates and asserts bool output.
|
||||||
|
func EvalBool(ctx context.Context, result *CompilationResult, variables map[string]any) (bool, error) { |
||||||
|
out, err := Eval(ctx, result, variables) |
||||||
|
if err != nil { |
||||||
|
return false, err |
||||||
|
} |
||||||
|
|
||||||
|
v, ok := out.Value().(bool) |
||||||
|
if !ok { |
||||||
|
return false, fmt.Errorf("expected bool result, got %T", out.Value()) |
||||||
|
} |
||||||
|
|
||||||
|
return v, nil |
||||||
|
} |
||||||
|
|
||||||
|
// EvalString is a convenience function that evaluates and asserts string output.
|
||||||
|
func EvalString(ctx context.Context, result *CompilationResult, variables map[string]any) (string, error) { |
||||||
|
out, err := Eval(ctx, result, variables) |
||||||
|
if err != nil { |
||||||
|
return "", err |
||||||
|
} |
||||||
|
|
||||||
|
v, ok := out.Value().(string) |
||||||
|
if !ok { |
||||||
|
return "", fmt.Errorf("expected string result, got %T", out.Value()) |
||||||
|
} |
||||||
|
|
||||||
|
return v, nil |
||||||
|
} |
||||||
@ -0,0 +1,270 @@ |
|||||||
|
package cel_test |
||||||
|
|
||||||
|
import ( |
||||||
|
"context" |
||||||
|
"strings" |
||||||
|
"testing" |
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert" |
||||||
|
"github.com/stretchr/testify/require" |
||||||
|
|
||||||
|
"github.com/dexidp/dex/connector" |
||||||
|
dexcel "github.com/dexidp/dex/pkg/cel" |
||||||
|
) |
||||||
|
|
||||||
|
func TestCompileBool(t *testing.T) { |
||||||
|
compiler, err := dexcel.NewCompiler(nil) |
||||||
|
require.NoError(t, err) |
||||||
|
|
||||||
|
tests := map[string]struct { |
||||||
|
expr string |
||||||
|
wantErr bool |
||||||
|
}{ |
||||||
|
"true literal": { |
||||||
|
expr: "true", |
||||||
|
}, |
||||||
|
"comparison": { |
||||||
|
expr: "1 == 1", |
||||||
|
}, |
||||||
|
"string type mismatch": { |
||||||
|
expr: "'hello'", |
||||||
|
wantErr: true, |
||||||
|
}, |
||||||
|
"int type mismatch": { |
||||||
|
expr: "42", |
||||||
|
wantErr: true, |
||||||
|
}, |
||||||
|
} |
||||||
|
|
||||||
|
for name, tc := range tests { |
||||||
|
t.Run(name, func(t *testing.T) { |
||||||
|
result, err := compiler.CompileBool(tc.expr) |
||||||
|
if tc.wantErr { |
||||||
|
assert.Error(t, err) |
||||||
|
assert.Nil(t, result) |
||||||
|
} else { |
||||||
|
assert.NoError(t, err) |
||||||
|
assert.NotNil(t, result) |
||||||
|
} |
||||||
|
}) |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func TestCompileString(t *testing.T) { |
||||||
|
compiler, err := dexcel.NewCompiler(nil) |
||||||
|
require.NoError(t, err) |
||||||
|
|
||||||
|
tests := map[string]struct { |
||||||
|
expr string |
||||||
|
wantErr bool |
||||||
|
}{ |
||||||
|
"string literal": { |
||||||
|
expr: "'hello'", |
||||||
|
}, |
||||||
|
"string concatenation": { |
||||||
|
expr: "'hello' + ' ' + 'world'", |
||||||
|
}, |
||||||
|
"bool type mismatch": { |
||||||
|
expr: "true", |
||||||
|
wantErr: true, |
||||||
|
}, |
||||||
|
} |
||||||
|
|
||||||
|
for name, tc := range tests { |
||||||
|
t.Run(name, func(t *testing.T) { |
||||||
|
result, err := compiler.CompileString(tc.expr) |
||||||
|
if tc.wantErr { |
||||||
|
assert.Error(t, err) |
||||||
|
} else { |
||||||
|
assert.NoError(t, err) |
||||||
|
assert.NotNil(t, result) |
||||||
|
} |
||||||
|
}) |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func TestCompileStringList(t *testing.T) { |
||||||
|
compiler, err := dexcel.NewCompiler(nil) |
||||||
|
require.NoError(t, err) |
||||||
|
|
||||||
|
result, err := compiler.CompileStringList("['a', 'b', 'c']") |
||||||
|
assert.NoError(t, err) |
||||||
|
assert.NotNil(t, result) |
||||||
|
|
||||||
|
_, err = compiler.CompileStringList("'not a list'") |
||||||
|
assert.Error(t, err) |
||||||
|
} |
||||||
|
|
||||||
|
func TestCompile(t *testing.T) { |
||||||
|
compiler, err := dexcel.NewCompiler(nil) |
||||||
|
require.NoError(t, err) |
||||||
|
|
||||||
|
// Compile accepts any type
|
||||||
|
result, err := compiler.Compile("true") |
||||||
|
assert.NoError(t, err) |
||||||
|
assert.NotNil(t, result) |
||||||
|
|
||||||
|
result, err = compiler.Compile("'hello'") |
||||||
|
assert.NoError(t, err) |
||||||
|
assert.NotNil(t, result) |
||||||
|
|
||||||
|
result, err = compiler.Compile("42") |
||||||
|
assert.NoError(t, err) |
||||||
|
assert.NotNil(t, result) |
||||||
|
} |
||||||
|
|
||||||
|
func TestCompileErrors(t *testing.T) { |
||||||
|
compiler, err := dexcel.NewCompiler(nil) |
||||||
|
require.NoError(t, err) |
||||||
|
|
||||||
|
tests := map[string]struct { |
||||||
|
expr string |
||||||
|
}{ |
||||||
|
"syntax error": { |
||||||
|
expr: "1 +", |
||||||
|
}, |
||||||
|
"undefined variable": { |
||||||
|
expr: "undefined_var", |
||||||
|
}, |
||||||
|
"undefined function": { |
||||||
|
expr: "undefinedFunc()", |
||||||
|
}, |
||||||
|
} |
||||||
|
|
||||||
|
for name, tc := range tests { |
||||||
|
t.Run(name, func(t *testing.T) { |
||||||
|
_, err := compiler.Compile(tc.expr) |
||||||
|
assert.Error(t, err) |
||||||
|
}) |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func TestMaxExpressionLength(t *testing.T) { |
||||||
|
compiler, err := dexcel.NewCompiler(nil) |
||||||
|
require.NoError(t, err) |
||||||
|
|
||||||
|
longExpr := "'" + strings.Repeat("a", dexcel.MaxExpressionLength) + "'" |
||||||
|
_, err = compiler.Compile(longExpr) |
||||||
|
assert.Error(t, err) |
||||||
|
assert.Contains(t, err.Error(), "maximum length") |
||||||
|
} |
||||||
|
|
||||||
|
func TestEvalBool(t *testing.T) { |
||||||
|
vars := dexcel.IdentityVariables() |
||||||
|
compiler, err := dexcel.NewCompiler(vars) |
||||||
|
require.NoError(t, err) |
||||||
|
|
||||||
|
tests := map[string]struct { |
||||||
|
expr string |
||||||
|
identity map[string]any |
||||||
|
want bool |
||||||
|
}{ |
||||||
|
"email endsWith": { |
||||||
|
expr: "identity.email.endsWith('@example.com')", |
||||||
|
identity: map[string]any{ |
||||||
|
"email": "user@example.com", |
||||||
|
}, |
||||||
|
want: true, |
||||||
|
}, |
||||||
|
"email endsWith false": { |
||||||
|
expr: "identity.email.endsWith('@example.com')", |
||||||
|
identity: map[string]any{ |
||||||
|
"email": "user@other.com", |
||||||
|
}, |
||||||
|
want: false, |
||||||
|
}, |
||||||
|
"email_verified": { |
||||||
|
expr: "identity.email_verified == true", |
||||||
|
identity: map[string]any{ |
||||||
|
"email_verified": true, |
||||||
|
}, |
||||||
|
want: true, |
||||||
|
}, |
||||||
|
"group membership": { |
||||||
|
expr: "identity.groups.exists(g, g == 'admin')", |
||||||
|
identity: map[string]any{ |
||||||
|
"groups": []string{"admin", "dev"}, |
||||||
|
}, |
||||||
|
want: true, |
||||||
|
}, |
||||||
|
} |
||||||
|
|
||||||
|
for name, tc := range tests { |
||||||
|
t.Run(name, func(t *testing.T) { |
||||||
|
prog, err := compiler.CompileBool(tc.expr) |
||||||
|
require.NoError(t, err) |
||||||
|
|
||||||
|
result, err := dexcel.EvalBool(context.Background(), prog, map[string]any{ |
||||||
|
"identity": tc.identity, |
||||||
|
}) |
||||||
|
require.NoError(t, err) |
||||||
|
assert.Equal(t, tc.want, result) |
||||||
|
}) |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func TestEvalString(t *testing.T) { |
||||||
|
vars := dexcel.IdentityVariables() |
||||||
|
compiler, err := dexcel.NewCompiler(vars) |
||||||
|
require.NoError(t, err) |
||||||
|
|
||||||
|
// identity.email returns dyn from map access, use Compile (not CompileString)
|
||||||
|
prog, err := compiler.Compile("identity.email") |
||||||
|
require.NoError(t, err) |
||||||
|
|
||||||
|
result, err := dexcel.EvalString(context.Background(), prog, map[string]any{ |
||||||
|
"identity": map[string]any{ |
||||||
|
"email": "user@example.com", |
||||||
|
}, |
||||||
|
}) |
||||||
|
require.NoError(t, err) |
||||||
|
assert.Equal(t, "user@example.com", result) |
||||||
|
} |
||||||
|
|
||||||
|
func TestEvalWithIdentityAndRequest(t *testing.T) { |
||||||
|
vars := append(dexcel.IdentityVariables(), dexcel.RequestVariables()...) |
||||||
|
compiler, err := dexcel.NewCompiler(vars) |
||||||
|
require.NoError(t, err) |
||||||
|
|
||||||
|
prog, err := compiler.CompileBool( |
||||||
|
`identity.email.endsWith('@example.com') && 'admin' in identity.groups && request.connector_id == 'okta'`, |
||||||
|
) |
||||||
|
require.NoError(t, err) |
||||||
|
|
||||||
|
identity := dexcel.IdentityFromConnector(connector.Identity{ |
||||||
|
UserID: "123", |
||||||
|
Username: "john", |
||||||
|
Email: "john@example.com", |
||||||
|
Groups: []string{"admin", "dev"}, |
||||||
|
}) |
||||||
|
request := dexcel.RequestFromContext(dexcel.RequestContext{ |
||||||
|
ClientID: "my-app", |
||||||
|
ConnectorID: "okta", |
||||||
|
Scopes: []string{"openid", "email"}, |
||||||
|
}) |
||||||
|
|
||||||
|
result, err := dexcel.EvalBool(context.Background(), prog, map[string]any{ |
||||||
|
"identity": identity, |
||||||
|
"request": request, |
||||||
|
}) |
||||||
|
require.NoError(t, err) |
||||||
|
assert.True(t, result) |
||||||
|
} |
||||||
|
|
||||||
|
func TestNewCompilerWithVariables(t *testing.T) { |
||||||
|
// Claims variable
|
||||||
|
compiler, err := dexcel.NewCompiler(dexcel.ClaimsVariable()) |
||||||
|
require.NoError(t, err) |
||||||
|
|
||||||
|
// claims.email returns dyn from map access, use Compile (not CompileString)
|
||||||
|
prog, err := compiler.Compile("claims.email") |
||||||
|
require.NoError(t, err) |
||||||
|
|
||||||
|
result, err := dexcel.EvalString(context.Background(), prog, map[string]any{ |
||||||
|
"claims": map[string]any{ |
||||||
|
"email": "test@example.com", |
||||||
|
}, |
||||||
|
}) |
||||||
|
require.NoError(t, err) |
||||||
|
assert.Equal(t, "test@example.com", result) |
||||||
|
} |
||||||
@ -0,0 +1,98 @@ |
|||||||
|
package cel |
||||||
|
|
||||||
|
import ( |
||||||
|
"github.com/google/cel-go/checker" |
||||||
|
) |
||||||
|
|
||||||
|
// DefaultCostBudget is the default cost budget for a single expression
|
||||||
|
// evaluation. Aligned with Kubernetes defaults: enough for typical identity
|
||||||
|
// operations but prevents runaway expressions.
|
||||||
|
const DefaultCostBudget uint64 = 10_000_000 |
||||||
|
|
||||||
|
// MaxExpressionLength is the maximum length of a CEL expression string.
|
||||||
|
const MaxExpressionLength = 10_240 |
||||||
|
|
||||||
|
// DefaultStringMaxLength is the estimated max length of string values
|
||||||
|
// (emails, usernames, group names, etc.) used for compile-time cost estimation.
|
||||||
|
const DefaultStringMaxLength = 256 |
||||||
|
|
||||||
|
// DefaultListMaxLength is the estimated max length of list values
|
||||||
|
// (groups, scopes) used for compile-time cost estimation.
|
||||||
|
const DefaultListMaxLength = 100 |
||||||
|
|
||||||
|
// CostEstimate holds the estimated cost range for a compiled expression.
|
||||||
|
type CostEstimate struct { |
||||||
|
Min uint64 |
||||||
|
Max uint64 |
||||||
|
} |
||||||
|
|
||||||
|
// EstimateCost returns the estimated cost range for a compiled expression.
|
||||||
|
// This is computed statically at compile time without evaluating the expression.
|
||||||
|
func (c *Compiler) EstimateCost(result *CompilationResult) CostEstimate { |
||||||
|
costEst, err := c.env.EstimateCost(result.ast, &defaultCostEstimator{}) |
||||||
|
if err != nil { |
||||||
|
return CostEstimate{} |
||||||
|
} |
||||||
|
|
||||||
|
return CostEstimate{Min: costEst.Min, Max: costEst.Max} |
||||||
|
} |
||||||
|
|
||||||
|
// defaultCostEstimator provides size hints for compile-time cost estimation.
|
||||||
|
// Without these hints, the CEL cost estimator assumes unbounded sizes for
|
||||||
|
// variables, leading to wildly overestimated max costs.
|
||||||
|
type defaultCostEstimator struct{} |
||||||
|
|
||||||
|
func (defaultCostEstimator) EstimateSize(element checker.AstNode) *checker.SizeEstimate { |
||||||
|
// Provide size hints for map(string, dyn) variables: identity, request, claims.
|
||||||
|
// Without these, the estimator assumes lists/strings can be infinitely large.
|
||||||
|
if element.Path() == nil { |
||||||
|
return nil |
||||||
|
} |
||||||
|
|
||||||
|
path := element.Path() |
||||||
|
if len(path) == 0 { |
||||||
|
return nil |
||||||
|
} |
||||||
|
|
||||||
|
root := path[0] |
||||||
|
|
||||||
|
switch root { |
||||||
|
case "identity", "request", "claims": |
||||||
|
// Nested field access (e.g. identity.email, identity.groups)
|
||||||
|
if len(path) >= 2 { |
||||||
|
field := path[1] |
||||||
|
switch field { |
||||||
|
case "groups", "scopes": |
||||||
|
return &checker.SizeEstimate{Min: 0, Max: DefaultListMaxLength} |
||||||
|
default: |
||||||
|
return &checker.SizeEstimate{Min: 0, Max: DefaultStringMaxLength} |
||||||
|
} |
||||||
|
} |
||||||
|
// The map itself: number of keys
|
||||||
|
return &checker.SizeEstimate{Min: 0, Max: 20} |
||||||
|
} |
||||||
|
|
||||||
|
return nil |
||||||
|
} |
||||||
|
|
||||||
|
func (defaultCostEstimator) EstimateCallCost(function, overloadID string, target *checker.AstNode, args []checker.AstNode) *checker.CallEstimate { |
||||||
|
switch function { |
||||||
|
case "dex.emailDomain", "dex.emailLocalPart": |
||||||
|
// Simple string split — O(n) where n is string length, bounded.
|
||||||
|
return &checker.CallEstimate{ |
||||||
|
CostEstimate: checker.CostEstimate{Min: 1, Max: 2}, |
||||||
|
} |
||||||
|
case "dex.groupMatches": |
||||||
|
// Iterates over groups list and matches each against a pattern.
|
||||||
|
return &checker.CallEstimate{ |
||||||
|
CostEstimate: checker.CostEstimate{Min: 1, Max: DefaultListMaxLength}, |
||||||
|
} |
||||||
|
case "dex.groupFilter": |
||||||
|
// Builds a set from allowed list, then iterates groups.
|
||||||
|
return &checker.CallEstimate{ |
||||||
|
CostEstimate: checker.CostEstimate{Min: 1, Max: 2 * DefaultListMaxLength}, |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
return nil |
||||||
|
} |
||||||
@ -0,0 +1,136 @@ |
|||||||
|
package cel_test |
||||||
|
|
||||||
|
import ( |
||||||
|
"testing" |
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert" |
||||||
|
"github.com/stretchr/testify/require" |
||||||
|
|
||||||
|
dexcel "github.com/dexidp/dex/pkg/cel" |
||||||
|
) |
||||||
|
|
||||||
|
func TestEstimateCost(t *testing.T) { |
||||||
|
vars := dexcel.IdentityVariables() |
||||||
|
compiler, err := dexcel.NewCompiler(vars) |
||||||
|
require.NoError(t, err) |
||||||
|
|
||||||
|
tests := map[string]struct { |
||||||
|
expr string |
||||||
|
}{ |
||||||
|
"simple bool": { |
||||||
|
expr: "true", |
||||||
|
}, |
||||||
|
"string comparison": { |
||||||
|
expr: "identity.email == 'test@example.com'", |
||||||
|
}, |
||||||
|
"group membership": { |
||||||
|
expr: "identity.groups.exists(g, g == 'admin')", |
||||||
|
}, |
||||||
|
} |
||||||
|
|
||||||
|
for name, tc := range tests { |
||||||
|
t.Run(name, func(t *testing.T) { |
||||||
|
prog, err := compiler.Compile(tc.expr) |
||||||
|
require.NoError(t, err) |
||||||
|
|
||||||
|
est := compiler.EstimateCost(prog) |
||||||
|
assert.True(t, est.Max >= est.Min, "max cost should be >= min cost") |
||||||
|
assert.True(t, est.Max <= dexcel.DefaultCostBudget, |
||||||
|
"estimated max cost %d should be within default budget %d", est.Max, dexcel.DefaultCostBudget) |
||||||
|
}) |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func TestCompileTimeCostAcceptsSimpleExpressions(t *testing.T) { |
||||||
|
vars := append(dexcel.IdentityVariables(), dexcel.RequestVariables()...) |
||||||
|
compiler, err := dexcel.NewCompiler(vars) |
||||||
|
require.NoError(t, err) |
||||||
|
|
||||||
|
tests := map[string]string{ |
||||||
|
"literal": "true", |
||||||
|
"email endsWith": "identity.email.endsWith('@example.com')", |
||||||
|
"group check": "'admin' in identity.groups", |
||||||
|
"emailDomain": `dex.emailDomain(identity.email)`, |
||||||
|
"groupMatches": `dex.groupMatches(identity.groups, "team:*")`, |
||||||
|
"groupFilter": `dex.groupFilter(identity.groups, ["admin", "dev"])`, |
||||||
|
"combined policy": `identity.email.endsWith('@example.com') && 'admin' in identity.groups`, |
||||||
|
"complex policy": `identity.email.endsWith('@example.com') && |
||||||
|
identity.groups.exists(g, g == 'admin') && |
||||||
|
request.connector_id == 'okta' && |
||||||
|
request.scopes.exists(s, s == 'openid')`, |
||||||
|
"filter+map chain": `identity.groups |
||||||
|
.filter(g, g.startsWith('team:')) |
||||||
|
.map(g, g.replace('team:', '')) |
||||||
|
.size() > 0`, |
||||||
|
} |
||||||
|
|
||||||
|
for name, expr := range tests { |
||||||
|
t.Run(name, func(t *testing.T) { |
||||||
|
_, err := compiler.Compile(expr) |
||||||
|
assert.NoError(t, err, "expression should compile within default budget") |
||||||
|
}) |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func TestCompileTimeCostRejection(t *testing.T) { |
||||||
|
vars := append(dexcel.IdentityVariables(), dexcel.RequestVariables()...) |
||||||
|
|
||||||
|
tests := map[string]struct { |
||||||
|
budget uint64 |
||||||
|
expr string |
||||||
|
}{ |
||||||
|
"simple exists exceeds tiny budget": { |
||||||
|
budget: 1, |
||||||
|
expr: "identity.groups.exists(g, g == 'admin')", |
||||||
|
}, |
||||||
|
"endsWith exceeds tiny budget": { |
||||||
|
budget: 2, |
||||||
|
expr: "identity.email.endsWith('@example.com')", |
||||||
|
}, |
||||||
|
"nested comprehension over groups exceeds moderate budget": { |
||||||
|
// Two nested iterations over groups: O(n^2) where n=100 → ~280K
|
||||||
|
budget: 10_000, |
||||||
|
expr: `identity.groups.exists(g1, |
||||||
|
identity.groups.exists(g2, |
||||||
|
g1 != g2 && g1.startsWith(g2) |
||||||
|
) |
||||||
|
)`, |
||||||
|
}, |
||||||
|
"cross-variable comprehension exceeds moderate budget": { |
||||||
|
// filter groups then check each against scopes: O(n*m) → ~162K
|
||||||
|
budget: 10_000, |
||||||
|
expr: `identity.groups |
||||||
|
.filter(g, g.startsWith('team:')) |
||||||
|
.exists(g, request.scopes.exists(s, s == g))`, |
||||||
|
}, |
||||||
|
"chained filter+map+filter+map exceeds small budget": { |
||||||
|
budget: 1000, |
||||||
|
expr: `identity.groups |
||||||
|
.filter(g, g.startsWith('team:')) |
||||||
|
.map(g, g.replace('team:', '')) |
||||||
|
.filter(g, g.size() > 3) |
||||||
|
.map(g, g.upperAscii()) |
||||||
|
.size() > 0`, |
||||||
|
}, |
||||||
|
"many independent exists exceeds small budget": { |
||||||
|
budget: 5000, |
||||||
|
expr: `identity.groups.exists(g, g.contains('a')) && |
||||||
|
identity.groups.exists(g, g.contains('b')) && |
||||||
|
identity.groups.exists(g, g.contains('c')) && |
||||||
|
identity.groups.exists(g, g.contains('d')) && |
||||||
|
identity.groups.exists(g, g.contains('e'))`, |
||||||
|
}, |
||||||
|
} |
||||||
|
|
||||||
|
for name, tc := range tests { |
||||||
|
t.Run(name, func(t *testing.T) { |
||||||
|
compiler, err := dexcel.NewCompiler(vars, dexcel.WithCostBudget(tc.budget)) |
||||||
|
require.NoError(t, err) |
||||||
|
|
||||||
|
_, err = compiler.Compile(tc.expr) |
||||||
|
assert.Error(t, err) |
||||||
|
assert.Contains(t, err.Error(), "estimated cost") |
||||||
|
assert.Contains(t, err.Error(), "exceeds budget") |
||||||
|
}) |
||||||
|
} |
||||||
|
} |
||||||
@ -0,0 +1,5 @@ |
|||||||
|
// Package cel provides a safe, sandboxed CEL (Common Expression Language)
|
||||||
|
// environment for policy evaluation, claim mapping, and token customization
|
||||||
|
// in Dex. It includes cost budgets, Kubernetes-grade compatibility guarantees,
|
||||||
|
// and a curated set of extension libraries.
|
||||||
|
package cel |
||||||
@ -0,0 +1,4 @@ |
|||||||
|
// Package library provides custom CEL function libraries for Dex.
|
||||||
|
// Each library implements the cel.Library interface and can be registered
|
||||||
|
// in a CEL environment.
|
||||||
|
package library |
||||||
@ -0,0 +1,73 @@ |
|||||||
|
package library |
||||||
|
|
||||||
|
import ( |
||||||
|
"strings" |
||||||
|
|
||||||
|
"github.com/google/cel-go/cel" |
||||||
|
"github.com/google/cel-go/common/types" |
||||||
|
"github.com/google/cel-go/common/types/ref" |
||||||
|
) |
||||||
|
|
||||||
|
// Email provides email-related CEL functions.
|
||||||
|
//
|
||||||
|
// Functions (V1):
|
||||||
|
//
|
||||||
|
// dex.emailDomain(email: string) -> string
|
||||||
|
// Returns the domain portion of an email address.
|
||||||
|
// Example: dex.emailDomain("user@example.com") == "example.com"
|
||||||
|
//
|
||||||
|
// dex.emailLocalPart(email: string) -> string
|
||||||
|
// Returns the local part of an email address.
|
||||||
|
// Example: dex.emailLocalPart("user@example.com") == "user"
|
||||||
|
type Email struct{} |
||||||
|
|
||||||
|
func (Email) CompileOptions() []cel.EnvOption { |
||||||
|
return []cel.EnvOption{ |
||||||
|
cel.Function("dex.emailDomain", |
||||||
|
cel.Overload("dex_email_domain_string", |
||||||
|
[]*cel.Type{cel.StringType}, |
||||||
|
cel.StringType, |
||||||
|
cel.UnaryBinding(emailDomainImpl), |
||||||
|
), |
||||||
|
), |
||||||
|
cel.Function("dex.emailLocalPart", |
||||||
|
cel.Overload("dex_email_local_part_string", |
||||||
|
[]*cel.Type{cel.StringType}, |
||||||
|
cel.StringType, |
||||||
|
cel.UnaryBinding(emailLocalPartImpl), |
||||||
|
), |
||||||
|
), |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func (Email) ProgramOptions() []cel.ProgramOption { |
||||||
|
return nil |
||||||
|
} |
||||||
|
|
||||||
|
func emailDomainImpl(arg ref.Val) ref.Val { |
||||||
|
email, ok := arg.Value().(string) |
||||||
|
if !ok { |
||||||
|
return types.NewErr("dex.emailDomain: expected string argument") |
||||||
|
} |
||||||
|
|
||||||
|
_, domain, found := strings.Cut(email, "@") |
||||||
|
if !found { |
||||||
|
return types.String("") |
||||||
|
} |
||||||
|
|
||||||
|
return types.String(domain) |
||||||
|
} |
||||||
|
|
||||||
|
func emailLocalPartImpl(arg ref.Val) ref.Val { |
||||||
|
email, ok := arg.Value().(string) |
||||||
|
if !ok { |
||||||
|
return types.NewErr("dex.emailLocalPart: expected string argument") |
||||||
|
} |
||||||
|
|
||||||
|
localPart, _, found := strings.Cut(email, "@") |
||||||
|
if !found { |
||||||
|
return types.String(email) |
||||||
|
} |
||||||
|
|
||||||
|
return types.String(localPart) |
||||||
|
} |
||||||
@ -0,0 +1,108 @@ |
|||||||
|
package library_test |
||||||
|
|
||||||
|
import ( |
||||||
|
"context" |
||||||
|
"testing" |
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert" |
||||||
|
"github.com/stretchr/testify/require" |
||||||
|
|
||||||
|
dexcel "github.com/dexidp/dex/pkg/cel" |
||||||
|
) |
||||||
|
|
||||||
|
func TestEmailDomain(t *testing.T) { |
||||||
|
compiler, err := dexcel.NewCompiler(nil) |
||||||
|
require.NoError(t, err) |
||||||
|
|
||||||
|
tests := map[string]struct { |
||||||
|
expr string |
||||||
|
want string |
||||||
|
}{ |
||||||
|
"standard email": { |
||||||
|
expr: `dex.emailDomain("user@example.com")`, |
||||||
|
want: "example.com", |
||||||
|
}, |
||||||
|
"subdomain": { |
||||||
|
expr: `dex.emailDomain("admin@sub.domain.org")`, |
||||||
|
want: "sub.domain.org", |
||||||
|
}, |
||||||
|
"no at sign": { |
||||||
|
expr: `dex.emailDomain("nodomain")`, |
||||||
|
want: "", |
||||||
|
}, |
||||||
|
"empty string": { |
||||||
|
expr: `dex.emailDomain("")`, |
||||||
|
want: "", |
||||||
|
}, |
||||||
|
"multiple at signs": { |
||||||
|
expr: `dex.emailDomain("user@name@example.com")`, |
||||||
|
want: "name@example.com", |
||||||
|
}, |
||||||
|
} |
||||||
|
|
||||||
|
for name, tc := range tests { |
||||||
|
t.Run(name, func(t *testing.T) { |
||||||
|
prog, err := compiler.CompileString(tc.expr) |
||||||
|
require.NoError(t, err) |
||||||
|
|
||||||
|
result, err := dexcel.EvalString(context.Background(), prog, map[string]any{}) |
||||||
|
require.NoError(t, err) |
||||||
|
assert.Equal(t, tc.want, result) |
||||||
|
}) |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func TestEmailLocalPart(t *testing.T) { |
||||||
|
compiler, err := dexcel.NewCompiler(nil) |
||||||
|
require.NoError(t, err) |
||||||
|
|
||||||
|
tests := map[string]struct { |
||||||
|
expr string |
||||||
|
want string |
||||||
|
}{ |
||||||
|
"standard email": { |
||||||
|
expr: `dex.emailLocalPart("user@example.com")`, |
||||||
|
want: "user", |
||||||
|
}, |
||||||
|
"no at sign": { |
||||||
|
expr: `dex.emailLocalPart("justuser")`, |
||||||
|
want: "justuser", |
||||||
|
}, |
||||||
|
"empty string": { |
||||||
|
expr: `dex.emailLocalPart("")`, |
||||||
|
want: "", |
||||||
|
}, |
||||||
|
"multiple at signs": { |
||||||
|
expr: `dex.emailLocalPart("user@name@example.com")`, |
||||||
|
want: "user", |
||||||
|
}, |
||||||
|
} |
||||||
|
|
||||||
|
for name, tc := range tests { |
||||||
|
t.Run(name, func(t *testing.T) { |
||||||
|
prog, err := compiler.CompileString(tc.expr) |
||||||
|
require.NoError(t, err) |
||||||
|
|
||||||
|
result, err := dexcel.EvalString(context.Background(), prog, map[string]any{}) |
||||||
|
require.NoError(t, err) |
||||||
|
assert.Equal(t, tc.want, result) |
||||||
|
}) |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func TestEmailDomainWithIdentityVariable(t *testing.T) { |
||||||
|
vars := dexcel.IdentityVariables() |
||||||
|
compiler, err := dexcel.NewCompiler(vars) |
||||||
|
require.NoError(t, err) |
||||||
|
|
||||||
|
prog, err := compiler.CompileString(`dex.emailDomain(identity.email)`) |
||||||
|
require.NoError(t, err) |
||||||
|
|
||||||
|
result, err := dexcel.EvalString(context.Background(), prog, map[string]any{ |
||||||
|
"identity": map[string]any{ |
||||||
|
"email": "admin@corp.example.com", |
||||||
|
}, |
||||||
|
}) |
||||||
|
require.NoError(t, err) |
||||||
|
assert.Equal(t, "corp.example.com", result) |
||||||
|
} |
||||||
@ -0,0 +1,119 @@ |
|||||||
|
package library |
||||||
|
|
||||||
|
import ( |
||||||
|
"path/filepath" |
||||||
|
|
||||||
|
"github.com/google/cel-go/cel" |
||||||
|
"github.com/google/cel-go/common/types" |
||||||
|
"github.com/google/cel-go/common/types/ref" |
||||||
|
"github.com/google/cel-go/common/types/traits" |
||||||
|
) |
||||||
|
|
||||||
|
// Groups provides group-related CEL functions.
|
||||||
|
//
|
||||||
|
// Functions (V1):
|
||||||
|
//
|
||||||
|
// dex.groupMatches(groups: list(string), pattern: string) -> list(string)
|
||||||
|
// Returns groups matching a glob pattern.
|
||||||
|
// Example: dex.groupMatches(["team:dev", "team:ops", "admin"], "team:*")
|
||||||
|
//
|
||||||
|
// dex.groupFilter(groups: list(string), allowed: list(string)) -> list(string)
|
||||||
|
// Returns only groups present in the allowed list.
|
||||||
|
// Example: dex.groupFilter(["admin", "dev", "ops"], ["admin", "ops"])
|
||||||
|
type Groups struct{} |
||||||
|
|
||||||
|
func (Groups) CompileOptions() []cel.EnvOption { |
||||||
|
return []cel.EnvOption{ |
||||||
|
cel.Function("dex.groupMatches", |
||||||
|
cel.Overload("dex_group_matches_list_string", |
||||||
|
[]*cel.Type{cel.ListType(cel.StringType), cel.StringType}, |
||||||
|
cel.ListType(cel.StringType), |
||||||
|
cel.BinaryBinding(groupMatchesImpl), |
||||||
|
), |
||||||
|
), |
||||||
|
cel.Function("dex.groupFilter", |
||||||
|
cel.Overload("dex_group_filter_list_list", |
||||||
|
[]*cel.Type{cel.ListType(cel.StringType), cel.ListType(cel.StringType)}, |
||||||
|
cel.ListType(cel.StringType), |
||||||
|
cel.BinaryBinding(groupFilterImpl), |
||||||
|
), |
||||||
|
), |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func (Groups) ProgramOptions() []cel.ProgramOption { |
||||||
|
return nil |
||||||
|
} |
||||||
|
|
||||||
|
func groupMatchesImpl(lhs, rhs ref.Val) ref.Val { |
||||||
|
groupList, ok := lhs.(traits.Lister) |
||||||
|
if !ok { |
||||||
|
return types.NewErr("dex.groupMatches: expected list(string) as first argument") |
||||||
|
} |
||||||
|
|
||||||
|
pattern, ok := rhs.Value().(string) |
||||||
|
if !ok { |
||||||
|
return types.NewErr("dex.groupMatches: expected string pattern as second argument") |
||||||
|
} |
||||||
|
|
||||||
|
iter := groupList.Iterator() |
||||||
|
var matched []ref.Val |
||||||
|
|
||||||
|
for iter.HasNext() == types.True { |
||||||
|
item := iter.Next() |
||||||
|
|
||||||
|
group, ok := item.Value().(string) |
||||||
|
if !ok { |
||||||
|
continue |
||||||
|
} |
||||||
|
|
||||||
|
if ok, _ := filepath.Match(pattern, group); ok { |
||||||
|
matched = append(matched, types.String(group)) |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
return types.NewRefValList(types.DefaultTypeAdapter, matched) |
||||||
|
} |
||||||
|
|
||||||
|
func groupFilterImpl(lhs, rhs ref.Val) ref.Val { |
||||||
|
groupList, ok := lhs.(traits.Lister) |
||||||
|
if !ok { |
||||||
|
return types.NewErr("dex.groupFilter: expected list(string) as first argument") |
||||||
|
} |
||||||
|
|
||||||
|
allowedList, ok := rhs.(traits.Lister) |
||||||
|
if !ok { |
||||||
|
return types.NewErr("dex.groupFilter: expected list(string) as second argument") |
||||||
|
} |
||||||
|
|
||||||
|
allowed := make(map[string]struct{}) |
||||||
|
iter := allowedList.Iterator() |
||||||
|
for iter.HasNext() == types.True { |
||||||
|
item := iter.Next() |
||||||
|
|
||||||
|
s, ok := item.Value().(string) |
||||||
|
if !ok { |
||||||
|
continue |
||||||
|
} |
||||||
|
|
||||||
|
allowed[s] = struct{}{} |
||||||
|
} |
||||||
|
|
||||||
|
var filtered []ref.Val |
||||||
|
iter = groupList.Iterator() |
||||||
|
|
||||||
|
for iter.HasNext() == types.True { |
||||||
|
item := iter.Next() |
||||||
|
|
||||||
|
group, ok := item.Value().(string) |
||||||
|
if !ok { |
||||||
|
continue |
||||||
|
} |
||||||
|
|
||||||
|
if _, exists := allowed[group]; exists { |
||||||
|
filtered = append(filtered, types.String(group)) |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
return types.NewRefValList(types.DefaultTypeAdapter, filtered) |
||||||
|
} |
||||||
@ -0,0 +1,130 @@ |
|||||||
|
package library_test |
||||||
|
|
||||||
|
import ( |
||||||
|
"context" |
||||||
|
"reflect" |
||||||
|
"testing" |
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert" |
||||||
|
"github.com/stretchr/testify/require" |
||||||
|
|
||||||
|
dexcel "github.com/dexidp/dex/pkg/cel" |
||||||
|
) |
||||||
|
|
||||||
|
func TestGroupMatches(t *testing.T) { |
||||||
|
vars := dexcel.IdentityVariables() |
||||||
|
compiler, err := dexcel.NewCompiler(vars) |
||||||
|
require.NoError(t, err) |
||||||
|
|
||||||
|
tests := map[string]struct { |
||||||
|
expr string |
||||||
|
groups []string |
||||||
|
want []string |
||||||
|
}{ |
||||||
|
"wildcard pattern": { |
||||||
|
expr: `dex.groupMatches(identity.groups, "team:*")`, |
||||||
|
groups: []string{"team:dev", "team:ops", "admin"}, |
||||||
|
want: []string{"team:dev", "team:ops"}, |
||||||
|
}, |
||||||
|
"exact match": { |
||||||
|
expr: `dex.groupMatches(identity.groups, "admin")`, |
||||||
|
groups: []string{"team:dev", "admin", "user"}, |
||||||
|
want: []string{"admin"}, |
||||||
|
}, |
||||||
|
"no matches": { |
||||||
|
expr: `dex.groupMatches(identity.groups, "nonexistent")`, |
||||||
|
groups: []string{"team:dev", "admin"}, |
||||||
|
want: []string{}, |
||||||
|
}, |
||||||
|
"question mark pattern": { |
||||||
|
expr: `dex.groupMatches(identity.groups, "team?")`, |
||||||
|
groups: []string{"teamA", "teamB", "teams-long"}, |
||||||
|
want: []string{"teamA", "teamB"}, |
||||||
|
}, |
||||||
|
"match all": { |
||||||
|
expr: `dex.groupMatches(identity.groups, "*")`, |
||||||
|
groups: []string{"a", "b", "c"}, |
||||||
|
want: []string{"a", "b", "c"}, |
||||||
|
}, |
||||||
|
} |
||||||
|
|
||||||
|
for name, tc := range tests { |
||||||
|
t.Run(name, func(t *testing.T) { |
||||||
|
prog, err := compiler.CompileStringList(tc.expr) |
||||||
|
require.NoError(t, err) |
||||||
|
|
||||||
|
out, err := dexcel.Eval(context.Background(), prog, map[string]any{ |
||||||
|
"identity": map[string]any{ |
||||||
|
"groups": tc.groups, |
||||||
|
}, |
||||||
|
}) |
||||||
|
require.NoError(t, err) |
||||||
|
|
||||||
|
nativeVal, err := out.ConvertToNative(reflect.TypeOf([]string{})) |
||||||
|
require.NoError(t, err) |
||||||
|
|
||||||
|
got, ok := nativeVal.([]string) |
||||||
|
require.True(t, ok, "expected []string, got %T", nativeVal) |
||||||
|
assert.Equal(t, tc.want, got) |
||||||
|
}) |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func TestGroupFilter(t *testing.T) { |
||||||
|
vars := dexcel.IdentityVariables() |
||||||
|
compiler, err := dexcel.NewCompiler(vars) |
||||||
|
require.NoError(t, err) |
||||||
|
|
||||||
|
tests := map[string]struct { |
||||||
|
expr string |
||||||
|
groups []string |
||||||
|
want []string |
||||||
|
}{ |
||||||
|
"filter to allowed": { |
||||||
|
expr: `dex.groupFilter(identity.groups, ["admin", "ops"])`, |
||||||
|
groups: []string{"admin", "dev", "ops"}, |
||||||
|
want: []string{"admin", "ops"}, |
||||||
|
}, |
||||||
|
"no overlap": { |
||||||
|
expr: `dex.groupFilter(identity.groups, ["marketing"])`, |
||||||
|
groups: []string{"admin", "dev"}, |
||||||
|
want: []string{}, |
||||||
|
}, |
||||||
|
"all allowed": { |
||||||
|
expr: `dex.groupFilter(identity.groups, ["a", "b", "c"])`, |
||||||
|
groups: []string{"a", "b", "c"}, |
||||||
|
want: []string{"a", "b", "c"}, |
||||||
|
}, |
||||||
|
"empty allowed list": { |
||||||
|
expr: `dex.groupFilter(identity.groups, [])`, |
||||||
|
groups: []string{"admin", "dev"}, |
||||||
|
want: []string{}, |
||||||
|
}, |
||||||
|
"preserves order": { |
||||||
|
expr: `dex.groupFilter(identity.groups, ["z", "a"])`, |
||||||
|
groups: []string{"a", "b", "z"}, |
||||||
|
want: []string{"a", "z"}, |
||||||
|
}, |
||||||
|
} |
||||||
|
|
||||||
|
for name, tc := range tests { |
||||||
|
t.Run(name, func(t *testing.T) { |
||||||
|
prog, err := compiler.CompileStringList(tc.expr) |
||||||
|
require.NoError(t, err) |
||||||
|
|
||||||
|
out, err := dexcel.Eval(context.Background(), prog, map[string]any{ |
||||||
|
"identity": map[string]any{ |
||||||
|
"groups": tc.groups, |
||||||
|
}, |
||||||
|
}) |
||||||
|
require.NoError(t, err) |
||||||
|
|
||||||
|
nativeVal, err := out.ConvertToNative(reflect.TypeOf([]string{})) |
||||||
|
require.NoError(t, err) |
||||||
|
|
||||||
|
got, ok := nativeVal.([]string) |
||||||
|
require.True(t, ok, "expected []string, got %T", nativeVal) |
||||||
|
assert.Equal(t, tc.want, got) |
||||||
|
}) |
||||||
|
} |
||||||
|
} |
||||||
@ -0,0 +1,80 @@ |
|||||||
|
package cel |
||||||
|
|
||||||
|
import ( |
||||||
|
"github.com/google/cel-go/cel" |
||||||
|
|
||||||
|
"github.com/dexidp/dex/connector" |
||||||
|
) |
||||||
|
|
||||||
|
// VariableDeclaration declares a named variable and its CEL type
|
||||||
|
// that will be available in expressions.
|
||||||
|
type VariableDeclaration struct { |
||||||
|
Name string |
||||||
|
Type *cel.Type |
||||||
|
} |
||||||
|
|
||||||
|
// IdentityVariables provides the 'identity' variable with user claims.
|
||||||
|
//
|
||||||
|
// identity.user_id — string
|
||||||
|
// identity.username — string
|
||||||
|
// identity.preferred_username — string
|
||||||
|
// identity.email — string
|
||||||
|
// identity.email_verified — bool
|
||||||
|
// identity.groups — list(string)
|
||||||
|
func IdentityVariables() []VariableDeclaration { |
||||||
|
return []VariableDeclaration{ |
||||||
|
{Name: "identity", Type: cel.MapType(cel.StringType, cel.DynType)}, |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
// RequestVariables provides the 'request' variable with request context.
|
||||||
|
//
|
||||||
|
// request.client_id — string
|
||||||
|
// request.connector_id — string
|
||||||
|
// request.scopes — list(string)
|
||||||
|
// request.redirect_uri — string
|
||||||
|
func RequestVariables() []VariableDeclaration { |
||||||
|
return []VariableDeclaration{ |
||||||
|
{Name: "request", Type: cel.MapType(cel.StringType, cel.DynType)}, |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
// ClaimsVariable provides a 'claims' map for raw upstream claims.
|
||||||
|
//
|
||||||
|
// claims — map(string, dyn)
|
||||||
|
func ClaimsVariable() []VariableDeclaration { |
||||||
|
return []VariableDeclaration{ |
||||||
|
{Name: "claims", Type: cel.MapType(cel.StringType, cel.DynType)}, |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
// IdentityFromConnector converts a connector.Identity to a CEL-compatible map.
|
||||||
|
func IdentityFromConnector(id connector.Identity) map[string]any { |
||||||
|
return map[string]any{ |
||||||
|
"user_id": id.UserID, |
||||||
|
"username": id.Username, |
||||||
|
"preferred_username": id.PreferredUsername, |
||||||
|
"email": id.Email, |
||||||
|
"email_verified": id.EmailVerified, |
||||||
|
"groups": id.Groups, |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
// RequestContext represents the authentication/token request context
|
||||||
|
// available as the 'request' variable in CEL expressions.
|
||||||
|
type RequestContext struct { |
||||||
|
ClientID string |
||||||
|
ConnectorID string |
||||||
|
Scopes []string |
||||||
|
RedirectURI string |
||||||
|
} |
||||||
|
|
||||||
|
// RequestFromContext converts a RequestContext to a CEL-compatible map.
|
||||||
|
func RequestFromContext(rc RequestContext) map[string]any { |
||||||
|
return map[string]any{ |
||||||
|
"client_id": rc.ClientID, |
||||||
|
"connector_id": rc.ConnectorID, |
||||||
|
"scopes": rc.Scopes, |
||||||
|
"redirect_uri": rc.RedirectURI, |
||||||
|
} |
||||||
|
} |
||||||
@ -0,0 +1,3 @@ |
|||||||
|
// Package featureflags provides a mechanism for toggling experimental or
|
||||||
|
// optional Dex features via environment variables (DEX_<FLAG_NAME>).
|
||||||
|
package featureflags |
||||||
@ -0,0 +1,2 @@ |
|||||||
|
// Package groups contains helper functions related to groups.
|
||||||
|
package groups |
||||||
Loading…
Reference in new issue