mirror of https://github.com/dexidp/dex.git
142 changed files with 18881 additions and 3624 deletions
@ -0,0 +1,56 @@
|
||||
package main |
||||
|
||||
import ( |
||||
"context" |
||||
"log/slog" |
||||
) |
||||
|
||||
// excludingHandler is an slog.Handler wrapper that drops log attributes
|
||||
// whose keys match a configured set. This allows PII fields like email,
|
||||
// username, or groups to be redacted at the logger level rather than
|
||||
// requiring per-callsite suppression logic.
|
||||
type excludingHandler struct { |
||||
inner slog.Handler |
||||
exclude map[string]bool |
||||
} |
||||
|
||||
func newExcludingHandler(inner slog.Handler, fields []string) slog.Handler { |
||||
if len(fields) == 0 { |
||||
return inner |
||||
} |
||||
m := make(map[string]bool, len(fields)) |
||||
for _, f := range fields { |
||||
m[f] = true |
||||
} |
||||
return &excludingHandler{inner: inner, exclude: m} |
||||
} |
||||
|
||||
func (h *excludingHandler) Enabled(ctx context.Context, level slog.Level) bool { |
||||
return h.inner.Enabled(ctx, level) |
||||
} |
||||
|
||||
func (h *excludingHandler) Handle(ctx context.Context, record slog.Record) error { |
||||
// Rebuild the record without excluded attributes.
|
||||
filtered := slog.NewRecord(record.Time, record.Level, record.Message, record.PC) |
||||
record.Attrs(func(a slog.Attr) bool { |
||||
if !h.exclude[a.Key] { |
||||
filtered.AddAttrs(a) |
||||
} |
||||
return true |
||||
}) |
||||
return h.inner.Handle(ctx, filtered) |
||||
} |
||||
|
||||
func (h *excludingHandler) WithAttrs(attrs []slog.Attr) slog.Handler { |
||||
var kept []slog.Attr |
||||
for _, a := range attrs { |
||||
if !h.exclude[a.Key] { |
||||
kept = append(kept, a) |
||||
} |
||||
} |
||||
return &excludingHandler{inner: h.inner.WithAttrs(kept), exclude: h.exclude} |
||||
} |
||||
|
||||
func (h *excludingHandler) WithGroup(name string) slog.Handler { |
||||
return &excludingHandler{inner: h.inner.WithGroup(name), exclude: h.exclude} |
||||
} |
||||
@ -0,0 +1,141 @@
|
||||
package main |
||||
|
||||
import ( |
||||
"bytes" |
||||
"context" |
||||
"encoding/json" |
||||
"log/slog" |
||||
"testing" |
||||
) |
||||
|
||||
func TestExcludingHandler(t *testing.T) { |
||||
tests := []struct { |
||||
name string |
||||
exclude []string |
||||
logAttrs []slog.Attr |
||||
wantKeys []string |
||||
absentKeys []string |
||||
}{ |
||||
{ |
||||
name: "no exclusions", |
||||
exclude: nil, |
||||
logAttrs: []slog.Attr{ |
||||
slog.String("email", "user@example.com"), |
||||
slog.String("connector_id", "github"), |
||||
}, |
||||
wantKeys: []string{"email", "connector_id"}, |
||||
}, |
||||
{ |
||||
name: "exclude email", |
||||
exclude: []string{"email"}, |
||||
logAttrs: []slog.Attr{ |
||||
slog.String("email", "user@example.com"), |
||||
slog.String("connector_id", "github"), |
||||
}, |
||||
wantKeys: []string{"connector_id"}, |
||||
absentKeys: []string{"email"}, |
||||
}, |
||||
{ |
||||
name: "exclude multiple fields", |
||||
exclude: []string{"email", "username", "groups"}, |
||||
logAttrs: []slog.Attr{ |
||||
slog.String("email", "user@example.com"), |
||||
slog.String("username", "johndoe"), |
||||
slog.String("connector_id", "github"), |
||||
slog.Any("groups", []string{"admin"}), |
||||
}, |
||||
wantKeys: []string{"connector_id"}, |
||||
absentKeys: []string{"email", "username", "groups"}, |
||||
}, |
||||
{ |
||||
name: "exclude non-existent field is harmless", |
||||
exclude: []string{"nonexistent"}, |
||||
logAttrs: []slog.Attr{ |
||||
slog.String("email", "user@example.com"), |
||||
}, |
||||
wantKeys: []string{"email"}, |
||||
}, |
||||
} |
||||
|
||||
for _, tt := range tests { |
||||
t.Run(tt.name, func(t *testing.T) { |
||||
var buf bytes.Buffer |
||||
inner := slog.NewJSONHandler(&buf, &slog.HandlerOptions{Level: slog.LevelInfo}) |
||||
handler := newExcludingHandler(inner, tt.exclude) |
||||
logger := slog.New(handler) |
||||
|
||||
attrs := make([]any, 0, len(tt.logAttrs)*2) |
||||
for _, a := range tt.logAttrs { |
||||
attrs = append(attrs, a) |
||||
} |
||||
logger.Info("test message", attrs...) |
||||
|
||||
var result map[string]any |
||||
if err := json.Unmarshal(buf.Bytes(), &result); err != nil { |
||||
t.Fatalf("failed to parse log output: %v", err) |
||||
} |
||||
|
||||
for _, key := range tt.wantKeys { |
||||
if _, ok := result[key]; !ok { |
||||
t.Errorf("expected key %q in log output", key) |
||||
} |
||||
} |
||||
for _, key := range tt.absentKeys { |
||||
if _, ok := result[key]; ok { |
||||
t.Errorf("expected key %q to be absent from log output", key) |
||||
} |
||||
} |
||||
}) |
||||
} |
||||
} |
||||
|
||||
func TestExcludingHandlerWithAttrs(t *testing.T) { |
||||
var buf bytes.Buffer |
||||
inner := slog.NewJSONHandler(&buf, &slog.HandlerOptions{Level: slog.LevelInfo}) |
||||
handler := newExcludingHandler(inner, []string{"email"}) |
||||
logger := slog.New(handler) |
||||
|
||||
// Pre-bind an excluded attr via With
|
||||
child := logger.With("email", "user@example.com", "connector_id", "github") |
||||
child.Info("login successful") |
||||
|
||||
var result map[string]any |
||||
if err := json.Unmarshal(buf.Bytes(), &result); err != nil { |
||||
t.Fatalf("failed to parse log output: %v", err) |
||||
} |
||||
|
||||
if _, ok := result["email"]; ok { |
||||
t.Error("expected email to be excluded from WithAttrs output") |
||||
} |
||||
if _, ok := result["connector_id"]; !ok { |
||||
t.Error("expected connector_id to be present") |
||||
} |
||||
} |
||||
|
||||
func TestExcludingHandlerEnabled(t *testing.T) { |
||||
inner := slog.NewJSONHandler(&bytes.Buffer{}, &slog.HandlerOptions{Level: slog.LevelWarn}) |
||||
handler := newExcludingHandler(inner, []string{"email"}) |
||||
|
||||
if handler.Enabled(context.Background(), slog.LevelInfo) { |
||||
t.Error("expected Info to be disabled when handler level is Warn") |
||||
} |
||||
if !handler.Enabled(context.Background(), slog.LevelWarn) { |
||||
t.Error("expected Warn to be enabled") |
||||
} |
||||
} |
||||
|
||||
func TestExcludingHandlerNilFields(t *testing.T) { |
||||
var buf bytes.Buffer |
||||
inner := slog.NewJSONHandler(&buf, &slog.HandlerOptions{Level: slog.LevelInfo}) |
||||
|
||||
// With nil/empty fields, should return the inner handler directly
|
||||
handler := newExcludingHandler(inner, nil) |
||||
if _, ok := handler.(*excludingHandler); ok { |
||||
t.Error("expected nil fields to return inner handler directly, not wrap it") |
||||
} |
||||
|
||||
handler = newExcludingHandler(inner, []string{}) |
||||
if _, ok := handler.(*excludingHandler); ok { |
||||
t.Error("expected empty fields to return inner handler directly, not wrap it") |
||||
} |
||||
} |
||||
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,732 @@
|
||||
# Dex Enhancement Proposal (DEP) - 2026-02-28 - CEL (Common Expression Language) Integration |
||||
|
||||
## Table of Contents |
||||
|
||||
- [Summary](#summary) |
||||
- [Context](#context) |
||||
- [Motivation](#motivation) |
||||
- [Goals/Pain](#goalspain) |
||||
- [Non-Goals](#non-goals) |
||||
- [Proposal](#proposal) |
||||
- [User Experience](#user-experience) |
||||
- [Implementation Details/Notes/Constraints](#implementation-detailsnotesconstraints) |
||||
- [Phase 1: pkg/cel - Core CEL Library](#phase-1-pkgcel---core-cel-library) |
||||
- [Phase 2: Authentication Policies](#phase-2-authentication-policies) |
||||
- [Phase 3: Token Policies](#phase-3-token-policies) |
||||
- [Phase 4: OIDC Connector Claim Mapping](#phase-4-oidc-connector-claim-mapping) |
||||
- [Policy Application Flow](#policy-application-flow) |
||||
- [Risks and Mitigations](#risks-and-mitigations) |
||||
- [Alternatives](#alternatives) |
||||
- [Future Improvements](#future-improvements) |
||||
|
||||
## Summary |
||||
|
||||
This DEP proposes integrating [CEL (Common Expression Language)][cel-spec] into Dex as a first-class |
||||
expression engine for policy evaluation, claim mapping, and token customization. A new reusable |
||||
`pkg/cel` package will provide a safe, sandboxed CEL environment with Kubernetes-grade compatibility |
||||
guarantees, cost budgets, and a curated set of extension libraries. Subsequent phases will leverage |
||||
this package to implement authentication policies, token policies, advanced claim mapping in |
||||
connectors, and per-client/global access rules — replacing the need for ad-hoc configuration fields |
||||
and external policy engines. |
||||
|
||||
[cel-spec]: https://github.com/google/cel-spec |
||||
|
||||
## Context |
||||
|
||||
- [#1583 Add allowedGroups option for clients config][#1583] — a long-standing request for a |
||||
configuration option to allow a client to specify a list of allowed groups. |
||||
- [#1635 Connector Middleware][#1635] — long-standing request for a policy/middleware layer between |
||||
connectors and the server for claim transformations and access control. |
||||
- [#1052 Allow restricting connectors per client][#1052] — frequently requested feature to restrict |
||||
which connectors are available to specific OAuth2 clients. |
||||
- [#2178 Custom claims in ID tokens][#2178] — requests for including additional payload in issued tokens. |
||||
- [#2812 Token Exchange DEP][dep-token-exchange] — mentions CEL/Rego as future improvement for |
||||
policy-based assertions on exchanged tokens. |
||||
- The OIDC connector already has a growing set of ad-hoc claim mutation options |
||||
(`ClaimMapping`, `ClaimMutations.NewGroupFromClaims`, `FilterGroupClaims`, `ModifyGroupNames`) |
||||
that would benefit from a unified expression language. |
||||
- Previous community discussions explored OPA/Rego and JMESPath, but CEL offers a better fit |
||||
(see [Alternatives](#alternatives)). |
||||
|
||||
[#1583]: https://github.com/dexidp/dex/pull/1583 |
||||
[#1635]: https://github.com/dexidp/dex/issues/1635 |
||||
[#1052]: https://github.com/dexidp/dex/issues/1052 |
||||
[#2178]: https://github.com/dexidp/dex/issues/2178 |
||||
[dep-token-exchange]: /docs/enhancements/token-exchange-2023-02-03-%232812.md |
||||
|
||||
## Motivation |
||||
|
||||
### Goals/Pain |
||||
|
||||
1. **Complex query/filter capabilities** — Dex needs a way to express complex validations and |
||||
mutations in multiple places (authentication flow, token issuance, claim mapping). Today each |
||||
feature requires new Go code, new config fields, and a new release cycle. CEL allows operators |
||||
to express these rules declaratively without code changes. |
||||
|
||||
2. **Authentication policies** — Operators want to control _who_ can log in based on rich |
||||
conditions: restrict specific connectors to specific clients, require group membership for |
||||
certain clients, deny login based on email domain, enforce MFA claims, etc. Currently there is |
||||
no unified mechanism; users rely on downstream applications or external proxies. |
||||
|
||||
3. **Token policies** — Operators want to customize issued tokens: add extra claims to ID tokens, |
||||
restrict scopes per client, modify `aud` claims, include upstream connector metadata, etc. |
||||
Today this requires forking Dex or using a reverse proxy. |
||||
|
||||
4. **Claim mapping in OIDC connector** — The OIDC connector has accumulated multiple ad-hoc config |
||||
options for claim mapping and group mutations (`ClaimMapping`, `NewGroupFromClaims`, |
||||
`FilterGroupClaims`, `ModifyGroupNames`). A single CEL expression field would replace all of |
||||
these with a more powerful and composable approach. |
||||
|
||||
5. **Per-client and global policies** — One of the most frequent requests is allowing different |
||||
connectors for different clients and restricting group-based access per client. CEL policies at |
||||
the global and per-client level address this cleanly. |
||||
|
||||
6. **CNCF ecosystem alignment** — CEL has massive adoption across the CNCF ecosystem: |
||||
|
||||
| Project | CEL Usage | Evidence | |
||||
|---------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------|----------| |
||||
| **Kubernetes** | ValidatingAdmissionPolicy, CRD validation rules (`x-kubernetes-validations`), AuthorizationPolicy, field selectors, CEL-based match conditions in webhooks | [KEP-3488][k8s-cel-kep], [CRD Validation Rules][k8s-crd-cel], [AuthorizationPolicy KEP-3221][k8s-authz-cel] | |
||||
| **Kyverno** | CEL expressions in validation/mutation policies (v1.12+), preconditions | [Kyverno CEL docs][kyverno-cel] | |
||||
| **OPA Gatekeeper** | Partially added support for CEL in constraint templates | [Gatekeeper CEL][gatekeeper-cel] | |
||||
| **Istio** | AuthorizationPolicy conditions, request routing, telemetry | [Istio CEL docs][istio-cel] | |
||||
| **Envoy / Envoy Gateway** | RBAC filter, ext_authz, rate limiting, route matching, access logging | [Envoy CEL docs][envoy-cel] | |
||||
| **Tekton** | Pipeline when expressions, CEL custom tasks | [Tekton CEL Interceptor][tekton-cel] | |
||||
| **Knative** | Trigger filters using CEL expressions | [Knative CEL filters][knative-cel] | |
||||
| **Google Cloud** | IAM Conditions, Cloud Deploy, Security Command Center | [Google IAM CEL][gcp-cel] | |
||||
| **Cert-Manager** | CertificateRequestPolicy approval using CEL | [cert-manager approver-policy CEL][cert-manager-cel] | |
||||
| **Cilium** | Hubble CEL filter logic | [Cilium CEL docs][cilium-cel] | |
||||
| **Crossplane** | Composition functions with CEL-based patch transforms | [Crossplane CEL transforms][crossplane-cel] | |
||||
| **Kube-OVN** | Network policy extensions using CEL | [Kube-OVN CEL][kube-ovn-cel] | |
||||
|
||||
[k8s-cel-kep]: https://github.com/kubernetes/enhancements/tree/master/keps/sig-api-machinery/3488-cel-admission-control |
||||
[k8s-crd-cel]: https://kubernetes.io/docs/tasks/extend-kubernetes/custom-resources/custom-resource-definitions/#validation-rules |
||||
[k8s-authz-cel]: https://github.com/kubernetes/enhancements/tree/master/keps/sig-auth/3221-structured-authorization-configuration |
||||
[kyverno-cel]: https://kyverno.io/docs/writing-policies/cel/ |
||||
[gatekeeper-cel]: https://open-policy-agent.github.io/gatekeeper/website/docs/validating-admission-policy/#policy-updates-to-add-vap-cel |
||||
[istio-cel]: https://istio.io/latest/docs/reference/config/security/conditions/ |
||||
[envoy-cel]: https://www.envoyproxy.io/docs/envoy/latest/xds/type/v3/cel.proto |
||||
[tekton-cel]: https://tekton.dev/docs/triggers/cel_expressions/ |
||||
[knative-cel]: https://github.com/knative/eventing/blob/main/docs/broker/filtering.md#add-cel-expression-filter |
||||
[gcp-cel]: https://cloud.google.com/iam/docs/conditions-overview |
||||
[cert-manager-cel]: https://cert-manager.io/docs/policy/approval/approver-policy/#validations |
||||
[cilium-cel]: https://docs.cilium.io/en/stable/_api/v1/flow/README/#flowfilter-experimental |
||||
[crossplane-cel]: https://github.com/crossplane-contrib/function-cel-filter |
||||
[kube-ovn-cel]: https://kubeovn.github.io/docs/stable/en/advance/cel-expression/ |
||||
|
||||
By choosing CEL, Dex operators who already use Kubernetes or other CNCF tools can reuse their |
||||
existing knowledge of the expression language. |
||||
|
||||
### Non-Goals |
||||
|
||||
- **Full policy engine** — This DEP does not aim to replace dedicated external policy engines |
||||
(OPA, Kyverno). CEL in Dex is scoped to identity and token operations. |
||||
- **Breaking changes to existing configuration** — All existing config fields (`ClaimMapping`, |
||||
`ClaimMutations`, etc.) will continue to work. CEL expressions are additive/opt-in. |
||||
- **Authorization (beyond Dex scope)** — Dex is an identity provider; downstream authorization |
||||
decisions remain the responsibility of relying parties. CEL policies in Dex are limited to |
||||
authentication and token issuance concerns. |
||||
- **Multi-phase CEL in a single DEP** — Only Phase 1 (`pkg/cel` package) is targeted for |
||||
immediate implementation. Phases 2-4 are included here for design context and will have their |
||||
own implementation PRs. |
||||
- **Multi-step logic** — CEL in Dex is scoped to single-expression evaluation. Each expression |
||||
is a standalone, stateless computation with no intermediate variables, chaining, or |
||||
multi-step transformations. If a use case requires sequential logic or conditionally chained |
||||
expressions, it belongs outside Dex (e.g. in an external policy engine or middleware). |
||||
This boundary protects the design from scope creep that pushes CEL beyond what it's good at. |
||||
|
||||
## Proposal |
||||
|
||||
### User Experience |
||||
|
||||
#### Authentication Policy (Phase 2) |
||||
|
||||
Operators can define global and per-client authentication policies in the Dex config: |
||||
|
||||
```yaml |
||||
# Global authentication policy — each expression evaluates to bool. |
||||
# If true — the request is denied. Evaluated in order; first match wins. |
||||
authPolicy: |
||||
- expression: "!identity.email.endsWith('@example.com')" |
||||
message: "'Login restricted to example.com domain'" |
||||
- expression: "!identity.email_verified" |
||||
message: "'Email must be verified'" |
||||
|
||||
staticClients: |
||||
- id: admin-app |
||||
name: Admin Application |
||||
secret: ... |
||||
redirectURIs: [...] |
||||
# Per-client policy — same structure as global |
||||
authPolicy: |
||||
- expression: "!(request.connector_id in ['okta', 'ldap'])" |
||||
message: "'This application requires Okta or LDAP login'" |
||||
- expression: "!('admin' in identity.groups)" |
||||
message: "'Admin group membership required'" |
||||
``` |
||||
|
||||
#### Token Policy (Phase 3) |
||||
|
||||
Operators can add extra claims or mutate token contents: |
||||
|
||||
```yaml |
||||
tokenPolicy: |
||||
# Global mutations applied to all ID tokens |
||||
claims: |
||||
# Add a custom claim based on group membership |
||||
- key: "'role'" |
||||
value: "identity.groups.exists(g, g == 'admin') ? 'admin' : 'user'" |
||||
# Include connector ID as a claim |
||||
- key: "'idp'" |
||||
value: "request.connector_id" |
||||
# Add department from upstream claims (only if present) |
||||
- key: "'department'" |
||||
value: "identity.extra['department']" |
||||
condition: "'department' in identity.extra" |
||||
|
||||
staticClients: |
||||
- id: internal-api |
||||
name: Internal API |
||||
secret: ... |
||||
redirectURIs: [...] |
||||
tokenPolicy: |
||||
claims: |
||||
- key: "'custom-claim.company.com/team'" |
||||
value: "identity.extra['team'].orValue('engineering')" |
||||
# Only add on-call claim for ops group members |
||||
- key: "'on_call'" |
||||
value: "true" |
||||
condition: "identity.groups.exists(g, g == 'ops')" |
||||
# Restrict scopes |
||||
filter: |
||||
expression: "request.scopes.all(s, s in ['openid', 'email', 'profile'])" |
||||
message: "'Unsupported scope requested'" |
||||
``` |
||||
|
||||
#### OIDC Connector Claim Mapping (Phase 4) |
||||
|
||||
Replace ad-hoc claim mapping with CEL: |
||||
|
||||
```yaml |
||||
connectors: |
||||
- type: oidc |
||||
id: corporate-idp |
||||
name: Corporate IdP |
||||
config: |
||||
issuer: https://idp.example.com |
||||
clientID: dex-client |
||||
clientSecret: ... |
||||
# CEL-based claim mapping — replaces claimMapping and claimModifications |
||||
claimMappingExpressions: |
||||
username: "claims.preferred_username.orValue(claims.email)" |
||||
email: "claims.email" |
||||
groups: > |
||||
claims.groups |
||||
.filter(g, g.startsWith('dex:')) |
||||
.map(g, g.trimPrefix('dex:')) |
||||
emailVerified: "claims.email_verified.orValue(true)" |
||||
# Extra claims to pass through to token policies |
||||
extra: |
||||
department: "claims.department.orValue('unknown')" |
||||
cost_center: "claims.cost_center.orValue('')" |
||||
``` |
||||
|
||||
### Implementation Details/Notes/Constraints |
||||
|
||||
### Phase 1: `pkg/cel` — Core CEL Library |
||||
|
||||
This is the foundation that all subsequent phases build upon. The package provides a safe, |
||||
reusable CEL environment with Kubernetes-grade guarantees. |
||||
|
||||
#### Package Structure |
||||
|
||||
``` |
||||
pkg/ |
||||
cel/ |
||||
cel.go # Core Environment, compilation, evaluation |
||||
types.go # CEL type declarations (Identity, Request, etc.) |
||||
cost.go # Cost estimation and budgeting |
||||
doc.go # Package documentation |
||||
library/ |
||||
email.go # Email-related CEL functions |
||||
groups.go # Group-related CEL functions |
||||
``` |
||||
|
||||
#### Dependencies |
||||
|
||||
``` |
||||
github.com/google/cel-go v0.27.0 |
||||
``` |
||||
|
||||
The `cel-go` library is the canonical Go implementation maintained by Google, used by Kubernetes |
||||
and all major CNCF projects. It follows semantic versioning and provides strong backward |
||||
compatibility guarantees. |
||||
|
||||
#### Core API Design |
||||
|
||||
**Public types:** |
||||
|
||||
```go |
||||
// CompilationResult holds a compiled CEL program ready for evaluation. |
||||
type CompilationResult struct { |
||||
Program cel.Program |
||||
OutputType *cel.Type |
||||
Expression string |
||||
} |
||||
|
||||
// Compiler compiles CEL expressions against a specific environment. |
||||
type Compiler struct { /* ... */ } |
||||
|
||||
// CompilerOption configures a Compiler. |
||||
type CompilerOption func(*compilerConfig) |
||||
``` |
||||
|
||||
**Compilation pipeline:** |
||||
|
||||
Each `Compile*` call performs these steps sequentially: |
||||
1. Reject expressions exceeding `MaxExpressionLength` (10,240 chars). |
||||
2. Compile and type-check the expression via `cel-go`. |
||||
3. Validate output type matches the expected type (for typed variants). |
||||
4. Estimate cost using `defaultCostEstimator` with size hints — reject if estimated max cost |
||||
exceeds the cost budget. |
||||
5. Create an optimized `cel.Program` with runtime cost limit. |
||||
|
||||
Presence tests (`has(field)`, `'key' in map`) have zero cost, matching Kubernetes CEL behavior. |
||||
|
||||
#### Variable Declarations |
||||
|
||||
Variables are declared via `VariableDeclaration{Name, Type}` and registered with `NewCompiler`. |
||||
Helper constructors provide pre-defined variable sets: |
||||
|
||||
**`IdentityVariables()`** — the `identity` variable (from `connector.Identity`), |
||||
typed as `cel.ObjectType`: |
||||
|
||||
| Field | CEL Type | Source | |
||||
|-------|----------|--------| |
||||
| `identity.user_id` | `string` | `connector.Identity.UserID` | |
||||
| `identity.username` | `string` | `connector.Identity.Username` | |
||||
| `identity.preferred_username` | `string` | `connector.Identity.PreferredUsername` | |
||||
| `identity.email` | `string` | `connector.Identity.Email` | |
||||
| `identity.email_verified` | `bool` | `connector.Identity.EmailVerified` | |
||||
| `identity.groups` | `list(string)` | `connector.Identity.Groups` | |
||||
|
||||
**`RequestVariables()`** — the `request` variable (from `RequestContext`), |
||||
typed as `cel.ObjectType`: |
||||
|
||||
| Field | CEL Type | |
||||
|-------|----------| |
||||
| `request.client_id` | `string` | |
||||
| `request.connector_id` | `string` | |
||||
| `request.scopes` | `list(string)` | |
||||
| `request.redirect_uri` | `string` | |
||||
|
||||
**`ClaimsVariable()`** — the `claims` variable for raw upstream claims as `map(string, dyn)`. |
||||
|
||||
**Typing strategy:** |
||||
|
||||
`identity` and `request` use `cel.ObjectType` with explicitly declared fields. This gives |
||||
compile-time type checking: a typo like `identity.emial` is rejected at config load time |
||||
rather than silently evaluating to null in production — critical for an auth system where a |
||||
misconfigured policy could lock users out. |
||||
|
||||
`claims` remains `map(string, dyn)` because its shape is genuinely unknown — it carries |
||||
arbitrary upstream IdP data. |
||||
|
||||
#### Compatibility Guarantees |
||||
|
||||
Following the Kubernetes CEL compatibility model |
||||
([KEP-3488: CEL for Admission Control][kep-3488], [Kubernetes CEL Migration Guide][k8s-cel-compat]): |
||||
|
||||
1. **Environment versioning** — The CEL environment is versioned. When new functions or variables |
||||
are added, they are introduced under a new environment version. Existing expressions compiled |
||||
against an older version continue to work. |
||||
|
||||
```go |
||||
// EnvironmentVersion represents the version of the CEL environment. |
||||
// New variables, functions, or libraries are introduced in new versions. |
||||
type EnvironmentVersion uint32 |
||||
|
||||
const ( |
||||
// EnvironmentV1 is the initial CEL environment. |
||||
EnvironmentV1 EnvironmentVersion = 1 |
||||
) |
||||
|
||||
// WithVersion sets the target environment version for the compiler. |
||||
func WithVersion(v EnvironmentVersion) CompilerOption |
||||
``` |
||||
|
||||
This is directly modeled on `k8s.io/apiserver/pkg/cel/environment`. |
||||
|
||||
2. **Library stability** — Custom functions in the `pkg/cel/library` subpackage follow these rules: |
||||
- Functions MUST NOT be removed once released. |
||||
- Function signatures MUST NOT change once released. |
||||
- New functions MUST be added under a new `EnvironmentVersion`. |
||||
- If a function needs to be replaced, the old one is deprecated but kept forever. |
||||
|
||||
3. **Type stability** — CEL types (`Identity`, `Request`, `Claims`) follow the same rules: |
||||
- Fields MUST NOT be removed. |
||||
- Field types MUST NOT change. |
||||
- New fields are added in a new `EnvironmentVersion`. |
||||
|
||||
4. **Semantic versioning of `cel-go`** — The `cel-go` dependency follows semver. Dex pins to a |
||||
minor version range and updates are tested for behavioral changes. This is exactly the approach |
||||
Kubernetes takes: `k8s.io/apiextensions-apiserver` pins `cel-go` and gates new features behind |
||||
environment versions. |
||||
|
||||
5. **Feature gates** — New CEL-powered features are gated behind Dex feature flags (using the |
||||
existing `pkg/featureflags` mechanism) during their alpha phase. |
||||
|
||||
[kep-3488]: https://github.com/kubernetes/enhancements/tree/master/keps/sig-api-machinery/3488-cel-admission-control |
||||
[k8s-cel-compat]: https://kubernetes.io/docs/reference/using-api/cel/ |
||||
|
||||
#### Cost Estimation and Budgets |
||||
|
||||
Like Kubernetes, Dex CEL expressions must be bounded to prevent denial-of-service. |
||||
|
||||
**Constants:** |
||||
|
||||
| Constant | Value | Description | |
||||
|----------|-------|-------------| |
||||
| `DefaultCostBudget` | `10_000_000` | Max cost units per evaluation (aligned with Kubernetes) | |
||||
| `MaxExpressionLength` | `10_240` | Max expression string length in characters | |
||||
| `DefaultStringMaxLength` | `256` | Estimated max string size for cost estimation | |
||||
| `DefaultListMaxLength` | `100` | Estimated max list size for cost estimation | |
||||
|
||||
**How it works:** |
||||
|
||||
A `defaultCostEstimator` (implementing `checker.CostEstimator`) provides size hints for known |
||||
variables (`identity`, `request`, `claims`) so the `cel-go` cost estimator doesn't assume |
||||
unbounded sizes. It also provides call cost estimates for custom Dex functions |
||||
(`dex.emailDomain`, `dex.emailLocalPart`, `dex.groupMatches`, `dex.groupFilter`). |
||||
|
||||
Expressions are validated at three levels: |
||||
1. **Length check** — reject expressions exceeding `MaxExpressionLength`. |
||||
2. **Compile-time cost estimation** — reject expressions whose estimated max cost exceeds |
||||
the cost budget. |
||||
3. **Runtime cost limit** — abort evaluation if actual cost exceeds the budget. |
||||
|
||||
#### Extension Libraries |
||||
|
||||
The `pkg/cel` environment includes these cel-go standard extensions (same set as Kubernetes): |
||||
|
||||
| Library | Description | Examples | |
||||
|---------|-------------|---------| |
||||
| `ext.Strings()` | Extended string functions | `"hello".upperAscii()`, `"foo:bar".split(':')`, `s.trim()`, `s.replace('a','b')` | |
||||
| `ext.Encoders()` | Base64 encoding/decoding | `base64.encode(bytes)`, `base64.decode(str)` | |
||||
| `ext.Lists()` | Extended list functions | `list.slice(1, 3)`, `list.flatten()` | |
||||
| `ext.Sets()` | Set operations on lists | `sets.contains(a, b)`, `sets.intersects(a, b)`, `sets.equivalent(a, b)` | |
||||
| `ext.Math()` | Math functions | `math.greatest(a, b)`, `math.least(a, b)` | |
||||
|
||||
Plus custom Dex libraries in the `pkg/cel/library` subpackage, each implementing the |
||||
`cel.Library` interface: |
||||
|
||||
**`library.Email`** — email-related helpers: |
||||
|
||||
| Function | Signature | Description | |
||||
|----------|-----------|-------------| |
||||
| `dex.emailDomain` | `(string) -> string` | Returns the domain portion of an email address. `dex.emailDomain("user@example.com") == "example.com"` | |
||||
| `dex.emailLocalPart` | `(string) -> string` | Returns the local part of an email address. `dex.emailLocalPart("user@example.com") == "user"` | |
||||
|
||||
**`library.Groups`** — group-related helpers: |
||||
|
||||
| Function | Signature | Description | |
||||
|----------|-----------|-------------| |
||||
| `dex.groupMatches` | `(list(string), string) -> list(string)` | Returns groups matching a glob pattern. `dex.groupMatches(identity.groups, "team:*")` | |
||||
| `dex.groupFilter` | `(list(string), list(string)) -> list(string)` | Returns only groups present in the allowed list. `dex.groupFilter(identity.groups, ["admin", "ops"])` | |
||||
|
||||
#### Example: Compile and Evaluate |
||||
|
||||
```go |
||||
// 1. Create a compiler with identity and request variables |
||||
compiler, _ := cel.NewCompiler( |
||||
append(cel.IdentityVariables(), cel.RequestVariables()...), |
||||
) |
||||
|
||||
// 2. Compile a policy expression (type-checked, cost-estimated) |
||||
prog, _ := compiler.CompileBool( |
||||
`identity.email.endsWith('@example.com') && 'admin' in identity.groups`, |
||||
) |
||||
|
||||
// 3. Evaluate against real data |
||||
result, _ := cel.EvalBool(ctx, prog, map[string]any{ |
||||
"identity": cel.IdentityFromConnector(connectorIdentity), |
||||
"request": cel.RequestFromContext(cel.RequestContext{...}), |
||||
}) |
||||
// result == true |
||||
``` |
||||
|
||||
### Phase 2: Authentication Policies |
||||
|
||||
**Config Model:** |
||||
|
||||
```go |
||||
// AuthPolicy is a list of deny expressions evaluated after a user |
||||
// authenticates with a connector. Each expression evaluates to bool. |
||||
// If true — the request is denied. Evaluated in order; first match wins. |
||||
type AuthPolicy []PolicyExpression |
||||
|
||||
// PolicyExpression is a CEL expression with an optional human-readable message. |
||||
type PolicyExpression struct { |
||||
// Expression is a CEL expression that evaluates to bool. |
||||
Expression string `json:"expression"` |
||||
// Message is a CEL expression that evaluates to string (displayed to the user on deny). |
||||
// If empty, a generic message is shown. |
||||
Message string `json:"message,omitempty"` |
||||
} |
||||
``` |
||||
|
||||
**Evaluation point:** After `connector.CallbackConnector.HandleCallback()` or |
||||
`connector.PasswordConnector.Login()` returns an identity, and before the auth request is |
||||
finalized. Implemented in `server/handlers.go` at `handleConnectorCallback`. |
||||
|
||||
**Available CEL variables:** `identity` (from connector), `request` (client_id, connector_id, |
||||
scopes, redirect_uri). |
||||
|
||||
**Compilation:** All policy expressions are compiled once at config load time (in |
||||
`cmd/dex/serve.go`) and stored in the `Server` struct. This ensures: |
||||
- Syntax/type errors are caught at startup, not at runtime. |
||||
- No compilation overhead per request. |
||||
- Cost estimation can warn operators about expensive expressions at startup. |
||||
|
||||
**Evaluation flow:** |
||||
|
||||
``` |
||||
User authenticates via connector |
||||
│ |
||||
v |
||||
connector.HandleCallback() returns Identity |
||||
│ |
||||
v |
||||
Evaluate global authPolicy (in order) |
||||
- For each expression: evaluate → bool |
||||
- If true → deny with message, HTTP 403 |
||||
│ |
||||
v |
||||
Evaluate per-client authPolicy (in order) |
||||
- Same logic as global |
||||
│ |
||||
v |
||||
Continue normal flow (approval screen or redirect) |
||||
``` |
||||
|
||||
### Phase 3: Token Policies |
||||
|
||||
**Config Model:** |
||||
|
||||
```go |
||||
// TokenPolicy defines policies for token issuance. |
||||
type TokenPolicy struct { |
||||
// Claims adds or overrides claims in the issued ID token. |
||||
Claims []ClaimExpression `json:"claims,omitempty"` |
||||
// Filter validates the token request. If expression evaluates to false, |
||||
// the request is denied. |
||||
Filter *PolicyExpression `json:"filter,omitempty"` |
||||
} |
||||
|
||||
type ClaimExpression struct { |
||||
// Key is a CEL expression evaluating to string — the claim name. |
||||
Key string `json:"key"` |
||||
// Value is a CEL expression evaluating to dyn — the claim value. |
||||
Value string `json:"value"` |
||||
// Condition is an optional CEL expression evaluating to bool. |
||||
// When set, the claim is only included in the token if the condition |
||||
// evaluates to true. If omitted, the claim is always included. |
||||
Condition string `json:"condition,omitempty"` |
||||
} |
||||
``` |
||||
|
||||
**Evaluation point:** In `server/oauth2.go` during ID token construction, after standard |
||||
claims are built but before JWT signing. |
||||
|
||||
**Available CEL variables:** `identity`, `request`, `existing_claims` (the standard claims already |
||||
computed as `map(string, dyn)`). |
||||
|
||||
**Claim merge order:** |
||||
1. Standard Dex claims (sub, iss, aud, email, groups, etc.) |
||||
2. Global `tokenPolicy.claims` evaluated and merged |
||||
3. Per-client `tokenPolicy.claims` evaluated and merged (overrides global) |
||||
|
||||
**Reserved (forbidden) claim names:** |
||||
|
||||
Certain claim names are reserved and MUST NOT be set or overridden by CEL token policy |
||||
expressions. Attempting to use a reserved claim key will result in a config validation error at |
||||
startup. This prevents operators from accidentally breaking the OIDC/OAuth2 contract or |
||||
undermining Dex's security guarantees. |
||||
|
||||
```go |
||||
// ReservedClaimNames is the set of claim names that CEL token policy |
||||
// expressions are forbidden from setting. These are core OIDC/OAuth2 claims |
||||
// managed exclusively by Dex. |
||||
var ReservedClaimNames = map[string]struct{}{ |
||||
"iss": {}, // Issuer — always set by Dex to its own issuer URL |
||||
"sub": {}, // Subject — derived from connector identity, must not be spoofed |
||||
"aud": {}, // Audience — determined by the OAuth2 client, not policy |
||||
"exp": {}, // Expiration — controlled by Dex token TTL configuration |
||||
"iat": {}, // Issued At — set by Dex at signing time |
||||
"nbf": {}, // Not Before — set by Dex at signing time |
||||
"jti": {}, // JWT ID — generated by Dex for token revocation/uniqueness |
||||
"auth_time": {}, // Authentication Time — set by Dex from the auth session |
||||
"nonce": {}, // Nonce — echoed from the client's authorization request |
||||
"at_hash": {}, // Access Token Hash — computed by Dex from the access token |
||||
"c_hash": {}, // Code Hash — computed by Dex from the authorization code |
||||
} |
||||
``` |
||||
|
||||
The reserved list is enforced in two places: |
||||
1. **Config load time** — When compiling token policy `ClaimExpression` entries, Dex statically |
||||
evaluates the `Key` expression (which must be a string literal or constant-foldable) and rejects |
||||
it if the result is in `ReservedClaimNames`. |
||||
2. **Runtime (defense in depth)** — Before merging evaluated claims into the ID token, Dex checks |
||||
each key against `ReservedClaimNames` and logs a warning + skips the claim if it matches. This |
||||
guards against dynamic key expressions that couldn't be statically checked. |
||||
|
||||
### Phase 4: OIDC Connector Claim Mapping |
||||
|
||||
**Config Model:** |
||||
|
||||
In `connector/oidc/oidc.go`: |
||||
|
||||
```go |
||||
type Config struct { |
||||
// ... existing fields ... |
||||
|
||||
// ClaimMappingExpressions provides CEL-based claim mapping. |
||||
// When set, these take precedence over ClaimMapping and ClaimMutations. |
||||
ClaimMappingExpressions *ClaimMappingExpression `json:"claimMappingExpressions,omitempty"` |
||||
} |
||||
|
||||
type ClaimMappingExpression struct { |
||||
// Username is a CEL expression evaluating to string. |
||||
// Available variable: 'claims' (map of upstream claims). |
||||
Username string `json:"username,omitempty"` |
||||
// Email is a CEL expression evaluating to string. |
||||
Email string `json:"email,omitempty"` |
||||
// Groups is a CEL expression evaluating to list(string). |
||||
Groups string `json:"groups,omitempty"` |
||||
// EmailVerified is a CEL expression evaluating to bool. |
||||
EmailVerified string `json:"emailVerified,omitempty"` |
||||
// Extra is a map of claim names to CEL expressions evaluating to dyn. |
||||
// These are carried through to token policies. |
||||
Extra map[string]string `json:"extra,omitempty"` |
||||
} |
||||
``` |
||||
|
||||
**Available CEL variable:** `claims` — `map(string, dyn)` containing all raw upstream claims from |
||||
the ID token and/or UserInfo endpoint. |
||||
|
||||
This replaces the need for `ClaimMapping`, `NewGroupFromClaims`, `FilterGroupClaims`, and |
||||
`ModifyGroupNames` with a single, more powerful mechanism. |
||||
|
||||
**Backward compatibility:** When `claimMappingExpressions` is nil, the existing `ClaimMapping` and |
||||
`ClaimMutations` logic is used unchanged. When `claimMappingExpressions` is set, a startup warning is |
||||
logged if legacy mapping fields are also configured. |
||||
|
||||
### Policy Application Flow |
||||
|
||||
The following diagram shows the order in which CEL policies are applied. |
||||
Each step is optional — if not configured, it is skipped. |
||||
|
||||
``` |
||||
Connector Authentication |
||||
│ |
||||
│ upstream claims → connector.Identity |
||||
│ |
||||
v |
||||
Authentication Policies |
||||
│ |
||||
│ Global authPolicy |
||||
│ Per-client authPolicy |
||||
│ |
||||
v |
||||
Token Issuance |
||||
│ |
||||
│ Global tokenPolicy.filter |
||||
│ Per-client tokenPolicy.filter |
||||
│ |
||||
│ Global tokenPolicy.claims |
||||
│ Per-client tokenPolicy.claims |
||||
│ |
||||
│ Sign JWT |
||||
│ |
||||
v |
||||
Token Response |
||||
``` |
||||
|
||||
| Step | Policy | Scope | Action on match | |
||||
|------|--------|-------|-----------------| |
||||
| 2 | `authPolicy` (global) | Global | Expression → `true` = DENY login | |
||||
| 3 | `authPolicy` (per-client) | Per-client | Expression → `true` = DENY login | |
||||
| 4 | `tokenPolicy.filter` (global) | Global | Expression → `false` = DENY token | |
||||
| 5 | `tokenPolicy.filter` (per-client) | Per-client | Expression → `false` = DENY token | |
||||
| 6 | `tokenPolicy.claims` (global) | Global | Adds/overrides claims (with optional condition) | |
||||
| 7 | `tokenPolicy.claims` (per-client) | Per-client | Adds/overrides claims (overrides global) | |
||||
|
||||
### Risks and Mitigations |
||||
|
||||
| Risk | Mitigation | |
||||
|------|------------| |
||||
| **CEL expression complexity / DoS** | Cost budgets with configurable limits (default aligned with Kubernetes). Expressions are validated at config load time. Runtime evaluation is aborted if cost exceeds budget. | |
||||
| **Learning curve for operators** | CEL has excellent documentation, playground ([cel.dev](https://cel.dev)), and massive CNCF adoption. Dex docs will include a dedicated CEL guide with examples. Most operators already know CEL from Kubernetes. | |
||||
| **`cel-go` dependency size** | `cel-go` adds ~5MB to binary. This is acceptable for the functionality provided. Kubernetes, Istio, Envoy all accept this trade-off. | |
||||
| **Breaking changes in `cel-go`** | Pin to semver minor range. Environment versioning ensures existing expressions continue to work across upgrades. | |
||||
| **Security: CEL expression injection** | CEL expressions are defined by operators in the server config, not by end users. No CEL expression is ever constructed from user input at runtime. | |
||||
| **Config migration** | Old config fields (`ClaimMapping`, `ClaimMutations`) continue to work. CEL expressions are opt-in. If both are specified, CEL takes precedence with a config-time warning. | |
||||
| **Error messages exposing internals** | CEL deny `message` expressions are controlled by the operator. Default messages are generic. Evaluation errors are logged server-side, not exposed to end users. | |
||||
| **Performance** | Expressions are compiled once at startup. Evaluation is sub-millisecond for typical identity operations. Cost budgets prevent pathological cases. Benchmarks will be included in `pkg/cel` tests. | |
||||
|
||||
### Alternatives |
||||
|
||||
#### OPA/Rego |
||||
|
||||
OPA was previously considered ([#1635], token exchange DEP). While powerful, it has significant |
||||
drawbacks for Dex: |
||||
|
||||
- **Separate daemon** — OPA typically runs as a sidecar or daemon; adds operational complexity. |
||||
Even the embedded Go library (`github.com/open-policy-agent/opa/rego`) is significantly |
||||
heavier than `cel-go`. |
||||
- **Rego learning curve** — Rego is a Datalog-derived language unfamiliar to most developers. |
||||
CEL syntax is closer to C/Java/Go and is immediately readable. |
||||
- **Overkill** — Dex needs simple expression evaluation, not a full policy engine with data |
||||
loading, bundles, and partial evaluation. |
||||
- **No inline expressions** — Rego policies are typically separate files, not inline config |
||||
expressions. This makes the config harder to understand and deploy. |
||||
- **Smaller CNCF footprint for embedding** — While OPA is a graduated CNCF project, CEL has |
||||
broader adoption as an _embedded_ language (Kubernetes, Istio, Envoy, Kyverno, etc.). |
||||
|
||||
#### JMESPath |
||||
|
||||
JMESPath was proposed for claim mapping. Drawbacks: |
||||
|
||||
- **Query-only** — JMESPath is a JSON query language. It cannot express boolean conditions, |
||||
mutations, or string operations naturally. |
||||
- **Limited type system** — No type checking at compile time. Errors are only caught at runtime. |
||||
- **Small ecosystem** — Limited adoption compared to CEL. No CNCF projects use JMESPath for |
||||
policy evaluation. |
||||
- **No cost estimation** — No way to bound execution time. |
||||
|
||||
#### Hardcoded Go Logic |
||||
|
||||
The current approach: each feature requires new Go structs, config fields, and code. This is |
||||
unsustainable: |
||||
- `ClaimMapping`, `NewGroupFromClaims`, `FilterGroupClaims`, `ModifyGroupNames` are each separate |
||||
features that could be one CEL expression. |
||||
- Every new policy need requires a Dex code change and release. |
||||
- Combinatorial explosion of config options. |
||||
|
||||
#### No Change |
||||
|
||||
Without CEL or an equivalent: |
||||
- Operators continue to request per-client connector restrictions, custom claims, claim |
||||
transformations, and access policies — issues remain open indefinitely. |
||||
- Dex accumulates more ad-hoc config fields, increasing maintenance burden. |
||||
- Complex use cases require external reverse proxies, forking Dex, or middleware. |
||||
|
||||
## Future Improvements |
||||
|
||||
- **CEL in other connectors** — Extend CEL claim mapping beyond OIDC to LDAP (attribute mapping), |
||||
SAML (assertion mapping), and other connectors with complex attribute mapping needs. |
||||
- **Policy testing framework** — Unit test framework for operators to validate their CEL |
||||
expressions against fixture data before deployment. |
||||
- **Connector selection via CEL** — Replace the static connector-per-client mapping with a CEL |
||||
expression that dynamically determines which connectors to show based on request attributes. |
||||
|
||||
|
||||
@ -0,0 +1,232 @@
|
||||
package cel |
||||
|
||||
import ( |
||||
"context" |
||||
"fmt" |
||||
"reflect" |
||||
|
||||
"github.com/google/cel-go/cel" |
||||
"github.com/google/cel-go/checker" |
||||
"github.com/google/cel-go/common/types/ref" |
||||
"github.com/google/cel-go/ext" |
||||
|
||||
"github.com/dexidp/dex/pkg/cel/library" |
||||
) |
||||
|
||||
// EnvironmentVersion represents the version of the CEL environment.
|
||||
// New variables, functions, or libraries are introduced in new versions.
|
||||
type EnvironmentVersion uint32 |
||||
|
||||
const ( |
||||
// EnvironmentV1 is the initial CEL environment.
|
||||
EnvironmentV1 EnvironmentVersion = 1 |
||||
) |
||||
|
||||
// CompilationResult holds a compiled CEL program ready for evaluation.
|
||||
type CompilationResult struct { |
||||
Program cel.Program |
||||
OutputType *cel.Type |
||||
Expression string |
||||
|
||||
ast *cel.Ast |
||||
} |
||||
|
||||
// CompilerOption configures a Compiler.
|
||||
type CompilerOption func(*compilerConfig) |
||||
|
||||
type compilerConfig struct { |
||||
costBudget uint64 |
||||
version EnvironmentVersion |
||||
} |
||||
|
||||
func defaultCompilerConfig() *compilerConfig { |
||||
return &compilerConfig{ |
||||
costBudget: DefaultCostBudget, |
||||
version: EnvironmentV1, |
||||
} |
||||
} |
||||
|
||||
// WithCostBudget sets a custom cost budget for expression evaluation.
|
||||
func WithCostBudget(budget uint64) CompilerOption { |
||||
return func(cfg *compilerConfig) { |
||||
cfg.costBudget = budget |
||||
} |
||||
} |
||||
|
||||
// WithVersion sets the target environment version for the compiler.
|
||||
// Defaults to the latest version. Specifying an older version ensures
|
||||
// that only functions/types available at that version are used.
|
||||
func WithVersion(v EnvironmentVersion) CompilerOption { |
||||
return func(cfg *compilerConfig) { |
||||
cfg.version = v |
||||
} |
||||
} |
||||
|
||||
// Compiler compiles CEL expressions against a specific environment.
|
||||
type Compiler struct { |
||||
env *cel.Env |
||||
cfg *compilerConfig |
||||
} |
||||
|
||||
// NewCompiler creates a new CEL compiler with the specified variable
|
||||
// declarations and options.
|
||||
//
|
||||
// All custom Dex libraries are automatically included.
|
||||
// The environment is configured with cost limits and safe defaults.
|
||||
func NewCompiler(variables []VariableDeclaration, opts ...CompilerOption) (*Compiler, error) { |
||||
cfg := defaultCompilerConfig() |
||||
for _, opt := range opts { |
||||
opt(cfg) |
||||
} |
||||
|
||||
envOpts := make([]cel.EnvOption, 0, 8+len(variables)) |
||||
envOpts = append(envOpts, |
||||
cel.DefaultUTCTimeZone(true), |
||||
|
||||
// Standard extension libraries (same set as Kubernetes)
|
||||
ext.Strings(), |
||||
ext.Encoders(), |
||||
ext.Lists(), |
||||
ext.Sets(), |
||||
ext.Math(), |
||||
|
||||
// Native Go types for typed variable access.
|
||||
// This gives compile-time field checking: identity.emial → error at config load.
|
||||
ext.NativeTypes( |
||||
ext.ParseStructTags(true), |
||||
reflect.TypeOf(IdentityVal{}), |
||||
reflect.TypeOf(RequestVal{}), |
||||
), |
||||
|
||||
// Custom Dex libraries
|
||||
cel.Lib(&library.Email{}), |
||||
cel.Lib(&library.Groups{}), |
||||
|
||||
// Presence tests like has(field) and 'key' in map are O(1) hash
|
||||
// lookups on map(string, dyn) variables, so they should not count
|
||||
// toward the cost budget. Without this, expressions with multiple
|
||||
// 'in' checks (e.g. "'admin' in identity.groups") would accumulate
|
||||
// inflated cost estimates. This matches Kubernetes CEL behavior
|
||||
// where presence tests are free for CRD validation rules.
|
||||
cel.CostEstimatorOptions( |
||||
checker.PresenceTestHasCost(false), |
||||
), |
||||
) |
||||
|
||||
for _, v := range variables { |
||||
envOpts = append(envOpts, cel.Variable(v.Name, v.Type)) |
||||
} |
||||
|
||||
env, err := cel.NewEnv(envOpts...) |
||||
if err != nil { |
||||
return nil, fmt.Errorf("failed to create CEL environment: %w", err) |
||||
} |
||||
|
||||
return &Compiler{env: env, cfg: cfg}, nil |
||||
} |
||||
|
||||
// CompileBool compiles a CEL expression that must evaluate to bool.
|
||||
func (c *Compiler) CompileBool(expression string) (*CompilationResult, error) { |
||||
return c.compile(expression, cel.BoolType) |
||||
} |
||||
|
||||
// CompileString compiles a CEL expression that must evaluate to string.
|
||||
func (c *Compiler) CompileString(expression string) (*CompilationResult, error) { |
||||
return c.compile(expression, cel.StringType) |
||||
} |
||||
|
||||
// CompileStringList compiles a CEL expression that must evaluate to list(string).
|
||||
func (c *Compiler) CompileStringList(expression string) (*CompilationResult, error) { |
||||
return c.compile(expression, cel.ListType(cel.StringType)) |
||||
} |
||||
|
||||
// Compile compiles a CEL expression with any output type.
|
||||
func (c *Compiler) Compile(expression string) (*CompilationResult, error) { |
||||
return c.compile(expression, nil) |
||||
} |
||||
|
||||
func (c *Compiler) compile(expression string, expectedType *cel.Type) (*CompilationResult, error) { |
||||
if len(expression) > MaxExpressionLength { |
||||
return nil, fmt.Errorf("expression exceeds maximum length of %d characters", MaxExpressionLength) |
||||
} |
||||
|
||||
ast, issues := c.env.Compile(expression) |
||||
if issues != nil && issues.Err() != nil { |
||||
return nil, fmt.Errorf("CEL compilation failed: %w", issues.Err()) |
||||
} |
||||
|
||||
if expectedType != nil && !ast.OutputType().IsEquivalentType(expectedType) { |
||||
return nil, fmt.Errorf( |
||||
"expected expression output type %s, got %s", |
||||
expectedType, ast.OutputType(), |
||||
) |
||||
} |
||||
|
||||
// Estimate cost at compile time and reject expressions that are too expensive.
|
||||
costEst, err := c.env.EstimateCost(ast, &defaultCostEstimator{}) |
||||
if err != nil { |
||||
return nil, fmt.Errorf("CEL cost estimation failed: %w", err) |
||||
} |
||||
|
||||
if costEst.Max > c.cfg.costBudget { |
||||
return nil, fmt.Errorf( |
||||
"CEL expression estimated cost %d exceeds budget %d", |
||||
costEst.Max, c.cfg.costBudget, |
||||
) |
||||
} |
||||
|
||||
prog, err := c.env.Program(ast, |
||||
cel.EvalOptions(cel.OptOptimize), |
||||
cel.CostLimit(c.cfg.costBudget), |
||||
) |
||||
if err != nil { |
||||
return nil, fmt.Errorf("CEL program creation failed: %w", err) |
||||
} |
||||
|
||||
return &CompilationResult{ |
||||
Program: prog, |
||||
OutputType: ast.OutputType(), |
||||
Expression: expression, |
||||
ast: ast, |
||||
}, nil |
||||
} |
||||
|
||||
// Eval evaluates a compiled program against the given variables.
|
||||
func Eval(ctx context.Context, result *CompilationResult, variables map[string]any) (ref.Val, error) { |
||||
out, _, err := result.Program.ContextEval(ctx, variables) |
||||
if err != nil { |
||||
return nil, fmt.Errorf("CEL evaluation failed: %w", err) |
||||
} |
||||
|
||||
return out, nil |
||||
} |
||||
|
||||
// EvalBool is a convenience function that evaluates and asserts bool output.
|
||||
func EvalBool(ctx context.Context, result *CompilationResult, variables map[string]any) (bool, error) { |
||||
out, err := Eval(ctx, result, variables) |
||||
if err != nil { |
||||
return false, err |
||||
} |
||||
|
||||
v, ok := out.Value().(bool) |
||||
if !ok { |
||||
return false, fmt.Errorf("expected bool result, got %T", out.Value()) |
||||
} |
||||
|
||||
return v, nil |
||||
} |
||||
|
||||
// EvalString is a convenience function that evaluates and asserts string output.
|
||||
func EvalString(ctx context.Context, result *CompilationResult, variables map[string]any) (string, error) { |
||||
out, err := Eval(ctx, result, variables) |
||||
if err != nil { |
||||
return "", err |
||||
} |
||||
|
||||
v, ok := out.Value().(string) |
||||
if !ok { |
||||
return "", fmt.Errorf("expected string result, got %T", out.Value()) |
||||
} |
||||
|
||||
return v, nil |
||||
} |
||||
@ -0,0 +1,280 @@
|
||||
package cel_test |
||||
|
||||
import ( |
||||
"context" |
||||
"strings" |
||||
"testing" |
||||
|
||||
"github.com/stretchr/testify/assert" |
||||
"github.com/stretchr/testify/require" |
||||
|
||||
"github.com/dexidp/dex/connector" |
||||
dexcel "github.com/dexidp/dex/pkg/cel" |
||||
) |
||||
|
||||
func TestCompileBool(t *testing.T) { |
||||
compiler, err := dexcel.NewCompiler(nil) |
||||
require.NoError(t, err) |
||||
|
||||
tests := map[string]struct { |
||||
expr string |
||||
wantErr bool |
||||
}{ |
||||
"true literal": { |
||||
expr: "true", |
||||
}, |
||||
"comparison": { |
||||
expr: "1 == 1", |
||||
}, |
||||
"string type mismatch": { |
||||
expr: "'hello'", |
||||
wantErr: true, |
||||
}, |
||||
"int type mismatch": { |
||||
expr: "42", |
||||
wantErr: true, |
||||
}, |
||||
} |
||||
|
||||
for name, tc := range tests { |
||||
t.Run(name, func(t *testing.T) { |
||||
result, err := compiler.CompileBool(tc.expr) |
||||
if tc.wantErr { |
||||
assert.Error(t, err) |
||||
assert.Nil(t, result) |
||||
} else { |
||||
assert.NoError(t, err) |
||||
assert.NotNil(t, result) |
||||
} |
||||
}) |
||||
} |
||||
} |
||||
|
||||
func TestCompileString(t *testing.T) { |
||||
compiler, err := dexcel.NewCompiler(nil) |
||||
require.NoError(t, err) |
||||
|
||||
tests := map[string]struct { |
||||
expr string |
||||
wantErr bool |
||||
}{ |
||||
"string literal": { |
||||
expr: "'hello'", |
||||
}, |
||||
"string concatenation": { |
||||
expr: "'hello' + ' ' + 'world'", |
||||
}, |
||||
"bool type mismatch": { |
||||
expr: "true", |
||||
wantErr: true, |
||||
}, |
||||
} |
||||
|
||||
for name, tc := range tests { |
||||
t.Run(name, func(t *testing.T) { |
||||
result, err := compiler.CompileString(tc.expr) |
||||
if tc.wantErr { |
||||
assert.Error(t, err) |
||||
} else { |
||||
assert.NoError(t, err) |
||||
assert.NotNil(t, result) |
||||
} |
||||
}) |
||||
} |
||||
} |
||||
|
||||
func TestCompileStringList(t *testing.T) { |
||||
compiler, err := dexcel.NewCompiler(nil) |
||||
require.NoError(t, err) |
||||
|
||||
result, err := compiler.CompileStringList("['a', 'b', 'c']") |
||||
assert.NoError(t, err) |
||||
assert.NotNil(t, result) |
||||
|
||||
_, err = compiler.CompileStringList("'not a list'") |
||||
assert.Error(t, err) |
||||
} |
||||
|
||||
func TestCompile(t *testing.T) { |
||||
compiler, err := dexcel.NewCompiler(nil) |
||||
require.NoError(t, err) |
||||
|
||||
// Compile accepts any type
|
||||
result, err := compiler.Compile("true") |
||||
assert.NoError(t, err) |
||||
assert.NotNil(t, result) |
||||
|
||||
result, err = compiler.Compile("'hello'") |
||||
assert.NoError(t, err) |
||||
assert.NotNil(t, result) |
||||
|
||||
result, err = compiler.Compile("42") |
||||
assert.NoError(t, err) |
||||
assert.NotNil(t, result) |
||||
} |
||||
|
||||
func TestCompileErrors(t *testing.T) { |
||||
compiler, err := dexcel.NewCompiler(nil) |
||||
require.NoError(t, err) |
||||
|
||||
tests := map[string]struct { |
||||
expr string |
||||
}{ |
||||
"syntax error": { |
||||
expr: "1 +", |
||||
}, |
||||
"undefined variable": { |
||||
expr: "undefined_var", |
||||
}, |
||||
"undefined function": { |
||||
expr: "undefinedFunc()", |
||||
}, |
||||
} |
||||
|
||||
for name, tc := range tests { |
||||
t.Run(name, func(t *testing.T) { |
||||
_, err := compiler.Compile(tc.expr) |
||||
assert.Error(t, err) |
||||
}) |
||||
} |
||||
} |
||||
|
||||
func TestCompileRejectsUnknownFields(t *testing.T) { |
||||
vars := dexcel.IdentityVariables() |
||||
compiler, err := dexcel.NewCompiler(vars) |
||||
require.NoError(t, err) |
||||
|
||||
// Typo in field name: should fail at compile time with ObjectType
|
||||
_, err = compiler.CompileBool("identity.emial == 'test@example.com'") |
||||
assert.Error(t, err) |
||||
assert.Contains(t, err.Error(), "compilation failed") |
||||
|
||||
// Type mismatch: comparing string field to int should fail at compile time
|
||||
_, err = compiler.CompileBool("identity.email == 123") |
||||
assert.Error(t, err) |
||||
assert.Contains(t, err.Error(), "compilation failed") |
||||
|
||||
// Valid field: should compile fine
|
||||
_, err = compiler.CompileBool("identity.email == 'test@example.com'") |
||||
assert.NoError(t, err) |
||||
} |
||||
|
||||
func TestMaxExpressionLength(t *testing.T) { |
||||
compiler, err := dexcel.NewCompiler(nil) |
||||
require.NoError(t, err) |
||||
|
||||
longExpr := "'" + strings.Repeat("a", dexcel.MaxExpressionLength) + "'" |
||||
_, err = compiler.Compile(longExpr) |
||||
assert.Error(t, err) |
||||
assert.Contains(t, err.Error(), "maximum length") |
||||
} |
||||
|
||||
func TestEvalBool(t *testing.T) { |
||||
vars := dexcel.IdentityVariables() |
||||
compiler, err := dexcel.NewCompiler(vars) |
||||
require.NoError(t, err) |
||||
|
||||
tests := map[string]struct { |
||||
expr string |
||||
identity dexcel.IdentityVal |
||||
want bool |
||||
}{ |
||||
"email endsWith": { |
||||
expr: "identity.email.endsWith('@example.com')", |
||||
identity: dexcel.IdentityVal{Email: "user@example.com"}, |
||||
want: true, |
||||
}, |
||||
"email endsWith false": { |
||||
expr: "identity.email.endsWith('@example.com')", |
||||
identity: dexcel.IdentityVal{Email: "user@other.com"}, |
||||
want: false, |
||||
}, |
||||
"email_verified": { |
||||
expr: "identity.email_verified == true", |
||||
identity: dexcel.IdentityVal{EmailVerified: true}, |
||||
want: true, |
||||
}, |
||||
"group membership": { |
||||
expr: "identity.groups.exists(g, g == 'admin')", |
||||
identity: dexcel.IdentityVal{Groups: []string{"admin", "dev"}}, |
||||
want: true, |
||||
}, |
||||
} |
||||
|
||||
for name, tc := range tests { |
||||
t.Run(name, func(t *testing.T) { |
||||
prog, err := compiler.CompileBool(tc.expr) |
||||
require.NoError(t, err) |
||||
|
||||
result, err := dexcel.EvalBool(context.Background(), prog, map[string]any{ |
||||
"identity": tc.identity, |
||||
}) |
||||
require.NoError(t, err) |
||||
assert.Equal(t, tc.want, result) |
||||
}) |
||||
} |
||||
} |
||||
|
||||
func TestEvalString(t *testing.T) { |
||||
vars := dexcel.IdentityVariables() |
||||
compiler, err := dexcel.NewCompiler(vars) |
||||
require.NoError(t, err) |
||||
|
||||
// With ObjectType, identity.email is typed as string, so CompileString works.
|
||||
prog, err := compiler.CompileString("identity.email") |
||||
require.NoError(t, err) |
||||
|
||||
result, err := dexcel.EvalString(context.Background(), prog, map[string]any{ |
||||
"identity": dexcel.IdentityVal{Email: "user@example.com"}, |
||||
}) |
||||
require.NoError(t, err) |
||||
assert.Equal(t, "user@example.com", result) |
||||
} |
||||
|
||||
func TestEvalWithIdentityAndRequest(t *testing.T) { |
||||
vars := append(dexcel.IdentityVariables(), dexcel.RequestVariables()...) |
||||
compiler, err := dexcel.NewCompiler(vars) |
||||
require.NoError(t, err) |
||||
|
||||
prog, err := compiler.CompileBool( |
||||
`identity.email.endsWith('@example.com') && 'admin' in identity.groups && request.connector_id == 'okta'`, |
||||
) |
||||
require.NoError(t, err) |
||||
|
||||
identity := dexcel.IdentityFromConnector(connector.Identity{ |
||||
UserID: "123", |
||||
Username: "john", |
||||
Email: "john@example.com", |
||||
Groups: []string{"admin", "dev"}, |
||||
}) |
||||
request := dexcel.RequestFromContext(dexcel.RequestContext{ |
||||
ClientID: "my-app", |
||||
ConnectorID: "okta", |
||||
Scopes: []string{"openid", "email"}, |
||||
}) |
||||
|
||||
result, err := dexcel.EvalBool(context.Background(), prog, map[string]any{ |
||||
"identity": identity, |
||||
"request": request, |
||||
}) |
||||
require.NoError(t, err) |
||||
assert.True(t, result) |
||||
} |
||||
|
||||
func TestNewCompilerWithVariables(t *testing.T) { |
||||
// Claims variable — remains map(string, dyn)
|
||||
compiler, err := dexcel.NewCompiler(dexcel.ClaimsVariable()) |
||||
require.NoError(t, err) |
||||
|
||||
// claims.email returns dyn from map access, use Compile (not CompileString)
|
||||
prog, err := compiler.Compile("claims.email") |
||||
require.NoError(t, err) |
||||
|
||||
result, err := dexcel.EvalString(context.Background(), prog, map[string]any{ |
||||
"claims": map[string]any{ |
||||
"email": "test@example.com", |
||||
}, |
||||
}) |
||||
require.NoError(t, err) |
||||
assert.Equal(t, "test@example.com", result) |
||||
} |
||||
@ -0,0 +1,105 @@
|
||||
package cel |
||||
|
||||
import ( |
||||
"fmt" |
||||
|
||||
"github.com/google/cel-go/checker" |
||||
) |
||||
|
||||
// DefaultCostBudget is the default cost budget for a single expression
|
||||
// evaluation. Aligned with Kubernetes defaults: enough for typical identity
|
||||
// operations but prevents runaway expressions.
|
||||
const DefaultCostBudget uint64 = 10_000_000 |
||||
|
||||
// MaxExpressionLength is the maximum length of a CEL expression string.
|
||||
const MaxExpressionLength = 10_240 |
||||
|
||||
// DefaultStringMaxLength is the estimated max length of string values
|
||||
// (emails, usernames, group names, etc.) used for compile-time cost estimation.
|
||||
const DefaultStringMaxLength = 256 |
||||
|
||||
// DefaultListMaxLength is the estimated max length of list values
|
||||
// (groups, scopes) used for compile-time cost estimation.
|
||||
const DefaultListMaxLength = 100 |
||||
|
||||
// CostEstimate holds the estimated cost range for a compiled expression.
|
||||
type CostEstimate struct { |
||||
Min uint64 |
||||
Max uint64 |
||||
} |
||||
|
||||
// EstimateCost returns the estimated cost range for a compiled expression.
|
||||
// This is computed statically at compile time without evaluating the expression.
|
||||
func (c *Compiler) EstimateCost(result *CompilationResult) (CostEstimate, error) { |
||||
costEst, err := c.env.EstimateCost(result.ast, &defaultCostEstimator{}) |
||||
if err != nil { |
||||
return CostEstimate{}, fmt.Errorf("CEL cost estimation failed: %w", err) |
||||
} |
||||
|
||||
return CostEstimate{Min: costEst.Min, Max: costEst.Max}, nil |
||||
} |
||||
|
||||
// defaultCostEstimator provides size hints for compile-time cost estimation.
|
||||
// Without these hints, the CEL cost estimator assumes unbounded sizes for
|
||||
// variables, leading to wildly overestimated max costs.
|
||||
type defaultCostEstimator struct{} |
||||
|
||||
func (defaultCostEstimator) EstimateSize(element checker.AstNode) *checker.SizeEstimate { |
||||
// Provide size hints for map(string, dyn) variables: identity, request, claims.
|
||||
// Without these, the estimator assumes lists/strings can be infinitely large.
|
||||
if element.Path() == nil { |
||||
return nil |
||||
} |
||||
|
||||
path := element.Path() |
||||
if len(path) == 0 { |
||||
return nil |
||||
} |
||||
|
||||
root := path[0] |
||||
|
||||
switch root { |
||||
case "identity", "request", "claims": |
||||
// Nested field access (e.g. identity.email, identity.groups)
|
||||
if len(path) >= 2 { |
||||
field := path[1] |
||||
switch field { |
||||
case "groups", "scopes": |
||||
// list(string) fields
|
||||
return &checker.SizeEstimate{Min: 0, Max: DefaultListMaxLength} |
||||
case "email_verified": |
||||
// bool field — size is always 1
|
||||
return &checker.SizeEstimate{Min: 1, Max: 1} |
||||
default: |
||||
// string fields (email, username, user_id, client_id, etc.)
|
||||
return &checker.SizeEstimate{Min: 0, Max: DefaultStringMaxLength} |
||||
} |
||||
} |
||||
// The map itself: number of keys
|
||||
return &checker.SizeEstimate{Min: 0, Max: 20} |
||||
} |
||||
|
||||
return nil |
||||
} |
||||
|
||||
func (defaultCostEstimator) EstimateCallCost(function, overloadID string, target *checker.AstNode, args []checker.AstNode) *checker.CallEstimate { |
||||
switch function { |
||||
case "dex.emailDomain", "dex.emailLocalPart": |
||||
// Simple string split — O(n) where n is string length, bounded.
|
||||
return &checker.CallEstimate{ |
||||
CostEstimate: checker.CostEstimate{Min: 1, Max: 2}, |
||||
} |
||||
case "dex.groupMatches": |
||||
// Iterates over groups list and matches each against a pattern.
|
||||
return &checker.CallEstimate{ |
||||
CostEstimate: checker.CostEstimate{Min: 1, Max: DefaultListMaxLength}, |
||||
} |
||||
case "dex.groupFilter": |
||||
// Builds a set from allowed list, then iterates groups.
|
||||
return &checker.CallEstimate{ |
||||
CostEstimate: checker.CostEstimate{Min: 1, Max: 2 * DefaultListMaxLength}, |
||||
} |
||||
} |
||||
|
||||
return nil |
||||
} |
||||
@ -0,0 +1,137 @@
|
||||
package cel_test |
||||
|
||||
import ( |
||||
"testing" |
||||
|
||||
"github.com/stretchr/testify/assert" |
||||
"github.com/stretchr/testify/require" |
||||
|
||||
dexcel "github.com/dexidp/dex/pkg/cel" |
||||
) |
||||
|
||||
func TestEstimateCost(t *testing.T) { |
||||
vars := dexcel.IdentityVariables() |
||||
compiler, err := dexcel.NewCompiler(vars) |
||||
require.NoError(t, err) |
||||
|
||||
tests := map[string]struct { |
||||
expr string |
||||
}{ |
||||
"simple bool": { |
||||
expr: "true", |
||||
}, |
||||
"string comparison": { |
||||
expr: "identity.email == 'test@example.com'", |
||||
}, |
||||
"group membership": { |
||||
expr: "identity.groups.exists(g, g == 'admin')", |
||||
}, |
||||
} |
||||
|
||||
for name, tc := range tests { |
||||
t.Run(name, func(t *testing.T) { |
||||
prog, err := compiler.Compile(tc.expr) |
||||
require.NoError(t, err) |
||||
|
||||
est, err := compiler.EstimateCost(prog) |
||||
require.NoError(t, err) |
||||
assert.True(t, est.Max >= est.Min, "max cost should be >= min cost") |
||||
assert.True(t, est.Max <= dexcel.DefaultCostBudget, |
||||
"estimated max cost %d should be within default budget %d", est.Max, dexcel.DefaultCostBudget) |
||||
}) |
||||
} |
||||
} |
||||
|
||||
func TestCompileTimeCostAcceptsSimpleExpressions(t *testing.T) { |
||||
vars := append(dexcel.IdentityVariables(), dexcel.RequestVariables()...) |
||||
compiler, err := dexcel.NewCompiler(vars) |
||||
require.NoError(t, err) |
||||
|
||||
tests := map[string]string{ |
||||
"literal": "true", |
||||
"email endsWith": "identity.email.endsWith('@example.com')", |
||||
"group check": "'admin' in identity.groups", |
||||
"emailDomain": `dex.emailDomain(identity.email)`, |
||||
"groupMatches": `dex.groupMatches(identity.groups, "team:*")`, |
||||
"groupFilter": `dex.groupFilter(identity.groups, ["admin", "dev"])`, |
||||
"combined policy": `identity.email.endsWith('@example.com') && 'admin' in identity.groups`, |
||||
"complex policy": `identity.email.endsWith('@example.com') && |
||||
identity.groups.exists(g, g == 'admin') && |
||||
request.connector_id == 'okta' && |
||||
request.scopes.exists(s, s == 'openid')`, |
||||
"filter+map chain": `identity.groups |
||||
.filter(g, g.startsWith('team:')) |
||||
.map(g, g.replace('team:', '')) |
||||
.size() > 0`, |
||||
} |
||||
|
||||
for name, expr := range tests { |
||||
t.Run(name, func(t *testing.T) { |
||||
_, err := compiler.Compile(expr) |
||||
assert.NoError(t, err, "expression should compile within default budget") |
||||
}) |
||||
} |
||||
} |
||||
|
||||
func TestCompileTimeCostRejection(t *testing.T) { |
||||
vars := append(dexcel.IdentityVariables(), dexcel.RequestVariables()...) |
||||
|
||||
tests := map[string]struct { |
||||
budget uint64 |
||||
expr string |
||||
}{ |
||||
"simple exists exceeds tiny budget": { |
||||
budget: 1, |
||||
expr: "identity.groups.exists(g, g == 'admin')", |
||||
}, |
||||
"endsWith exceeds tiny budget": { |
||||
budget: 2, |
||||
expr: "identity.email.endsWith('@example.com')", |
||||
}, |
||||
"nested comprehension over groups exceeds moderate budget": { |
||||
// Two nested iterations over groups: O(n^2) where n=100 → ~280K
|
||||
budget: 10_000, |
||||
expr: `identity.groups.exists(g1, |
||||
identity.groups.exists(g2, |
||||
g1 != g2 && g1.startsWith(g2) |
||||
) |
||||
)`, |
||||
}, |
||||
"cross-variable comprehension exceeds moderate budget": { |
||||
// filter groups then check each against scopes: O(n*m) → ~162K
|
||||
budget: 10_000, |
||||
expr: `identity.groups |
||||
.filter(g, g.startsWith('team:')) |
||||
.exists(g, request.scopes.exists(s, s == g))`, |
||||
}, |
||||
"chained filter+map+filter+map exceeds small budget": { |
||||
budget: 1000, |
||||
expr: `identity.groups |
||||
.filter(g, g.startsWith('team:')) |
||||
.map(g, g.replace('team:', '')) |
||||
.filter(g, g.size() > 3) |
||||
.map(g, g.upperAscii()) |
||||
.size() > 0`, |
||||
}, |
||||
"many independent exists exceeds small budget": { |
||||
budget: 5000, |
||||
expr: `identity.groups.exists(g, g.contains('a')) && |
||||
identity.groups.exists(g, g.contains('b')) && |
||||
identity.groups.exists(g, g.contains('c')) && |
||||
identity.groups.exists(g, g.contains('d')) && |
||||
identity.groups.exists(g, g.contains('e'))`, |
||||
}, |
||||
} |
||||
|
||||
for name, tc := range tests { |
||||
t.Run(name, func(t *testing.T) { |
||||
compiler, err := dexcel.NewCompiler(vars, dexcel.WithCostBudget(tc.budget)) |
||||
require.NoError(t, err) |
||||
|
||||
_, err = compiler.Compile(tc.expr) |
||||
assert.Error(t, err) |
||||
assert.Contains(t, err.Error(), "estimated cost") |
||||
assert.Contains(t, err.Error(), "exceeds budget") |
||||
}) |
||||
} |
||||
} |
||||
@ -0,0 +1,5 @@
|
||||
// Package cel provides a safe, sandboxed CEL (Common Expression Language)
|
||||
// environment for policy evaluation, claim mapping, and token customization
|
||||
// in Dex. It includes cost budgets, Kubernetes-grade compatibility guarantees,
|
||||
// and a curated set of extension libraries.
|
||||
package cel |
||||
@ -0,0 +1,4 @@
|
||||
// Package library provides custom CEL function libraries for Dex.
|
||||
// Each library implements the cel.Library interface and can be registered
|
||||
// in a CEL environment.
|
||||
package library |
||||
@ -0,0 +1,73 @@
|
||||
package library |
||||
|
||||
import ( |
||||
"strings" |
||||
|
||||
"github.com/google/cel-go/cel" |
||||
"github.com/google/cel-go/common/types" |
||||
"github.com/google/cel-go/common/types/ref" |
||||
) |
||||
|
||||
// Email provides email-related CEL functions.
|
||||
//
|
||||
// Functions (V1):
|
||||
//
|
||||
// dex.emailDomain(email: string) -> string
|
||||
// Returns the domain portion of an email address.
|
||||
// Example: dex.emailDomain("user@example.com") == "example.com"
|
||||
//
|
||||
// dex.emailLocalPart(email: string) -> string
|
||||
// Returns the local part of an email address.
|
||||
// Example: dex.emailLocalPart("user@example.com") == "user"
|
||||
type Email struct{} |
||||
|
||||
func (Email) CompileOptions() []cel.EnvOption { |
||||
return []cel.EnvOption{ |
||||
cel.Function("dex.emailDomain", |
||||
cel.Overload("dex_email_domain_string", |
||||
[]*cel.Type{cel.StringType}, |
||||
cel.StringType, |
||||
cel.UnaryBinding(emailDomainImpl), |
||||
), |
||||
), |
||||
cel.Function("dex.emailLocalPart", |
||||
cel.Overload("dex_email_local_part_string", |
||||
[]*cel.Type{cel.StringType}, |
||||
cel.StringType, |
||||
cel.UnaryBinding(emailLocalPartImpl), |
||||
), |
||||
), |
||||
} |
||||
} |
||||
|
||||
func (Email) ProgramOptions() []cel.ProgramOption { |
||||
return nil |
||||
} |
||||
|
||||
func emailDomainImpl(arg ref.Val) ref.Val { |
||||
email, ok := arg.Value().(string) |
||||
if !ok { |
||||
return types.NewErr("dex.emailDomain: expected string argument") |
||||
} |
||||
|
||||
_, domain, found := strings.Cut(email, "@") |
||||
if !found { |
||||
return types.String("") |
||||
} |
||||
|
||||
return types.String(domain) |
||||
} |
||||
|
||||
func emailLocalPartImpl(arg ref.Val) ref.Val { |
||||
email, ok := arg.Value().(string) |
||||
if !ok { |
||||
return types.NewErr("dex.emailLocalPart: expected string argument") |
||||
} |
||||
|
||||
localPart, _, found := strings.Cut(email, "@") |
||||
if !found { |
||||
return types.String(email) |
||||
} |
||||
|
||||
return types.String(localPart) |
||||
} |
||||
@ -0,0 +1,106 @@
|
||||
package library_test |
||||
|
||||
import ( |
||||
"context" |
||||
"testing" |
||||
|
||||
"github.com/stretchr/testify/assert" |
||||
"github.com/stretchr/testify/require" |
||||
|
||||
dexcel "github.com/dexidp/dex/pkg/cel" |
||||
) |
||||
|
||||
func TestEmailDomain(t *testing.T) { |
||||
compiler, err := dexcel.NewCompiler(nil) |
||||
require.NoError(t, err) |
||||
|
||||
tests := map[string]struct { |
||||
expr string |
||||
want string |
||||
}{ |
||||
"standard email": { |
||||
expr: `dex.emailDomain("user@example.com")`, |
||||
want: "example.com", |
||||
}, |
||||
"subdomain": { |
||||
expr: `dex.emailDomain("admin@sub.domain.org")`, |
||||
want: "sub.domain.org", |
||||
}, |
||||
"no at sign": { |
||||
expr: `dex.emailDomain("nodomain")`, |
||||
want: "", |
||||
}, |
||||
"empty string": { |
||||
expr: `dex.emailDomain("")`, |
||||
want: "", |
||||
}, |
||||
"multiple at signs": { |
||||
expr: `dex.emailDomain("user@name@example.com")`, |
||||
want: "name@example.com", |
||||
}, |
||||
} |
||||
|
||||
for name, tc := range tests { |
||||
t.Run(name, func(t *testing.T) { |
||||
prog, err := compiler.CompileString(tc.expr) |
||||
require.NoError(t, err) |
||||
|
||||
result, err := dexcel.EvalString(context.Background(), prog, map[string]any{}) |
||||
require.NoError(t, err) |
||||
assert.Equal(t, tc.want, result) |
||||
}) |
||||
} |
||||
} |
||||
|
||||
func TestEmailLocalPart(t *testing.T) { |
||||
compiler, err := dexcel.NewCompiler(nil) |
||||
require.NoError(t, err) |
||||
|
||||
tests := map[string]struct { |
||||
expr string |
||||
want string |
||||
}{ |
||||
"standard email": { |
||||
expr: `dex.emailLocalPart("user@example.com")`, |
||||
want: "user", |
||||
}, |
||||
"no at sign": { |
||||
expr: `dex.emailLocalPart("justuser")`, |
||||
want: "justuser", |
||||
}, |
||||
"empty string": { |
||||
expr: `dex.emailLocalPart("")`, |
||||
want: "", |
||||
}, |
||||
"multiple at signs": { |
||||
expr: `dex.emailLocalPart("user@name@example.com")`, |
||||
want: "user", |
||||
}, |
||||
} |
||||
|
||||
for name, tc := range tests { |
||||
t.Run(name, func(t *testing.T) { |
||||
prog, err := compiler.CompileString(tc.expr) |
||||
require.NoError(t, err) |
||||
|
||||
result, err := dexcel.EvalString(context.Background(), prog, map[string]any{}) |
||||
require.NoError(t, err) |
||||
assert.Equal(t, tc.want, result) |
||||
}) |
||||
} |
||||
} |
||||
|
||||
func TestEmailDomainWithIdentityVariable(t *testing.T) { |
||||
vars := dexcel.IdentityVariables() |
||||
compiler, err := dexcel.NewCompiler(vars) |
||||
require.NoError(t, err) |
||||
|
||||
prog, err := compiler.CompileString(`dex.emailDomain(identity.email)`) |
||||
require.NoError(t, err) |
||||
|
||||
result, err := dexcel.EvalString(context.Background(), prog, map[string]any{ |
||||
"identity": dexcel.IdentityVal{Email: "admin@corp.example.com"}, |
||||
}) |
||||
require.NoError(t, err) |
||||
assert.Equal(t, "corp.example.com", result) |
||||
} |
||||
@ -0,0 +1,123 @@
|
||||
package library |
||||
|
||||
import ( |
||||
"path" |
||||
|
||||
"github.com/google/cel-go/cel" |
||||
"github.com/google/cel-go/common/types" |
||||
"github.com/google/cel-go/common/types/ref" |
||||
"github.com/google/cel-go/common/types/traits" |
||||
) |
||||
|
||||
// Groups provides group-related CEL functions.
|
||||
//
|
||||
// Functions (V1):
|
||||
//
|
||||
// dex.groupMatches(groups: list(string), pattern: string) -> list(string)
|
||||
// Returns groups matching a glob pattern.
|
||||
// Example: dex.groupMatches(["team:dev", "team:ops", "admin"], "team:*")
|
||||
//
|
||||
// dex.groupFilter(groups: list(string), allowed: list(string)) -> list(string)
|
||||
// Returns only groups present in the allowed list.
|
||||
// Example: dex.groupFilter(["admin", "dev", "ops"], ["admin", "ops"])
|
||||
type Groups struct{} |
||||
|
||||
func (Groups) CompileOptions() []cel.EnvOption { |
||||
return []cel.EnvOption{ |
||||
cel.Function("dex.groupMatches", |
||||
cel.Overload("dex_group_matches_list_string", |
||||
[]*cel.Type{cel.ListType(cel.StringType), cel.StringType}, |
||||
cel.ListType(cel.StringType), |
||||
cel.BinaryBinding(groupMatchesImpl), |
||||
), |
||||
), |
||||
cel.Function("dex.groupFilter", |
||||
cel.Overload("dex_group_filter_list_list", |
||||
[]*cel.Type{cel.ListType(cel.StringType), cel.ListType(cel.StringType)}, |
||||
cel.ListType(cel.StringType), |
||||
cel.BinaryBinding(groupFilterImpl), |
||||
), |
||||
), |
||||
} |
||||
} |
||||
|
||||
func (Groups) ProgramOptions() []cel.ProgramOption { |
||||
return nil |
||||
} |
||||
|
||||
func groupMatchesImpl(lhs, rhs ref.Val) ref.Val { |
||||
groupList, ok := lhs.(traits.Lister) |
||||
if !ok { |
||||
return types.NewErr("dex.groupMatches: expected list(string) as first argument") |
||||
} |
||||
|
||||
pattern, ok := rhs.Value().(string) |
||||
if !ok { |
||||
return types.NewErr("dex.groupMatches: expected string pattern as second argument") |
||||
} |
||||
|
||||
iter := groupList.Iterator() |
||||
var matched []ref.Val |
||||
|
||||
for iter.HasNext() == types.True { |
||||
item := iter.Next() |
||||
|
||||
group, ok := item.Value().(string) |
||||
if !ok { |
||||
continue |
||||
} |
||||
|
||||
ok, err := path.Match(pattern, group) |
||||
if err != nil { |
||||
return types.NewErr("dex.groupMatches: invalid pattern %q: %v", pattern, err) |
||||
} |
||||
if ok { |
||||
matched = append(matched, types.String(group)) |
||||
} |
||||
} |
||||
|
||||
return types.NewRefValList(types.DefaultTypeAdapter, matched) |
||||
} |
||||
|
||||
func groupFilterImpl(lhs, rhs ref.Val) ref.Val { |
||||
groupList, ok := lhs.(traits.Lister) |
||||
if !ok { |
||||
return types.NewErr("dex.groupFilter: expected list(string) as first argument") |
||||
} |
||||
|
||||
allowedList, ok := rhs.(traits.Lister) |
||||
if !ok { |
||||
return types.NewErr("dex.groupFilter: expected list(string) as second argument") |
||||
} |
||||
|
||||
allowed := make(map[string]struct{}) |
||||
iter := allowedList.Iterator() |
||||
for iter.HasNext() == types.True { |
||||
item := iter.Next() |
||||
|
||||
s, ok := item.Value().(string) |
||||
if !ok { |
||||
continue |
||||
} |
||||
|
||||
allowed[s] = struct{}{} |
||||
} |
||||
|
||||
var filtered []ref.Val |
||||
iter = groupList.Iterator() |
||||
|
||||
for iter.HasNext() == types.True { |
||||
item := iter.Next() |
||||
|
||||
group, ok := item.Value().(string) |
||||
if !ok { |
||||
continue |
||||
} |
||||
|
||||
if _, exists := allowed[group]; exists { |
||||
filtered = append(filtered, types.String(group)) |
||||
} |
||||
} |
||||
|
||||
return types.NewRefValList(types.DefaultTypeAdapter, filtered) |
||||
} |
||||
@ -0,0 +1,141 @@
|
||||
package library_test |
||||
|
||||
import ( |
||||
"context" |
||||
"reflect" |
||||
"testing" |
||||
|
||||
"github.com/stretchr/testify/assert" |
||||
"github.com/stretchr/testify/require" |
||||
|
||||
dexcel "github.com/dexidp/dex/pkg/cel" |
||||
) |
||||
|
||||
func TestGroupMatches(t *testing.T) { |
||||
vars := dexcel.IdentityVariables() |
||||
compiler, err := dexcel.NewCompiler(vars) |
||||
require.NoError(t, err) |
||||
|
||||
tests := map[string]struct { |
||||
expr string |
||||
groups []string |
||||
want []string |
||||
}{ |
||||
"wildcard pattern": { |
||||
expr: `dex.groupMatches(identity.groups, "team:*")`, |
||||
groups: []string{"team:dev", "team:ops", "admin"}, |
||||
want: []string{"team:dev", "team:ops"}, |
||||
}, |
||||
"exact match": { |
||||
expr: `dex.groupMatches(identity.groups, "admin")`, |
||||
groups: []string{"team:dev", "admin", "user"}, |
||||
want: []string{"admin"}, |
||||
}, |
||||
"no matches": { |
||||
expr: `dex.groupMatches(identity.groups, "nonexistent")`, |
||||
groups: []string{"team:dev", "admin"}, |
||||
want: []string{}, |
||||
}, |
||||
"question mark pattern": { |
||||
expr: `dex.groupMatches(identity.groups, "team?")`, |
||||
groups: []string{"teamA", "teamB", "teams-long"}, |
||||
want: []string{"teamA", "teamB"}, |
||||
}, |
||||
"match all": { |
||||
expr: `dex.groupMatches(identity.groups, "*")`, |
||||
groups: []string{"a", "b", "c"}, |
||||
want: []string{"a", "b", "c"}, |
||||
}, |
||||
} |
||||
|
||||
for name, tc := range tests { |
||||
t.Run(name, func(t *testing.T) { |
||||
prog, err := compiler.CompileStringList(tc.expr) |
||||
require.NoError(t, err) |
||||
|
||||
out, err := dexcel.Eval(context.Background(), prog, map[string]any{ |
||||
"identity": dexcel.IdentityVal{Groups: tc.groups}, |
||||
}) |
||||
require.NoError(t, err) |
||||
|
||||
nativeVal, err := out.ConvertToNative(reflect.TypeOf([]string{})) |
||||
require.NoError(t, err) |
||||
|
||||
got, ok := nativeVal.([]string) |
||||
require.True(t, ok, "expected []string, got %T", nativeVal) |
||||
assert.Equal(t, tc.want, got) |
||||
}) |
||||
} |
||||
} |
||||
|
||||
func TestGroupMatchesInvalidPattern(t *testing.T) { |
||||
vars := dexcel.IdentityVariables() |
||||
compiler, err := dexcel.NewCompiler(vars) |
||||
require.NoError(t, err) |
||||
|
||||
prog, err := compiler.CompileStringList(`dex.groupMatches(identity.groups, "[invalid")`) |
||||
require.NoError(t, err) |
||||
|
||||
_, err = dexcel.Eval(context.Background(), prog, map[string]any{ |
||||
"identity": dexcel.IdentityVal{Groups: []string{"admin"}}, |
||||
}) |
||||
require.Error(t, err) |
||||
assert.Contains(t, err.Error(), "invalid pattern") |
||||
} |
||||
|
||||
func TestGroupFilter(t *testing.T) { |
||||
vars := dexcel.IdentityVariables() |
||||
compiler, err := dexcel.NewCompiler(vars) |
||||
require.NoError(t, err) |
||||
|
||||
tests := map[string]struct { |
||||
expr string |
||||
groups []string |
||||
want []string |
||||
}{ |
||||
"filter to allowed": { |
||||
expr: `dex.groupFilter(identity.groups, ["admin", "ops"])`, |
||||
groups: []string{"admin", "dev", "ops"}, |
||||
want: []string{"admin", "ops"}, |
||||
}, |
||||
"no overlap": { |
||||
expr: `dex.groupFilter(identity.groups, ["marketing"])`, |
||||
groups: []string{"admin", "dev"}, |
||||
want: []string{}, |
||||
}, |
||||
"all allowed": { |
||||
expr: `dex.groupFilter(identity.groups, ["a", "b", "c"])`, |
||||
groups: []string{"a", "b", "c"}, |
||||
want: []string{"a", "b", "c"}, |
||||
}, |
||||
"empty allowed list": { |
||||
expr: `dex.groupFilter(identity.groups, [])`, |
||||
groups: []string{"admin", "dev"}, |
||||
want: []string{}, |
||||
}, |
||||
"preserves order": { |
||||
expr: `dex.groupFilter(identity.groups, ["z", "a"])`, |
||||
groups: []string{"a", "b", "z"}, |
||||
want: []string{"a", "z"}, |
||||
}, |
||||
} |
||||
|
||||
for name, tc := range tests { |
||||
t.Run(name, func(t *testing.T) { |
||||
prog, err := compiler.CompileStringList(tc.expr) |
||||
require.NoError(t, err) |
||||
|
||||
out, err := dexcel.Eval(context.Background(), prog, map[string]any{ |
||||
"identity": dexcel.IdentityVal{Groups: tc.groups}, |
||||
}) |
||||
require.NoError(t, err) |
||||
|
||||
nativeVal, err := out.ConvertToNative(reflect.TypeOf([]string{})) |
||||
require.NoError(t, err) |
||||
|
||||
got, ok := nativeVal.([]string) |
||||
require.True(t, ok, "expected []string, got %T", nativeVal) |
||||
assert.Equal(t, tc.want, got) |
||||
}) |
||||
} |
||||
} |
||||
@ -0,0 +1,109 @@
|
||||
package cel |
||||
|
||||
import ( |
||||
"github.com/google/cel-go/cel" |
||||
|
||||
"github.com/dexidp/dex/connector" |
||||
) |
||||
|
||||
// VariableDeclaration declares a named variable and its CEL type
|
||||
// that will be available in expressions.
|
||||
type VariableDeclaration struct { |
||||
Name string |
||||
Type *cel.Type |
||||
} |
||||
|
||||
// IdentityVal is the CEL native type for the identity variable.
|
||||
// Fields are typed so that the CEL compiler rejects unknown field access
|
||||
// (e.g. identity.emial) at config load time rather than at evaluation time.
|
||||
type IdentityVal struct { |
||||
UserID string `cel:"user_id"` |
||||
Username string `cel:"username"` |
||||
PreferredUsername string `cel:"preferred_username"` |
||||
Email string `cel:"email"` |
||||
EmailVerified bool `cel:"email_verified"` |
||||
Groups []string `cel:"groups"` |
||||
} |
||||
|
||||
// RequestVal is the CEL native type for the request variable.
|
||||
type RequestVal struct { |
||||
ClientID string `cel:"client_id"` |
||||
ConnectorID string `cel:"connector_id"` |
||||
Scopes []string `cel:"scopes"` |
||||
RedirectURI string `cel:"redirect_uri"` |
||||
} |
||||
|
||||
// identityTypeName is the CEL type name for IdentityVal.
|
||||
// Derived by ext.NativeTypes as simplePkgAlias(pkgPath) + "." + structName.
|
||||
const identityTypeName = "cel.IdentityVal" |
||||
|
||||
// requestTypeName is the CEL type name for RequestVal.
|
||||
const requestTypeName = "cel.RequestVal" |
||||
|
||||
// IdentityVariables provides the 'identity' variable with typed fields.
|
||||
//
|
||||
// identity.user_id — string
|
||||
// identity.username — string
|
||||
// identity.preferred_username — string
|
||||
// identity.email — string
|
||||
// identity.email_verified — bool
|
||||
// identity.groups — list(string)
|
||||
func IdentityVariables() []VariableDeclaration { |
||||
return []VariableDeclaration{ |
||||
{Name: "identity", Type: cel.ObjectType(identityTypeName)}, |
||||
} |
||||
} |
||||
|
||||
// RequestVariables provides the 'request' variable with typed fields.
|
||||
//
|
||||
// request.client_id — string
|
||||
// request.connector_id — string
|
||||
// request.scopes — list(string)
|
||||
// request.redirect_uri — string
|
||||
func RequestVariables() []VariableDeclaration { |
||||
return []VariableDeclaration{ |
||||
{Name: "request", Type: cel.ObjectType(requestTypeName)}, |
||||
} |
||||
} |
||||
|
||||
// ClaimsVariable provides a 'claims' map for raw upstream claims.
|
||||
// Claims remain map(string, dyn) because their shape is genuinely
|
||||
// unknown — they carry arbitrary upstream IdP data.
|
||||
//
|
||||
// claims — map(string, dyn)
|
||||
func ClaimsVariable() []VariableDeclaration { |
||||
return []VariableDeclaration{ |
||||
{Name: "claims", Type: cel.MapType(cel.StringType, cel.DynType)}, |
||||
} |
||||
} |
||||
|
||||
// IdentityFromConnector converts a connector.Identity to a CEL-compatible IdentityVal.
|
||||
func IdentityFromConnector(id connector.Identity) IdentityVal { |
||||
return IdentityVal{ |
||||
UserID: id.UserID, |
||||
Username: id.Username, |
||||
PreferredUsername: id.PreferredUsername, |
||||
Email: id.Email, |
||||
EmailVerified: id.EmailVerified, |
||||
Groups: id.Groups, |
||||
} |
||||
} |
||||
|
||||
// RequestContext represents the authentication/token request context
|
||||
// available as the 'request' variable in CEL expressions.
|
||||
type RequestContext struct { |
||||
ClientID string |
||||
ConnectorID string |
||||
Scopes []string |
||||
RedirectURI string |
||||
} |
||||
|
||||
// RequestFromContext converts a RequestContext to a CEL-compatible RequestVal.
|
||||
func RequestFromContext(rc RequestContext) RequestVal { |
||||
return RequestVal{ |
||||
ClientID: rc.ClientID, |
||||
ConnectorID: rc.ConnectorID, |
||||
Scopes: rc.Scopes, |
||||
RedirectURI: rc.RedirectURI, |
||||
} |
||||
} |
||||
@ -0,0 +1,3 @@
|
||||
// Package featureflags provides a mechanism for toggling experimental or
|
||||
// optional Dex features via environment variables (DEX_<FLAG_NAME>).
|
||||
package featureflags |
||||
@ -0,0 +1,2 @@
|
||||
// Package groups contains helper functions related to groups.
|
||||
package groups |
||||
@ -0,0 +1,3 @@
|
||||
// Package httpclient provides a configurable HTTP client constructor with
|
||||
// support for custom CA certificates, root CAs, and TLS settings.
|
||||
package httpclient |
||||
@ -0,0 +1,133 @@
|
||||
package server |
||||
|
||||
import ( |
||||
"context" |
||||
"encoding/json" |
||||
"testing" |
||||
|
||||
"github.com/dexidp/dex/api/v2" |
||||
"github.com/dexidp/dex/connector" |
||||
"github.com/dexidp/dex/connector/mock" |
||||
"github.com/dexidp/dex/storage/memory" |
||||
) |
||||
|
||||
func TestConnectorCacheInvalidation(t *testing.T) { |
||||
t.Setenv("DEX_API_CONNECTORS_CRUD", "true") |
||||
|
||||
logger := newLogger(t) |
||||
s := memory.New(logger) |
||||
|
||||
serv := &Server{ |
||||
storage: s, |
||||
logger: logger, |
||||
connectors: make(map[string]Connector), |
||||
} |
||||
|
||||
apiServer := NewAPI(s, logger, "test", serv) |
||||
ctx := context.Background() |
||||
|
||||
connID := "mock-conn" |
||||
|
||||
// 1. Create a connector via API
|
||||
config1 := mock.PasswordConfig{ |
||||
Username: "user", |
||||
Password: "first-password", |
||||
} |
||||
config1Bytes, _ := json.Marshal(config1) |
||||
|
||||
_, err := apiServer.CreateConnector(ctx, &api.CreateConnectorReq{ |
||||
Connector: &api.Connector{ |
||||
Id: connID, |
||||
Type: "mockPassword", |
||||
Name: "Mock", |
||||
Config: config1Bytes, |
||||
}, |
||||
}) |
||||
if err != nil { |
||||
t.Fatalf("failed to create connector: %v", err) |
||||
} |
||||
|
||||
// 2. Load it into server cache
|
||||
c1, err := serv.getConnector(ctx, connID) |
||||
if err != nil { |
||||
t.Fatalf("failed to get connector: %v", err) |
||||
} |
||||
|
||||
pc1 := c1.Connector.(connector.PasswordConnector) |
||||
_, valid, err := pc1.Login(ctx, connector.Scopes{}, "user", "first-password") |
||||
if err != nil || !valid { |
||||
t.Fatalf("failed to login with first password: %v", err) |
||||
} |
||||
|
||||
// 3. Delete it via API
|
||||
_, err = apiServer.DeleteConnector(ctx, &api.DeleteConnectorReq{Id: connID}) |
||||
if err != nil { |
||||
t.Fatalf("failed to delete connector: %v", err) |
||||
} |
||||
|
||||
// 4. Create it again with different password
|
||||
config2 := mock.PasswordConfig{ |
||||
Username: "user", |
||||
Password: "second-password", |
||||
} |
||||
config2Bytes, _ := json.Marshal(config2) |
||||
|
||||
_, err = apiServer.CreateConnector(ctx, &api.CreateConnectorReq{ |
||||
Connector: &api.Connector{ |
||||
Id: connID, |
||||
Type: "mockPassword", |
||||
Name: "Mock", |
||||
Config: config2Bytes, |
||||
}, |
||||
}) |
||||
if err != nil { |
||||
t.Fatalf("failed to create connector: %v", err) |
||||
} |
||||
|
||||
// 5. Load it again
|
||||
c2, err := serv.getConnector(ctx, connID) |
||||
if err != nil { |
||||
t.Fatalf("failed to get connector second time: %v", err) |
||||
} |
||||
|
||||
pc2 := c2.Connector.(connector.PasswordConnector) |
||||
|
||||
// If the fix works, it should now use the second password.
|
||||
_, valid2, err := pc2.Login(ctx, connector.Scopes{}, "user", "second-password") |
||||
if err != nil || !valid2 { |
||||
t.Errorf("failed to login with second password, cache might still be stale") |
||||
} |
||||
|
||||
_, valid1, _ := pc2.Login(ctx, connector.Scopes{}, "user", "first-password") |
||||
if valid1 { |
||||
t.Errorf("unexpectedly logged in with first password, cache is definitely stale") |
||||
} |
||||
|
||||
// 6. Update it via API with a third password
|
||||
config3 := mock.PasswordConfig{ |
||||
Username: "user", |
||||
Password: "third-password", |
||||
} |
||||
config3Bytes, _ := json.Marshal(config3) |
||||
|
||||
_, err = apiServer.UpdateConnector(ctx, &api.UpdateConnectorReq{ |
||||
Id: connID, |
||||
NewConfig: config3Bytes, |
||||
}) |
||||
if err != nil { |
||||
t.Fatalf("failed to update connector: %v", err) |
||||
} |
||||
|
||||
// 7. Load it again
|
||||
c3, err := serv.getConnector(ctx, connID) |
||||
if err != nil { |
||||
t.Fatalf("failed to get connector third time: %v", err) |
||||
} |
||||
|
||||
pc3 := c3.Connector.(connector.PasswordConnector) |
||||
|
||||
_, valid3, err := pc3.Login(ctx, connector.Scopes{}, "user", "third-password") |
||||
if err != nil || !valid3 { |
||||
t.Errorf("failed to login with third password, UpdateConnector might be missing cache invalidation") |
||||
} |
||||
} |
||||
@ -0,0 +1,117 @@
|
||||
package server |
||||
|
||||
import ( |
||||
"context" |
||||
"crypto/hmac" |
||||
"crypto/sha256" |
||||
"encoding/base64" |
||||
"errors" |
||||
"net/http" |
||||
"net/http/httptest" |
||||
"net/url" |
||||
"strings" |
||||
"testing" |
||||
"time" |
||||
|
||||
"github.com/stretchr/testify/require" |
||||
|
||||
"github.com/dexidp/dex/storage" |
||||
) |
||||
|
||||
type getAuthRequestErrorStorage struct { |
||||
storage.Storage |
||||
err error |
||||
} |
||||
|
||||
func (s *getAuthRequestErrorStorage) GetAuthRequest(context.Context, string) (storage.AuthRequest, error) { |
||||
return storage.AuthRequest{}, s.err |
||||
} |
||||
|
||||
func TestHandleApprovalGetAuthRequestErrorGET(t *testing.T) { |
||||
httpServer, server := newTestServer(t, func(c *Config) { |
||||
c.Storage = &getAuthRequestErrorStorage{Storage: c.Storage, err: errors.New("storage unavailable")} |
||||
}) |
||||
defer httpServer.Close() |
||||
|
||||
rr := httptest.NewRecorder() |
||||
req := httptest.NewRequest(http.MethodGet, "/approval?req=any&hmac=AQ", nil) |
||||
|
||||
server.ServeHTTP(rr, req) |
||||
|
||||
require.Equal(t, http.StatusInternalServerError, rr.Code) |
||||
require.Contains(t, rr.Body.String(), "Database error.") |
||||
} |
||||
|
||||
func TestHandleApprovalGetAuthRequestNotFoundGET(t *testing.T) { |
||||
httpServer, server := newTestServer(t, nil) |
||||
defer httpServer.Close() |
||||
|
||||
rr := httptest.NewRecorder() |
||||
req := httptest.NewRequest(http.MethodGet, "/approval?req=does-not-exist&hmac=AQ", nil) |
||||
|
||||
server.ServeHTTP(rr, req) |
||||
|
||||
require.Equal(t, http.StatusBadRequest, rr.Code) |
||||
require.Contains(t, rr.Body.String(), "User session error.") |
||||
require.NotContains(t, rr.Body.String(), "Database error.") |
||||
} |
||||
|
||||
func TestHandleApprovalGetAuthRequestNotFoundPOST(t *testing.T) { |
||||
httpServer, server := newTestServer(t, nil) |
||||
defer httpServer.Close() |
||||
|
||||
body := strings.NewReader("approval=approve&req=does-not-exist&hmac=AQ") |
||||
rr := httptest.NewRecorder() |
||||
req := httptest.NewRequest(http.MethodPost, "/approval", body) |
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded") |
||||
|
||||
server.ServeHTTP(rr, req) |
||||
|
||||
require.Equal(t, http.StatusBadRequest, rr.Code) |
||||
require.Contains(t, rr.Body.String(), "User session error.") |
||||
require.NotContains(t, rr.Body.String(), "Database error.") |
||||
} |
||||
|
||||
func TestHandleApprovalDoubleSubmitPOST(t *testing.T) { |
||||
ctx := t.Context() |
||||
httpServer, server := newTestServer(t, nil) |
||||
defer httpServer.Close() |
||||
|
||||
authReq := storage.AuthRequest{ |
||||
ID: "approval-double-submit", |
||||
ClientID: "test", |
||||
ResponseTypes: []string{responseTypeCode}, |
||||
RedirectURI: "https://client.example/callback", |
||||
Expiry: time.Now().Add(time.Minute), |
||||
LoggedIn: true, |
||||
HMACKey: []byte("approval-double-submit-key"), |
||||
} |
||||
require.NoError(t, server.storage.CreateAuthRequest(ctx, authReq)) |
||||
|
||||
h := hmac.New(sha256.New, authReq.HMACKey) |
||||
h.Write([]byte(authReq.ID)) |
||||
mac := base64.RawURLEncoding.EncodeToString(h.Sum(nil)) |
||||
|
||||
form := url.Values{ |
||||
"approval": {"approve"}, |
||||
"req": {authReq.ID}, |
||||
"hmac": {mac}, |
||||
} |
||||
|
||||
firstRR := httptest.NewRecorder() |
||||
firstReq := httptest.NewRequest(http.MethodPost, "/approval", strings.NewReader(form.Encode())) |
||||
firstReq.Header.Set("Content-Type", "application/x-www-form-urlencoded") |
||||
server.ServeHTTP(firstRR, firstReq) |
||||
|
||||
require.Equal(t, http.StatusSeeOther, firstRR.Code) |
||||
require.Contains(t, firstRR.Header().Get("Location"), "https://client.example/callback") |
||||
|
||||
secondRR := httptest.NewRecorder() |
||||
secondReq := httptest.NewRequest(http.MethodPost, "/approval", strings.NewReader(form.Encode())) |
||||
secondReq.Header.Set("Content-Type", "application/x-www-form-urlencoded") |
||||
server.ServeHTTP(secondRR, secondReq) |
||||
|
||||
require.Equal(t, http.StatusBadRequest, secondRR.Code) |
||||
require.Contains(t, secondRR.Body.String(), "User session error.") |
||||
require.NotContains(t, secondRR.Body.String(), "Database error.") |
||||
} |
||||
@ -0,0 +1,108 @@
|
||||
package client |
||||
|
||||
import ( |
||||
"context" |
||||
"encoding/json" |
||||
"fmt" |
||||
|
||||
"github.com/dexidp/dex/storage" |
||||
) |
||||
|
||||
// CreateAuthSession saves provided auth session into the database.
|
||||
func (d *Database) CreateAuthSession(ctx context.Context, session storage.AuthSession) error { |
||||
if session.ClientStates == nil { |
||||
session.ClientStates = make(map[string]*storage.ClientAuthState) |
||||
} |
||||
encodedStates, err := json.Marshal(session.ClientStates) |
||||
if err != nil { |
||||
return fmt.Errorf("encode client states auth session: %w", err) |
||||
} |
||||
|
||||
_, err = d.client.AuthSession.Create(). |
||||
SetID(session.ID). |
||||
SetClientStates(encodedStates). |
||||
SetCreatedAt(session.CreatedAt). |
||||
SetLastActivity(session.LastActivity). |
||||
SetIPAddress(session.IPAddress). |
||||
SetUserAgent(session.UserAgent). |
||||
Save(ctx) |
||||
if err != nil { |
||||
return convertDBError("create auth session: %w", err) |
||||
} |
||||
return nil |
||||
} |
||||
|
||||
// GetAuthSession extracts an auth session from the database by session ID.
|
||||
func (d *Database) GetAuthSession(ctx context.Context, sessionID string) (storage.AuthSession, error) { |
||||
authSession, err := d.client.AuthSession.Get(ctx, sessionID) |
||||
if err != nil { |
||||
return storage.AuthSession{}, convertDBError("get auth session: %w", err) |
||||
} |
||||
return toStorageAuthSession(authSession), nil |
||||
} |
||||
|
||||
// ListAuthSessions extracts all auth sessions from the database.
|
||||
func (d *Database) ListAuthSessions(ctx context.Context) ([]storage.AuthSession, error) { |
||||
authSessions, err := d.client.AuthSession.Query().All(ctx) |
||||
if err != nil { |
||||
return nil, convertDBError("list auth sessions: %w", err) |
||||
} |
||||
|
||||
storageAuthSessions := make([]storage.AuthSession, 0, len(authSessions)) |
||||
for _, s := range authSessions { |
||||
storageAuthSessions = append(storageAuthSessions, toStorageAuthSession(s)) |
||||
} |
||||
return storageAuthSessions, nil |
||||
} |
||||
|
||||
// DeleteAuthSession deletes an auth session from the database by session ID.
|
||||
func (d *Database) DeleteAuthSession(ctx context.Context, sessionID string) error { |
||||
err := d.client.AuthSession.DeleteOneID(sessionID).Exec(ctx) |
||||
if err != nil { |
||||
return convertDBError("delete auth session: %w", err) |
||||
} |
||||
return nil |
||||
} |
||||
|
||||
// UpdateAuthSession changes an auth session using an updater function.
|
||||
func (d *Database) UpdateAuthSession(ctx context.Context, sessionID string, updater func(s storage.AuthSession) (storage.AuthSession, error)) error { |
||||
tx, err := d.BeginTx(ctx) |
||||
if err != nil { |
||||
return convertDBError("update auth session tx: %w", err) |
||||
} |
||||
|
||||
authSession, err := tx.AuthSession.Get(ctx, sessionID) |
||||
if err != nil { |
||||
return rollback(tx, "update auth session database: %w", err) |
||||
} |
||||
|
||||
newSession, err := updater(toStorageAuthSession(authSession)) |
||||
if err != nil { |
||||
return rollback(tx, "update auth session updating: %w", err) |
||||
} |
||||
|
||||
if newSession.ClientStates == nil { |
||||
newSession.ClientStates = make(map[string]*storage.ClientAuthState) |
||||
} |
||||
|
||||
encodedStates, err := json.Marshal(newSession.ClientStates) |
||||
if err != nil { |
||||
return rollback(tx, "encode client states auth session: %w", err) |
||||
} |
||||
|
||||
_, err = tx.AuthSession.UpdateOneID(sessionID). |
||||
SetClientStates(encodedStates). |
||||
SetLastActivity(newSession.LastActivity). |
||||
SetIPAddress(newSession.IPAddress). |
||||
SetUserAgent(newSession.UserAgent). |
||||
Save(ctx) |
||||
if err != nil { |
||||
return rollback(tx, "update auth session updating: %w", err) |
||||
} |
||||
|
||||
if err = tx.Commit(); err != nil { |
||||
return rollback(tx, "update auth session commit: %w", err) |
||||
} |
||||
|
||||
return nil |
||||
} |
||||
@ -0,0 +1,130 @@
|
||||
package client |
||||
|
||||
import ( |
||||
"context" |
||||
"encoding/json" |
||||
"fmt" |
||||
|
||||
"github.com/dexidp/dex/storage" |
||||
) |
||||
|
||||
// CreateUserIdentity saves provided user identity into the database.
|
||||
func (d *Database) CreateUserIdentity(ctx context.Context, identity storage.UserIdentity) error { |
||||
if identity.Consents == nil { |
||||
identity.Consents = make(map[string][]string) |
||||
} |
||||
encodedConsents, err := json.Marshal(identity.Consents) |
||||
if err != nil { |
||||
return fmt.Errorf("encode consents user identity: %w", err) |
||||
} |
||||
|
||||
id := compositeKeyID(identity.UserID, identity.ConnectorID, d.hasher) |
||||
_, err = d.client.UserIdentity.Create(). |
||||
SetID(id). |
||||
SetUserID(identity.UserID). |
||||
SetConnectorID(identity.ConnectorID). |
||||
SetClaimsUserID(identity.Claims.UserID). |
||||
SetClaimsUsername(identity.Claims.Username). |
||||
SetClaimsPreferredUsername(identity.Claims.PreferredUsername). |
||||
SetClaimsEmail(identity.Claims.Email). |
||||
SetClaimsEmailVerified(identity.Claims.EmailVerified). |
||||
SetClaimsGroups(identity.Claims.Groups). |
||||
SetConsents(encodedConsents). |
||||
SetCreatedAt(identity.CreatedAt). |
||||
SetLastLogin(identity.LastLogin). |
||||
SetBlockedUntil(identity.BlockedUntil). |
||||
Save(ctx) |
||||
if err != nil { |
||||
return convertDBError("create user identity: %w", err) |
||||
} |
||||
return nil |
||||
} |
||||
|
||||
// GetUserIdentity extracts a user identity from the database by user id and connector id.
|
||||
func (d *Database) GetUserIdentity(ctx context.Context, userID, connectorID string) (storage.UserIdentity, error) { |
||||
id := compositeKeyID(userID, connectorID, d.hasher) |
||||
|
||||
userIdentity, err := d.client.UserIdentity.Get(ctx, id) |
||||
if err != nil { |
||||
return storage.UserIdentity{}, convertDBError("get user identity: %w", err) |
||||
} |
||||
return toStorageUserIdentity(userIdentity), nil |
||||
} |
||||
|
||||
// DeleteUserIdentity deletes a user identity from the database by user id and connector id.
|
||||
func (d *Database) DeleteUserIdentity(ctx context.Context, userID, connectorID string) error { |
||||
id := compositeKeyID(userID, connectorID, d.hasher) |
||||
|
||||
err := d.client.UserIdentity.DeleteOneID(id).Exec(ctx) |
||||
if err != nil { |
||||
return convertDBError("delete user identity: %w", err) |
||||
} |
||||
return nil |
||||
} |
||||
|
||||
// UpdateUserIdentity changes a user identity by user id and connector id using an updater function.
|
||||
func (d *Database) UpdateUserIdentity(ctx context.Context, userID string, connectorID string, updater func(u storage.UserIdentity) (storage.UserIdentity, error)) error { |
||||
id := compositeKeyID(userID, connectorID, d.hasher) |
||||
|
||||
tx, err := d.BeginTx(ctx) |
||||
if err != nil { |
||||
return convertDBError("update user identity tx: %w", err) |
||||
} |
||||
|
||||
userIdentity, err := tx.UserIdentity.Get(ctx, id) |
||||
if err != nil { |
||||
return rollback(tx, "update user identity database: %w", err) |
||||
} |
||||
|
||||
newUserIdentity, err := updater(toStorageUserIdentity(userIdentity)) |
||||
if err != nil { |
||||
return rollback(tx, "update user identity updating: %w", err) |
||||
} |
||||
|
||||
if newUserIdentity.Consents == nil { |
||||
newUserIdentity.Consents = make(map[string][]string) |
||||
} |
||||
|
||||
encodedConsents, err := json.Marshal(newUserIdentity.Consents) |
||||
if err != nil { |
||||
return rollback(tx, "encode consents user identity: %w", err) |
||||
} |
||||
|
||||
_, err = tx.UserIdentity.UpdateOneID(id). |
||||
SetUserID(newUserIdentity.UserID). |
||||
SetConnectorID(newUserIdentity.ConnectorID). |
||||
SetClaimsUserID(newUserIdentity.Claims.UserID). |
||||
SetClaimsUsername(newUserIdentity.Claims.Username). |
||||
SetClaimsPreferredUsername(newUserIdentity.Claims.PreferredUsername). |
||||
SetClaimsEmail(newUserIdentity.Claims.Email). |
||||
SetClaimsEmailVerified(newUserIdentity.Claims.EmailVerified). |
||||
SetClaimsGroups(newUserIdentity.Claims.Groups). |
||||
SetConsents(encodedConsents). |
||||
SetCreatedAt(newUserIdentity.CreatedAt). |
||||
SetLastLogin(newUserIdentity.LastLogin). |
||||
SetBlockedUntil(newUserIdentity.BlockedUntil). |
||||
Save(ctx) |
||||
if err != nil { |
||||
return rollback(tx, "update user identity uploading: %w", err) |
||||
} |
||||
|
||||
if err = tx.Commit(); err != nil { |
||||
return rollback(tx, "update user identity commit: %w", err) |
||||
} |
||||
|
||||
return nil |
||||
} |
||||
|
||||
// ListUserIdentities lists all user identities in the database.
|
||||
func (d *Database) ListUserIdentities(ctx context.Context) ([]storage.UserIdentity, error) { |
||||
userIdentities, err := d.client.UserIdentity.Query().All(ctx) |
||||
if err != nil { |
||||
return nil, convertDBError("list user identities: %w", err) |
||||
} |
||||
|
||||
storageUserIdentities := make([]storage.UserIdentity, 0, len(userIdentities)) |
||||
for _, u := range userIdentities { |
||||
storageUserIdentities = append(storageUserIdentities, toStorageUserIdentity(u)) |
||||
} |
||||
return storageUserIdentities, nil |
||||
} |
||||
@ -0,0 +1,150 @@
|
||||
// Code generated by ent, DO NOT EDIT.
|
||||
|
||||
package db |
||||
|
||||
import ( |
||||
"fmt" |
||||
"strings" |
||||
"time" |
||||
|
||||
"entgo.io/ent" |
||||
"entgo.io/ent/dialect/sql" |
||||
"github.com/dexidp/dex/storage/ent/db/authsession" |
||||
) |
||||
|
||||
// AuthSession is the model entity for the AuthSession schema.
|
||||
type AuthSession struct { |
||||
config `json:"-"` |
||||
// ID of the ent.
|
||||
ID string `json:"id,omitempty"` |
||||
// ClientStates holds the value of the "client_states" field.
|
||||
ClientStates []byte `json:"client_states,omitempty"` |
||||
// CreatedAt holds the value of the "created_at" field.
|
||||
CreatedAt time.Time `json:"created_at,omitempty"` |
||||
// LastActivity holds the value of the "last_activity" field.
|
||||
LastActivity time.Time `json:"last_activity,omitempty"` |
||||
// IPAddress holds the value of the "ip_address" field.
|
||||
IPAddress string `json:"ip_address,omitempty"` |
||||
// UserAgent holds the value of the "user_agent" field.
|
||||
UserAgent string `json:"user_agent,omitempty"` |
||||
selectValues sql.SelectValues |
||||
} |
||||
|
||||
// scanValues returns the types for scanning values from sql.Rows.
|
||||
func (*AuthSession) scanValues(columns []string) ([]any, error) { |
||||
values := make([]any, len(columns)) |
||||
for i := range columns { |
||||
switch columns[i] { |
||||
case authsession.FieldClientStates: |
||||
values[i] = new([]byte) |
||||
case authsession.FieldID, authsession.FieldIPAddress, authsession.FieldUserAgent: |
||||
values[i] = new(sql.NullString) |
||||
case authsession.FieldCreatedAt, authsession.FieldLastActivity: |
||||
values[i] = new(sql.NullTime) |
||||
default: |
||||
values[i] = new(sql.UnknownType) |
||||
} |
||||
} |
||||
return values, nil |
||||
} |
||||
|
||||
// assignValues assigns the values that were returned from sql.Rows (after scanning)
|
||||
// to the AuthSession fields.
|
||||
func (_m *AuthSession) assignValues(columns []string, values []any) error { |
||||
if m, n := len(values), len(columns); m < n { |
||||
return fmt.Errorf("mismatch number of scan values: %d != %d", m, n) |
||||
} |
||||
for i := range columns { |
||||
switch columns[i] { |
||||
case authsession.FieldID: |
||||
if value, ok := values[i].(*sql.NullString); !ok { |
||||
return fmt.Errorf("unexpected type %T for field id", values[i]) |
||||
} else if value.Valid { |
||||
_m.ID = value.String |
||||
} |
||||
case authsession.FieldClientStates: |
||||
if value, ok := values[i].(*[]byte); !ok { |
||||
return fmt.Errorf("unexpected type %T for field client_states", values[i]) |
||||
} else if value != nil { |
||||
_m.ClientStates = *value |
||||
} |
||||
case authsession.FieldCreatedAt: |
||||
if value, ok := values[i].(*sql.NullTime); !ok { |
||||
return fmt.Errorf("unexpected type %T for field created_at", values[i]) |
||||
} else if value.Valid { |
||||
_m.CreatedAt = value.Time |
||||
} |
||||
case authsession.FieldLastActivity: |
||||
if value, ok := values[i].(*sql.NullTime); !ok { |
||||
return fmt.Errorf("unexpected type %T for field last_activity", values[i]) |
||||
} else if value.Valid { |
||||
_m.LastActivity = value.Time |
||||
} |
||||
case authsession.FieldIPAddress: |
||||
if value, ok := values[i].(*sql.NullString); !ok { |
||||
return fmt.Errorf("unexpected type %T for field ip_address", values[i]) |
||||
} else if value.Valid { |
||||
_m.IPAddress = value.String |
||||
} |
||||
case authsession.FieldUserAgent: |
||||
if value, ok := values[i].(*sql.NullString); !ok { |
||||
return fmt.Errorf("unexpected type %T for field user_agent", values[i]) |
||||
} else if value.Valid { |
||||
_m.UserAgent = value.String |
||||
} |
||||
default: |
||||
_m.selectValues.Set(columns[i], values[i]) |
||||
} |
||||
} |
||||
return nil |
||||
} |
||||
|
||||
// Value returns the ent.Value that was dynamically selected and assigned to the AuthSession.
|
||||
// This includes values selected through modifiers, order, etc.
|
||||
func (_m *AuthSession) Value(name string) (ent.Value, error) { |
||||
return _m.selectValues.Get(name) |
||||
} |
||||
|
||||
// Update returns a builder for updating this AuthSession.
|
||||
// Note that you need to call AuthSession.Unwrap() before calling this method if this AuthSession
|
||||
// was returned from a transaction, and the transaction was committed or rolled back.
|
||||
func (_m *AuthSession) Update() *AuthSessionUpdateOne { |
||||
return NewAuthSessionClient(_m.config).UpdateOne(_m) |
||||
} |
||||
|
||||
// Unwrap unwraps the AuthSession entity that was returned from a transaction after it was closed,
|
||||
// so that all future queries will be executed through the driver which created the transaction.
|
||||
func (_m *AuthSession) Unwrap() *AuthSession { |
||||
_tx, ok := _m.config.driver.(*txDriver) |
||||
if !ok { |
||||
panic("db: AuthSession is not a transactional entity") |
||||
} |
||||
_m.config.driver = _tx.drv |
||||
return _m |
||||
} |
||||
|
||||
// String implements the fmt.Stringer.
|
||||
func (_m *AuthSession) String() string { |
||||
var builder strings.Builder |
||||
builder.WriteString("AuthSession(") |
||||
builder.WriteString(fmt.Sprintf("id=%v, ", _m.ID)) |
||||
builder.WriteString("client_states=") |
||||
builder.WriteString(fmt.Sprintf("%v", _m.ClientStates)) |
||||
builder.WriteString(", ") |
||||
builder.WriteString("created_at=") |
||||
builder.WriteString(_m.CreatedAt.Format(time.ANSIC)) |
||||
builder.WriteString(", ") |
||||
builder.WriteString("last_activity=") |
||||
builder.WriteString(_m.LastActivity.Format(time.ANSIC)) |
||||
builder.WriteString(", ") |
||||
builder.WriteString("ip_address=") |
||||
builder.WriteString(_m.IPAddress) |
||||
builder.WriteString(", ") |
||||
builder.WriteString("user_agent=") |
||||
builder.WriteString(_m.UserAgent) |
||||
builder.WriteByte(')') |
||||
return builder.String() |
||||
} |
||||
|
||||
// AuthSessions is a parsable slice of AuthSession.
|
||||
type AuthSessions []*AuthSession |
||||
@ -0,0 +1,83 @@
|
||||
// Code generated by ent, DO NOT EDIT.
|
||||
|
||||
package authsession |
||||
|
||||
import ( |
||||
"entgo.io/ent/dialect/sql" |
||||
) |
||||
|
||||
const ( |
||||
// Label holds the string label denoting the authsession type in the database.
|
||||
Label = "auth_session" |
||||
// FieldID holds the string denoting the id field in the database.
|
||||
FieldID = "id" |
||||
// FieldClientStates holds the string denoting the client_states field in the database.
|
||||
FieldClientStates = "client_states" |
||||
// FieldCreatedAt holds the string denoting the created_at field in the database.
|
||||
FieldCreatedAt = "created_at" |
||||
// FieldLastActivity holds the string denoting the last_activity field in the database.
|
||||
FieldLastActivity = "last_activity" |
||||
// FieldIPAddress holds the string denoting the ip_address field in the database.
|
||||
FieldIPAddress = "ip_address" |
||||
// FieldUserAgent holds the string denoting the user_agent field in the database.
|
||||
FieldUserAgent = "user_agent" |
||||
// Table holds the table name of the authsession in the database.
|
||||
Table = "auth_sessions" |
||||
) |
||||
|
||||
// Columns holds all SQL columns for authsession fields.
|
||||
var Columns = []string{ |
||||
FieldID, |
||||
FieldClientStates, |
||||
FieldCreatedAt, |
||||
FieldLastActivity, |
||||
FieldIPAddress, |
||||
FieldUserAgent, |
||||
} |
||||
|
||||
// ValidColumn reports if the column name is valid (part of the table columns).
|
||||
func ValidColumn(column string) bool { |
||||
for i := range Columns { |
||||
if column == Columns[i] { |
||||
return true |
||||
} |
||||
} |
||||
return false |
||||
} |
||||
|
||||
var ( |
||||
// DefaultIPAddress holds the default value on creation for the "ip_address" field.
|
||||
DefaultIPAddress string |
||||
// DefaultUserAgent holds the default value on creation for the "user_agent" field.
|
||||
DefaultUserAgent string |
||||
// IDValidator is a validator for the "id" field. It is called by the builders before save.
|
||||
IDValidator func(string) error |
||||
) |
||||
|
||||
// OrderOption defines the ordering options for the AuthSession queries.
|
||||
type OrderOption func(*sql.Selector) |
||||
|
||||
// ByID orders the results by the id field.
|
||||
func ByID(opts ...sql.OrderTermOption) OrderOption { |
||||
return sql.OrderByField(FieldID, opts...).ToFunc() |
||||
} |
||||
|
||||
// ByCreatedAt orders the results by the created_at field.
|
||||
func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption { |
||||
return sql.OrderByField(FieldCreatedAt, opts...).ToFunc() |
||||
} |
||||
|
||||
// ByLastActivity orders the results by the last_activity field.
|
||||
func ByLastActivity(opts ...sql.OrderTermOption) OrderOption { |
||||
return sql.OrderByField(FieldLastActivity, opts...).ToFunc() |
||||
} |
||||
|
||||
// ByIPAddress orders the results by the ip_address field.
|
||||
func ByIPAddress(opts ...sql.OrderTermOption) OrderOption { |
||||
return sql.OrderByField(FieldIPAddress, opts...).ToFunc() |
||||
} |
||||
|
||||
// ByUserAgent orders the results by the user_agent field.
|
||||
func ByUserAgent(opts ...sql.OrderTermOption) OrderOption { |
||||
return sql.OrderByField(FieldUserAgent, opts...).ToFunc() |
||||
} |
||||
@ -0,0 +1,355 @@
|
||||
// Code generated by ent, DO NOT EDIT.
|
||||
|
||||
package authsession |
||||
|
||||
import ( |
||||
"time" |
||||
|
||||
"entgo.io/ent/dialect/sql" |
||||
"github.com/dexidp/dex/storage/ent/db/predicate" |
||||
) |
||||
|
||||
// ID filters vertices based on their ID field.
|
||||
func ID(id string) predicate.AuthSession { |
||||
return predicate.AuthSession(sql.FieldEQ(FieldID, id)) |
||||
} |
||||
|
||||
// IDEQ applies the EQ predicate on the ID field.
|
||||
func IDEQ(id string) predicate.AuthSession { |
||||
return predicate.AuthSession(sql.FieldEQ(FieldID, id)) |
||||
} |
||||
|
||||
// IDNEQ applies the NEQ predicate on the ID field.
|
||||
func IDNEQ(id string) predicate.AuthSession { |
||||
return predicate.AuthSession(sql.FieldNEQ(FieldID, id)) |
||||
} |
||||
|
||||
// IDIn applies the In predicate on the ID field.
|
||||
func IDIn(ids ...string) predicate.AuthSession { |
||||
return predicate.AuthSession(sql.FieldIn(FieldID, ids...)) |
||||
} |
||||
|
||||
// IDNotIn applies the NotIn predicate on the ID field.
|
||||
func IDNotIn(ids ...string) predicate.AuthSession { |
||||
return predicate.AuthSession(sql.FieldNotIn(FieldID, ids...)) |
||||
} |
||||
|
||||
// IDGT applies the GT predicate on the ID field.
|
||||
func IDGT(id string) predicate.AuthSession { |
||||
return predicate.AuthSession(sql.FieldGT(FieldID, id)) |
||||
} |
||||
|
||||
// IDGTE applies the GTE predicate on the ID field.
|
||||
func IDGTE(id string) predicate.AuthSession { |
||||
return predicate.AuthSession(sql.FieldGTE(FieldID, id)) |
||||
} |
||||
|
||||
// IDLT applies the LT predicate on the ID field.
|
||||
func IDLT(id string) predicate.AuthSession { |
||||
return predicate.AuthSession(sql.FieldLT(FieldID, id)) |
||||
} |
||||
|
||||
// IDLTE applies the LTE predicate on the ID field.
|
||||
func IDLTE(id string) predicate.AuthSession { |
||||
return predicate.AuthSession(sql.FieldLTE(FieldID, id)) |
||||
} |
||||
|
||||
// IDEqualFold applies the EqualFold predicate on the ID field.
|
||||
func IDEqualFold(id string) predicate.AuthSession { |
||||
return predicate.AuthSession(sql.FieldEqualFold(FieldID, id)) |
||||
} |
||||
|
||||
// IDContainsFold applies the ContainsFold predicate on the ID field.
|
||||
func IDContainsFold(id string) predicate.AuthSession { |
||||
return predicate.AuthSession(sql.FieldContainsFold(FieldID, id)) |
||||
} |
||||
|
||||
// ClientStates applies equality check predicate on the "client_states" field. It's identical to ClientStatesEQ.
|
||||
func ClientStates(v []byte) predicate.AuthSession { |
||||
return predicate.AuthSession(sql.FieldEQ(FieldClientStates, v)) |
||||
} |
||||
|
||||
// CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ.
|
||||
func CreatedAt(v time.Time) predicate.AuthSession { |
||||
return predicate.AuthSession(sql.FieldEQ(FieldCreatedAt, v)) |
||||
} |
||||
|
||||
// LastActivity applies equality check predicate on the "last_activity" field. It's identical to LastActivityEQ.
|
||||
func LastActivity(v time.Time) predicate.AuthSession { |
||||
return predicate.AuthSession(sql.FieldEQ(FieldLastActivity, v)) |
||||
} |
||||
|
||||
// IPAddress applies equality check predicate on the "ip_address" field. It's identical to IPAddressEQ.
|
||||
func IPAddress(v string) predicate.AuthSession { |
||||
return predicate.AuthSession(sql.FieldEQ(FieldIPAddress, v)) |
||||
} |
||||
|
||||
// UserAgent applies equality check predicate on the "user_agent" field. It's identical to UserAgentEQ.
|
||||
func UserAgent(v string) predicate.AuthSession { |
||||
return predicate.AuthSession(sql.FieldEQ(FieldUserAgent, v)) |
||||
} |
||||
|
||||
// ClientStatesEQ applies the EQ predicate on the "client_states" field.
|
||||
func ClientStatesEQ(v []byte) predicate.AuthSession { |
||||
return predicate.AuthSession(sql.FieldEQ(FieldClientStates, v)) |
||||
} |
||||
|
||||
// ClientStatesNEQ applies the NEQ predicate on the "client_states" field.
|
||||
func ClientStatesNEQ(v []byte) predicate.AuthSession { |
||||
return predicate.AuthSession(sql.FieldNEQ(FieldClientStates, v)) |
||||
} |
||||
|
||||
// ClientStatesIn applies the In predicate on the "client_states" field.
|
||||
func ClientStatesIn(vs ...[]byte) predicate.AuthSession { |
||||
return predicate.AuthSession(sql.FieldIn(FieldClientStates, vs...)) |
||||
} |
||||
|
||||
// ClientStatesNotIn applies the NotIn predicate on the "client_states" field.
|
||||
func ClientStatesNotIn(vs ...[]byte) predicate.AuthSession { |
||||
return predicate.AuthSession(sql.FieldNotIn(FieldClientStates, vs...)) |
||||
} |
||||
|
||||
// ClientStatesGT applies the GT predicate on the "client_states" field.
|
||||
func ClientStatesGT(v []byte) predicate.AuthSession { |
||||
return predicate.AuthSession(sql.FieldGT(FieldClientStates, v)) |
||||
} |
||||
|
||||
// ClientStatesGTE applies the GTE predicate on the "client_states" field.
|
||||
func ClientStatesGTE(v []byte) predicate.AuthSession { |
||||
return predicate.AuthSession(sql.FieldGTE(FieldClientStates, v)) |
||||
} |
||||
|
||||
// ClientStatesLT applies the LT predicate on the "client_states" field.
|
||||
func ClientStatesLT(v []byte) predicate.AuthSession { |
||||
return predicate.AuthSession(sql.FieldLT(FieldClientStates, v)) |
||||
} |
||||
|
||||
// ClientStatesLTE applies the LTE predicate on the "client_states" field.
|
||||
func ClientStatesLTE(v []byte) predicate.AuthSession { |
||||
return predicate.AuthSession(sql.FieldLTE(FieldClientStates, v)) |
||||
} |
||||
|
||||
// CreatedAtEQ applies the EQ predicate on the "created_at" field.
|
||||
func CreatedAtEQ(v time.Time) predicate.AuthSession { |
||||
return predicate.AuthSession(sql.FieldEQ(FieldCreatedAt, v)) |
||||
} |
||||
|
||||
// CreatedAtNEQ applies the NEQ predicate on the "created_at" field.
|
||||
func CreatedAtNEQ(v time.Time) predicate.AuthSession { |
||||
return predicate.AuthSession(sql.FieldNEQ(FieldCreatedAt, v)) |
||||
} |
||||
|
||||
// CreatedAtIn applies the In predicate on the "created_at" field.
|
||||
func CreatedAtIn(vs ...time.Time) predicate.AuthSession { |
||||
return predicate.AuthSession(sql.FieldIn(FieldCreatedAt, vs...)) |
||||
} |
||||
|
||||
// CreatedAtNotIn applies the NotIn predicate on the "created_at" field.
|
||||
func CreatedAtNotIn(vs ...time.Time) predicate.AuthSession { |
||||
return predicate.AuthSession(sql.FieldNotIn(FieldCreatedAt, vs...)) |
||||
} |
||||
|
||||
// CreatedAtGT applies the GT predicate on the "created_at" field.
|
||||
func CreatedAtGT(v time.Time) predicate.AuthSession { |
||||
return predicate.AuthSession(sql.FieldGT(FieldCreatedAt, v)) |
||||
} |
||||
|
||||
// CreatedAtGTE applies the GTE predicate on the "created_at" field.
|
||||
func CreatedAtGTE(v time.Time) predicate.AuthSession { |
||||
return predicate.AuthSession(sql.FieldGTE(FieldCreatedAt, v)) |
||||
} |
||||
|
||||
// CreatedAtLT applies the LT predicate on the "created_at" field.
|
||||
func CreatedAtLT(v time.Time) predicate.AuthSession { |
||||
return predicate.AuthSession(sql.FieldLT(FieldCreatedAt, v)) |
||||
} |
||||
|
||||
// CreatedAtLTE applies the LTE predicate on the "created_at" field.
|
||||
func CreatedAtLTE(v time.Time) predicate.AuthSession { |
||||
return predicate.AuthSession(sql.FieldLTE(FieldCreatedAt, v)) |
||||
} |
||||
|
||||
// LastActivityEQ applies the EQ predicate on the "last_activity" field.
|
||||
func LastActivityEQ(v time.Time) predicate.AuthSession { |
||||
return predicate.AuthSession(sql.FieldEQ(FieldLastActivity, v)) |
||||
} |
||||
|
||||
// LastActivityNEQ applies the NEQ predicate on the "last_activity" field.
|
||||
func LastActivityNEQ(v time.Time) predicate.AuthSession { |
||||
return predicate.AuthSession(sql.FieldNEQ(FieldLastActivity, v)) |
||||
} |
||||
|
||||
// LastActivityIn applies the In predicate on the "last_activity" field.
|
||||
func LastActivityIn(vs ...time.Time) predicate.AuthSession { |
||||
return predicate.AuthSession(sql.FieldIn(FieldLastActivity, vs...)) |
||||
} |
||||
|
||||
// LastActivityNotIn applies the NotIn predicate on the "last_activity" field.
|
||||
func LastActivityNotIn(vs ...time.Time) predicate.AuthSession { |
||||
return predicate.AuthSession(sql.FieldNotIn(FieldLastActivity, vs...)) |
||||
} |
||||
|
||||
// LastActivityGT applies the GT predicate on the "last_activity" field.
|
||||
func LastActivityGT(v time.Time) predicate.AuthSession { |
||||
return predicate.AuthSession(sql.FieldGT(FieldLastActivity, v)) |
||||
} |
||||
|
||||
// LastActivityGTE applies the GTE predicate on the "last_activity" field.
|
||||
func LastActivityGTE(v time.Time) predicate.AuthSession { |
||||
return predicate.AuthSession(sql.FieldGTE(FieldLastActivity, v)) |
||||
} |
||||
|
||||
// LastActivityLT applies the LT predicate on the "last_activity" field.
|
||||
func LastActivityLT(v time.Time) predicate.AuthSession { |
||||
return predicate.AuthSession(sql.FieldLT(FieldLastActivity, v)) |
||||
} |
||||
|
||||
// LastActivityLTE applies the LTE predicate on the "last_activity" field.
|
||||
func LastActivityLTE(v time.Time) predicate.AuthSession { |
||||
return predicate.AuthSession(sql.FieldLTE(FieldLastActivity, v)) |
||||
} |
||||
|
||||
// IPAddressEQ applies the EQ predicate on the "ip_address" field.
|
||||
func IPAddressEQ(v string) predicate.AuthSession { |
||||
return predicate.AuthSession(sql.FieldEQ(FieldIPAddress, v)) |
||||
} |
||||
|
||||
// IPAddressNEQ applies the NEQ predicate on the "ip_address" field.
|
||||
func IPAddressNEQ(v string) predicate.AuthSession { |
||||
return predicate.AuthSession(sql.FieldNEQ(FieldIPAddress, v)) |
||||
} |
||||
|
||||
// IPAddressIn applies the In predicate on the "ip_address" field.
|
||||
func IPAddressIn(vs ...string) predicate.AuthSession { |
||||
return predicate.AuthSession(sql.FieldIn(FieldIPAddress, vs...)) |
||||
} |
||||
|
||||
// IPAddressNotIn applies the NotIn predicate on the "ip_address" field.
|
||||
func IPAddressNotIn(vs ...string) predicate.AuthSession { |
||||
return predicate.AuthSession(sql.FieldNotIn(FieldIPAddress, vs...)) |
||||
} |
||||
|
||||
// IPAddressGT applies the GT predicate on the "ip_address" field.
|
||||
func IPAddressGT(v string) predicate.AuthSession { |
||||
return predicate.AuthSession(sql.FieldGT(FieldIPAddress, v)) |
||||
} |
||||
|
||||
// IPAddressGTE applies the GTE predicate on the "ip_address" field.
|
||||
func IPAddressGTE(v string) predicate.AuthSession { |
||||
return predicate.AuthSession(sql.FieldGTE(FieldIPAddress, v)) |
||||
} |
||||
|
||||
// IPAddressLT applies the LT predicate on the "ip_address" field.
|
||||
func IPAddressLT(v string) predicate.AuthSession { |
||||
return predicate.AuthSession(sql.FieldLT(FieldIPAddress, v)) |
||||
} |
||||
|
||||
// IPAddressLTE applies the LTE predicate on the "ip_address" field.
|
||||
func IPAddressLTE(v string) predicate.AuthSession { |
||||
return predicate.AuthSession(sql.FieldLTE(FieldIPAddress, v)) |
||||
} |
||||
|
||||
// IPAddressContains applies the Contains predicate on the "ip_address" field.
|
||||
func IPAddressContains(v string) predicate.AuthSession { |
||||
return predicate.AuthSession(sql.FieldContains(FieldIPAddress, v)) |
||||
} |
||||
|
||||
// IPAddressHasPrefix applies the HasPrefix predicate on the "ip_address" field.
|
||||
func IPAddressHasPrefix(v string) predicate.AuthSession { |
||||
return predicate.AuthSession(sql.FieldHasPrefix(FieldIPAddress, v)) |
||||
} |
||||
|
||||
// IPAddressHasSuffix applies the HasSuffix predicate on the "ip_address" field.
|
||||
func IPAddressHasSuffix(v string) predicate.AuthSession { |
||||
return predicate.AuthSession(sql.FieldHasSuffix(FieldIPAddress, v)) |
||||
} |
||||
|
||||
// IPAddressEqualFold applies the EqualFold predicate on the "ip_address" field.
|
||||
func IPAddressEqualFold(v string) predicate.AuthSession { |
||||
return predicate.AuthSession(sql.FieldEqualFold(FieldIPAddress, v)) |
||||
} |
||||
|
||||
// IPAddressContainsFold applies the ContainsFold predicate on the "ip_address" field.
|
||||
func IPAddressContainsFold(v string) predicate.AuthSession { |
||||
return predicate.AuthSession(sql.FieldContainsFold(FieldIPAddress, v)) |
||||
} |
||||
|
||||
// UserAgentEQ applies the EQ predicate on the "user_agent" field.
|
||||
func UserAgentEQ(v string) predicate.AuthSession { |
||||
return predicate.AuthSession(sql.FieldEQ(FieldUserAgent, v)) |
||||
} |
||||
|
||||
// UserAgentNEQ applies the NEQ predicate on the "user_agent" field.
|
||||
func UserAgentNEQ(v string) predicate.AuthSession { |
||||
return predicate.AuthSession(sql.FieldNEQ(FieldUserAgent, v)) |
||||
} |
||||
|
||||
// UserAgentIn applies the In predicate on the "user_agent" field.
|
||||
func UserAgentIn(vs ...string) predicate.AuthSession { |
||||
return predicate.AuthSession(sql.FieldIn(FieldUserAgent, vs...)) |
||||
} |
||||
|
||||
// UserAgentNotIn applies the NotIn predicate on the "user_agent" field.
|
||||
func UserAgentNotIn(vs ...string) predicate.AuthSession { |
||||
return predicate.AuthSession(sql.FieldNotIn(FieldUserAgent, vs...)) |
||||
} |
||||
|
||||
// UserAgentGT applies the GT predicate on the "user_agent" field.
|
||||
func UserAgentGT(v string) predicate.AuthSession { |
||||
return predicate.AuthSession(sql.FieldGT(FieldUserAgent, v)) |
||||
} |
||||
|
||||
// UserAgentGTE applies the GTE predicate on the "user_agent" field.
|
||||
func UserAgentGTE(v string) predicate.AuthSession { |
||||
return predicate.AuthSession(sql.FieldGTE(FieldUserAgent, v)) |
||||
} |
||||
|
||||
// UserAgentLT applies the LT predicate on the "user_agent" field.
|
||||
func UserAgentLT(v string) predicate.AuthSession { |
||||
return predicate.AuthSession(sql.FieldLT(FieldUserAgent, v)) |
||||
} |
||||
|
||||
// UserAgentLTE applies the LTE predicate on the "user_agent" field.
|
||||
func UserAgentLTE(v string) predicate.AuthSession { |
||||
return predicate.AuthSession(sql.FieldLTE(FieldUserAgent, v)) |
||||
} |
||||
|
||||
// UserAgentContains applies the Contains predicate on the "user_agent" field.
|
||||
func UserAgentContains(v string) predicate.AuthSession { |
||||
return predicate.AuthSession(sql.FieldContains(FieldUserAgent, v)) |
||||
} |
||||
|
||||
// UserAgentHasPrefix applies the HasPrefix predicate on the "user_agent" field.
|
||||
func UserAgentHasPrefix(v string) predicate.AuthSession { |
||||
return predicate.AuthSession(sql.FieldHasPrefix(FieldUserAgent, v)) |
||||
} |
||||
|
||||
// UserAgentHasSuffix applies the HasSuffix predicate on the "user_agent" field.
|
||||
func UserAgentHasSuffix(v string) predicate.AuthSession { |
||||
return predicate.AuthSession(sql.FieldHasSuffix(FieldUserAgent, v)) |
||||
} |
||||
|
||||
// UserAgentEqualFold applies the EqualFold predicate on the "user_agent" field.
|
||||
func UserAgentEqualFold(v string) predicate.AuthSession { |
||||
return predicate.AuthSession(sql.FieldEqualFold(FieldUserAgent, v)) |
||||
} |
||||
|
||||
// UserAgentContainsFold applies the ContainsFold predicate on the "user_agent" field.
|
||||
func UserAgentContainsFold(v string) predicate.AuthSession { |
||||
return predicate.AuthSession(sql.FieldContainsFold(FieldUserAgent, v)) |
||||
} |
||||
|
||||
// And groups predicates with the AND operator between them.
|
||||
func And(predicates ...predicate.AuthSession) predicate.AuthSession { |
||||
return predicate.AuthSession(sql.AndPredicates(predicates...)) |
||||
} |
||||
|
||||
// Or groups predicates with the OR operator between them.
|
||||
func Or(predicates ...predicate.AuthSession) predicate.AuthSession { |
||||
return predicate.AuthSession(sql.OrPredicates(predicates...)) |
||||
} |
||||
|
||||
// Not applies the not operator on the given predicate.
|
||||
func Not(p predicate.AuthSession) predicate.AuthSession { |
||||
return predicate.AuthSession(sql.NotPredicates(p)) |
||||
} |
||||
@ -0,0 +1,282 @@
|
||||
// Code generated by ent, DO NOT EDIT.
|
||||
|
||||
package db |
||||
|
||||
import ( |
||||
"context" |
||||
"errors" |
||||
"fmt" |
||||
"time" |
||||
|
||||
"entgo.io/ent/dialect/sql/sqlgraph" |
||||
"entgo.io/ent/schema/field" |
||||
"github.com/dexidp/dex/storage/ent/db/authsession" |
||||
) |
||||
|
||||
// AuthSessionCreate is the builder for creating a AuthSession entity.
|
||||
type AuthSessionCreate struct { |
||||
config |
||||
mutation *AuthSessionMutation |
||||
hooks []Hook |
||||
} |
||||
|
||||
// SetClientStates sets the "client_states" field.
|
||||
func (_c *AuthSessionCreate) SetClientStates(v []byte) *AuthSessionCreate { |
||||
_c.mutation.SetClientStates(v) |
||||
return _c |
||||
} |
||||
|
||||
// SetCreatedAt sets the "created_at" field.
|
||||
func (_c *AuthSessionCreate) SetCreatedAt(v time.Time) *AuthSessionCreate { |
||||
_c.mutation.SetCreatedAt(v) |
||||
return _c |
||||
} |
||||
|
||||
// SetLastActivity sets the "last_activity" field.
|
||||
func (_c *AuthSessionCreate) SetLastActivity(v time.Time) *AuthSessionCreate { |
||||
_c.mutation.SetLastActivity(v) |
||||
return _c |
||||
} |
||||
|
||||
// SetIPAddress sets the "ip_address" field.
|
||||
func (_c *AuthSessionCreate) SetIPAddress(v string) *AuthSessionCreate { |
||||
_c.mutation.SetIPAddress(v) |
||||
return _c |
||||
} |
||||
|
||||
// SetNillableIPAddress sets the "ip_address" field if the given value is not nil.
|
||||
func (_c *AuthSessionCreate) SetNillableIPAddress(v *string) *AuthSessionCreate { |
||||
if v != nil { |
||||
_c.SetIPAddress(*v) |
||||
} |
||||
return _c |
||||
} |
||||
|
||||
// SetUserAgent sets the "user_agent" field.
|
||||
func (_c *AuthSessionCreate) SetUserAgent(v string) *AuthSessionCreate { |
||||
_c.mutation.SetUserAgent(v) |
||||
return _c |
||||
} |
||||
|
||||
// SetNillableUserAgent sets the "user_agent" field if the given value is not nil.
|
||||
func (_c *AuthSessionCreate) SetNillableUserAgent(v *string) *AuthSessionCreate { |
||||
if v != nil { |
||||
_c.SetUserAgent(*v) |
||||
} |
||||
return _c |
||||
} |
||||
|
||||
// SetID sets the "id" field.
|
||||
func (_c *AuthSessionCreate) SetID(v string) *AuthSessionCreate { |
||||
_c.mutation.SetID(v) |
||||
return _c |
||||
} |
||||
|
||||
// Mutation returns the AuthSessionMutation object of the builder.
|
||||
func (_c *AuthSessionCreate) Mutation() *AuthSessionMutation { |
||||
return _c.mutation |
||||
} |
||||
|
||||
// Save creates the AuthSession in the database.
|
||||
func (_c *AuthSessionCreate) Save(ctx context.Context) (*AuthSession, error) { |
||||
_c.defaults() |
||||
return withHooks(ctx, _c.sqlSave, _c.mutation, _c.hooks) |
||||
} |
||||
|
||||
// SaveX calls Save and panics if Save returns an error.
|
||||
func (_c *AuthSessionCreate) SaveX(ctx context.Context) *AuthSession { |
||||
v, err := _c.Save(ctx) |
||||
if err != nil { |
||||
panic(err) |
||||
} |
||||
return v |
||||
} |
||||
|
||||
// Exec executes the query.
|
||||
func (_c *AuthSessionCreate) Exec(ctx context.Context) error { |
||||
_, err := _c.Save(ctx) |
||||
return err |
||||
} |
||||
|
||||
// ExecX is like Exec, but panics if an error occurs.
|
||||
func (_c *AuthSessionCreate) ExecX(ctx context.Context) { |
||||
if err := _c.Exec(ctx); err != nil { |
||||
panic(err) |
||||
} |
||||
} |
||||
|
||||
// defaults sets the default values of the builder before save.
|
||||
func (_c *AuthSessionCreate) defaults() { |
||||
if _, ok := _c.mutation.IPAddress(); !ok { |
||||
v := authsession.DefaultIPAddress |
||||
_c.mutation.SetIPAddress(v) |
||||
} |
||||
if _, ok := _c.mutation.UserAgent(); !ok { |
||||
v := authsession.DefaultUserAgent |
||||
_c.mutation.SetUserAgent(v) |
||||
} |
||||
} |
||||
|
||||
// check runs all checks and user-defined validators on the builder.
|
||||
func (_c *AuthSessionCreate) check() error { |
||||
if _, ok := _c.mutation.ClientStates(); !ok { |
||||
return &ValidationError{Name: "client_states", err: errors.New(`db: missing required field "AuthSession.client_states"`)} |
||||
} |
||||
if _, ok := _c.mutation.CreatedAt(); !ok { |
||||
return &ValidationError{Name: "created_at", err: errors.New(`db: missing required field "AuthSession.created_at"`)} |
||||
} |
||||
if _, ok := _c.mutation.LastActivity(); !ok { |
||||
return &ValidationError{Name: "last_activity", err: errors.New(`db: missing required field "AuthSession.last_activity"`)} |
||||
} |
||||
if _, ok := _c.mutation.IPAddress(); !ok { |
||||
return &ValidationError{Name: "ip_address", err: errors.New(`db: missing required field "AuthSession.ip_address"`)} |
||||
} |
||||
if _, ok := _c.mutation.UserAgent(); !ok { |
||||
return &ValidationError{Name: "user_agent", err: errors.New(`db: missing required field "AuthSession.user_agent"`)} |
||||
} |
||||
if v, ok := _c.mutation.ID(); ok { |
||||
if err := authsession.IDValidator(v); err != nil { |
||||
return &ValidationError{Name: "id", err: fmt.Errorf(`db: validator failed for field "AuthSession.id": %w`, err)} |
||||
} |
||||
} |
||||
return nil |
||||
} |
||||
|
||||
func (_c *AuthSessionCreate) sqlSave(ctx context.Context) (*AuthSession, error) { |
||||
if err := _c.check(); err != nil { |
||||
return nil, err |
||||
} |
||||
_node, _spec := _c.createSpec() |
||||
if err := sqlgraph.CreateNode(ctx, _c.driver, _spec); err != nil { |
||||
if sqlgraph.IsConstraintError(err) { |
||||
err = &ConstraintError{msg: err.Error(), wrap: err} |
||||
} |
||||
return nil, err |
||||
} |
||||
if _spec.ID.Value != nil { |
||||
if id, ok := _spec.ID.Value.(string); ok { |
||||
_node.ID = id |
||||
} else { |
||||
return nil, fmt.Errorf("unexpected AuthSession.ID type: %T", _spec.ID.Value) |
||||
} |
||||
} |
||||
_c.mutation.id = &_node.ID |
||||
_c.mutation.done = true |
||||
return _node, nil |
||||
} |
||||
|
||||
func (_c *AuthSessionCreate) createSpec() (*AuthSession, *sqlgraph.CreateSpec) { |
||||
var ( |
||||
_node = &AuthSession{config: _c.config} |
||||
_spec = sqlgraph.NewCreateSpec(authsession.Table, sqlgraph.NewFieldSpec(authsession.FieldID, field.TypeString)) |
||||
) |
||||
if id, ok := _c.mutation.ID(); ok { |
||||
_node.ID = id |
||||
_spec.ID.Value = id |
||||
} |
||||
if value, ok := _c.mutation.ClientStates(); ok { |
||||
_spec.SetField(authsession.FieldClientStates, field.TypeBytes, value) |
||||
_node.ClientStates = value |
||||
} |
||||
if value, ok := _c.mutation.CreatedAt(); ok { |
||||
_spec.SetField(authsession.FieldCreatedAt, field.TypeTime, value) |
||||
_node.CreatedAt = value |
||||
} |
||||
if value, ok := _c.mutation.LastActivity(); ok { |
||||
_spec.SetField(authsession.FieldLastActivity, field.TypeTime, value) |
||||
_node.LastActivity = value |
||||
} |
||||
if value, ok := _c.mutation.IPAddress(); ok { |
||||
_spec.SetField(authsession.FieldIPAddress, field.TypeString, value) |
||||
_node.IPAddress = value |
||||
} |
||||
if value, ok := _c.mutation.UserAgent(); ok { |
||||
_spec.SetField(authsession.FieldUserAgent, field.TypeString, value) |
||||
_node.UserAgent = value |
||||
} |
||||
return _node, _spec |
||||
} |
||||
|
||||
// AuthSessionCreateBulk is the builder for creating many AuthSession entities in bulk.
|
||||
type AuthSessionCreateBulk struct { |
||||
config |
||||
err error |
||||
builders []*AuthSessionCreate |
||||
} |
||||
|
||||
// Save creates the AuthSession entities in the database.
|
||||
func (_c *AuthSessionCreateBulk) Save(ctx context.Context) ([]*AuthSession, error) { |
||||
if _c.err != nil { |
||||
return nil, _c.err |
||||
} |
||||
specs := make([]*sqlgraph.CreateSpec, len(_c.builders)) |
||||
nodes := make([]*AuthSession, len(_c.builders)) |
||||
mutators := make([]Mutator, len(_c.builders)) |
||||
for i := range _c.builders { |
||||
func(i int, root context.Context) { |
||||
builder := _c.builders[i] |
||||
builder.defaults() |
||||
var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { |
||||
mutation, ok := m.(*AuthSessionMutation) |
||||
if !ok { |
||||
return nil, fmt.Errorf("unexpected mutation type %T", m) |
||||
} |
||||
if err := builder.check(); err != nil { |
||||
return nil, err |
||||
} |
||||
builder.mutation = mutation |
||||
var err error |
||||
nodes[i], specs[i] = builder.createSpec() |
||||
if i < len(mutators)-1 { |
||||
_, err = mutators[i+1].Mutate(root, _c.builders[i+1].mutation) |
||||
} else { |
||||
spec := &sqlgraph.BatchCreateSpec{Nodes: specs} |
||||
// Invoke the actual operation on the latest mutation in the chain.
|
||||
if err = sqlgraph.BatchCreate(ctx, _c.driver, spec); err != nil { |
||||
if sqlgraph.IsConstraintError(err) { |
||||
err = &ConstraintError{msg: err.Error(), wrap: err} |
||||
} |
||||
} |
||||
} |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
mutation.id = &nodes[i].ID |
||||
mutation.done = true |
||||
return nodes[i], nil |
||||
}) |
||||
for i := len(builder.hooks) - 1; i >= 0; i-- { |
||||
mut = builder.hooks[i](mut) |
||||
} |
||||
mutators[i] = mut |
||||
}(i, ctx) |
||||
} |
||||
if len(mutators) > 0 { |
||||
if _, err := mutators[0].Mutate(ctx, _c.builders[0].mutation); err != nil { |
||||
return nil, err |
||||
} |
||||
} |
||||
return nodes, nil |
||||
} |
||||
|
||||
// SaveX is like Save, but panics if an error occurs.
|
||||
func (_c *AuthSessionCreateBulk) SaveX(ctx context.Context) []*AuthSession { |
||||
v, err := _c.Save(ctx) |
||||
if err != nil { |
||||
panic(err) |
||||
} |
||||
return v |
||||
} |
||||
|
||||
// Exec executes the query.
|
||||
func (_c *AuthSessionCreateBulk) Exec(ctx context.Context) error { |
||||
_, err := _c.Save(ctx) |
||||
return err |
||||
} |
||||
|
||||
// ExecX is like Exec, but panics if an error occurs.
|
||||
func (_c *AuthSessionCreateBulk) ExecX(ctx context.Context) { |
||||
if err := _c.Exec(ctx); err != nil { |
||||
panic(err) |
||||
} |
||||
} |
||||
@ -0,0 +1,88 @@
|
||||
// Code generated by ent, DO NOT EDIT.
|
||||
|
||||
package db |
||||
|
||||
import ( |
||||
"context" |
||||
|
||||
"entgo.io/ent/dialect/sql" |
||||
"entgo.io/ent/dialect/sql/sqlgraph" |
||||
"entgo.io/ent/schema/field" |
||||
"github.com/dexidp/dex/storage/ent/db/authsession" |
||||
"github.com/dexidp/dex/storage/ent/db/predicate" |
||||
) |
||||
|
||||
// AuthSessionDelete is the builder for deleting a AuthSession entity.
|
||||
type AuthSessionDelete struct { |
||||
config |
||||
hooks []Hook |
||||
mutation *AuthSessionMutation |
||||
} |
||||
|
||||
// Where appends a list predicates to the AuthSessionDelete builder.
|
||||
func (_d *AuthSessionDelete) Where(ps ...predicate.AuthSession) *AuthSessionDelete { |
||||
_d.mutation.Where(ps...) |
||||
return _d |
||||
} |
||||
|
||||
// Exec executes the deletion query and returns how many vertices were deleted.
|
||||
func (_d *AuthSessionDelete) Exec(ctx context.Context) (int, error) { |
||||
return withHooks(ctx, _d.sqlExec, _d.mutation, _d.hooks) |
||||
} |
||||
|
||||
// ExecX is like Exec, but panics if an error occurs.
|
||||
func (_d *AuthSessionDelete) ExecX(ctx context.Context) int { |
||||
n, err := _d.Exec(ctx) |
||||
if err != nil { |
||||
panic(err) |
||||
} |
||||
return n |
||||
} |
||||
|
||||
func (_d *AuthSessionDelete) sqlExec(ctx context.Context) (int, error) { |
||||
_spec := sqlgraph.NewDeleteSpec(authsession.Table, sqlgraph.NewFieldSpec(authsession.FieldID, field.TypeString)) |
||||
if ps := _d.mutation.predicates; len(ps) > 0 { |
||||
_spec.Predicate = func(selector *sql.Selector) { |
||||
for i := range ps { |
||||
ps[i](selector) |
||||
} |
||||
} |
||||
} |
||||
affected, err := sqlgraph.DeleteNodes(ctx, _d.driver, _spec) |
||||
if err != nil && sqlgraph.IsConstraintError(err) { |
||||
err = &ConstraintError{msg: err.Error(), wrap: err} |
||||
} |
||||
_d.mutation.done = true |
||||
return affected, err |
||||
} |
||||
|
||||
// AuthSessionDeleteOne is the builder for deleting a single AuthSession entity.
|
||||
type AuthSessionDeleteOne struct { |
||||
_d *AuthSessionDelete |
||||
} |
||||
|
||||
// Where appends a list predicates to the AuthSessionDelete builder.
|
||||
func (_d *AuthSessionDeleteOne) Where(ps ...predicate.AuthSession) *AuthSessionDeleteOne { |
||||
_d._d.mutation.Where(ps...) |
||||
return _d |
||||
} |
||||
|
||||
// Exec executes the deletion query.
|
||||
func (_d *AuthSessionDeleteOne) Exec(ctx context.Context) error { |
||||
n, err := _d._d.Exec(ctx) |
||||
switch { |
||||
case err != nil: |
||||
return err |
||||
case n == 0: |
||||
return &NotFoundError{authsession.Label} |
||||
default: |
||||
return nil |
||||
} |
||||
} |
||||
|
||||
// ExecX is like Exec, but panics if an error occurs.
|
||||
func (_d *AuthSessionDeleteOne) ExecX(ctx context.Context) { |
||||
if err := _d.Exec(ctx); err != nil { |
||||
panic(err) |
||||
} |
||||
} |
||||
@ -0,0 +1,527 @@
|
||||
// Code generated by ent, DO NOT EDIT.
|
||||
|
||||
package db |
||||
|
||||
import ( |
||||
"context" |
||||
"fmt" |
||||
"math" |
||||
|
||||
"entgo.io/ent" |
||||
"entgo.io/ent/dialect/sql" |
||||
"entgo.io/ent/dialect/sql/sqlgraph" |
||||
"entgo.io/ent/schema/field" |
||||
"github.com/dexidp/dex/storage/ent/db/authsession" |
||||
"github.com/dexidp/dex/storage/ent/db/predicate" |
||||
) |
||||
|
||||
// AuthSessionQuery is the builder for querying AuthSession entities.
|
||||
type AuthSessionQuery struct { |
||||
config |
||||
ctx *QueryContext |
||||
order []authsession.OrderOption |
||||
inters []Interceptor |
||||
predicates []predicate.AuthSession |
||||
// intermediate query (i.e. traversal path).
|
||||
sql *sql.Selector |
||||
path func(context.Context) (*sql.Selector, error) |
||||
} |
||||
|
||||
// Where adds a new predicate for the AuthSessionQuery builder.
|
||||
func (_q *AuthSessionQuery) Where(ps ...predicate.AuthSession) *AuthSessionQuery { |
||||
_q.predicates = append(_q.predicates, ps...) |
||||
return _q |
||||
} |
||||
|
||||
// Limit the number of records to be returned by this query.
|
||||
func (_q *AuthSessionQuery) Limit(limit int) *AuthSessionQuery { |
||||
_q.ctx.Limit = &limit |
||||
return _q |
||||
} |
||||
|
||||
// Offset to start from.
|
||||
func (_q *AuthSessionQuery) Offset(offset int) *AuthSessionQuery { |
||||
_q.ctx.Offset = &offset |
||||
return _q |
||||
} |
||||
|
||||
// Unique configures the query builder to filter duplicate records on query.
|
||||
// By default, unique is set to true, and can be disabled using this method.
|
||||
func (_q *AuthSessionQuery) Unique(unique bool) *AuthSessionQuery { |
||||
_q.ctx.Unique = &unique |
||||
return _q |
||||
} |
||||
|
||||
// Order specifies how the records should be ordered.
|
||||
func (_q *AuthSessionQuery) Order(o ...authsession.OrderOption) *AuthSessionQuery { |
||||
_q.order = append(_q.order, o...) |
||||
return _q |
||||
} |
||||
|
||||
// First returns the first AuthSession entity from the query.
|
||||
// Returns a *NotFoundError when no AuthSession was found.
|
||||
func (_q *AuthSessionQuery) First(ctx context.Context) (*AuthSession, error) { |
||||
nodes, err := _q.Limit(1).All(setContextOp(ctx, _q.ctx, ent.OpQueryFirst)) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
if len(nodes) == 0 { |
||||
return nil, &NotFoundError{authsession.Label} |
||||
} |
||||
return nodes[0], nil |
||||
} |
||||
|
||||
// FirstX is like First, but panics if an error occurs.
|
||||
func (_q *AuthSessionQuery) FirstX(ctx context.Context) *AuthSession { |
||||
node, err := _q.First(ctx) |
||||
if err != nil && !IsNotFound(err) { |
||||
panic(err) |
||||
} |
||||
return node |
||||
} |
||||
|
||||
// FirstID returns the first AuthSession ID from the query.
|
||||
// Returns a *NotFoundError when no AuthSession ID was found.
|
||||
func (_q *AuthSessionQuery) FirstID(ctx context.Context) (id string, err error) { |
||||
var ids []string |
||||
if ids, err = _q.Limit(1).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryFirstID)); err != nil { |
||||
return |
||||
} |
||||
if len(ids) == 0 { |
||||
err = &NotFoundError{authsession.Label} |
||||
return |
||||
} |
||||
return ids[0], nil |
||||
} |
||||
|
||||
// FirstIDX is like FirstID, but panics if an error occurs.
|
||||
func (_q *AuthSessionQuery) FirstIDX(ctx context.Context) string { |
||||
id, err := _q.FirstID(ctx) |
||||
if err != nil && !IsNotFound(err) { |
||||
panic(err) |
||||
} |
||||
return id |
||||
} |
||||
|
||||
// Only returns a single AuthSession entity found by the query, ensuring it only returns one.
|
||||
// Returns a *NotSingularError when more than one AuthSession entity is found.
|
||||
// Returns a *NotFoundError when no AuthSession entities are found.
|
||||
func (_q *AuthSessionQuery) Only(ctx context.Context) (*AuthSession, error) { |
||||
nodes, err := _q.Limit(2).All(setContextOp(ctx, _q.ctx, ent.OpQueryOnly)) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
switch len(nodes) { |
||||
case 1: |
||||
return nodes[0], nil |
||||
case 0: |
||||
return nil, &NotFoundError{authsession.Label} |
||||
default: |
||||
return nil, &NotSingularError{authsession.Label} |
||||
} |
||||
} |
||||
|
||||
// OnlyX is like Only, but panics if an error occurs.
|
||||
func (_q *AuthSessionQuery) OnlyX(ctx context.Context) *AuthSession { |
||||
node, err := _q.Only(ctx) |
||||
if err != nil { |
||||
panic(err) |
||||
} |
||||
return node |
||||
} |
||||
|
||||
// OnlyID is like Only, but returns the only AuthSession ID in the query.
|
||||
// Returns a *NotSingularError when more than one AuthSession ID is found.
|
||||
// Returns a *NotFoundError when no entities are found.
|
||||
func (_q *AuthSessionQuery) OnlyID(ctx context.Context) (id string, err error) { |
||||
var ids []string |
||||
if ids, err = _q.Limit(2).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryOnlyID)); err != nil { |
||||
return |
||||
} |
||||
switch len(ids) { |
||||
case 1: |
||||
id = ids[0] |
||||
case 0: |
||||
err = &NotFoundError{authsession.Label} |
||||
default: |
||||
err = &NotSingularError{authsession.Label} |
||||
} |
||||
return |
||||
} |
||||
|
||||
// OnlyIDX is like OnlyID, but panics if an error occurs.
|
||||
func (_q *AuthSessionQuery) OnlyIDX(ctx context.Context) string { |
||||
id, err := _q.OnlyID(ctx) |
||||
if err != nil { |
||||
panic(err) |
||||
} |
||||
return id |
||||
} |
||||
|
||||
// All executes the query and returns a list of AuthSessions.
|
||||
func (_q *AuthSessionQuery) All(ctx context.Context) ([]*AuthSession, error) { |
||||
ctx = setContextOp(ctx, _q.ctx, ent.OpQueryAll) |
||||
if err := _q.prepareQuery(ctx); err != nil { |
||||
return nil, err |
||||
} |
||||
qr := querierAll[[]*AuthSession, *AuthSessionQuery]() |
||||
return withInterceptors[[]*AuthSession](ctx, _q, qr, _q.inters) |
||||
} |
||||
|
||||
// AllX is like All, but panics if an error occurs.
|
||||
func (_q *AuthSessionQuery) AllX(ctx context.Context) []*AuthSession { |
||||
nodes, err := _q.All(ctx) |
||||
if err != nil { |
||||
panic(err) |
||||
} |
||||
return nodes |
||||
} |
||||
|
||||
// IDs executes the query and returns a list of AuthSession IDs.
|
||||
func (_q *AuthSessionQuery) IDs(ctx context.Context) (ids []string, err error) { |
||||
if _q.ctx.Unique == nil && _q.path != nil { |
||||
_q.Unique(true) |
||||
} |
||||
ctx = setContextOp(ctx, _q.ctx, ent.OpQueryIDs) |
||||
if err = _q.Select(authsession.FieldID).Scan(ctx, &ids); err != nil { |
||||
return nil, err |
||||
} |
||||
return ids, nil |
||||
} |
||||
|
||||
// IDsX is like IDs, but panics if an error occurs.
|
||||
func (_q *AuthSessionQuery) IDsX(ctx context.Context) []string { |
||||
ids, err := _q.IDs(ctx) |
||||
if err != nil { |
||||
panic(err) |
||||
} |
||||
return ids |
||||
} |
||||
|
||||
// Count returns the count of the given query.
|
||||
func (_q *AuthSessionQuery) Count(ctx context.Context) (int, error) { |
||||
ctx = setContextOp(ctx, _q.ctx, ent.OpQueryCount) |
||||
if err := _q.prepareQuery(ctx); err != nil { |
||||
return 0, err |
||||
} |
||||
return withInterceptors[int](ctx, _q, querierCount[*AuthSessionQuery](), _q.inters) |
||||
} |
||||
|
||||
// CountX is like Count, but panics if an error occurs.
|
||||
func (_q *AuthSessionQuery) CountX(ctx context.Context) int { |
||||
count, err := _q.Count(ctx) |
||||
if err != nil { |
||||
panic(err) |
||||
} |
||||
return count |
||||
} |
||||
|
||||
// Exist returns true if the query has elements in the graph.
|
||||
func (_q *AuthSessionQuery) Exist(ctx context.Context) (bool, error) { |
||||
ctx = setContextOp(ctx, _q.ctx, ent.OpQueryExist) |
||||
switch _, err := _q.FirstID(ctx); { |
||||
case IsNotFound(err): |
||||
return false, nil |
||||
case err != nil: |
||||
return false, fmt.Errorf("db: check existence: %w", err) |
||||
default: |
||||
return true, nil |
||||
} |
||||
} |
||||
|
||||
// ExistX is like Exist, but panics if an error occurs.
|
||||
func (_q *AuthSessionQuery) ExistX(ctx context.Context) bool { |
||||
exist, err := _q.Exist(ctx) |
||||
if err != nil { |
||||
panic(err) |
||||
} |
||||
return exist |
||||
} |
||||
|
||||
// Clone returns a duplicate of the AuthSessionQuery builder, including all associated steps. It can be
|
||||
// used to prepare common query builders and use them differently after the clone is made.
|
||||
func (_q *AuthSessionQuery) Clone() *AuthSessionQuery { |
||||
if _q == nil { |
||||
return nil |
||||
} |
||||
return &AuthSessionQuery{ |
||||
config: _q.config, |
||||
ctx: _q.ctx.Clone(), |
||||
order: append([]authsession.OrderOption{}, _q.order...), |
||||
inters: append([]Interceptor{}, _q.inters...), |
||||
predicates: append([]predicate.AuthSession{}, _q.predicates...), |
||||
// clone intermediate query.
|
||||
sql: _q.sql.Clone(), |
||||
path: _q.path, |
||||
} |
||||
} |
||||
|
||||
// GroupBy is used to group vertices by one or more fields/columns.
|
||||
// It is often used with aggregate functions, like: count, max, mean, min, sum.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// var v []struct {
|
||||
// ClientStates []byte `json:"client_states,omitempty"`
|
||||
// Count int `json:"count,omitempty"`
|
||||
// }
|
||||
//
|
||||
// client.AuthSession.Query().
|
||||
// GroupBy(authsession.FieldClientStates).
|
||||
// Aggregate(db.Count()).
|
||||
// Scan(ctx, &v)
|
||||
func (_q *AuthSessionQuery) GroupBy(field string, fields ...string) *AuthSessionGroupBy { |
||||
_q.ctx.Fields = append([]string{field}, fields...) |
||||
grbuild := &AuthSessionGroupBy{build: _q} |
||||
grbuild.flds = &_q.ctx.Fields |
||||
grbuild.label = authsession.Label |
||||
grbuild.scan = grbuild.Scan |
||||
return grbuild |
||||
} |
||||
|
||||
// Select allows the selection one or more fields/columns for the given query,
|
||||
// instead of selecting all fields in the entity.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// var v []struct {
|
||||
// ClientStates []byte `json:"client_states,omitempty"`
|
||||
// }
|
||||
//
|
||||
// client.AuthSession.Query().
|
||||
// Select(authsession.FieldClientStates).
|
||||
// Scan(ctx, &v)
|
||||
func (_q *AuthSessionQuery) Select(fields ...string) *AuthSessionSelect { |
||||
_q.ctx.Fields = append(_q.ctx.Fields, fields...) |
||||
sbuild := &AuthSessionSelect{AuthSessionQuery: _q} |
||||
sbuild.label = authsession.Label |
||||
sbuild.flds, sbuild.scan = &_q.ctx.Fields, sbuild.Scan |
||||
return sbuild |
||||
} |
||||
|
||||
// Aggregate returns a AuthSessionSelect configured with the given aggregations.
|
||||
func (_q *AuthSessionQuery) Aggregate(fns ...AggregateFunc) *AuthSessionSelect { |
||||
return _q.Select().Aggregate(fns...) |
||||
} |
||||
|
||||
func (_q *AuthSessionQuery) prepareQuery(ctx context.Context) error { |
||||
for _, inter := range _q.inters { |
||||
if inter == nil { |
||||
return fmt.Errorf("db: uninitialized interceptor (forgotten import db/runtime?)") |
||||
} |
||||
if trv, ok := inter.(Traverser); ok { |
||||
if err := trv.Traverse(ctx, _q); err != nil { |
||||
return err |
||||
} |
||||
} |
||||
} |
||||
for _, f := range _q.ctx.Fields { |
||||
if !authsession.ValidColumn(f) { |
||||
return &ValidationError{Name: f, err: fmt.Errorf("db: invalid field %q for query", f)} |
||||
} |
||||
} |
||||
if _q.path != nil { |
||||
prev, err := _q.path(ctx) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
_q.sql = prev |
||||
} |
||||
return nil |
||||
} |
||||
|
||||
func (_q *AuthSessionQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*AuthSession, error) { |
||||
var ( |
||||
nodes = []*AuthSession{} |
||||
_spec = _q.querySpec() |
||||
) |
||||
_spec.ScanValues = func(columns []string) ([]any, error) { |
||||
return (*AuthSession).scanValues(nil, columns) |
||||
} |
||||
_spec.Assign = func(columns []string, values []any) error { |
||||
node := &AuthSession{config: _q.config} |
||||
nodes = append(nodes, node) |
||||
return node.assignValues(columns, values) |
||||
} |
||||
for i := range hooks { |
||||
hooks[i](ctx, _spec) |
||||
} |
||||
if err := sqlgraph.QueryNodes(ctx, _q.driver, _spec); err != nil { |
||||
return nil, err |
||||
} |
||||
if len(nodes) == 0 { |
||||
return nodes, nil |
||||
} |
||||
return nodes, nil |
||||
} |
||||
|
||||
func (_q *AuthSessionQuery) sqlCount(ctx context.Context) (int, error) { |
||||
_spec := _q.querySpec() |
||||
_spec.Node.Columns = _q.ctx.Fields |
||||
if len(_q.ctx.Fields) > 0 { |
||||
_spec.Unique = _q.ctx.Unique != nil && *_q.ctx.Unique |
||||
} |
||||
return sqlgraph.CountNodes(ctx, _q.driver, _spec) |
||||
} |
||||
|
||||
func (_q *AuthSessionQuery) querySpec() *sqlgraph.QuerySpec { |
||||
_spec := sqlgraph.NewQuerySpec(authsession.Table, authsession.Columns, sqlgraph.NewFieldSpec(authsession.FieldID, field.TypeString)) |
||||
_spec.From = _q.sql |
||||
if unique := _q.ctx.Unique; unique != nil { |
||||
_spec.Unique = *unique |
||||
} else if _q.path != nil { |
||||
_spec.Unique = true |
||||
} |
||||
if fields := _q.ctx.Fields; len(fields) > 0 { |
||||
_spec.Node.Columns = make([]string, 0, len(fields)) |
||||
_spec.Node.Columns = append(_spec.Node.Columns, authsession.FieldID) |
||||
for i := range fields { |
||||
if fields[i] != authsession.FieldID { |
||||
_spec.Node.Columns = append(_spec.Node.Columns, fields[i]) |
||||
} |
||||
} |
||||
} |
||||
if ps := _q.predicates; len(ps) > 0 { |
||||
_spec.Predicate = func(selector *sql.Selector) { |
||||
for i := range ps { |
||||
ps[i](selector) |
||||
} |
||||
} |
||||
} |
||||
if limit := _q.ctx.Limit; limit != nil { |
||||
_spec.Limit = *limit |
||||
} |
||||
if offset := _q.ctx.Offset; offset != nil { |
||||
_spec.Offset = *offset |
||||
} |
||||
if ps := _q.order; len(ps) > 0 { |
||||
_spec.Order = func(selector *sql.Selector) { |
||||
for i := range ps { |
||||
ps[i](selector) |
||||
} |
||||
} |
||||
} |
||||
return _spec |
||||
} |
||||
|
||||
func (_q *AuthSessionQuery) sqlQuery(ctx context.Context) *sql.Selector { |
||||
builder := sql.Dialect(_q.driver.Dialect()) |
||||
t1 := builder.Table(authsession.Table) |
||||
columns := _q.ctx.Fields |
||||
if len(columns) == 0 { |
||||
columns = authsession.Columns |
||||
} |
||||
selector := builder.Select(t1.Columns(columns...)...).From(t1) |
||||
if _q.sql != nil { |
||||
selector = _q.sql |
||||
selector.Select(selector.Columns(columns...)...) |
||||
} |
||||
if _q.ctx.Unique != nil && *_q.ctx.Unique { |
||||
selector.Distinct() |
||||
} |
||||
for _, p := range _q.predicates { |
||||
p(selector) |
||||
} |
||||
for _, p := range _q.order { |
||||
p(selector) |
||||
} |
||||
if offset := _q.ctx.Offset; offset != nil { |
||||
// limit is mandatory for offset clause. We start
|
||||
// with default value, and override it below if needed.
|
||||
selector.Offset(*offset).Limit(math.MaxInt32) |
||||
} |
||||
if limit := _q.ctx.Limit; limit != nil { |
||||
selector.Limit(*limit) |
||||
} |
||||
return selector |
||||
} |
||||
|
||||
// AuthSessionGroupBy is the group-by builder for AuthSession entities.
|
||||
type AuthSessionGroupBy struct { |
||||
selector |
||||
build *AuthSessionQuery |
||||
} |
||||
|
||||
// Aggregate adds the given aggregation functions to the group-by query.
|
||||
func (_g *AuthSessionGroupBy) Aggregate(fns ...AggregateFunc) *AuthSessionGroupBy { |
||||
_g.fns = append(_g.fns, fns...) |
||||
return _g |
||||
} |
||||
|
||||
// Scan applies the selector query and scans the result into the given value.
|
||||
func (_g *AuthSessionGroupBy) Scan(ctx context.Context, v any) error { |
||||
ctx = setContextOp(ctx, _g.build.ctx, ent.OpQueryGroupBy) |
||||
if err := _g.build.prepareQuery(ctx); err != nil { |
||||
return err |
||||
} |
||||
return scanWithInterceptors[*AuthSessionQuery, *AuthSessionGroupBy](ctx, _g.build, _g, _g.build.inters, v) |
||||
} |
||||
|
||||
func (_g *AuthSessionGroupBy) sqlScan(ctx context.Context, root *AuthSessionQuery, v any) error { |
||||
selector := root.sqlQuery(ctx).Select() |
||||
aggregation := make([]string, 0, len(_g.fns)) |
||||
for _, fn := range _g.fns { |
||||
aggregation = append(aggregation, fn(selector)) |
||||
} |
||||
if len(selector.SelectedColumns()) == 0 { |
||||
columns := make([]string, 0, len(*_g.flds)+len(_g.fns)) |
||||
for _, f := range *_g.flds { |
||||
columns = append(columns, selector.C(f)) |
||||
} |
||||
columns = append(columns, aggregation...) |
||||
selector.Select(columns...) |
||||
} |
||||
selector.GroupBy(selector.Columns(*_g.flds...)...) |
||||
if err := selector.Err(); err != nil { |
||||
return err |
||||
} |
||||
rows := &sql.Rows{} |
||||
query, args := selector.Query() |
||||
if err := _g.build.driver.Query(ctx, query, args, rows); err != nil { |
||||
return err |
||||
} |
||||
defer rows.Close() |
||||
return sql.ScanSlice(rows, v) |
||||
} |
||||
|
||||
// AuthSessionSelect is the builder for selecting fields of AuthSession entities.
|
||||
type AuthSessionSelect struct { |
||||
*AuthSessionQuery |
||||
selector |
||||
} |
||||
|
||||
// Aggregate adds the given aggregation functions to the selector query.
|
||||
func (_s *AuthSessionSelect) Aggregate(fns ...AggregateFunc) *AuthSessionSelect { |
||||
_s.fns = append(_s.fns, fns...) |
||||
return _s |
||||
} |
||||
|
||||
// Scan applies the selector query and scans the result into the given value.
|
||||
func (_s *AuthSessionSelect) Scan(ctx context.Context, v any) error { |
||||
ctx = setContextOp(ctx, _s.ctx, ent.OpQuerySelect) |
||||
if err := _s.prepareQuery(ctx); err != nil { |
||||
return err |
||||
} |
||||
return scanWithInterceptors[*AuthSessionQuery, *AuthSessionSelect](ctx, _s.AuthSessionQuery, _s, _s.inters, v) |
||||
} |
||||
|
||||
func (_s *AuthSessionSelect) sqlScan(ctx context.Context, root *AuthSessionQuery, v any) error { |
||||
selector := root.sqlQuery(ctx) |
||||
aggregation := make([]string, 0, len(_s.fns)) |
||||
for _, fn := range _s.fns { |
||||
aggregation = append(aggregation, fn(selector)) |
||||
} |
||||
switch n := len(*_s.selector.flds); { |
||||
case n == 0 && len(aggregation) > 0: |
||||
selector.Select(aggregation...) |
||||
case n != 0 && len(aggregation) > 0: |
||||
selector.AppendSelect(aggregation...) |
||||
} |
||||
rows := &sql.Rows{} |
||||
query, args := selector.Query() |
||||
if err := _s.driver.Query(ctx, query, args, rows); err != nil { |
||||
return err |
||||
} |
||||
defer rows.Close() |
||||
return sql.ScanSlice(rows, v) |
||||
} |
||||
@ -0,0 +1,330 @@
|
||||
// Code generated by ent, DO NOT EDIT.
|
||||
|
||||
package db |
||||
|
||||
import ( |
||||
"context" |
||||
"errors" |
||||
"fmt" |
||||
"time" |
||||
|
||||
"entgo.io/ent/dialect/sql" |
||||
"entgo.io/ent/dialect/sql/sqlgraph" |
||||
"entgo.io/ent/schema/field" |
||||
"github.com/dexidp/dex/storage/ent/db/authsession" |
||||
"github.com/dexidp/dex/storage/ent/db/predicate" |
||||
) |
||||
|
||||
// AuthSessionUpdate is the builder for updating AuthSession entities.
|
||||
type AuthSessionUpdate struct { |
||||
config |
||||
hooks []Hook |
||||
mutation *AuthSessionMutation |
||||
} |
||||
|
||||
// Where appends a list predicates to the AuthSessionUpdate builder.
|
||||
func (_u *AuthSessionUpdate) Where(ps ...predicate.AuthSession) *AuthSessionUpdate { |
||||
_u.mutation.Where(ps...) |
||||
return _u |
||||
} |
||||
|
||||
// SetClientStates sets the "client_states" field.
|
||||
func (_u *AuthSessionUpdate) SetClientStates(v []byte) *AuthSessionUpdate { |
||||
_u.mutation.SetClientStates(v) |
||||
return _u |
||||
} |
||||
|
||||
// SetCreatedAt sets the "created_at" field.
|
||||
func (_u *AuthSessionUpdate) SetCreatedAt(v time.Time) *AuthSessionUpdate { |
||||
_u.mutation.SetCreatedAt(v) |
||||
return _u |
||||
} |
||||
|
||||
// SetNillableCreatedAt sets the "created_at" field if the given value is not nil.
|
||||
func (_u *AuthSessionUpdate) SetNillableCreatedAt(v *time.Time) *AuthSessionUpdate { |
||||
if v != nil { |
||||
_u.SetCreatedAt(*v) |
||||
} |
||||
return _u |
||||
} |
||||
|
||||
// SetLastActivity sets the "last_activity" field.
|
||||
func (_u *AuthSessionUpdate) SetLastActivity(v time.Time) *AuthSessionUpdate { |
||||
_u.mutation.SetLastActivity(v) |
||||
return _u |
||||
} |
||||
|
||||
// SetNillableLastActivity sets the "last_activity" field if the given value is not nil.
|
||||
func (_u *AuthSessionUpdate) SetNillableLastActivity(v *time.Time) *AuthSessionUpdate { |
||||
if v != nil { |
||||
_u.SetLastActivity(*v) |
||||
} |
||||
return _u |
||||
} |
||||
|
||||
// SetIPAddress sets the "ip_address" field.
|
||||
func (_u *AuthSessionUpdate) SetIPAddress(v string) *AuthSessionUpdate { |
||||
_u.mutation.SetIPAddress(v) |
||||
return _u |
||||
} |
||||
|
||||
// SetNillableIPAddress sets the "ip_address" field if the given value is not nil.
|
||||
func (_u *AuthSessionUpdate) SetNillableIPAddress(v *string) *AuthSessionUpdate { |
||||
if v != nil { |
||||
_u.SetIPAddress(*v) |
||||
} |
||||
return _u |
||||
} |
||||
|
||||
// SetUserAgent sets the "user_agent" field.
|
||||
func (_u *AuthSessionUpdate) SetUserAgent(v string) *AuthSessionUpdate { |
||||
_u.mutation.SetUserAgent(v) |
||||
return _u |
||||
} |
||||
|
||||
// SetNillableUserAgent sets the "user_agent" field if the given value is not nil.
|
||||
func (_u *AuthSessionUpdate) SetNillableUserAgent(v *string) *AuthSessionUpdate { |
||||
if v != nil { |
||||
_u.SetUserAgent(*v) |
||||
} |
||||
return _u |
||||
} |
||||
|
||||
// Mutation returns the AuthSessionMutation object of the builder.
|
||||
func (_u *AuthSessionUpdate) Mutation() *AuthSessionMutation { |
||||
return _u.mutation |
||||
} |
||||
|
||||
// Save executes the query and returns the number of nodes affected by the update operation.
|
||||
func (_u *AuthSessionUpdate) Save(ctx context.Context) (int, error) { |
||||
return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks) |
||||
} |
||||
|
||||
// SaveX is like Save, but panics if an error occurs.
|
||||
func (_u *AuthSessionUpdate) SaveX(ctx context.Context) int { |
||||
affected, err := _u.Save(ctx) |
||||
if err != nil { |
||||
panic(err) |
||||
} |
||||
return affected |
||||
} |
||||
|
||||
// Exec executes the query.
|
||||
func (_u *AuthSessionUpdate) Exec(ctx context.Context) error { |
||||
_, err := _u.Save(ctx) |
||||
return err |
||||
} |
||||
|
||||
// ExecX is like Exec, but panics if an error occurs.
|
||||
func (_u *AuthSessionUpdate) ExecX(ctx context.Context) { |
||||
if err := _u.Exec(ctx); err != nil { |
||||
panic(err) |
||||
} |
||||
} |
||||
|
||||
func (_u *AuthSessionUpdate) sqlSave(ctx context.Context) (_node int, err error) { |
||||
_spec := sqlgraph.NewUpdateSpec(authsession.Table, authsession.Columns, sqlgraph.NewFieldSpec(authsession.FieldID, field.TypeString)) |
||||
if ps := _u.mutation.predicates; len(ps) > 0 { |
||||
_spec.Predicate = func(selector *sql.Selector) { |
||||
for i := range ps { |
||||
ps[i](selector) |
||||
} |
||||
} |
||||
} |
||||
if value, ok := _u.mutation.ClientStates(); ok { |
||||
_spec.SetField(authsession.FieldClientStates, field.TypeBytes, value) |
||||
} |
||||
if value, ok := _u.mutation.CreatedAt(); ok { |
||||
_spec.SetField(authsession.FieldCreatedAt, field.TypeTime, value) |
||||
} |
||||
if value, ok := _u.mutation.LastActivity(); ok { |
||||
_spec.SetField(authsession.FieldLastActivity, field.TypeTime, value) |
||||
} |
||||
if value, ok := _u.mutation.IPAddress(); ok { |
||||
_spec.SetField(authsession.FieldIPAddress, field.TypeString, value) |
||||
} |
||||
if value, ok := _u.mutation.UserAgent(); ok { |
||||
_spec.SetField(authsession.FieldUserAgent, field.TypeString, value) |
||||
} |
||||
if _node, err = sqlgraph.UpdateNodes(ctx, _u.driver, _spec); err != nil { |
||||
if _, ok := err.(*sqlgraph.NotFoundError); ok { |
||||
err = &NotFoundError{authsession.Label} |
||||
} else if sqlgraph.IsConstraintError(err) { |
||||
err = &ConstraintError{msg: err.Error(), wrap: err} |
||||
} |
||||
return 0, err |
||||
} |
||||
_u.mutation.done = true |
||||
return _node, nil |
||||
} |
||||
|
||||
// AuthSessionUpdateOne is the builder for updating a single AuthSession entity.
|
||||
type AuthSessionUpdateOne struct { |
||||
config |
||||
fields []string |
||||
hooks []Hook |
||||
mutation *AuthSessionMutation |
||||
} |
||||
|
||||
// SetClientStates sets the "client_states" field.
|
||||
func (_u *AuthSessionUpdateOne) SetClientStates(v []byte) *AuthSessionUpdateOne { |
||||
_u.mutation.SetClientStates(v) |
||||
return _u |
||||
} |
||||
|
||||
// SetCreatedAt sets the "created_at" field.
|
||||
func (_u *AuthSessionUpdateOne) SetCreatedAt(v time.Time) *AuthSessionUpdateOne { |
||||
_u.mutation.SetCreatedAt(v) |
||||
return _u |
||||
} |
||||
|
||||
// SetNillableCreatedAt sets the "created_at" field if the given value is not nil.
|
||||
func (_u *AuthSessionUpdateOne) SetNillableCreatedAt(v *time.Time) *AuthSessionUpdateOne { |
||||
if v != nil { |
||||
_u.SetCreatedAt(*v) |
||||
} |
||||
return _u |
||||
} |
||||
|
||||
// SetLastActivity sets the "last_activity" field.
|
||||
func (_u *AuthSessionUpdateOne) SetLastActivity(v time.Time) *AuthSessionUpdateOne { |
||||
_u.mutation.SetLastActivity(v) |
||||
return _u |
||||
} |
||||
|
||||
// SetNillableLastActivity sets the "last_activity" field if the given value is not nil.
|
||||
func (_u *AuthSessionUpdateOne) SetNillableLastActivity(v *time.Time) *AuthSessionUpdateOne { |
||||
if v != nil { |
||||
_u.SetLastActivity(*v) |
||||
} |
||||
return _u |
||||
} |
||||
|
||||
// SetIPAddress sets the "ip_address" field.
|
||||
func (_u *AuthSessionUpdateOne) SetIPAddress(v string) *AuthSessionUpdateOne { |
||||
_u.mutation.SetIPAddress(v) |
||||
return _u |
||||
} |
||||
|
||||
// SetNillableIPAddress sets the "ip_address" field if the given value is not nil.
|
||||
func (_u *AuthSessionUpdateOne) SetNillableIPAddress(v *string) *AuthSessionUpdateOne { |
||||
if v != nil { |
||||
_u.SetIPAddress(*v) |
||||
} |
||||
return _u |
||||
} |
||||
|
||||
// SetUserAgent sets the "user_agent" field.
|
||||
func (_u *AuthSessionUpdateOne) SetUserAgent(v string) *AuthSessionUpdateOne { |
||||
_u.mutation.SetUserAgent(v) |
||||
return _u |
||||
} |
||||
|
||||
// SetNillableUserAgent sets the "user_agent" field if the given value is not nil.
|
||||
func (_u *AuthSessionUpdateOne) SetNillableUserAgent(v *string) *AuthSessionUpdateOne { |
||||
if v != nil { |
||||
_u.SetUserAgent(*v) |
||||
} |
||||
return _u |
||||
} |
||||
|
||||
// Mutation returns the AuthSessionMutation object of the builder.
|
||||
func (_u *AuthSessionUpdateOne) Mutation() *AuthSessionMutation { |
||||
return _u.mutation |
||||
} |
||||
|
||||
// Where appends a list predicates to the AuthSessionUpdate builder.
|
||||
func (_u *AuthSessionUpdateOne) Where(ps ...predicate.AuthSession) *AuthSessionUpdateOne { |
||||
_u.mutation.Where(ps...) |
||||
return _u |
||||
} |
||||
|
||||
// Select allows selecting one or more fields (columns) of the returned entity.
|
||||
// The default is selecting all fields defined in the entity schema.
|
||||
func (_u *AuthSessionUpdateOne) Select(field string, fields ...string) *AuthSessionUpdateOne { |
||||
_u.fields = append([]string{field}, fields...) |
||||
return _u |
||||
} |
||||
|
||||
// Save executes the query and returns the updated AuthSession entity.
|
||||
func (_u *AuthSessionUpdateOne) Save(ctx context.Context) (*AuthSession, error) { |
||||
return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks) |
||||
} |
||||
|
||||
// SaveX is like Save, but panics if an error occurs.
|
||||
func (_u *AuthSessionUpdateOne) SaveX(ctx context.Context) *AuthSession { |
||||
node, err := _u.Save(ctx) |
||||
if err != nil { |
||||
panic(err) |
||||
} |
||||
return node |
||||
} |
||||
|
||||
// Exec executes the query on the entity.
|
||||
func (_u *AuthSessionUpdateOne) Exec(ctx context.Context) error { |
||||
_, err := _u.Save(ctx) |
||||
return err |
||||
} |
||||
|
||||
// ExecX is like Exec, but panics if an error occurs.
|
||||
func (_u *AuthSessionUpdateOne) ExecX(ctx context.Context) { |
||||
if err := _u.Exec(ctx); err != nil { |
||||
panic(err) |
||||
} |
||||
} |
||||
|
||||
func (_u *AuthSessionUpdateOne) sqlSave(ctx context.Context) (_node *AuthSession, err error) { |
||||
_spec := sqlgraph.NewUpdateSpec(authsession.Table, authsession.Columns, sqlgraph.NewFieldSpec(authsession.FieldID, field.TypeString)) |
||||
id, ok := _u.mutation.ID() |
||||
if !ok { |
||||
return nil, &ValidationError{Name: "id", err: errors.New(`db: missing "AuthSession.id" for update`)} |
||||
} |
||||
_spec.Node.ID.Value = id |
||||
if fields := _u.fields; len(fields) > 0 { |
||||
_spec.Node.Columns = make([]string, 0, len(fields)) |
||||
_spec.Node.Columns = append(_spec.Node.Columns, authsession.FieldID) |
||||
for _, f := range fields { |
||||
if !authsession.ValidColumn(f) { |
||||
return nil, &ValidationError{Name: f, err: fmt.Errorf("db: invalid field %q for query", f)} |
||||
} |
||||
if f != authsession.FieldID { |
||||
_spec.Node.Columns = append(_spec.Node.Columns, f) |
||||
} |
||||
} |
||||
} |
||||
if ps := _u.mutation.predicates; len(ps) > 0 { |
||||
_spec.Predicate = func(selector *sql.Selector) { |
||||
for i := range ps { |
||||
ps[i](selector) |
||||
} |
||||
} |
||||
} |
||||
if value, ok := _u.mutation.ClientStates(); ok { |
||||
_spec.SetField(authsession.FieldClientStates, field.TypeBytes, value) |
||||
} |
||||
if value, ok := _u.mutation.CreatedAt(); ok { |
||||
_spec.SetField(authsession.FieldCreatedAt, field.TypeTime, value) |
||||
} |
||||
if value, ok := _u.mutation.LastActivity(); ok { |
||||
_spec.SetField(authsession.FieldLastActivity, field.TypeTime, value) |
||||
} |
||||
if value, ok := _u.mutation.IPAddress(); ok { |
||||
_spec.SetField(authsession.FieldIPAddress, field.TypeString, value) |
||||
} |
||||
if value, ok := _u.mutation.UserAgent(); ok { |
||||
_spec.SetField(authsession.FieldUserAgent, field.TypeString, value) |
||||
} |
||||
_node = &AuthSession{config: _u.config} |
||||
_spec.Assign = _node.assignValues |
||||
_spec.ScanValues = _node.scanValues |
||||
if err = sqlgraph.UpdateNode(ctx, _u.driver, _spec); err != nil { |
||||
if _, ok := err.(*sqlgraph.NotFoundError); ok { |
||||
err = &NotFoundError{authsession.Label} |
||||
} else if sqlgraph.IsConstraintError(err) { |
||||
err = &ConstraintError{msg: err.Error(), wrap: err} |
||||
} |
||||
return nil, err |
||||
} |
||||
_u.mutation.done = true |
||||
return _node, nil |
||||
} |
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in new issue