From 5a4395fd12de61478fc1b701e5d1fda29ce883c6 Mon Sep 17 00:00:00 2001 From: Maksim Nabokikh Date: Sat, 14 Mar 2026 12:58:18 +0100 Subject: [PATCH] feat: add UserIdentity entity and CRUD operations (#4643) Signed-off-by: maksim.nabokikh Signed-off-by: Maksim Nabokikh Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> --- pkg/featureflags/set.go | 3 + storage/conformance/conformance.go | 82 ++ storage/ent/client/offlinesession.go | 8 +- storage/ent/client/types.go | 33 + storage/ent/client/useridentity.go | 130 +++ storage/ent/client/utils.go | 8 +- storage/ent/db/client.go | 150 ++- storage/ent/db/ent.go | 2 + storage/ent/db/hook/hook.go | 12 + storage/ent/db/migrate/schema.go | 23 + storage/ent/db/mutation.go | 967 ++++++++++++++++++++ storage/ent/db/predicate/predicate.go | 3 + storage/ent/db/runtime.go | 35 + storage/ent/db/tx.go | 3 + storage/ent/db/useridentity.go | 232 +++++ storage/ent/db/useridentity/useridentity.go | 144 +++ storage/ent/db/useridentity/where.go | 705 ++++++++++++++ storage/ent/db/useridentity_create.go | 416 +++++++++ storage/ent/db/useridentity_delete.go | 88 ++ storage/ent/db/useridentity_query.go | 527 +++++++++++ storage/ent/db/useridentity_update.go | 629 +++++++++++++ storage/ent/schema/useridentity.go | 56 ++ storage/etcd/etcd.go | 60 ++ storage/etcd/types.go | 40 + storage/kubernetes/storage.go | 66 ++ storage/kubernetes/types.go | 76 ++ storage/memory/memory.go | 88 +- storage/sql/crud.go | 151 +++ storage/sql/migrate.go | 20 + storage/storage.go | 16 + 30 files changed, 4754 insertions(+), 19 deletions(-) create mode 100644 storage/ent/client/useridentity.go create mode 100644 storage/ent/db/useridentity.go create mode 100644 storage/ent/db/useridentity/useridentity.go create mode 100644 storage/ent/db/useridentity/where.go create mode 100644 storage/ent/db/useridentity_create.go create mode 100644 storage/ent/db/useridentity_delete.go create mode 100644 storage/ent/db/useridentity_query.go create mode 100644 storage/ent/db/useridentity_update.go create mode 100644 storage/ent/schema/useridentity.go diff --git a/pkg/featureflags/set.go b/pkg/featureflags/set.go index dcea1ca7..d3942979 100644 --- a/pkg/featureflags/set.go +++ b/pkg/featureflags/set.go @@ -21,4 +21,7 @@ var ( // ClientCredentialGrantEnabledByDefault enables the client_credentials grant type by default // without requiring explicit configuration in oauth2.grantTypes. ClientCredentialGrantEnabledByDefault = newFlag("client_credential_grant_enabled_by_default", false) + + // SessionsEnabled enables experimental auth sessions support. + SessionsEnabled = newFlag("sessions_enabled", false) ) diff --git a/storage/conformance/conformance.go b/storage/conformance/conformance.go index 92a4d140..ec3cae11 100644 --- a/storage/conformance/conformance.go +++ b/storage/conformance/conformance.go @@ -51,6 +51,7 @@ func RunTests(t *testing.T, newStorage func(t *testing.T) storage.Storage) { {"TimezoneSupport", testTimezones}, {"DeviceRequestCRUD", testDeviceRequestCRUD}, {"DeviceTokenCRUD", testDeviceTokenCRUD}, + {"UserIdentityCRUD", testUserIdentityCRUD}, }) } @@ -1084,3 +1085,84 @@ func testDeviceTokenCRUD(t *testing.T, s storage.Storage) { t.Fatalf("storage does not support PKCE, wanted challenge=%#v got %#v", codeChallenge, got.PKCE) } } + +func testUserIdentityCRUD(t *testing.T, s storage.Storage) { + ctx := t.Context() + + now := time.Now().UTC().Round(time.Millisecond) + + u1 := storage.UserIdentity{ + UserID: "user1", + ConnectorID: "conn1", + Claims: storage.Claims{ + UserID: "user1", + Username: "jane", + Email: "jane@example.com", + EmailVerified: true, + Groups: []string{"a", "b"}, + }, + Consents: make(map[string][]string), + CreatedAt: now, + LastLogin: now, + BlockedUntil: time.Unix(0, 0).UTC(), + } + + // Create with empty Consents map. + if err := s.CreateUserIdentity(ctx, u1); err != nil { + t.Fatalf("create user identity: %v", err) + } + + // Duplicate create should return ErrAlreadyExists. + err := s.CreateUserIdentity(ctx, u1) + mustBeErrAlreadyExists(t, "user identity", err) + + // Get and compare. + got, err := s.GetUserIdentity(ctx, u1.UserID, u1.ConnectorID) + if err != nil { + t.Fatalf("get user identity: %v", err) + } + + got.CreatedAt = got.CreatedAt.UTC().Round(time.Millisecond) + got.LastLogin = got.LastLogin.UTC().Round(time.Millisecond) + got.BlockedUntil = got.BlockedUntil.UTC().Round(time.Millisecond) + u1.BlockedUntil = u1.BlockedUntil.UTC().Round(time.Millisecond) + if diff := pretty.Compare(u1, got); diff != "" { + t.Errorf("user identity retrieved from storage did not match: %s", diff) + } + + // Update: add consent entry. + if err := s.UpdateUserIdentity(ctx, u1.UserID, u1.ConnectorID, func(old storage.UserIdentity) (storage.UserIdentity, error) { + old.Consents["client1"] = []string{"openid", "email"} + return old, nil + }); err != nil { + t.Fatalf("update user identity: %v", err) + } + + // Get and verify updated consents. + got, err = s.GetUserIdentity(ctx, u1.UserID, u1.ConnectorID) + if err != nil { + t.Fatalf("get user identity after update: %v", err) + } + wantConsents := map[string][]string{"client1": {"openid", "email"}} + if diff := pretty.Compare(wantConsents, got.Consents); diff != "" { + t.Errorf("user identity consents did not match after update: %s", diff) + } + + // List and verify. + identities, err := s.ListUserIdentities(ctx) + if err != nil { + t.Fatalf("list user identities: %v", err) + } + if len(identities) != 1 { + t.Fatalf("expected 1 user identity, got %d", len(identities)) + } + + // Delete. + if err := s.DeleteUserIdentity(ctx, u1.UserID, u1.ConnectorID); err != nil { + t.Fatalf("delete user identity: %v", err) + } + + // Get deleted should return ErrNotFound. + _, err = s.GetUserIdentity(ctx, u1.UserID, u1.ConnectorID) + mustBeErrNotFound(t, "user identity", err) +} diff --git a/storage/ent/client/offlinesession.go b/storage/ent/client/offlinesession.go index 9d608cb6..c8f18433 100644 --- a/storage/ent/client/offlinesession.go +++ b/storage/ent/client/offlinesession.go @@ -15,7 +15,7 @@ func (d *Database) CreateOfflineSessions(ctx context.Context, session storage.Of return fmt.Errorf("encode refresh offline session: %w", err) } - id := offlineSessionID(session.UserID, session.ConnID, d.hasher) + id := compositeKeyID(session.UserID, session.ConnID, d.hasher) _, err = d.client.OfflineSession.Create(). SetID(id). SetUserID(session.UserID). @@ -31,7 +31,7 @@ func (d *Database) CreateOfflineSessions(ctx context.Context, session storage.Of // GetOfflineSessions extracts an offline session from the database by user id and connector id. func (d *Database) GetOfflineSessions(ctx context.Context, userID, connID string) (storage.OfflineSessions, error) { - id := offlineSessionID(userID, connID, d.hasher) + id := compositeKeyID(userID, connID, d.hasher) offlineSession, err := d.client.OfflineSession.Get(ctx, id) if err != nil { @@ -42,7 +42,7 @@ func (d *Database) GetOfflineSessions(ctx context.Context, userID, connID string // DeleteOfflineSessions deletes an offline session from the database by user id and connector id. func (d *Database) DeleteOfflineSessions(ctx context.Context, userID, connID string) error { - id := offlineSessionID(userID, connID, d.hasher) + id := compositeKeyID(userID, connID, d.hasher) err := d.client.OfflineSession.DeleteOneID(id).Exec(ctx) if err != nil { @@ -53,7 +53,7 @@ func (d *Database) DeleteOfflineSessions(ctx context.Context, userID, connID str // UpdateOfflineSessions changes an offline session by user id and connector id using an updater function. func (d *Database) UpdateOfflineSessions(ctx context.Context, userID string, connID string, updater func(s storage.OfflineSessions) (storage.OfflineSessions, error)) error { - id := offlineSessionID(userID, connID, d.hasher) + id := compositeKeyID(userID, connID, d.hasher) tx, err := d.BeginTx(ctx) if err != nil { diff --git a/storage/ent/client/types.go b/storage/ent/client/types.go index a59fb6f5..535ab5f8 100644 --- a/storage/ent/client/types.go +++ b/storage/ent/client/types.go @@ -163,6 +163,39 @@ func toStorageDeviceRequest(r *db.DeviceRequest) storage.DeviceRequest { } } +func toStorageUserIdentity(u *db.UserIdentity) storage.UserIdentity { + s := storage.UserIdentity{ + UserID: u.UserID, + ConnectorID: u.ConnectorID, + Claims: storage.Claims{ + UserID: u.ClaimsUserID, + Username: u.ClaimsUsername, + PreferredUsername: u.ClaimsPreferredUsername, + Email: u.ClaimsEmail, + EmailVerified: u.ClaimsEmailVerified, + Groups: u.ClaimsGroups, + }, + CreatedAt: u.CreatedAt, + LastLogin: u.LastLogin, + BlockedUntil: u.BlockedUntil, + } + + if u.Consents != nil { + if err := json.Unmarshal(u.Consents, &s.Consents); err != nil { + // Correctness of json structure is guaranteed on uploading + panic(err) + } + if s.Consents == nil { + // Ensure Consents is non-nil even if JSON was "null". + s.Consents = make(map[string][]string) + } + } else { + // Server code assumes this will be non-nil. + s.Consents = make(map[string][]string) + } + return s +} + func toStorageDeviceToken(t *db.DeviceToken) storage.DeviceToken { return storage.DeviceToken{ DeviceCode: t.DeviceCode, diff --git a/storage/ent/client/useridentity.go b/storage/ent/client/useridentity.go new file mode 100644 index 00000000..1cf87919 --- /dev/null +++ b/storage/ent/client/useridentity.go @@ -0,0 +1,130 @@ +package client + +import ( + "context" + "encoding/json" + "fmt" + + "github.com/dexidp/dex/storage" +) + +// CreateUserIdentity saves provided user identity into the database. +func (d *Database) CreateUserIdentity(ctx context.Context, identity storage.UserIdentity) error { + if identity.Consents == nil { + identity.Consents = make(map[string][]string) + } + encodedConsents, err := json.Marshal(identity.Consents) + if err != nil { + return fmt.Errorf("encode consents user identity: %w", err) + } + + id := compositeKeyID(identity.UserID, identity.ConnectorID, d.hasher) + _, err = d.client.UserIdentity.Create(). + SetID(id). + SetUserID(identity.UserID). + SetConnectorID(identity.ConnectorID). + SetClaimsUserID(identity.Claims.UserID). + SetClaimsUsername(identity.Claims.Username). + SetClaimsPreferredUsername(identity.Claims.PreferredUsername). + SetClaimsEmail(identity.Claims.Email). + SetClaimsEmailVerified(identity.Claims.EmailVerified). + SetClaimsGroups(identity.Claims.Groups). + SetConsents(encodedConsents). + SetCreatedAt(identity.CreatedAt). + SetLastLogin(identity.LastLogin). + SetBlockedUntil(identity.BlockedUntil). + Save(ctx) + if err != nil { + return convertDBError("create user identity: %w", err) + } + return nil +} + +// GetUserIdentity extracts a user identity from the database by user id and connector id. +func (d *Database) GetUserIdentity(ctx context.Context, userID, connectorID string) (storage.UserIdentity, error) { + id := compositeKeyID(userID, connectorID, d.hasher) + + userIdentity, err := d.client.UserIdentity.Get(ctx, id) + if err != nil { + return storage.UserIdentity{}, convertDBError("get user identity: %w", err) + } + return toStorageUserIdentity(userIdentity), nil +} + +// DeleteUserIdentity deletes a user identity from the database by user id and connector id. +func (d *Database) DeleteUserIdentity(ctx context.Context, userID, connectorID string) error { + id := compositeKeyID(userID, connectorID, d.hasher) + + err := d.client.UserIdentity.DeleteOneID(id).Exec(ctx) + if err != nil { + return convertDBError("delete user identity: %w", err) + } + return nil +} + +// UpdateUserIdentity changes a user identity by user id and connector id using an updater function. +func (d *Database) UpdateUserIdentity(ctx context.Context, userID string, connectorID string, updater func(u storage.UserIdentity) (storage.UserIdentity, error)) error { + id := compositeKeyID(userID, connectorID, d.hasher) + + tx, err := d.BeginTx(ctx) + if err != nil { + return convertDBError("update user identity tx: %w", err) + } + + userIdentity, err := tx.UserIdentity.Get(ctx, id) + if err != nil { + return rollback(tx, "update user identity database: %w", err) + } + + newUserIdentity, err := updater(toStorageUserIdentity(userIdentity)) + if err != nil { + return rollback(tx, "update user identity updating: %w", err) + } + + if newUserIdentity.Consents == nil { + newUserIdentity.Consents = make(map[string][]string) + } + + encodedConsents, err := json.Marshal(newUserIdentity.Consents) + if err != nil { + return rollback(tx, "encode consents user identity: %w", err) + } + + _, err = tx.UserIdentity.UpdateOneID(id). + SetUserID(newUserIdentity.UserID). + SetConnectorID(newUserIdentity.ConnectorID). + SetClaimsUserID(newUserIdentity.Claims.UserID). + SetClaimsUsername(newUserIdentity.Claims.Username). + SetClaimsPreferredUsername(newUserIdentity.Claims.PreferredUsername). + SetClaimsEmail(newUserIdentity.Claims.Email). + SetClaimsEmailVerified(newUserIdentity.Claims.EmailVerified). + SetClaimsGroups(newUserIdentity.Claims.Groups). + SetConsents(encodedConsents). + SetCreatedAt(newUserIdentity.CreatedAt). + SetLastLogin(newUserIdentity.LastLogin). + SetBlockedUntil(newUserIdentity.BlockedUntil). + Save(ctx) + if err != nil { + return rollback(tx, "update user identity uploading: %w", err) + } + + if err = tx.Commit(); err != nil { + return rollback(tx, "update user identity commit: %w", err) + } + + return nil +} + +// ListUserIdentities lists all user identities in the database. +func (d *Database) ListUserIdentities(ctx context.Context) ([]storage.UserIdentity, error) { + userIdentities, err := d.client.UserIdentity.Query().All(ctx) + if err != nil { + return nil, convertDBError("list user identities: %w", err) + } + + storageUserIdentities := make([]storage.UserIdentity, 0, len(userIdentities)) + for _, u := range userIdentities { + storageUserIdentities = append(storageUserIdentities, toStorageUserIdentity(u)) + } + return storageUserIdentities, nil +} diff --git a/storage/ent/client/utils.go b/storage/ent/client/utils.go index 65c037ac..950e612c 100644 --- a/storage/ent/client/utils.go +++ b/storage/ent/client/utils.go @@ -32,13 +32,13 @@ func convertDBError(t string, err error) error { return fmt.Errorf(t, err) } -// compose hashed id from user and connection id to use it as primary key +// compositeKeyID composes a hashed id from two key parts to use as primary key. // ent doesn't support multi-key primary yet // https://github.com/facebook/ent/issues/400 -func offlineSessionID(userID string, connID string, hasher func() hash.Hash) string { +func compositeKeyID(first string, second string, hasher func() hash.Hash) string { h := hasher() - h.Write([]byte(userID)) - h.Write([]byte(connID)) + h.Write([]byte(first)) + h.Write([]byte(second)) return fmt.Sprintf("%x", h.Sum(nil)) } diff --git a/storage/ent/db/client.go b/storage/ent/db/client.go index 4fb28cb3..efb6da6b 100644 --- a/storage/ent/db/client.go +++ b/storage/ent/db/client.go @@ -24,6 +24,7 @@ import ( "github.com/dexidp/dex/storage/ent/db/offlinesession" "github.com/dexidp/dex/storage/ent/db/password" "github.com/dexidp/dex/storage/ent/db/refreshtoken" + "github.com/dexidp/dex/storage/ent/db/useridentity" ) // Client is the client that holds all ent builders. @@ -51,6 +52,8 @@ type Client struct { Password *PasswordClient // RefreshToken is the client for interacting with the RefreshToken builders. RefreshToken *RefreshTokenClient + // UserIdentity is the client for interacting with the UserIdentity builders. + UserIdentity *UserIdentityClient } // NewClient creates a new client configured with the given options. @@ -72,6 +75,7 @@ func (c *Client) init() { c.OfflineSession = NewOfflineSessionClient(c.config) c.Password = NewPasswordClient(c.config) c.RefreshToken = NewRefreshTokenClient(c.config) + c.UserIdentity = NewUserIdentityClient(c.config) } type ( @@ -174,6 +178,7 @@ func (c *Client) Tx(ctx context.Context) (*Tx, error) { OfflineSession: NewOfflineSessionClient(cfg), Password: NewPasswordClient(cfg), RefreshToken: NewRefreshTokenClient(cfg), + UserIdentity: NewUserIdentityClient(cfg), }, nil } @@ -203,6 +208,7 @@ func (c *Client) BeginTx(ctx context.Context, opts *sql.TxOptions) (*Tx, error) OfflineSession: NewOfflineSessionClient(cfg), Password: NewPasswordClient(cfg), RefreshToken: NewRefreshTokenClient(cfg), + UserIdentity: NewUserIdentityClient(cfg), }, nil } @@ -233,7 +239,7 @@ func (c *Client) Close() error { func (c *Client) Use(hooks ...Hook) { for _, n := range []interface{ Use(...Hook) }{ c.AuthCode, c.AuthRequest, c.Connector, c.DeviceRequest, c.DeviceToken, c.Keys, - c.OAuth2Client, c.OfflineSession, c.Password, c.RefreshToken, + c.OAuth2Client, c.OfflineSession, c.Password, c.RefreshToken, c.UserIdentity, } { n.Use(hooks...) } @@ -244,7 +250,7 @@ func (c *Client) Use(hooks ...Hook) { func (c *Client) Intercept(interceptors ...Interceptor) { for _, n := range []interface{ Intercept(...Interceptor) }{ c.AuthCode, c.AuthRequest, c.Connector, c.DeviceRequest, c.DeviceToken, c.Keys, - c.OAuth2Client, c.OfflineSession, c.Password, c.RefreshToken, + c.OAuth2Client, c.OfflineSession, c.Password, c.RefreshToken, c.UserIdentity, } { n.Intercept(interceptors...) } @@ -273,6 +279,8 @@ func (c *Client) Mutate(ctx context.Context, m Mutation) (Value, error) { return c.Password.mutate(ctx, m) case *RefreshTokenMutation: return c.RefreshToken.mutate(ctx, m) + case *UserIdentityMutation: + return c.UserIdentity.mutate(ctx, m) default: return nil, fmt.Errorf("db: unknown mutation type %T", m) } @@ -1608,14 +1616,148 @@ func (c *RefreshTokenClient) mutate(ctx context.Context, m *RefreshTokenMutation } } +// UserIdentityClient is a client for the UserIdentity schema. +type UserIdentityClient struct { + config +} + +// NewUserIdentityClient returns a client for the UserIdentity from the given config. +func NewUserIdentityClient(c config) *UserIdentityClient { + return &UserIdentityClient{config: c} +} + +// Use adds a list of mutation hooks to the hooks stack. +// A call to `Use(f, g, h)` equals to `useridentity.Hooks(f(g(h())))`. +func (c *UserIdentityClient) Use(hooks ...Hook) { + c.hooks.UserIdentity = append(c.hooks.UserIdentity, hooks...) +} + +// Intercept adds a list of query interceptors to the interceptors stack. +// A call to `Intercept(f, g, h)` equals to `useridentity.Intercept(f(g(h())))`. +func (c *UserIdentityClient) Intercept(interceptors ...Interceptor) { + c.inters.UserIdentity = append(c.inters.UserIdentity, interceptors...) +} + +// Create returns a builder for creating a UserIdentity entity. +func (c *UserIdentityClient) Create() *UserIdentityCreate { + mutation := newUserIdentityMutation(c.config, OpCreate) + return &UserIdentityCreate{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// CreateBulk returns a builder for creating a bulk of UserIdentity entities. +func (c *UserIdentityClient) CreateBulk(builders ...*UserIdentityCreate) *UserIdentityCreateBulk { + return &UserIdentityCreateBulk{config: c.config, builders: builders} +} + +// MapCreateBulk creates a bulk creation builder from the given slice. For each item in the slice, the function creates +// a builder and applies setFunc on it. +func (c *UserIdentityClient) MapCreateBulk(slice any, setFunc func(*UserIdentityCreate, int)) *UserIdentityCreateBulk { + rv := reflect.ValueOf(slice) + if rv.Kind() != reflect.Slice { + return &UserIdentityCreateBulk{err: fmt.Errorf("calling to UserIdentityClient.MapCreateBulk with wrong type %T, need slice", slice)} + } + builders := make([]*UserIdentityCreate, rv.Len()) + for i := 0; i < rv.Len(); i++ { + builders[i] = c.Create() + setFunc(builders[i], i) + } + return &UserIdentityCreateBulk{config: c.config, builders: builders} +} + +// Update returns an update builder for UserIdentity. +func (c *UserIdentityClient) Update() *UserIdentityUpdate { + mutation := newUserIdentityMutation(c.config, OpUpdate) + return &UserIdentityUpdate{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// UpdateOne returns an update builder for the given entity. +func (c *UserIdentityClient) UpdateOne(_m *UserIdentity) *UserIdentityUpdateOne { + mutation := newUserIdentityMutation(c.config, OpUpdateOne, withUserIdentity(_m)) + return &UserIdentityUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// UpdateOneID returns an update builder for the given id. +func (c *UserIdentityClient) UpdateOneID(id string) *UserIdentityUpdateOne { + mutation := newUserIdentityMutation(c.config, OpUpdateOne, withUserIdentityID(id)) + return &UserIdentityUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// Delete returns a delete builder for UserIdentity. +func (c *UserIdentityClient) Delete() *UserIdentityDelete { + mutation := newUserIdentityMutation(c.config, OpDelete) + return &UserIdentityDelete{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// DeleteOne returns a builder for deleting the given entity. +func (c *UserIdentityClient) DeleteOne(_m *UserIdentity) *UserIdentityDeleteOne { + return c.DeleteOneID(_m.ID) +} + +// DeleteOneID returns a builder for deleting the given entity by its id. +func (c *UserIdentityClient) DeleteOneID(id string) *UserIdentityDeleteOne { + builder := c.Delete().Where(useridentity.ID(id)) + builder.mutation.id = &id + builder.mutation.op = OpDeleteOne + return &UserIdentityDeleteOne{builder} +} + +// Query returns a query builder for UserIdentity. +func (c *UserIdentityClient) Query() *UserIdentityQuery { + return &UserIdentityQuery{ + config: c.config, + ctx: &QueryContext{Type: TypeUserIdentity}, + inters: c.Interceptors(), + } +} + +// Get returns a UserIdentity entity by its id. +func (c *UserIdentityClient) Get(ctx context.Context, id string) (*UserIdentity, error) { + return c.Query().Where(useridentity.ID(id)).Only(ctx) +} + +// GetX is like Get, but panics if an error occurs. +func (c *UserIdentityClient) GetX(ctx context.Context, id string) *UserIdentity { + obj, err := c.Get(ctx, id) + if err != nil { + panic(err) + } + return obj +} + +// Hooks returns the client hooks. +func (c *UserIdentityClient) Hooks() []Hook { + return c.hooks.UserIdentity +} + +// Interceptors returns the client interceptors. +func (c *UserIdentityClient) Interceptors() []Interceptor { + return c.inters.UserIdentity +} + +func (c *UserIdentityClient) mutate(ctx context.Context, m *UserIdentityMutation) (Value, error) { + switch m.Op() { + case OpCreate: + return (&UserIdentityCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdate: + return (&UserIdentityUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdateOne: + return (&UserIdentityUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpDelete, OpDeleteOne: + return (&UserIdentityDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx) + default: + return nil, fmt.Errorf("db: unknown UserIdentity mutation op: %q", m.Op()) + } +} + // hooks and interceptors per client, for fast access. type ( hooks struct { AuthCode, AuthRequest, Connector, DeviceRequest, DeviceToken, Keys, - OAuth2Client, OfflineSession, Password, RefreshToken []ent.Hook + OAuth2Client, OfflineSession, Password, RefreshToken, UserIdentity []ent.Hook } inters struct { AuthCode, AuthRequest, Connector, DeviceRequest, DeviceToken, Keys, - OAuth2Client, OfflineSession, Password, RefreshToken []ent.Interceptor + OAuth2Client, OfflineSession, Password, RefreshToken, + UserIdentity []ent.Interceptor } ) diff --git a/storage/ent/db/ent.go b/storage/ent/db/ent.go index 06bee261..fe965a79 100644 --- a/storage/ent/db/ent.go +++ b/storage/ent/db/ent.go @@ -22,6 +22,7 @@ import ( "github.com/dexidp/dex/storage/ent/db/offlinesession" "github.com/dexidp/dex/storage/ent/db/password" "github.com/dexidp/dex/storage/ent/db/refreshtoken" + "github.com/dexidp/dex/storage/ent/db/useridentity" ) // ent aliases to avoid import conflicts in user's code. @@ -92,6 +93,7 @@ func checkColumn(t, c string) error { offlinesession.Table: offlinesession.ValidColumn, password.Table: password.ValidColumn, refreshtoken.Table: refreshtoken.ValidColumn, + useridentity.Table: useridentity.ValidColumn, }) }) return columnCheck(t, c) diff --git a/storage/ent/db/hook/hook.go b/storage/ent/db/hook/hook.go index 12cb91c6..008dc7ee 100644 --- a/storage/ent/db/hook/hook.go +++ b/storage/ent/db/hook/hook.go @@ -129,6 +129,18 @@ func (f RefreshTokenFunc) Mutate(ctx context.Context, m db.Mutation) (db.Value, return nil, fmt.Errorf("unexpected mutation type %T. expect *db.RefreshTokenMutation", m) } +// The UserIdentityFunc type is an adapter to allow the use of ordinary +// function as UserIdentity mutator. +type UserIdentityFunc func(context.Context, *db.UserIdentityMutation) (db.Value, error) + +// Mutate calls f(ctx, m). +func (f UserIdentityFunc) Mutate(ctx context.Context, m db.Mutation) (db.Value, error) { + if mv, ok := m.(*db.UserIdentityMutation); ok { + return f(ctx, mv) + } + return nil, fmt.Errorf("unexpected mutation type %T. expect *db.UserIdentityMutation", m) +} + // Condition is a hook condition function. type Condition func(context.Context, db.Mutation) bool diff --git a/storage/ent/db/migrate/schema.go b/storage/ent/db/migrate/schema.go index b59e455b..1b9b61c9 100644 --- a/storage/ent/db/migrate/schema.go +++ b/storage/ent/db/migrate/schema.go @@ -200,6 +200,28 @@ var ( Columns: RefreshTokensColumns, PrimaryKey: []*schema.Column{RefreshTokensColumns[0]}, } + // UserIdentitiesColumns holds the columns for the "user_identities" table. + UserIdentitiesColumns = []*schema.Column{ + {Name: "id", Type: field.TypeString, Unique: true, Size: 2147483647, SchemaType: map[string]string{"mysql": "varchar(384)", "postgres": "text", "sqlite3": "text"}}, + {Name: "user_id", Type: field.TypeString, Size: 2147483647, SchemaType: map[string]string{"mysql": "varchar(384)", "postgres": "text", "sqlite3": "text"}}, + {Name: "connector_id", Type: field.TypeString, Size: 2147483647, SchemaType: map[string]string{"mysql": "varchar(384)", "postgres": "text", "sqlite3": "text"}}, + {Name: "claims_user_id", Type: field.TypeString, Size: 2147483647, Default: "", SchemaType: map[string]string{"mysql": "varchar(384)", "postgres": "text", "sqlite3": "text"}}, + {Name: "claims_username", Type: field.TypeString, Size: 2147483647, Default: "", SchemaType: map[string]string{"mysql": "varchar(384)", "postgres": "text", "sqlite3": "text"}}, + {Name: "claims_preferred_username", Type: field.TypeString, Size: 2147483647, Default: "", SchemaType: map[string]string{"mysql": "varchar(384)", "postgres": "text", "sqlite3": "text"}}, + {Name: "claims_email", Type: field.TypeString, Size: 2147483647, Default: "", SchemaType: map[string]string{"mysql": "varchar(384)", "postgres": "text", "sqlite3": "text"}}, + {Name: "claims_email_verified", Type: field.TypeBool, Default: false}, + {Name: "claims_groups", Type: field.TypeJSON, Nullable: true}, + {Name: "consents", Type: field.TypeBytes}, + {Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"mysql": "datetime(3)", "postgres": "timestamptz", "sqlite3": "timestamp"}}, + {Name: "last_login", Type: field.TypeTime, SchemaType: map[string]string{"mysql": "datetime(3)", "postgres": "timestamptz", "sqlite3": "timestamp"}}, + {Name: "blocked_until", Type: field.TypeTime, SchemaType: map[string]string{"mysql": "datetime(3)", "postgres": "timestamptz", "sqlite3": "timestamp"}}, + } + // UserIdentitiesTable holds the schema information for the "user_identities" table. + UserIdentitiesTable = &schema.Table{ + Name: "user_identities", + Columns: UserIdentitiesColumns, + PrimaryKey: []*schema.Column{UserIdentitiesColumns[0]}, + } // Tables holds all the tables in the schema. Tables = []*schema.Table{ AuthCodesTable, @@ -212,6 +234,7 @@ var ( OfflineSessionsTable, PasswordsTable, RefreshTokensTable, + UserIdentitiesTable, } ) diff --git a/storage/ent/db/mutation.go b/storage/ent/db/mutation.go index 73c7f095..8625471a 100644 --- a/storage/ent/db/mutation.go +++ b/storage/ent/db/mutation.go @@ -23,6 +23,7 @@ import ( "github.com/dexidp/dex/storage/ent/db/password" "github.com/dexidp/dex/storage/ent/db/predicate" "github.com/dexidp/dex/storage/ent/db/refreshtoken" + "github.com/dexidp/dex/storage/ent/db/useridentity" jose "github.com/go-jose/go-jose/v4" ) @@ -45,6 +46,7 @@ const ( TypeOfflineSession = "OfflineSession" TypePassword = "Password" TypeRefreshToken = "RefreshToken" + TypeUserIdentity = "UserIdentity" ) // AuthCodeMutation represents an operation that mutates the AuthCode nodes in the graph. @@ -8437,3 +8439,968 @@ func (m *RefreshTokenMutation) ClearEdge(name string) error { func (m *RefreshTokenMutation) ResetEdge(name string) error { return fmt.Errorf("unknown RefreshToken edge %s", name) } + +// UserIdentityMutation represents an operation that mutates the UserIdentity nodes in the graph. +type UserIdentityMutation struct { + config + op Op + typ string + id *string + user_id *string + connector_id *string + claims_user_id *string + claims_username *string + claims_preferred_username *string + claims_email *string + claims_email_verified *bool + claims_groups *[]string + appendclaims_groups []string + consents *[]byte + created_at *time.Time + last_login *time.Time + blocked_until *time.Time + clearedFields map[string]struct{} + done bool + oldValue func(context.Context) (*UserIdentity, error) + predicates []predicate.UserIdentity +} + +var _ ent.Mutation = (*UserIdentityMutation)(nil) + +// useridentityOption allows management of the mutation configuration using functional options. +type useridentityOption func(*UserIdentityMutation) + +// newUserIdentityMutation creates new mutation for the UserIdentity entity. +func newUserIdentityMutation(c config, op Op, opts ...useridentityOption) *UserIdentityMutation { + m := &UserIdentityMutation{ + config: c, + op: op, + typ: TypeUserIdentity, + clearedFields: make(map[string]struct{}), + } + for _, opt := range opts { + opt(m) + } + return m +} + +// withUserIdentityID sets the ID field of the mutation. +func withUserIdentityID(id string) useridentityOption { + return func(m *UserIdentityMutation) { + var ( + err error + once sync.Once + value *UserIdentity + ) + m.oldValue = func(ctx context.Context) (*UserIdentity, error) { + once.Do(func() { + if m.done { + err = errors.New("querying old values post mutation is not allowed") + } else { + value, err = m.Client().UserIdentity.Get(ctx, id) + } + }) + return value, err + } + m.id = &id + } +} + +// withUserIdentity sets the old UserIdentity of the mutation. +func withUserIdentity(node *UserIdentity) useridentityOption { + return func(m *UserIdentityMutation) { + m.oldValue = func(context.Context) (*UserIdentity, error) { + return node, nil + } + m.id = &node.ID + } +} + +// Client returns a new `ent.Client` from the mutation. If the mutation was +// executed in a transaction (ent.Tx), a transactional client is returned. +func (m UserIdentityMutation) Client() *Client { + client := &Client{config: m.config} + client.init() + return client +} + +// Tx returns an `ent.Tx` for mutations that were executed in transactions; +// it returns an error otherwise. +func (m UserIdentityMutation) Tx() (*Tx, error) { + if _, ok := m.driver.(*txDriver); !ok { + return nil, errors.New("db: mutation is not running in a transaction") + } + tx := &Tx{config: m.config} + tx.init() + return tx, nil +} + +// SetID sets the value of the id field. Note that this +// operation is only accepted on creation of UserIdentity entities. +func (m *UserIdentityMutation) SetID(id string) { + m.id = &id +} + +// ID returns the ID value in the mutation. Note that the ID is only available +// if it was provided to the builder or after it was returned from the database. +func (m *UserIdentityMutation) ID() (id string, exists bool) { + if m.id == nil { + return + } + return *m.id, true +} + +// IDs queries the database and returns the entity ids that match the mutation's predicate. +// That means, if the mutation is applied within a transaction with an isolation level such +// as sql.LevelSerializable, the returned ids match the ids of the rows that will be updated +// or updated by the mutation. +func (m *UserIdentityMutation) IDs(ctx context.Context) ([]string, error) { + switch { + case m.op.Is(OpUpdateOne | OpDeleteOne): + id, exists := m.ID() + if exists { + return []string{id}, nil + } + fallthrough + case m.op.Is(OpUpdate | OpDelete): + return m.Client().UserIdentity.Query().Where(m.predicates...).IDs(ctx) + default: + return nil, fmt.Errorf("IDs is not allowed on %s operations", m.op) + } +} + +// SetUserID sets the "user_id" field. +func (m *UserIdentityMutation) SetUserID(s string) { + m.user_id = &s +} + +// UserID returns the value of the "user_id" field in the mutation. +func (m *UserIdentityMutation) UserID() (r string, exists bool) { + v := m.user_id + if v == nil { + return + } + return *v, true +} + +// OldUserID returns the old "user_id" field's value of the UserIdentity entity. +// If the UserIdentity object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UserIdentityMutation) OldUserID(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldUserID is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldUserID requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldUserID: %w", err) + } + return oldValue.UserID, nil +} + +// ResetUserID resets all changes to the "user_id" field. +func (m *UserIdentityMutation) ResetUserID() { + m.user_id = nil +} + +// SetConnectorID sets the "connector_id" field. +func (m *UserIdentityMutation) SetConnectorID(s string) { + m.connector_id = &s +} + +// ConnectorID returns the value of the "connector_id" field in the mutation. +func (m *UserIdentityMutation) ConnectorID() (r string, exists bool) { + v := m.connector_id + if v == nil { + return + } + return *v, true +} + +// OldConnectorID returns the old "connector_id" field's value of the UserIdentity entity. +// If the UserIdentity object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UserIdentityMutation) OldConnectorID(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldConnectorID is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldConnectorID requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldConnectorID: %w", err) + } + return oldValue.ConnectorID, nil +} + +// ResetConnectorID resets all changes to the "connector_id" field. +func (m *UserIdentityMutation) ResetConnectorID() { + m.connector_id = nil +} + +// SetClaimsUserID sets the "claims_user_id" field. +func (m *UserIdentityMutation) SetClaimsUserID(s string) { + m.claims_user_id = &s +} + +// ClaimsUserID returns the value of the "claims_user_id" field in the mutation. +func (m *UserIdentityMutation) ClaimsUserID() (r string, exists bool) { + v := m.claims_user_id + if v == nil { + return + } + return *v, true +} + +// OldClaimsUserID returns the old "claims_user_id" field's value of the UserIdentity entity. +// If the UserIdentity object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UserIdentityMutation) OldClaimsUserID(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldClaimsUserID is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldClaimsUserID requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldClaimsUserID: %w", err) + } + return oldValue.ClaimsUserID, nil +} + +// ResetClaimsUserID resets all changes to the "claims_user_id" field. +func (m *UserIdentityMutation) ResetClaimsUserID() { + m.claims_user_id = nil +} + +// SetClaimsUsername sets the "claims_username" field. +func (m *UserIdentityMutation) SetClaimsUsername(s string) { + m.claims_username = &s +} + +// ClaimsUsername returns the value of the "claims_username" field in the mutation. +func (m *UserIdentityMutation) ClaimsUsername() (r string, exists bool) { + v := m.claims_username + if v == nil { + return + } + return *v, true +} + +// OldClaimsUsername returns the old "claims_username" field's value of the UserIdentity entity. +// If the UserIdentity object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UserIdentityMutation) OldClaimsUsername(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldClaimsUsername is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldClaimsUsername requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldClaimsUsername: %w", err) + } + return oldValue.ClaimsUsername, nil +} + +// ResetClaimsUsername resets all changes to the "claims_username" field. +func (m *UserIdentityMutation) ResetClaimsUsername() { + m.claims_username = nil +} + +// SetClaimsPreferredUsername sets the "claims_preferred_username" field. +func (m *UserIdentityMutation) SetClaimsPreferredUsername(s string) { + m.claims_preferred_username = &s +} + +// ClaimsPreferredUsername returns the value of the "claims_preferred_username" field in the mutation. +func (m *UserIdentityMutation) ClaimsPreferredUsername() (r string, exists bool) { + v := m.claims_preferred_username + if v == nil { + return + } + return *v, true +} + +// OldClaimsPreferredUsername returns the old "claims_preferred_username" field's value of the UserIdentity entity. +// If the UserIdentity object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UserIdentityMutation) OldClaimsPreferredUsername(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldClaimsPreferredUsername is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldClaimsPreferredUsername requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldClaimsPreferredUsername: %w", err) + } + return oldValue.ClaimsPreferredUsername, nil +} + +// ResetClaimsPreferredUsername resets all changes to the "claims_preferred_username" field. +func (m *UserIdentityMutation) ResetClaimsPreferredUsername() { + m.claims_preferred_username = nil +} + +// SetClaimsEmail sets the "claims_email" field. +func (m *UserIdentityMutation) SetClaimsEmail(s string) { + m.claims_email = &s +} + +// ClaimsEmail returns the value of the "claims_email" field in the mutation. +func (m *UserIdentityMutation) ClaimsEmail() (r string, exists bool) { + v := m.claims_email + if v == nil { + return + } + return *v, true +} + +// OldClaimsEmail returns the old "claims_email" field's value of the UserIdentity entity. +// If the UserIdentity object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UserIdentityMutation) OldClaimsEmail(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldClaimsEmail is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldClaimsEmail requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldClaimsEmail: %w", err) + } + return oldValue.ClaimsEmail, nil +} + +// ResetClaimsEmail resets all changes to the "claims_email" field. +func (m *UserIdentityMutation) ResetClaimsEmail() { + m.claims_email = nil +} + +// SetClaimsEmailVerified sets the "claims_email_verified" field. +func (m *UserIdentityMutation) SetClaimsEmailVerified(b bool) { + m.claims_email_verified = &b +} + +// ClaimsEmailVerified returns the value of the "claims_email_verified" field in the mutation. +func (m *UserIdentityMutation) ClaimsEmailVerified() (r bool, exists bool) { + v := m.claims_email_verified + if v == nil { + return + } + return *v, true +} + +// OldClaimsEmailVerified returns the old "claims_email_verified" field's value of the UserIdentity entity. +// If the UserIdentity object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UserIdentityMutation) OldClaimsEmailVerified(ctx context.Context) (v bool, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldClaimsEmailVerified is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldClaimsEmailVerified requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldClaimsEmailVerified: %w", err) + } + return oldValue.ClaimsEmailVerified, nil +} + +// ResetClaimsEmailVerified resets all changes to the "claims_email_verified" field. +func (m *UserIdentityMutation) ResetClaimsEmailVerified() { + m.claims_email_verified = nil +} + +// SetClaimsGroups sets the "claims_groups" field. +func (m *UserIdentityMutation) SetClaimsGroups(s []string) { + m.claims_groups = &s + m.appendclaims_groups = nil +} + +// ClaimsGroups returns the value of the "claims_groups" field in the mutation. +func (m *UserIdentityMutation) ClaimsGroups() (r []string, exists bool) { + v := m.claims_groups + if v == nil { + return + } + return *v, true +} + +// OldClaimsGroups returns the old "claims_groups" field's value of the UserIdentity entity. +// If the UserIdentity object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UserIdentityMutation) OldClaimsGroups(ctx context.Context) (v []string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldClaimsGroups is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldClaimsGroups requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldClaimsGroups: %w", err) + } + return oldValue.ClaimsGroups, nil +} + +// AppendClaimsGroups adds s to the "claims_groups" field. +func (m *UserIdentityMutation) AppendClaimsGroups(s []string) { + m.appendclaims_groups = append(m.appendclaims_groups, s...) +} + +// AppendedClaimsGroups returns the list of values that were appended to the "claims_groups" field in this mutation. +func (m *UserIdentityMutation) AppendedClaimsGroups() ([]string, bool) { + if len(m.appendclaims_groups) == 0 { + return nil, false + } + return m.appendclaims_groups, true +} + +// ClearClaimsGroups clears the value of the "claims_groups" field. +func (m *UserIdentityMutation) ClearClaimsGroups() { + m.claims_groups = nil + m.appendclaims_groups = nil + m.clearedFields[useridentity.FieldClaimsGroups] = struct{}{} +} + +// ClaimsGroupsCleared returns if the "claims_groups" field was cleared in this mutation. +func (m *UserIdentityMutation) ClaimsGroupsCleared() bool { + _, ok := m.clearedFields[useridentity.FieldClaimsGroups] + return ok +} + +// ResetClaimsGroups resets all changes to the "claims_groups" field. +func (m *UserIdentityMutation) ResetClaimsGroups() { + m.claims_groups = nil + m.appendclaims_groups = nil + delete(m.clearedFields, useridentity.FieldClaimsGroups) +} + +// SetConsents sets the "consents" field. +func (m *UserIdentityMutation) SetConsents(b []byte) { + m.consents = &b +} + +// Consents returns the value of the "consents" field in the mutation. +func (m *UserIdentityMutation) Consents() (r []byte, exists bool) { + v := m.consents + if v == nil { + return + } + return *v, true +} + +// OldConsents returns the old "consents" field's value of the UserIdentity entity. +// If the UserIdentity object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UserIdentityMutation) OldConsents(ctx context.Context) (v []byte, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldConsents is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldConsents requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldConsents: %w", err) + } + return oldValue.Consents, nil +} + +// ResetConsents resets all changes to the "consents" field. +func (m *UserIdentityMutation) ResetConsents() { + m.consents = nil +} + +// SetCreatedAt sets the "created_at" field. +func (m *UserIdentityMutation) SetCreatedAt(t time.Time) { + m.created_at = &t +} + +// CreatedAt returns the value of the "created_at" field in the mutation. +func (m *UserIdentityMutation) CreatedAt() (r time.Time, exists bool) { + v := m.created_at + if v == nil { + return + } + return *v, true +} + +// OldCreatedAt returns the old "created_at" field's value of the UserIdentity entity. +// If the UserIdentity object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UserIdentityMutation) OldCreatedAt(ctx context.Context) (v time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldCreatedAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldCreatedAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldCreatedAt: %w", err) + } + return oldValue.CreatedAt, nil +} + +// ResetCreatedAt resets all changes to the "created_at" field. +func (m *UserIdentityMutation) ResetCreatedAt() { + m.created_at = nil +} + +// SetLastLogin sets the "last_login" field. +func (m *UserIdentityMutation) SetLastLogin(t time.Time) { + m.last_login = &t +} + +// LastLogin returns the value of the "last_login" field in the mutation. +func (m *UserIdentityMutation) LastLogin() (r time.Time, exists bool) { + v := m.last_login + if v == nil { + return + } + return *v, true +} + +// OldLastLogin returns the old "last_login" field's value of the UserIdentity entity. +// If the UserIdentity object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UserIdentityMutation) OldLastLogin(ctx context.Context) (v time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldLastLogin is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldLastLogin requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldLastLogin: %w", err) + } + return oldValue.LastLogin, nil +} + +// ResetLastLogin resets all changes to the "last_login" field. +func (m *UserIdentityMutation) ResetLastLogin() { + m.last_login = nil +} + +// SetBlockedUntil sets the "blocked_until" field. +func (m *UserIdentityMutation) SetBlockedUntil(t time.Time) { + m.blocked_until = &t +} + +// BlockedUntil returns the value of the "blocked_until" field in the mutation. +func (m *UserIdentityMutation) BlockedUntil() (r time.Time, exists bool) { + v := m.blocked_until + if v == nil { + return + } + return *v, true +} + +// OldBlockedUntil returns the old "blocked_until" field's value of the UserIdentity entity. +// If the UserIdentity object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UserIdentityMutation) OldBlockedUntil(ctx context.Context) (v time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldBlockedUntil is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldBlockedUntil requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldBlockedUntil: %w", err) + } + return oldValue.BlockedUntil, nil +} + +// ResetBlockedUntil resets all changes to the "blocked_until" field. +func (m *UserIdentityMutation) ResetBlockedUntil() { + m.blocked_until = nil +} + +// Where appends a list predicates to the UserIdentityMutation builder. +func (m *UserIdentityMutation) Where(ps ...predicate.UserIdentity) { + m.predicates = append(m.predicates, ps...) +} + +// WhereP appends storage-level predicates to the UserIdentityMutation builder. Using this method, +// users can use type-assertion to append predicates that do not depend on any generated package. +func (m *UserIdentityMutation) WhereP(ps ...func(*sql.Selector)) { + p := make([]predicate.UserIdentity, len(ps)) + for i := range ps { + p[i] = ps[i] + } + m.Where(p...) +} + +// Op returns the operation name. +func (m *UserIdentityMutation) Op() Op { + return m.op +} + +// SetOp allows setting the mutation operation. +func (m *UserIdentityMutation) SetOp(op Op) { + m.op = op +} + +// Type returns the node type of this mutation (UserIdentity). +func (m *UserIdentityMutation) Type() string { + return m.typ +} + +// Fields returns all fields that were changed during this mutation. Note that in +// order to get all numeric fields that were incremented/decremented, call +// AddedFields(). +func (m *UserIdentityMutation) Fields() []string { + fields := make([]string, 0, 12) + if m.user_id != nil { + fields = append(fields, useridentity.FieldUserID) + } + if m.connector_id != nil { + fields = append(fields, useridentity.FieldConnectorID) + } + if m.claims_user_id != nil { + fields = append(fields, useridentity.FieldClaimsUserID) + } + if m.claims_username != nil { + fields = append(fields, useridentity.FieldClaimsUsername) + } + if m.claims_preferred_username != nil { + fields = append(fields, useridentity.FieldClaimsPreferredUsername) + } + if m.claims_email != nil { + fields = append(fields, useridentity.FieldClaimsEmail) + } + if m.claims_email_verified != nil { + fields = append(fields, useridentity.FieldClaimsEmailVerified) + } + if m.claims_groups != nil { + fields = append(fields, useridentity.FieldClaimsGroups) + } + if m.consents != nil { + fields = append(fields, useridentity.FieldConsents) + } + if m.created_at != nil { + fields = append(fields, useridentity.FieldCreatedAt) + } + if m.last_login != nil { + fields = append(fields, useridentity.FieldLastLogin) + } + if m.blocked_until != nil { + fields = append(fields, useridentity.FieldBlockedUntil) + } + return fields +} + +// Field returns the value of a field with the given name. The second boolean +// return value indicates that this field was not set, or was not defined in the +// schema. +func (m *UserIdentityMutation) Field(name string) (ent.Value, bool) { + switch name { + case useridentity.FieldUserID: + return m.UserID() + case useridentity.FieldConnectorID: + return m.ConnectorID() + case useridentity.FieldClaimsUserID: + return m.ClaimsUserID() + case useridentity.FieldClaimsUsername: + return m.ClaimsUsername() + case useridentity.FieldClaimsPreferredUsername: + return m.ClaimsPreferredUsername() + case useridentity.FieldClaimsEmail: + return m.ClaimsEmail() + case useridentity.FieldClaimsEmailVerified: + return m.ClaimsEmailVerified() + case useridentity.FieldClaimsGroups: + return m.ClaimsGroups() + case useridentity.FieldConsents: + return m.Consents() + case useridentity.FieldCreatedAt: + return m.CreatedAt() + case useridentity.FieldLastLogin: + return m.LastLogin() + case useridentity.FieldBlockedUntil: + return m.BlockedUntil() + } + return nil, false +} + +// OldField returns the old value of the field from the database. An error is +// returned if the mutation operation is not UpdateOne, or the query to the +// database failed. +func (m *UserIdentityMutation) OldField(ctx context.Context, name string) (ent.Value, error) { + switch name { + case useridentity.FieldUserID: + return m.OldUserID(ctx) + case useridentity.FieldConnectorID: + return m.OldConnectorID(ctx) + case useridentity.FieldClaimsUserID: + return m.OldClaimsUserID(ctx) + case useridentity.FieldClaimsUsername: + return m.OldClaimsUsername(ctx) + case useridentity.FieldClaimsPreferredUsername: + return m.OldClaimsPreferredUsername(ctx) + case useridentity.FieldClaimsEmail: + return m.OldClaimsEmail(ctx) + case useridentity.FieldClaimsEmailVerified: + return m.OldClaimsEmailVerified(ctx) + case useridentity.FieldClaimsGroups: + return m.OldClaimsGroups(ctx) + case useridentity.FieldConsents: + return m.OldConsents(ctx) + case useridentity.FieldCreatedAt: + return m.OldCreatedAt(ctx) + case useridentity.FieldLastLogin: + return m.OldLastLogin(ctx) + case useridentity.FieldBlockedUntil: + return m.OldBlockedUntil(ctx) + } + return nil, fmt.Errorf("unknown UserIdentity field %s", name) +} + +// SetField sets the value of a field with the given name. It returns an error if +// the field is not defined in the schema, or if the type mismatched the field +// type. +func (m *UserIdentityMutation) SetField(name string, value ent.Value) error { + switch name { + case useridentity.FieldUserID: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetUserID(v) + return nil + case useridentity.FieldConnectorID: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetConnectorID(v) + return nil + case useridentity.FieldClaimsUserID: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetClaimsUserID(v) + return nil + case useridentity.FieldClaimsUsername: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetClaimsUsername(v) + return nil + case useridentity.FieldClaimsPreferredUsername: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetClaimsPreferredUsername(v) + return nil + case useridentity.FieldClaimsEmail: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetClaimsEmail(v) + return nil + case useridentity.FieldClaimsEmailVerified: + v, ok := value.(bool) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetClaimsEmailVerified(v) + return nil + case useridentity.FieldClaimsGroups: + v, ok := value.([]string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetClaimsGroups(v) + return nil + case useridentity.FieldConsents: + v, ok := value.([]byte) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetConsents(v) + return nil + case useridentity.FieldCreatedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetCreatedAt(v) + return nil + case useridentity.FieldLastLogin: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetLastLogin(v) + return nil + case useridentity.FieldBlockedUntil: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetBlockedUntil(v) + return nil + } + return fmt.Errorf("unknown UserIdentity field %s", name) +} + +// AddedFields returns all numeric fields that were incremented/decremented during +// this mutation. +func (m *UserIdentityMutation) AddedFields() []string { + return nil +} + +// AddedField returns the numeric value that was incremented/decremented on a field +// with the given name. The second boolean return value indicates that this field +// was not set, or was not defined in the schema. +func (m *UserIdentityMutation) AddedField(name string) (ent.Value, bool) { + return nil, false +} + +// AddField adds the value to the field with the given name. It returns an error if +// the field is not defined in the schema, or if the type mismatched the field +// type. +func (m *UserIdentityMutation) AddField(name string, value ent.Value) error { + switch name { + } + return fmt.Errorf("unknown UserIdentity numeric field %s", name) +} + +// ClearedFields returns all nullable fields that were cleared during this +// mutation. +func (m *UserIdentityMutation) ClearedFields() []string { + var fields []string + if m.FieldCleared(useridentity.FieldClaimsGroups) { + fields = append(fields, useridentity.FieldClaimsGroups) + } + return fields +} + +// FieldCleared returns a boolean indicating if a field with the given name was +// cleared in this mutation. +func (m *UserIdentityMutation) FieldCleared(name string) bool { + _, ok := m.clearedFields[name] + return ok +} + +// ClearField clears the value of the field with the given name. It returns an +// error if the field is not defined in the schema. +func (m *UserIdentityMutation) ClearField(name string) error { + switch name { + case useridentity.FieldClaimsGroups: + m.ClearClaimsGroups() + return nil + } + return fmt.Errorf("unknown UserIdentity nullable field %s", name) +} + +// ResetField resets all changes in the mutation for the field with the given name. +// It returns an error if the field is not defined in the schema. +func (m *UserIdentityMutation) ResetField(name string) error { + switch name { + case useridentity.FieldUserID: + m.ResetUserID() + return nil + case useridentity.FieldConnectorID: + m.ResetConnectorID() + return nil + case useridentity.FieldClaimsUserID: + m.ResetClaimsUserID() + return nil + case useridentity.FieldClaimsUsername: + m.ResetClaimsUsername() + return nil + case useridentity.FieldClaimsPreferredUsername: + m.ResetClaimsPreferredUsername() + return nil + case useridentity.FieldClaimsEmail: + m.ResetClaimsEmail() + return nil + case useridentity.FieldClaimsEmailVerified: + m.ResetClaimsEmailVerified() + return nil + case useridentity.FieldClaimsGroups: + m.ResetClaimsGroups() + return nil + case useridentity.FieldConsents: + m.ResetConsents() + return nil + case useridentity.FieldCreatedAt: + m.ResetCreatedAt() + return nil + case useridentity.FieldLastLogin: + m.ResetLastLogin() + return nil + case useridentity.FieldBlockedUntil: + m.ResetBlockedUntil() + return nil + } + return fmt.Errorf("unknown UserIdentity field %s", name) +} + +// AddedEdges returns all edge names that were set/added in this mutation. +func (m *UserIdentityMutation) AddedEdges() []string { + edges := make([]string, 0, 0) + return edges +} + +// AddedIDs returns all IDs (to other nodes) that were added for the given edge +// name in this mutation. +func (m *UserIdentityMutation) AddedIDs(name string) []ent.Value { + return nil +} + +// RemovedEdges returns all edge names that were removed in this mutation. +func (m *UserIdentityMutation) RemovedEdges() []string { + edges := make([]string, 0, 0) + return edges +} + +// RemovedIDs returns all IDs (to other nodes) that were removed for the edge with +// the given name in this mutation. +func (m *UserIdentityMutation) RemovedIDs(name string) []ent.Value { + return nil +} + +// ClearedEdges returns all edge names that were cleared in this mutation. +func (m *UserIdentityMutation) ClearedEdges() []string { + edges := make([]string, 0, 0) + return edges +} + +// EdgeCleared returns a boolean which indicates if the edge with the given name +// was cleared in this mutation. +func (m *UserIdentityMutation) EdgeCleared(name string) bool { + return false +} + +// ClearEdge clears the value of the edge with the given name. It returns an error +// if that edge is not defined in the schema. +func (m *UserIdentityMutation) ClearEdge(name string) error { + return fmt.Errorf("unknown UserIdentity unique edge %s", name) +} + +// ResetEdge resets all changes to the edge with the given name in this mutation. +// It returns an error if the edge is not defined in the schema. +func (m *UserIdentityMutation) ResetEdge(name string) error { + return fmt.Errorf("unknown UserIdentity edge %s", name) +} diff --git a/storage/ent/db/predicate/predicate.go b/storage/ent/db/predicate/predicate.go index ed07a071..90ddffcb 100644 --- a/storage/ent/db/predicate/predicate.go +++ b/storage/ent/db/predicate/predicate.go @@ -35,3 +35,6 @@ type Password func(*sql.Selector) // RefreshToken is the predicate function for refreshtoken builders. type RefreshToken func(*sql.Selector) + +// UserIdentity is the predicate function for useridentity builders. +type UserIdentity func(*sql.Selector) diff --git a/storage/ent/db/runtime.go b/storage/ent/db/runtime.go index fdb47bd7..b6f34aef 100644 --- a/storage/ent/db/runtime.go +++ b/storage/ent/db/runtime.go @@ -15,6 +15,7 @@ import ( "github.com/dexidp/dex/storage/ent/db/offlinesession" "github.com/dexidp/dex/storage/ent/db/password" "github.com/dexidp/dex/storage/ent/db/refreshtoken" + "github.com/dexidp/dex/storage/ent/db/useridentity" "github.com/dexidp/dex/storage/ent/schema" ) @@ -274,4 +275,38 @@ func init() { refreshtokenDescID := refreshtokenFields[0].Descriptor() // refreshtoken.IDValidator is a validator for the "id" field. It is called by the builders before save. refreshtoken.IDValidator = refreshtokenDescID.Validators[0].(func(string) error) + useridentityFields := schema.UserIdentity{}.Fields() + _ = useridentityFields + // useridentityDescUserID is the schema descriptor for user_id field. + useridentityDescUserID := useridentityFields[1].Descriptor() + // useridentity.UserIDValidator is a validator for the "user_id" field. It is called by the builders before save. + useridentity.UserIDValidator = useridentityDescUserID.Validators[0].(func(string) error) + // useridentityDescConnectorID is the schema descriptor for connector_id field. + useridentityDescConnectorID := useridentityFields[2].Descriptor() + // useridentity.ConnectorIDValidator is a validator for the "connector_id" field. It is called by the builders before save. + useridentity.ConnectorIDValidator = useridentityDescConnectorID.Validators[0].(func(string) error) + // useridentityDescClaimsUserID is the schema descriptor for claims_user_id field. + useridentityDescClaimsUserID := useridentityFields[3].Descriptor() + // useridentity.DefaultClaimsUserID holds the default value on creation for the claims_user_id field. + useridentity.DefaultClaimsUserID = useridentityDescClaimsUserID.Default.(string) + // useridentityDescClaimsUsername is the schema descriptor for claims_username field. + useridentityDescClaimsUsername := useridentityFields[4].Descriptor() + // useridentity.DefaultClaimsUsername holds the default value on creation for the claims_username field. + useridentity.DefaultClaimsUsername = useridentityDescClaimsUsername.Default.(string) + // useridentityDescClaimsPreferredUsername is the schema descriptor for claims_preferred_username field. + useridentityDescClaimsPreferredUsername := useridentityFields[5].Descriptor() + // useridentity.DefaultClaimsPreferredUsername holds the default value on creation for the claims_preferred_username field. + useridentity.DefaultClaimsPreferredUsername = useridentityDescClaimsPreferredUsername.Default.(string) + // useridentityDescClaimsEmail is the schema descriptor for claims_email field. + useridentityDescClaimsEmail := useridentityFields[6].Descriptor() + // useridentity.DefaultClaimsEmail holds the default value on creation for the claims_email field. + useridentity.DefaultClaimsEmail = useridentityDescClaimsEmail.Default.(string) + // useridentityDescClaimsEmailVerified is the schema descriptor for claims_email_verified field. + useridentityDescClaimsEmailVerified := useridentityFields[7].Descriptor() + // useridentity.DefaultClaimsEmailVerified holds the default value on creation for the claims_email_verified field. + useridentity.DefaultClaimsEmailVerified = useridentityDescClaimsEmailVerified.Default.(bool) + // useridentityDescID is the schema descriptor for id field. + useridentityDescID := useridentityFields[0].Descriptor() + // useridentity.IDValidator is a validator for the "id" field. It is called by the builders before save. + useridentity.IDValidator = useridentityDescID.Validators[0].(func(string) error) } diff --git a/storage/ent/db/tx.go b/storage/ent/db/tx.go index 42ba241a..77be8a8f 100644 --- a/storage/ent/db/tx.go +++ b/storage/ent/db/tx.go @@ -32,6 +32,8 @@ type Tx struct { Password *PasswordClient // RefreshToken is the client for interacting with the RefreshToken builders. RefreshToken *RefreshTokenClient + // UserIdentity is the client for interacting with the UserIdentity builders. + UserIdentity *UserIdentityClient // lazily loaded. client *Client @@ -173,6 +175,7 @@ func (tx *Tx) init() { tx.OfflineSession = NewOfflineSessionClient(tx.config) tx.Password = NewPasswordClient(tx.config) tx.RefreshToken = NewRefreshTokenClient(tx.config) + tx.UserIdentity = NewUserIdentityClient(tx.config) } // txDriver wraps the given dialect.Tx with a nop dialect.Driver implementation. diff --git a/storage/ent/db/useridentity.go b/storage/ent/db/useridentity.go new file mode 100644 index 00000000..7127299b --- /dev/null +++ b/storage/ent/db/useridentity.go @@ -0,0 +1,232 @@ +// Code generated by ent, DO NOT EDIT. + +package db + +import ( + "encoding/json" + "fmt" + "strings" + "time" + + "entgo.io/ent" + "entgo.io/ent/dialect/sql" + "github.com/dexidp/dex/storage/ent/db/useridentity" +) + +// UserIdentity is the model entity for the UserIdentity schema. +type UserIdentity struct { + config `json:"-"` + // ID of the ent. + ID string `json:"id,omitempty"` + // UserID holds the value of the "user_id" field. + UserID string `json:"user_id,omitempty"` + // ConnectorID holds the value of the "connector_id" field. + ConnectorID string `json:"connector_id,omitempty"` + // ClaimsUserID holds the value of the "claims_user_id" field. + ClaimsUserID string `json:"claims_user_id,omitempty"` + // ClaimsUsername holds the value of the "claims_username" field. + ClaimsUsername string `json:"claims_username,omitempty"` + // ClaimsPreferredUsername holds the value of the "claims_preferred_username" field. + ClaimsPreferredUsername string `json:"claims_preferred_username,omitempty"` + // ClaimsEmail holds the value of the "claims_email" field. + ClaimsEmail string `json:"claims_email,omitempty"` + // ClaimsEmailVerified holds the value of the "claims_email_verified" field. + ClaimsEmailVerified bool `json:"claims_email_verified,omitempty"` + // ClaimsGroups holds the value of the "claims_groups" field. + ClaimsGroups []string `json:"claims_groups,omitempty"` + // Consents holds the value of the "consents" field. + Consents []byte `json:"consents,omitempty"` + // CreatedAt holds the value of the "created_at" field. + CreatedAt time.Time `json:"created_at,omitempty"` + // LastLogin holds the value of the "last_login" field. + LastLogin time.Time `json:"last_login,omitempty"` + // BlockedUntil holds the value of the "blocked_until" field. + BlockedUntil time.Time `json:"blocked_until,omitempty"` + selectValues sql.SelectValues +} + +// scanValues returns the types for scanning values from sql.Rows. +func (*UserIdentity) scanValues(columns []string) ([]any, error) { + values := make([]any, len(columns)) + for i := range columns { + switch columns[i] { + case useridentity.FieldClaimsGroups, useridentity.FieldConsents: + values[i] = new([]byte) + case useridentity.FieldClaimsEmailVerified: + values[i] = new(sql.NullBool) + case useridentity.FieldID, useridentity.FieldUserID, useridentity.FieldConnectorID, useridentity.FieldClaimsUserID, useridentity.FieldClaimsUsername, useridentity.FieldClaimsPreferredUsername, useridentity.FieldClaimsEmail: + values[i] = new(sql.NullString) + case useridentity.FieldCreatedAt, useridentity.FieldLastLogin, useridentity.FieldBlockedUntil: + values[i] = new(sql.NullTime) + default: + values[i] = new(sql.UnknownType) + } + } + return values, nil +} + +// assignValues assigns the values that were returned from sql.Rows (after scanning) +// to the UserIdentity fields. +func (_m *UserIdentity) assignValues(columns []string, values []any) error { + if m, n := len(values), len(columns); m < n { + return fmt.Errorf("mismatch number of scan values: %d != %d", m, n) + } + for i := range columns { + switch columns[i] { + case useridentity.FieldID: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field id", values[i]) + } else if value.Valid { + _m.ID = value.String + } + case useridentity.FieldUserID: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field user_id", values[i]) + } else if value.Valid { + _m.UserID = value.String + } + case useridentity.FieldConnectorID: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field connector_id", values[i]) + } else if value.Valid { + _m.ConnectorID = value.String + } + case useridentity.FieldClaimsUserID: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field claims_user_id", values[i]) + } else if value.Valid { + _m.ClaimsUserID = value.String + } + case useridentity.FieldClaimsUsername: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field claims_username", values[i]) + } else if value.Valid { + _m.ClaimsUsername = value.String + } + case useridentity.FieldClaimsPreferredUsername: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field claims_preferred_username", values[i]) + } else if value.Valid { + _m.ClaimsPreferredUsername = value.String + } + case useridentity.FieldClaimsEmail: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field claims_email", values[i]) + } else if value.Valid { + _m.ClaimsEmail = value.String + } + case useridentity.FieldClaimsEmailVerified: + if value, ok := values[i].(*sql.NullBool); !ok { + return fmt.Errorf("unexpected type %T for field claims_email_verified", values[i]) + } else if value.Valid { + _m.ClaimsEmailVerified = value.Bool + } + case useridentity.FieldClaimsGroups: + if value, ok := values[i].(*[]byte); !ok { + return fmt.Errorf("unexpected type %T for field claims_groups", values[i]) + } else if value != nil && len(*value) > 0 { + if err := json.Unmarshal(*value, &_m.ClaimsGroups); err != nil { + return fmt.Errorf("unmarshal field claims_groups: %w", err) + } + } + case useridentity.FieldConsents: + if value, ok := values[i].(*[]byte); !ok { + return fmt.Errorf("unexpected type %T for field consents", values[i]) + } else if value != nil { + _m.Consents = *value + } + case useridentity.FieldCreatedAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field created_at", values[i]) + } else if value.Valid { + _m.CreatedAt = value.Time + } + case useridentity.FieldLastLogin: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field last_login", values[i]) + } else if value.Valid { + _m.LastLogin = value.Time + } + case useridentity.FieldBlockedUntil: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field blocked_until", values[i]) + } else if value.Valid { + _m.BlockedUntil = value.Time + } + default: + _m.selectValues.Set(columns[i], values[i]) + } + } + return nil +} + +// Value returns the ent.Value that was dynamically selected and assigned to the UserIdentity. +// This includes values selected through modifiers, order, etc. +func (_m *UserIdentity) Value(name string) (ent.Value, error) { + return _m.selectValues.Get(name) +} + +// Update returns a builder for updating this UserIdentity. +// Note that you need to call UserIdentity.Unwrap() before calling this method if this UserIdentity +// was returned from a transaction, and the transaction was committed or rolled back. +func (_m *UserIdentity) Update() *UserIdentityUpdateOne { + return NewUserIdentityClient(_m.config).UpdateOne(_m) +} + +// Unwrap unwraps the UserIdentity entity that was returned from a transaction after it was closed, +// so that all future queries will be executed through the driver which created the transaction. +func (_m *UserIdentity) Unwrap() *UserIdentity { + _tx, ok := _m.config.driver.(*txDriver) + if !ok { + panic("db: UserIdentity is not a transactional entity") + } + _m.config.driver = _tx.drv + return _m +} + +// String implements the fmt.Stringer. +func (_m *UserIdentity) String() string { + var builder strings.Builder + builder.WriteString("UserIdentity(") + builder.WriteString(fmt.Sprintf("id=%v, ", _m.ID)) + builder.WriteString("user_id=") + builder.WriteString(_m.UserID) + builder.WriteString(", ") + builder.WriteString("connector_id=") + builder.WriteString(_m.ConnectorID) + builder.WriteString(", ") + builder.WriteString("claims_user_id=") + builder.WriteString(_m.ClaimsUserID) + builder.WriteString(", ") + builder.WriteString("claims_username=") + builder.WriteString(_m.ClaimsUsername) + builder.WriteString(", ") + builder.WriteString("claims_preferred_username=") + builder.WriteString(_m.ClaimsPreferredUsername) + builder.WriteString(", ") + builder.WriteString("claims_email=") + builder.WriteString(_m.ClaimsEmail) + builder.WriteString(", ") + builder.WriteString("claims_email_verified=") + builder.WriteString(fmt.Sprintf("%v", _m.ClaimsEmailVerified)) + builder.WriteString(", ") + builder.WriteString("claims_groups=") + builder.WriteString(fmt.Sprintf("%v", _m.ClaimsGroups)) + builder.WriteString(", ") + builder.WriteString("consents=") + builder.WriteString(fmt.Sprintf("%v", _m.Consents)) + builder.WriteString(", ") + builder.WriteString("created_at=") + builder.WriteString(_m.CreatedAt.Format(time.ANSIC)) + builder.WriteString(", ") + builder.WriteString("last_login=") + builder.WriteString(_m.LastLogin.Format(time.ANSIC)) + builder.WriteString(", ") + builder.WriteString("blocked_until=") + builder.WriteString(_m.BlockedUntil.Format(time.ANSIC)) + builder.WriteByte(')') + return builder.String() +} + +// UserIdentities is a parsable slice of UserIdentity. +type UserIdentities []*UserIdentity diff --git a/storage/ent/db/useridentity/useridentity.go b/storage/ent/db/useridentity/useridentity.go new file mode 100644 index 00000000..f08d74ec --- /dev/null +++ b/storage/ent/db/useridentity/useridentity.go @@ -0,0 +1,144 @@ +// Code generated by ent, DO NOT EDIT. + +package useridentity + +import ( + "entgo.io/ent/dialect/sql" +) + +const ( + // Label holds the string label denoting the useridentity type in the database. + Label = "user_identity" + // FieldID holds the string denoting the id field in the database. + FieldID = "id" + // FieldUserID holds the string denoting the user_id field in the database. + FieldUserID = "user_id" + // FieldConnectorID holds the string denoting the connector_id field in the database. + FieldConnectorID = "connector_id" + // FieldClaimsUserID holds the string denoting the claims_user_id field in the database. + FieldClaimsUserID = "claims_user_id" + // FieldClaimsUsername holds the string denoting the claims_username field in the database. + FieldClaimsUsername = "claims_username" + // FieldClaimsPreferredUsername holds the string denoting the claims_preferred_username field in the database. + FieldClaimsPreferredUsername = "claims_preferred_username" + // FieldClaimsEmail holds the string denoting the claims_email field in the database. + FieldClaimsEmail = "claims_email" + // FieldClaimsEmailVerified holds the string denoting the claims_email_verified field in the database. + FieldClaimsEmailVerified = "claims_email_verified" + // FieldClaimsGroups holds the string denoting the claims_groups field in the database. + FieldClaimsGroups = "claims_groups" + // FieldConsents holds the string denoting the consents field in the database. + FieldConsents = "consents" + // FieldCreatedAt holds the string denoting the created_at field in the database. + FieldCreatedAt = "created_at" + // FieldLastLogin holds the string denoting the last_login field in the database. + FieldLastLogin = "last_login" + // FieldBlockedUntil holds the string denoting the blocked_until field in the database. + FieldBlockedUntil = "blocked_until" + // Table holds the table name of the useridentity in the database. + Table = "user_identities" +) + +// Columns holds all SQL columns for useridentity fields. +var Columns = []string{ + FieldID, + FieldUserID, + FieldConnectorID, + FieldClaimsUserID, + FieldClaimsUsername, + FieldClaimsPreferredUsername, + FieldClaimsEmail, + FieldClaimsEmailVerified, + FieldClaimsGroups, + FieldConsents, + FieldCreatedAt, + FieldLastLogin, + FieldBlockedUntil, +} + +// ValidColumn reports if the column name is valid (part of the table columns). +func ValidColumn(column string) bool { + for i := range Columns { + if column == Columns[i] { + return true + } + } + return false +} + +var ( + // UserIDValidator is a validator for the "user_id" field. It is called by the builders before save. + UserIDValidator func(string) error + // ConnectorIDValidator is a validator for the "connector_id" field. It is called by the builders before save. + ConnectorIDValidator func(string) error + // DefaultClaimsUserID holds the default value on creation for the "claims_user_id" field. + DefaultClaimsUserID string + // DefaultClaimsUsername holds the default value on creation for the "claims_username" field. + DefaultClaimsUsername string + // DefaultClaimsPreferredUsername holds the default value on creation for the "claims_preferred_username" field. + DefaultClaimsPreferredUsername string + // DefaultClaimsEmail holds the default value on creation for the "claims_email" field. + DefaultClaimsEmail string + // DefaultClaimsEmailVerified holds the default value on creation for the "claims_email_verified" field. + DefaultClaimsEmailVerified bool + // IDValidator is a validator for the "id" field. It is called by the builders before save. + IDValidator func(string) error +) + +// OrderOption defines the ordering options for the UserIdentity queries. +type OrderOption func(*sql.Selector) + +// ByID orders the results by the id field. +func ByID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldID, opts...).ToFunc() +} + +// ByUserID orders the results by the user_id field. +func ByUserID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldUserID, opts...).ToFunc() +} + +// ByConnectorID orders the results by the connector_id field. +func ByConnectorID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldConnectorID, opts...).ToFunc() +} + +// ByClaimsUserID orders the results by the claims_user_id field. +func ByClaimsUserID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldClaimsUserID, opts...).ToFunc() +} + +// ByClaimsUsername orders the results by the claims_username field. +func ByClaimsUsername(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldClaimsUsername, opts...).ToFunc() +} + +// ByClaimsPreferredUsername orders the results by the claims_preferred_username field. +func ByClaimsPreferredUsername(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldClaimsPreferredUsername, opts...).ToFunc() +} + +// ByClaimsEmail orders the results by the claims_email field. +func ByClaimsEmail(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldClaimsEmail, opts...).ToFunc() +} + +// ByClaimsEmailVerified orders the results by the claims_email_verified field. +func ByClaimsEmailVerified(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldClaimsEmailVerified, opts...).ToFunc() +} + +// ByCreatedAt orders the results by the created_at field. +func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldCreatedAt, opts...).ToFunc() +} + +// ByLastLogin orders the results by the last_login field. +func ByLastLogin(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldLastLogin, opts...).ToFunc() +} + +// ByBlockedUntil orders the results by the blocked_until field. +func ByBlockedUntil(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldBlockedUntil, opts...).ToFunc() +} diff --git a/storage/ent/db/useridentity/where.go b/storage/ent/db/useridentity/where.go new file mode 100644 index 00000000..201d340f --- /dev/null +++ b/storage/ent/db/useridentity/where.go @@ -0,0 +1,705 @@ +// Code generated by ent, DO NOT EDIT. + +package useridentity + +import ( + "time" + + "entgo.io/ent/dialect/sql" + "github.com/dexidp/dex/storage/ent/db/predicate" +) + +// ID filters vertices based on their ID field. +func ID(id string) predicate.UserIdentity { + return predicate.UserIdentity(sql.FieldEQ(FieldID, id)) +} + +// IDEQ applies the EQ predicate on the ID field. +func IDEQ(id string) predicate.UserIdentity { + return predicate.UserIdentity(sql.FieldEQ(FieldID, id)) +} + +// IDNEQ applies the NEQ predicate on the ID field. +func IDNEQ(id string) predicate.UserIdentity { + return predicate.UserIdentity(sql.FieldNEQ(FieldID, id)) +} + +// IDIn applies the In predicate on the ID field. +func IDIn(ids ...string) predicate.UserIdentity { + return predicate.UserIdentity(sql.FieldIn(FieldID, ids...)) +} + +// IDNotIn applies the NotIn predicate on the ID field. +func IDNotIn(ids ...string) predicate.UserIdentity { + return predicate.UserIdentity(sql.FieldNotIn(FieldID, ids...)) +} + +// IDGT applies the GT predicate on the ID field. +func IDGT(id string) predicate.UserIdentity { + return predicate.UserIdentity(sql.FieldGT(FieldID, id)) +} + +// IDGTE applies the GTE predicate on the ID field. +func IDGTE(id string) predicate.UserIdentity { + return predicate.UserIdentity(sql.FieldGTE(FieldID, id)) +} + +// IDLT applies the LT predicate on the ID field. +func IDLT(id string) predicate.UserIdentity { + return predicate.UserIdentity(sql.FieldLT(FieldID, id)) +} + +// IDLTE applies the LTE predicate on the ID field. +func IDLTE(id string) predicate.UserIdentity { + return predicate.UserIdentity(sql.FieldLTE(FieldID, id)) +} + +// IDEqualFold applies the EqualFold predicate on the ID field. +func IDEqualFold(id string) predicate.UserIdentity { + return predicate.UserIdentity(sql.FieldEqualFold(FieldID, id)) +} + +// IDContainsFold applies the ContainsFold predicate on the ID field. +func IDContainsFold(id string) predicate.UserIdentity { + return predicate.UserIdentity(sql.FieldContainsFold(FieldID, id)) +} + +// UserID applies equality check predicate on the "user_id" field. It's identical to UserIDEQ. +func UserID(v string) predicate.UserIdentity { + return predicate.UserIdentity(sql.FieldEQ(FieldUserID, v)) +} + +// ConnectorID applies equality check predicate on the "connector_id" field. It's identical to ConnectorIDEQ. +func ConnectorID(v string) predicate.UserIdentity { + return predicate.UserIdentity(sql.FieldEQ(FieldConnectorID, v)) +} + +// ClaimsUserID applies equality check predicate on the "claims_user_id" field. It's identical to ClaimsUserIDEQ. +func ClaimsUserID(v string) predicate.UserIdentity { + return predicate.UserIdentity(sql.FieldEQ(FieldClaimsUserID, v)) +} + +// ClaimsUsername applies equality check predicate on the "claims_username" field. It's identical to ClaimsUsernameEQ. +func ClaimsUsername(v string) predicate.UserIdentity { + return predicate.UserIdentity(sql.FieldEQ(FieldClaimsUsername, v)) +} + +// ClaimsPreferredUsername applies equality check predicate on the "claims_preferred_username" field. It's identical to ClaimsPreferredUsernameEQ. +func ClaimsPreferredUsername(v string) predicate.UserIdentity { + return predicate.UserIdentity(sql.FieldEQ(FieldClaimsPreferredUsername, v)) +} + +// ClaimsEmail applies equality check predicate on the "claims_email" field. It's identical to ClaimsEmailEQ. +func ClaimsEmail(v string) predicate.UserIdentity { + return predicate.UserIdentity(sql.FieldEQ(FieldClaimsEmail, v)) +} + +// ClaimsEmailVerified applies equality check predicate on the "claims_email_verified" field. It's identical to ClaimsEmailVerifiedEQ. +func ClaimsEmailVerified(v bool) predicate.UserIdentity { + return predicate.UserIdentity(sql.FieldEQ(FieldClaimsEmailVerified, v)) +} + +// Consents applies equality check predicate on the "consents" field. It's identical to ConsentsEQ. +func Consents(v []byte) predicate.UserIdentity { + return predicate.UserIdentity(sql.FieldEQ(FieldConsents, v)) +} + +// CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ. +func CreatedAt(v time.Time) predicate.UserIdentity { + return predicate.UserIdentity(sql.FieldEQ(FieldCreatedAt, v)) +} + +// LastLogin applies equality check predicate on the "last_login" field. It's identical to LastLoginEQ. +func LastLogin(v time.Time) predicate.UserIdentity { + return predicate.UserIdentity(sql.FieldEQ(FieldLastLogin, v)) +} + +// BlockedUntil applies equality check predicate on the "blocked_until" field. It's identical to BlockedUntilEQ. +func BlockedUntil(v time.Time) predicate.UserIdentity { + return predicate.UserIdentity(sql.FieldEQ(FieldBlockedUntil, v)) +} + +// UserIDEQ applies the EQ predicate on the "user_id" field. +func UserIDEQ(v string) predicate.UserIdentity { + return predicate.UserIdentity(sql.FieldEQ(FieldUserID, v)) +} + +// UserIDNEQ applies the NEQ predicate on the "user_id" field. +func UserIDNEQ(v string) predicate.UserIdentity { + return predicate.UserIdentity(sql.FieldNEQ(FieldUserID, v)) +} + +// UserIDIn applies the In predicate on the "user_id" field. +func UserIDIn(vs ...string) predicate.UserIdentity { + return predicate.UserIdentity(sql.FieldIn(FieldUserID, vs...)) +} + +// UserIDNotIn applies the NotIn predicate on the "user_id" field. +func UserIDNotIn(vs ...string) predicate.UserIdentity { + return predicate.UserIdentity(sql.FieldNotIn(FieldUserID, vs...)) +} + +// UserIDGT applies the GT predicate on the "user_id" field. +func UserIDGT(v string) predicate.UserIdentity { + return predicate.UserIdentity(sql.FieldGT(FieldUserID, v)) +} + +// UserIDGTE applies the GTE predicate on the "user_id" field. +func UserIDGTE(v string) predicate.UserIdentity { + return predicate.UserIdentity(sql.FieldGTE(FieldUserID, v)) +} + +// UserIDLT applies the LT predicate on the "user_id" field. +func UserIDLT(v string) predicate.UserIdentity { + return predicate.UserIdentity(sql.FieldLT(FieldUserID, v)) +} + +// UserIDLTE applies the LTE predicate on the "user_id" field. +func UserIDLTE(v string) predicate.UserIdentity { + return predicate.UserIdentity(sql.FieldLTE(FieldUserID, v)) +} + +// UserIDContains applies the Contains predicate on the "user_id" field. +func UserIDContains(v string) predicate.UserIdentity { + return predicate.UserIdentity(sql.FieldContains(FieldUserID, v)) +} + +// UserIDHasPrefix applies the HasPrefix predicate on the "user_id" field. +func UserIDHasPrefix(v string) predicate.UserIdentity { + return predicate.UserIdentity(sql.FieldHasPrefix(FieldUserID, v)) +} + +// UserIDHasSuffix applies the HasSuffix predicate on the "user_id" field. +func UserIDHasSuffix(v string) predicate.UserIdentity { + return predicate.UserIdentity(sql.FieldHasSuffix(FieldUserID, v)) +} + +// UserIDEqualFold applies the EqualFold predicate on the "user_id" field. +func UserIDEqualFold(v string) predicate.UserIdentity { + return predicate.UserIdentity(sql.FieldEqualFold(FieldUserID, v)) +} + +// UserIDContainsFold applies the ContainsFold predicate on the "user_id" field. +func UserIDContainsFold(v string) predicate.UserIdentity { + return predicate.UserIdentity(sql.FieldContainsFold(FieldUserID, v)) +} + +// ConnectorIDEQ applies the EQ predicate on the "connector_id" field. +func ConnectorIDEQ(v string) predicate.UserIdentity { + return predicate.UserIdentity(sql.FieldEQ(FieldConnectorID, v)) +} + +// ConnectorIDNEQ applies the NEQ predicate on the "connector_id" field. +func ConnectorIDNEQ(v string) predicate.UserIdentity { + return predicate.UserIdentity(sql.FieldNEQ(FieldConnectorID, v)) +} + +// ConnectorIDIn applies the In predicate on the "connector_id" field. +func ConnectorIDIn(vs ...string) predicate.UserIdentity { + return predicate.UserIdentity(sql.FieldIn(FieldConnectorID, vs...)) +} + +// ConnectorIDNotIn applies the NotIn predicate on the "connector_id" field. +func ConnectorIDNotIn(vs ...string) predicate.UserIdentity { + return predicate.UserIdentity(sql.FieldNotIn(FieldConnectorID, vs...)) +} + +// ConnectorIDGT applies the GT predicate on the "connector_id" field. +func ConnectorIDGT(v string) predicate.UserIdentity { + return predicate.UserIdentity(sql.FieldGT(FieldConnectorID, v)) +} + +// ConnectorIDGTE applies the GTE predicate on the "connector_id" field. +func ConnectorIDGTE(v string) predicate.UserIdentity { + return predicate.UserIdentity(sql.FieldGTE(FieldConnectorID, v)) +} + +// ConnectorIDLT applies the LT predicate on the "connector_id" field. +func ConnectorIDLT(v string) predicate.UserIdentity { + return predicate.UserIdentity(sql.FieldLT(FieldConnectorID, v)) +} + +// ConnectorIDLTE applies the LTE predicate on the "connector_id" field. +func ConnectorIDLTE(v string) predicate.UserIdentity { + return predicate.UserIdentity(sql.FieldLTE(FieldConnectorID, v)) +} + +// ConnectorIDContains applies the Contains predicate on the "connector_id" field. +func ConnectorIDContains(v string) predicate.UserIdentity { + return predicate.UserIdentity(sql.FieldContains(FieldConnectorID, v)) +} + +// ConnectorIDHasPrefix applies the HasPrefix predicate on the "connector_id" field. +func ConnectorIDHasPrefix(v string) predicate.UserIdentity { + return predicate.UserIdentity(sql.FieldHasPrefix(FieldConnectorID, v)) +} + +// ConnectorIDHasSuffix applies the HasSuffix predicate on the "connector_id" field. +func ConnectorIDHasSuffix(v string) predicate.UserIdentity { + return predicate.UserIdentity(sql.FieldHasSuffix(FieldConnectorID, v)) +} + +// ConnectorIDEqualFold applies the EqualFold predicate on the "connector_id" field. +func ConnectorIDEqualFold(v string) predicate.UserIdentity { + return predicate.UserIdentity(sql.FieldEqualFold(FieldConnectorID, v)) +} + +// ConnectorIDContainsFold applies the ContainsFold predicate on the "connector_id" field. +func ConnectorIDContainsFold(v string) predicate.UserIdentity { + return predicate.UserIdentity(sql.FieldContainsFold(FieldConnectorID, v)) +} + +// ClaimsUserIDEQ applies the EQ predicate on the "claims_user_id" field. +func ClaimsUserIDEQ(v string) predicate.UserIdentity { + return predicate.UserIdentity(sql.FieldEQ(FieldClaimsUserID, v)) +} + +// ClaimsUserIDNEQ applies the NEQ predicate on the "claims_user_id" field. +func ClaimsUserIDNEQ(v string) predicate.UserIdentity { + return predicate.UserIdentity(sql.FieldNEQ(FieldClaimsUserID, v)) +} + +// ClaimsUserIDIn applies the In predicate on the "claims_user_id" field. +func ClaimsUserIDIn(vs ...string) predicate.UserIdentity { + return predicate.UserIdentity(sql.FieldIn(FieldClaimsUserID, vs...)) +} + +// ClaimsUserIDNotIn applies the NotIn predicate on the "claims_user_id" field. +func ClaimsUserIDNotIn(vs ...string) predicate.UserIdentity { + return predicate.UserIdentity(sql.FieldNotIn(FieldClaimsUserID, vs...)) +} + +// ClaimsUserIDGT applies the GT predicate on the "claims_user_id" field. +func ClaimsUserIDGT(v string) predicate.UserIdentity { + return predicate.UserIdentity(sql.FieldGT(FieldClaimsUserID, v)) +} + +// ClaimsUserIDGTE applies the GTE predicate on the "claims_user_id" field. +func ClaimsUserIDGTE(v string) predicate.UserIdentity { + return predicate.UserIdentity(sql.FieldGTE(FieldClaimsUserID, v)) +} + +// ClaimsUserIDLT applies the LT predicate on the "claims_user_id" field. +func ClaimsUserIDLT(v string) predicate.UserIdentity { + return predicate.UserIdentity(sql.FieldLT(FieldClaimsUserID, v)) +} + +// ClaimsUserIDLTE applies the LTE predicate on the "claims_user_id" field. +func ClaimsUserIDLTE(v string) predicate.UserIdentity { + return predicate.UserIdentity(sql.FieldLTE(FieldClaimsUserID, v)) +} + +// ClaimsUserIDContains applies the Contains predicate on the "claims_user_id" field. +func ClaimsUserIDContains(v string) predicate.UserIdentity { + return predicate.UserIdentity(sql.FieldContains(FieldClaimsUserID, v)) +} + +// ClaimsUserIDHasPrefix applies the HasPrefix predicate on the "claims_user_id" field. +func ClaimsUserIDHasPrefix(v string) predicate.UserIdentity { + return predicate.UserIdentity(sql.FieldHasPrefix(FieldClaimsUserID, v)) +} + +// ClaimsUserIDHasSuffix applies the HasSuffix predicate on the "claims_user_id" field. +func ClaimsUserIDHasSuffix(v string) predicate.UserIdentity { + return predicate.UserIdentity(sql.FieldHasSuffix(FieldClaimsUserID, v)) +} + +// ClaimsUserIDEqualFold applies the EqualFold predicate on the "claims_user_id" field. +func ClaimsUserIDEqualFold(v string) predicate.UserIdentity { + return predicate.UserIdentity(sql.FieldEqualFold(FieldClaimsUserID, v)) +} + +// ClaimsUserIDContainsFold applies the ContainsFold predicate on the "claims_user_id" field. +func ClaimsUserIDContainsFold(v string) predicate.UserIdentity { + return predicate.UserIdentity(sql.FieldContainsFold(FieldClaimsUserID, v)) +} + +// ClaimsUsernameEQ applies the EQ predicate on the "claims_username" field. +func ClaimsUsernameEQ(v string) predicate.UserIdentity { + return predicate.UserIdentity(sql.FieldEQ(FieldClaimsUsername, v)) +} + +// ClaimsUsernameNEQ applies the NEQ predicate on the "claims_username" field. +func ClaimsUsernameNEQ(v string) predicate.UserIdentity { + return predicate.UserIdentity(sql.FieldNEQ(FieldClaimsUsername, v)) +} + +// ClaimsUsernameIn applies the In predicate on the "claims_username" field. +func ClaimsUsernameIn(vs ...string) predicate.UserIdentity { + return predicate.UserIdentity(sql.FieldIn(FieldClaimsUsername, vs...)) +} + +// ClaimsUsernameNotIn applies the NotIn predicate on the "claims_username" field. +func ClaimsUsernameNotIn(vs ...string) predicate.UserIdentity { + return predicate.UserIdentity(sql.FieldNotIn(FieldClaimsUsername, vs...)) +} + +// ClaimsUsernameGT applies the GT predicate on the "claims_username" field. +func ClaimsUsernameGT(v string) predicate.UserIdentity { + return predicate.UserIdentity(sql.FieldGT(FieldClaimsUsername, v)) +} + +// ClaimsUsernameGTE applies the GTE predicate on the "claims_username" field. +func ClaimsUsernameGTE(v string) predicate.UserIdentity { + return predicate.UserIdentity(sql.FieldGTE(FieldClaimsUsername, v)) +} + +// ClaimsUsernameLT applies the LT predicate on the "claims_username" field. +func ClaimsUsernameLT(v string) predicate.UserIdentity { + return predicate.UserIdentity(sql.FieldLT(FieldClaimsUsername, v)) +} + +// ClaimsUsernameLTE applies the LTE predicate on the "claims_username" field. +func ClaimsUsernameLTE(v string) predicate.UserIdentity { + return predicate.UserIdentity(sql.FieldLTE(FieldClaimsUsername, v)) +} + +// ClaimsUsernameContains applies the Contains predicate on the "claims_username" field. +func ClaimsUsernameContains(v string) predicate.UserIdentity { + return predicate.UserIdentity(sql.FieldContains(FieldClaimsUsername, v)) +} + +// ClaimsUsernameHasPrefix applies the HasPrefix predicate on the "claims_username" field. +func ClaimsUsernameHasPrefix(v string) predicate.UserIdentity { + return predicate.UserIdentity(sql.FieldHasPrefix(FieldClaimsUsername, v)) +} + +// ClaimsUsernameHasSuffix applies the HasSuffix predicate on the "claims_username" field. +func ClaimsUsernameHasSuffix(v string) predicate.UserIdentity { + return predicate.UserIdentity(sql.FieldHasSuffix(FieldClaimsUsername, v)) +} + +// ClaimsUsernameEqualFold applies the EqualFold predicate on the "claims_username" field. +func ClaimsUsernameEqualFold(v string) predicate.UserIdentity { + return predicate.UserIdentity(sql.FieldEqualFold(FieldClaimsUsername, v)) +} + +// ClaimsUsernameContainsFold applies the ContainsFold predicate on the "claims_username" field. +func ClaimsUsernameContainsFold(v string) predicate.UserIdentity { + return predicate.UserIdentity(sql.FieldContainsFold(FieldClaimsUsername, v)) +} + +// ClaimsPreferredUsernameEQ applies the EQ predicate on the "claims_preferred_username" field. +func ClaimsPreferredUsernameEQ(v string) predicate.UserIdentity { + return predicate.UserIdentity(sql.FieldEQ(FieldClaimsPreferredUsername, v)) +} + +// ClaimsPreferredUsernameNEQ applies the NEQ predicate on the "claims_preferred_username" field. +func ClaimsPreferredUsernameNEQ(v string) predicate.UserIdentity { + return predicate.UserIdentity(sql.FieldNEQ(FieldClaimsPreferredUsername, v)) +} + +// ClaimsPreferredUsernameIn applies the In predicate on the "claims_preferred_username" field. +func ClaimsPreferredUsernameIn(vs ...string) predicate.UserIdentity { + return predicate.UserIdentity(sql.FieldIn(FieldClaimsPreferredUsername, vs...)) +} + +// ClaimsPreferredUsernameNotIn applies the NotIn predicate on the "claims_preferred_username" field. +func ClaimsPreferredUsernameNotIn(vs ...string) predicate.UserIdentity { + return predicate.UserIdentity(sql.FieldNotIn(FieldClaimsPreferredUsername, vs...)) +} + +// ClaimsPreferredUsernameGT applies the GT predicate on the "claims_preferred_username" field. +func ClaimsPreferredUsernameGT(v string) predicate.UserIdentity { + return predicate.UserIdentity(sql.FieldGT(FieldClaimsPreferredUsername, v)) +} + +// ClaimsPreferredUsernameGTE applies the GTE predicate on the "claims_preferred_username" field. +func ClaimsPreferredUsernameGTE(v string) predicate.UserIdentity { + return predicate.UserIdentity(sql.FieldGTE(FieldClaimsPreferredUsername, v)) +} + +// ClaimsPreferredUsernameLT applies the LT predicate on the "claims_preferred_username" field. +func ClaimsPreferredUsernameLT(v string) predicate.UserIdentity { + return predicate.UserIdentity(sql.FieldLT(FieldClaimsPreferredUsername, v)) +} + +// ClaimsPreferredUsernameLTE applies the LTE predicate on the "claims_preferred_username" field. +func ClaimsPreferredUsernameLTE(v string) predicate.UserIdentity { + return predicate.UserIdentity(sql.FieldLTE(FieldClaimsPreferredUsername, v)) +} + +// ClaimsPreferredUsernameContains applies the Contains predicate on the "claims_preferred_username" field. +func ClaimsPreferredUsernameContains(v string) predicate.UserIdentity { + return predicate.UserIdentity(sql.FieldContains(FieldClaimsPreferredUsername, v)) +} + +// ClaimsPreferredUsernameHasPrefix applies the HasPrefix predicate on the "claims_preferred_username" field. +func ClaimsPreferredUsernameHasPrefix(v string) predicate.UserIdentity { + return predicate.UserIdentity(sql.FieldHasPrefix(FieldClaimsPreferredUsername, v)) +} + +// ClaimsPreferredUsernameHasSuffix applies the HasSuffix predicate on the "claims_preferred_username" field. +func ClaimsPreferredUsernameHasSuffix(v string) predicate.UserIdentity { + return predicate.UserIdentity(sql.FieldHasSuffix(FieldClaimsPreferredUsername, v)) +} + +// ClaimsPreferredUsernameEqualFold applies the EqualFold predicate on the "claims_preferred_username" field. +func ClaimsPreferredUsernameEqualFold(v string) predicate.UserIdentity { + return predicate.UserIdentity(sql.FieldEqualFold(FieldClaimsPreferredUsername, v)) +} + +// ClaimsPreferredUsernameContainsFold applies the ContainsFold predicate on the "claims_preferred_username" field. +func ClaimsPreferredUsernameContainsFold(v string) predicate.UserIdentity { + return predicate.UserIdentity(sql.FieldContainsFold(FieldClaimsPreferredUsername, v)) +} + +// ClaimsEmailEQ applies the EQ predicate on the "claims_email" field. +func ClaimsEmailEQ(v string) predicate.UserIdentity { + return predicate.UserIdentity(sql.FieldEQ(FieldClaimsEmail, v)) +} + +// ClaimsEmailNEQ applies the NEQ predicate on the "claims_email" field. +func ClaimsEmailNEQ(v string) predicate.UserIdentity { + return predicate.UserIdentity(sql.FieldNEQ(FieldClaimsEmail, v)) +} + +// ClaimsEmailIn applies the In predicate on the "claims_email" field. +func ClaimsEmailIn(vs ...string) predicate.UserIdentity { + return predicate.UserIdentity(sql.FieldIn(FieldClaimsEmail, vs...)) +} + +// ClaimsEmailNotIn applies the NotIn predicate on the "claims_email" field. +func ClaimsEmailNotIn(vs ...string) predicate.UserIdentity { + return predicate.UserIdentity(sql.FieldNotIn(FieldClaimsEmail, vs...)) +} + +// ClaimsEmailGT applies the GT predicate on the "claims_email" field. +func ClaimsEmailGT(v string) predicate.UserIdentity { + return predicate.UserIdentity(sql.FieldGT(FieldClaimsEmail, v)) +} + +// ClaimsEmailGTE applies the GTE predicate on the "claims_email" field. +func ClaimsEmailGTE(v string) predicate.UserIdentity { + return predicate.UserIdentity(sql.FieldGTE(FieldClaimsEmail, v)) +} + +// ClaimsEmailLT applies the LT predicate on the "claims_email" field. +func ClaimsEmailLT(v string) predicate.UserIdentity { + return predicate.UserIdentity(sql.FieldLT(FieldClaimsEmail, v)) +} + +// ClaimsEmailLTE applies the LTE predicate on the "claims_email" field. +func ClaimsEmailLTE(v string) predicate.UserIdentity { + return predicate.UserIdentity(sql.FieldLTE(FieldClaimsEmail, v)) +} + +// ClaimsEmailContains applies the Contains predicate on the "claims_email" field. +func ClaimsEmailContains(v string) predicate.UserIdentity { + return predicate.UserIdentity(sql.FieldContains(FieldClaimsEmail, v)) +} + +// ClaimsEmailHasPrefix applies the HasPrefix predicate on the "claims_email" field. +func ClaimsEmailHasPrefix(v string) predicate.UserIdentity { + return predicate.UserIdentity(sql.FieldHasPrefix(FieldClaimsEmail, v)) +} + +// ClaimsEmailHasSuffix applies the HasSuffix predicate on the "claims_email" field. +func ClaimsEmailHasSuffix(v string) predicate.UserIdentity { + return predicate.UserIdentity(sql.FieldHasSuffix(FieldClaimsEmail, v)) +} + +// ClaimsEmailEqualFold applies the EqualFold predicate on the "claims_email" field. +func ClaimsEmailEqualFold(v string) predicate.UserIdentity { + return predicate.UserIdentity(sql.FieldEqualFold(FieldClaimsEmail, v)) +} + +// ClaimsEmailContainsFold applies the ContainsFold predicate on the "claims_email" field. +func ClaimsEmailContainsFold(v string) predicate.UserIdentity { + return predicate.UserIdentity(sql.FieldContainsFold(FieldClaimsEmail, v)) +} + +// ClaimsEmailVerifiedEQ applies the EQ predicate on the "claims_email_verified" field. +func ClaimsEmailVerifiedEQ(v bool) predicate.UserIdentity { + return predicate.UserIdentity(sql.FieldEQ(FieldClaimsEmailVerified, v)) +} + +// ClaimsEmailVerifiedNEQ applies the NEQ predicate on the "claims_email_verified" field. +func ClaimsEmailVerifiedNEQ(v bool) predicate.UserIdentity { + return predicate.UserIdentity(sql.FieldNEQ(FieldClaimsEmailVerified, v)) +} + +// ClaimsGroupsIsNil applies the IsNil predicate on the "claims_groups" field. +func ClaimsGroupsIsNil() predicate.UserIdentity { + return predicate.UserIdentity(sql.FieldIsNull(FieldClaimsGroups)) +} + +// ClaimsGroupsNotNil applies the NotNil predicate on the "claims_groups" field. +func ClaimsGroupsNotNil() predicate.UserIdentity { + return predicate.UserIdentity(sql.FieldNotNull(FieldClaimsGroups)) +} + +// ConsentsEQ applies the EQ predicate on the "consents" field. +func ConsentsEQ(v []byte) predicate.UserIdentity { + return predicate.UserIdentity(sql.FieldEQ(FieldConsents, v)) +} + +// ConsentsNEQ applies the NEQ predicate on the "consents" field. +func ConsentsNEQ(v []byte) predicate.UserIdentity { + return predicate.UserIdentity(sql.FieldNEQ(FieldConsents, v)) +} + +// ConsentsIn applies the In predicate on the "consents" field. +func ConsentsIn(vs ...[]byte) predicate.UserIdentity { + return predicate.UserIdentity(sql.FieldIn(FieldConsents, vs...)) +} + +// ConsentsNotIn applies the NotIn predicate on the "consents" field. +func ConsentsNotIn(vs ...[]byte) predicate.UserIdentity { + return predicate.UserIdentity(sql.FieldNotIn(FieldConsents, vs...)) +} + +// ConsentsGT applies the GT predicate on the "consents" field. +func ConsentsGT(v []byte) predicate.UserIdentity { + return predicate.UserIdentity(sql.FieldGT(FieldConsents, v)) +} + +// ConsentsGTE applies the GTE predicate on the "consents" field. +func ConsentsGTE(v []byte) predicate.UserIdentity { + return predicate.UserIdentity(sql.FieldGTE(FieldConsents, v)) +} + +// ConsentsLT applies the LT predicate on the "consents" field. +func ConsentsLT(v []byte) predicate.UserIdentity { + return predicate.UserIdentity(sql.FieldLT(FieldConsents, v)) +} + +// ConsentsLTE applies the LTE predicate on the "consents" field. +func ConsentsLTE(v []byte) predicate.UserIdentity { + return predicate.UserIdentity(sql.FieldLTE(FieldConsents, v)) +} + +// CreatedAtEQ applies the EQ predicate on the "created_at" field. +func CreatedAtEQ(v time.Time) predicate.UserIdentity { + return predicate.UserIdentity(sql.FieldEQ(FieldCreatedAt, v)) +} + +// CreatedAtNEQ applies the NEQ predicate on the "created_at" field. +func CreatedAtNEQ(v time.Time) predicate.UserIdentity { + return predicate.UserIdentity(sql.FieldNEQ(FieldCreatedAt, v)) +} + +// CreatedAtIn applies the In predicate on the "created_at" field. +func CreatedAtIn(vs ...time.Time) predicate.UserIdentity { + return predicate.UserIdentity(sql.FieldIn(FieldCreatedAt, vs...)) +} + +// CreatedAtNotIn applies the NotIn predicate on the "created_at" field. +func CreatedAtNotIn(vs ...time.Time) predicate.UserIdentity { + return predicate.UserIdentity(sql.FieldNotIn(FieldCreatedAt, vs...)) +} + +// CreatedAtGT applies the GT predicate on the "created_at" field. +func CreatedAtGT(v time.Time) predicate.UserIdentity { + return predicate.UserIdentity(sql.FieldGT(FieldCreatedAt, v)) +} + +// CreatedAtGTE applies the GTE predicate on the "created_at" field. +func CreatedAtGTE(v time.Time) predicate.UserIdentity { + return predicate.UserIdentity(sql.FieldGTE(FieldCreatedAt, v)) +} + +// CreatedAtLT applies the LT predicate on the "created_at" field. +func CreatedAtLT(v time.Time) predicate.UserIdentity { + return predicate.UserIdentity(sql.FieldLT(FieldCreatedAt, v)) +} + +// CreatedAtLTE applies the LTE predicate on the "created_at" field. +func CreatedAtLTE(v time.Time) predicate.UserIdentity { + return predicate.UserIdentity(sql.FieldLTE(FieldCreatedAt, v)) +} + +// LastLoginEQ applies the EQ predicate on the "last_login" field. +func LastLoginEQ(v time.Time) predicate.UserIdentity { + return predicate.UserIdentity(sql.FieldEQ(FieldLastLogin, v)) +} + +// LastLoginNEQ applies the NEQ predicate on the "last_login" field. +func LastLoginNEQ(v time.Time) predicate.UserIdentity { + return predicate.UserIdentity(sql.FieldNEQ(FieldLastLogin, v)) +} + +// LastLoginIn applies the In predicate on the "last_login" field. +func LastLoginIn(vs ...time.Time) predicate.UserIdentity { + return predicate.UserIdentity(sql.FieldIn(FieldLastLogin, vs...)) +} + +// LastLoginNotIn applies the NotIn predicate on the "last_login" field. +func LastLoginNotIn(vs ...time.Time) predicate.UserIdentity { + return predicate.UserIdentity(sql.FieldNotIn(FieldLastLogin, vs...)) +} + +// LastLoginGT applies the GT predicate on the "last_login" field. +func LastLoginGT(v time.Time) predicate.UserIdentity { + return predicate.UserIdentity(sql.FieldGT(FieldLastLogin, v)) +} + +// LastLoginGTE applies the GTE predicate on the "last_login" field. +func LastLoginGTE(v time.Time) predicate.UserIdentity { + return predicate.UserIdentity(sql.FieldGTE(FieldLastLogin, v)) +} + +// LastLoginLT applies the LT predicate on the "last_login" field. +func LastLoginLT(v time.Time) predicate.UserIdentity { + return predicate.UserIdentity(sql.FieldLT(FieldLastLogin, v)) +} + +// LastLoginLTE applies the LTE predicate on the "last_login" field. +func LastLoginLTE(v time.Time) predicate.UserIdentity { + return predicate.UserIdentity(sql.FieldLTE(FieldLastLogin, v)) +} + +// BlockedUntilEQ applies the EQ predicate on the "blocked_until" field. +func BlockedUntilEQ(v time.Time) predicate.UserIdentity { + return predicate.UserIdentity(sql.FieldEQ(FieldBlockedUntil, v)) +} + +// BlockedUntilNEQ applies the NEQ predicate on the "blocked_until" field. +func BlockedUntilNEQ(v time.Time) predicate.UserIdentity { + return predicate.UserIdentity(sql.FieldNEQ(FieldBlockedUntil, v)) +} + +// BlockedUntilIn applies the In predicate on the "blocked_until" field. +func BlockedUntilIn(vs ...time.Time) predicate.UserIdentity { + return predicate.UserIdentity(sql.FieldIn(FieldBlockedUntil, vs...)) +} + +// BlockedUntilNotIn applies the NotIn predicate on the "blocked_until" field. +func BlockedUntilNotIn(vs ...time.Time) predicate.UserIdentity { + return predicate.UserIdentity(sql.FieldNotIn(FieldBlockedUntil, vs...)) +} + +// BlockedUntilGT applies the GT predicate on the "blocked_until" field. +func BlockedUntilGT(v time.Time) predicate.UserIdentity { + return predicate.UserIdentity(sql.FieldGT(FieldBlockedUntil, v)) +} + +// BlockedUntilGTE applies the GTE predicate on the "blocked_until" field. +func BlockedUntilGTE(v time.Time) predicate.UserIdentity { + return predicate.UserIdentity(sql.FieldGTE(FieldBlockedUntil, v)) +} + +// BlockedUntilLT applies the LT predicate on the "blocked_until" field. +func BlockedUntilLT(v time.Time) predicate.UserIdentity { + return predicate.UserIdentity(sql.FieldLT(FieldBlockedUntil, v)) +} + +// BlockedUntilLTE applies the LTE predicate on the "blocked_until" field. +func BlockedUntilLTE(v time.Time) predicate.UserIdentity { + return predicate.UserIdentity(sql.FieldLTE(FieldBlockedUntil, v)) +} + +// And groups predicates with the AND operator between them. +func And(predicates ...predicate.UserIdentity) predicate.UserIdentity { + return predicate.UserIdentity(sql.AndPredicates(predicates...)) +} + +// Or groups predicates with the OR operator between them. +func Or(predicates ...predicate.UserIdentity) predicate.UserIdentity { + return predicate.UserIdentity(sql.OrPredicates(predicates...)) +} + +// Not applies the not operator on the given predicate. +func Not(p predicate.UserIdentity) predicate.UserIdentity { + return predicate.UserIdentity(sql.NotPredicates(p)) +} diff --git a/storage/ent/db/useridentity_create.go b/storage/ent/db/useridentity_create.go new file mode 100644 index 00000000..336d5c30 --- /dev/null +++ b/storage/ent/db/useridentity_create.go @@ -0,0 +1,416 @@ +// Code generated by ent, DO NOT EDIT. + +package db + +import ( + "context" + "errors" + "fmt" + "time" + + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/dexidp/dex/storage/ent/db/useridentity" +) + +// UserIdentityCreate is the builder for creating a UserIdentity entity. +type UserIdentityCreate struct { + config + mutation *UserIdentityMutation + hooks []Hook +} + +// SetUserID sets the "user_id" field. +func (_c *UserIdentityCreate) SetUserID(v string) *UserIdentityCreate { + _c.mutation.SetUserID(v) + return _c +} + +// SetConnectorID sets the "connector_id" field. +func (_c *UserIdentityCreate) SetConnectorID(v string) *UserIdentityCreate { + _c.mutation.SetConnectorID(v) + return _c +} + +// SetClaimsUserID sets the "claims_user_id" field. +func (_c *UserIdentityCreate) SetClaimsUserID(v string) *UserIdentityCreate { + _c.mutation.SetClaimsUserID(v) + return _c +} + +// SetNillableClaimsUserID sets the "claims_user_id" field if the given value is not nil. +func (_c *UserIdentityCreate) SetNillableClaimsUserID(v *string) *UserIdentityCreate { + if v != nil { + _c.SetClaimsUserID(*v) + } + return _c +} + +// SetClaimsUsername sets the "claims_username" field. +func (_c *UserIdentityCreate) SetClaimsUsername(v string) *UserIdentityCreate { + _c.mutation.SetClaimsUsername(v) + return _c +} + +// SetNillableClaimsUsername sets the "claims_username" field if the given value is not nil. +func (_c *UserIdentityCreate) SetNillableClaimsUsername(v *string) *UserIdentityCreate { + if v != nil { + _c.SetClaimsUsername(*v) + } + return _c +} + +// SetClaimsPreferredUsername sets the "claims_preferred_username" field. +func (_c *UserIdentityCreate) SetClaimsPreferredUsername(v string) *UserIdentityCreate { + _c.mutation.SetClaimsPreferredUsername(v) + return _c +} + +// SetNillableClaimsPreferredUsername sets the "claims_preferred_username" field if the given value is not nil. +func (_c *UserIdentityCreate) SetNillableClaimsPreferredUsername(v *string) *UserIdentityCreate { + if v != nil { + _c.SetClaimsPreferredUsername(*v) + } + return _c +} + +// SetClaimsEmail sets the "claims_email" field. +func (_c *UserIdentityCreate) SetClaimsEmail(v string) *UserIdentityCreate { + _c.mutation.SetClaimsEmail(v) + return _c +} + +// SetNillableClaimsEmail sets the "claims_email" field if the given value is not nil. +func (_c *UserIdentityCreate) SetNillableClaimsEmail(v *string) *UserIdentityCreate { + if v != nil { + _c.SetClaimsEmail(*v) + } + return _c +} + +// SetClaimsEmailVerified sets the "claims_email_verified" field. +func (_c *UserIdentityCreate) SetClaimsEmailVerified(v bool) *UserIdentityCreate { + _c.mutation.SetClaimsEmailVerified(v) + return _c +} + +// SetNillableClaimsEmailVerified sets the "claims_email_verified" field if the given value is not nil. +func (_c *UserIdentityCreate) SetNillableClaimsEmailVerified(v *bool) *UserIdentityCreate { + if v != nil { + _c.SetClaimsEmailVerified(*v) + } + return _c +} + +// SetClaimsGroups sets the "claims_groups" field. +func (_c *UserIdentityCreate) SetClaimsGroups(v []string) *UserIdentityCreate { + _c.mutation.SetClaimsGroups(v) + return _c +} + +// SetConsents sets the "consents" field. +func (_c *UserIdentityCreate) SetConsents(v []byte) *UserIdentityCreate { + _c.mutation.SetConsents(v) + return _c +} + +// SetCreatedAt sets the "created_at" field. +func (_c *UserIdentityCreate) SetCreatedAt(v time.Time) *UserIdentityCreate { + _c.mutation.SetCreatedAt(v) + return _c +} + +// SetLastLogin sets the "last_login" field. +func (_c *UserIdentityCreate) SetLastLogin(v time.Time) *UserIdentityCreate { + _c.mutation.SetLastLogin(v) + return _c +} + +// SetBlockedUntil sets the "blocked_until" field. +func (_c *UserIdentityCreate) SetBlockedUntil(v time.Time) *UserIdentityCreate { + _c.mutation.SetBlockedUntil(v) + return _c +} + +// SetID sets the "id" field. +func (_c *UserIdentityCreate) SetID(v string) *UserIdentityCreate { + _c.mutation.SetID(v) + return _c +} + +// Mutation returns the UserIdentityMutation object of the builder. +func (_c *UserIdentityCreate) Mutation() *UserIdentityMutation { + return _c.mutation +} + +// Save creates the UserIdentity in the database. +func (_c *UserIdentityCreate) Save(ctx context.Context) (*UserIdentity, error) { + _c.defaults() + return withHooks(ctx, _c.sqlSave, _c.mutation, _c.hooks) +} + +// SaveX calls Save and panics if Save returns an error. +func (_c *UserIdentityCreate) SaveX(ctx context.Context) *UserIdentity { + v, err := _c.Save(ctx) + if err != nil { + panic(err) + } + return v +} + +// Exec executes the query. +func (_c *UserIdentityCreate) Exec(ctx context.Context) error { + _, err := _c.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_c *UserIdentityCreate) ExecX(ctx context.Context) { + if err := _c.Exec(ctx); err != nil { + panic(err) + } +} + +// defaults sets the default values of the builder before save. +func (_c *UserIdentityCreate) defaults() { + if _, ok := _c.mutation.ClaimsUserID(); !ok { + v := useridentity.DefaultClaimsUserID + _c.mutation.SetClaimsUserID(v) + } + if _, ok := _c.mutation.ClaimsUsername(); !ok { + v := useridentity.DefaultClaimsUsername + _c.mutation.SetClaimsUsername(v) + } + if _, ok := _c.mutation.ClaimsPreferredUsername(); !ok { + v := useridentity.DefaultClaimsPreferredUsername + _c.mutation.SetClaimsPreferredUsername(v) + } + if _, ok := _c.mutation.ClaimsEmail(); !ok { + v := useridentity.DefaultClaimsEmail + _c.mutation.SetClaimsEmail(v) + } + if _, ok := _c.mutation.ClaimsEmailVerified(); !ok { + v := useridentity.DefaultClaimsEmailVerified + _c.mutation.SetClaimsEmailVerified(v) + } +} + +// check runs all checks and user-defined validators on the builder. +func (_c *UserIdentityCreate) check() error { + if _, ok := _c.mutation.UserID(); !ok { + return &ValidationError{Name: "user_id", err: errors.New(`db: missing required field "UserIdentity.user_id"`)} + } + if v, ok := _c.mutation.UserID(); ok { + if err := useridentity.UserIDValidator(v); err != nil { + return &ValidationError{Name: "user_id", err: fmt.Errorf(`db: validator failed for field "UserIdentity.user_id": %w`, err)} + } + } + if _, ok := _c.mutation.ConnectorID(); !ok { + return &ValidationError{Name: "connector_id", err: errors.New(`db: missing required field "UserIdentity.connector_id"`)} + } + if v, ok := _c.mutation.ConnectorID(); ok { + if err := useridentity.ConnectorIDValidator(v); err != nil { + return &ValidationError{Name: "connector_id", err: fmt.Errorf(`db: validator failed for field "UserIdentity.connector_id": %w`, err)} + } + } + if _, ok := _c.mutation.ClaimsUserID(); !ok { + return &ValidationError{Name: "claims_user_id", err: errors.New(`db: missing required field "UserIdentity.claims_user_id"`)} + } + if _, ok := _c.mutation.ClaimsUsername(); !ok { + return &ValidationError{Name: "claims_username", err: errors.New(`db: missing required field "UserIdentity.claims_username"`)} + } + if _, ok := _c.mutation.ClaimsPreferredUsername(); !ok { + return &ValidationError{Name: "claims_preferred_username", err: errors.New(`db: missing required field "UserIdentity.claims_preferred_username"`)} + } + if _, ok := _c.mutation.ClaimsEmail(); !ok { + return &ValidationError{Name: "claims_email", err: errors.New(`db: missing required field "UserIdentity.claims_email"`)} + } + if _, ok := _c.mutation.ClaimsEmailVerified(); !ok { + return &ValidationError{Name: "claims_email_verified", err: errors.New(`db: missing required field "UserIdentity.claims_email_verified"`)} + } + if _, ok := _c.mutation.Consents(); !ok { + return &ValidationError{Name: "consents", err: errors.New(`db: missing required field "UserIdentity.consents"`)} + } + if _, ok := _c.mutation.CreatedAt(); !ok { + return &ValidationError{Name: "created_at", err: errors.New(`db: missing required field "UserIdentity.created_at"`)} + } + if _, ok := _c.mutation.LastLogin(); !ok { + return &ValidationError{Name: "last_login", err: errors.New(`db: missing required field "UserIdentity.last_login"`)} + } + if _, ok := _c.mutation.BlockedUntil(); !ok { + return &ValidationError{Name: "blocked_until", err: errors.New(`db: missing required field "UserIdentity.blocked_until"`)} + } + if v, ok := _c.mutation.ID(); ok { + if err := useridentity.IDValidator(v); err != nil { + return &ValidationError{Name: "id", err: fmt.Errorf(`db: validator failed for field "UserIdentity.id": %w`, err)} + } + } + return nil +} + +func (_c *UserIdentityCreate) sqlSave(ctx context.Context) (*UserIdentity, error) { + if err := _c.check(); err != nil { + return nil, err + } + _node, _spec := _c.createSpec() + if err := sqlgraph.CreateNode(ctx, _c.driver, _spec); err != nil { + if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return nil, err + } + if _spec.ID.Value != nil { + if id, ok := _spec.ID.Value.(string); ok { + _node.ID = id + } else { + return nil, fmt.Errorf("unexpected UserIdentity.ID type: %T", _spec.ID.Value) + } + } + _c.mutation.id = &_node.ID + _c.mutation.done = true + return _node, nil +} + +func (_c *UserIdentityCreate) createSpec() (*UserIdentity, *sqlgraph.CreateSpec) { + var ( + _node = &UserIdentity{config: _c.config} + _spec = sqlgraph.NewCreateSpec(useridentity.Table, sqlgraph.NewFieldSpec(useridentity.FieldID, field.TypeString)) + ) + if id, ok := _c.mutation.ID(); ok { + _node.ID = id + _spec.ID.Value = id + } + if value, ok := _c.mutation.UserID(); ok { + _spec.SetField(useridentity.FieldUserID, field.TypeString, value) + _node.UserID = value + } + if value, ok := _c.mutation.ConnectorID(); ok { + _spec.SetField(useridentity.FieldConnectorID, field.TypeString, value) + _node.ConnectorID = value + } + if value, ok := _c.mutation.ClaimsUserID(); ok { + _spec.SetField(useridentity.FieldClaimsUserID, field.TypeString, value) + _node.ClaimsUserID = value + } + if value, ok := _c.mutation.ClaimsUsername(); ok { + _spec.SetField(useridentity.FieldClaimsUsername, field.TypeString, value) + _node.ClaimsUsername = value + } + if value, ok := _c.mutation.ClaimsPreferredUsername(); ok { + _spec.SetField(useridentity.FieldClaimsPreferredUsername, field.TypeString, value) + _node.ClaimsPreferredUsername = value + } + if value, ok := _c.mutation.ClaimsEmail(); ok { + _spec.SetField(useridentity.FieldClaimsEmail, field.TypeString, value) + _node.ClaimsEmail = value + } + if value, ok := _c.mutation.ClaimsEmailVerified(); ok { + _spec.SetField(useridentity.FieldClaimsEmailVerified, field.TypeBool, value) + _node.ClaimsEmailVerified = value + } + if value, ok := _c.mutation.ClaimsGroups(); ok { + _spec.SetField(useridentity.FieldClaimsGroups, field.TypeJSON, value) + _node.ClaimsGroups = value + } + if value, ok := _c.mutation.Consents(); ok { + _spec.SetField(useridentity.FieldConsents, field.TypeBytes, value) + _node.Consents = value + } + if value, ok := _c.mutation.CreatedAt(); ok { + _spec.SetField(useridentity.FieldCreatedAt, field.TypeTime, value) + _node.CreatedAt = value + } + if value, ok := _c.mutation.LastLogin(); ok { + _spec.SetField(useridentity.FieldLastLogin, field.TypeTime, value) + _node.LastLogin = value + } + if value, ok := _c.mutation.BlockedUntil(); ok { + _spec.SetField(useridentity.FieldBlockedUntil, field.TypeTime, value) + _node.BlockedUntil = value + } + return _node, _spec +} + +// UserIdentityCreateBulk is the builder for creating many UserIdentity entities in bulk. +type UserIdentityCreateBulk struct { + config + err error + builders []*UserIdentityCreate +} + +// Save creates the UserIdentity entities in the database. +func (_c *UserIdentityCreateBulk) Save(ctx context.Context) ([]*UserIdentity, error) { + if _c.err != nil { + return nil, _c.err + } + specs := make([]*sqlgraph.CreateSpec, len(_c.builders)) + nodes := make([]*UserIdentity, len(_c.builders)) + mutators := make([]Mutator, len(_c.builders)) + for i := range _c.builders { + func(i int, root context.Context) { + builder := _c.builders[i] + builder.defaults() + var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { + mutation, ok := m.(*UserIdentityMutation) + if !ok { + return nil, fmt.Errorf("unexpected mutation type %T", m) + } + if err := builder.check(); err != nil { + return nil, err + } + builder.mutation = mutation + var err error + nodes[i], specs[i] = builder.createSpec() + if i < len(mutators)-1 { + _, err = mutators[i+1].Mutate(root, _c.builders[i+1].mutation) + } else { + spec := &sqlgraph.BatchCreateSpec{Nodes: specs} + // Invoke the actual operation on the latest mutation in the chain. + if err = sqlgraph.BatchCreate(ctx, _c.driver, spec); err != nil { + if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + } + } + if err != nil { + return nil, err + } + mutation.id = &nodes[i].ID + mutation.done = true + return nodes[i], nil + }) + for i := len(builder.hooks) - 1; i >= 0; i-- { + mut = builder.hooks[i](mut) + } + mutators[i] = mut + }(i, ctx) + } + if len(mutators) > 0 { + if _, err := mutators[0].Mutate(ctx, _c.builders[0].mutation); err != nil { + return nil, err + } + } + return nodes, nil +} + +// SaveX is like Save, but panics if an error occurs. +func (_c *UserIdentityCreateBulk) SaveX(ctx context.Context) []*UserIdentity { + v, err := _c.Save(ctx) + if err != nil { + panic(err) + } + return v +} + +// Exec executes the query. +func (_c *UserIdentityCreateBulk) Exec(ctx context.Context) error { + _, err := _c.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_c *UserIdentityCreateBulk) ExecX(ctx context.Context) { + if err := _c.Exec(ctx); err != nil { + panic(err) + } +} diff --git a/storage/ent/db/useridentity_delete.go b/storage/ent/db/useridentity_delete.go new file mode 100644 index 00000000..0bc51e24 --- /dev/null +++ b/storage/ent/db/useridentity_delete.go @@ -0,0 +1,88 @@ +// Code generated by ent, DO NOT EDIT. + +package db + +import ( + "context" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/dexidp/dex/storage/ent/db/predicate" + "github.com/dexidp/dex/storage/ent/db/useridentity" +) + +// UserIdentityDelete is the builder for deleting a UserIdentity entity. +type UserIdentityDelete struct { + config + hooks []Hook + mutation *UserIdentityMutation +} + +// Where appends a list predicates to the UserIdentityDelete builder. +func (_d *UserIdentityDelete) Where(ps ...predicate.UserIdentity) *UserIdentityDelete { + _d.mutation.Where(ps...) + return _d +} + +// Exec executes the deletion query and returns how many vertices were deleted. +func (_d *UserIdentityDelete) Exec(ctx context.Context) (int, error) { + return withHooks(ctx, _d.sqlExec, _d.mutation, _d.hooks) +} + +// ExecX is like Exec, but panics if an error occurs. +func (_d *UserIdentityDelete) ExecX(ctx context.Context) int { + n, err := _d.Exec(ctx) + if err != nil { + panic(err) + } + return n +} + +func (_d *UserIdentityDelete) sqlExec(ctx context.Context) (int, error) { + _spec := sqlgraph.NewDeleteSpec(useridentity.Table, sqlgraph.NewFieldSpec(useridentity.FieldID, field.TypeString)) + if ps := _d.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + affected, err := sqlgraph.DeleteNodes(ctx, _d.driver, _spec) + if err != nil && sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + _d.mutation.done = true + return affected, err +} + +// UserIdentityDeleteOne is the builder for deleting a single UserIdentity entity. +type UserIdentityDeleteOne struct { + _d *UserIdentityDelete +} + +// Where appends a list predicates to the UserIdentityDelete builder. +func (_d *UserIdentityDeleteOne) Where(ps ...predicate.UserIdentity) *UserIdentityDeleteOne { + _d._d.mutation.Where(ps...) + return _d +} + +// Exec executes the deletion query. +func (_d *UserIdentityDeleteOne) Exec(ctx context.Context) error { + n, err := _d._d.Exec(ctx) + switch { + case err != nil: + return err + case n == 0: + return &NotFoundError{useridentity.Label} + default: + return nil + } +} + +// ExecX is like Exec, but panics if an error occurs. +func (_d *UserIdentityDeleteOne) ExecX(ctx context.Context) { + if err := _d.Exec(ctx); err != nil { + panic(err) + } +} diff --git a/storage/ent/db/useridentity_query.go b/storage/ent/db/useridentity_query.go new file mode 100644 index 00000000..3e509038 --- /dev/null +++ b/storage/ent/db/useridentity_query.go @@ -0,0 +1,527 @@ +// Code generated by ent, DO NOT EDIT. + +package db + +import ( + "context" + "fmt" + "math" + + "entgo.io/ent" + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/dexidp/dex/storage/ent/db/predicate" + "github.com/dexidp/dex/storage/ent/db/useridentity" +) + +// UserIdentityQuery is the builder for querying UserIdentity entities. +type UserIdentityQuery struct { + config + ctx *QueryContext + order []useridentity.OrderOption + inters []Interceptor + predicates []predicate.UserIdentity + // intermediate query (i.e. traversal path). + sql *sql.Selector + path func(context.Context) (*sql.Selector, error) +} + +// Where adds a new predicate for the UserIdentityQuery builder. +func (_q *UserIdentityQuery) Where(ps ...predicate.UserIdentity) *UserIdentityQuery { + _q.predicates = append(_q.predicates, ps...) + return _q +} + +// Limit the number of records to be returned by this query. +func (_q *UserIdentityQuery) Limit(limit int) *UserIdentityQuery { + _q.ctx.Limit = &limit + return _q +} + +// Offset to start from. +func (_q *UserIdentityQuery) Offset(offset int) *UserIdentityQuery { + _q.ctx.Offset = &offset + return _q +} + +// Unique configures the query builder to filter duplicate records on query. +// By default, unique is set to true, and can be disabled using this method. +func (_q *UserIdentityQuery) Unique(unique bool) *UserIdentityQuery { + _q.ctx.Unique = &unique + return _q +} + +// Order specifies how the records should be ordered. +func (_q *UserIdentityQuery) Order(o ...useridentity.OrderOption) *UserIdentityQuery { + _q.order = append(_q.order, o...) + return _q +} + +// First returns the first UserIdentity entity from the query. +// Returns a *NotFoundError when no UserIdentity was found. +func (_q *UserIdentityQuery) First(ctx context.Context) (*UserIdentity, error) { + nodes, err := _q.Limit(1).All(setContextOp(ctx, _q.ctx, ent.OpQueryFirst)) + if err != nil { + return nil, err + } + if len(nodes) == 0 { + return nil, &NotFoundError{useridentity.Label} + } + return nodes[0], nil +} + +// FirstX is like First, but panics if an error occurs. +func (_q *UserIdentityQuery) FirstX(ctx context.Context) *UserIdentity { + node, err := _q.First(ctx) + if err != nil && !IsNotFound(err) { + panic(err) + } + return node +} + +// FirstID returns the first UserIdentity ID from the query. +// Returns a *NotFoundError when no UserIdentity ID was found. +func (_q *UserIdentityQuery) FirstID(ctx context.Context) (id string, err error) { + var ids []string + if ids, err = _q.Limit(1).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryFirstID)); err != nil { + return + } + if len(ids) == 0 { + err = &NotFoundError{useridentity.Label} + return + } + return ids[0], nil +} + +// FirstIDX is like FirstID, but panics if an error occurs. +func (_q *UserIdentityQuery) FirstIDX(ctx context.Context) string { + id, err := _q.FirstID(ctx) + if err != nil && !IsNotFound(err) { + panic(err) + } + return id +} + +// Only returns a single UserIdentity entity found by the query, ensuring it only returns one. +// Returns a *NotSingularError when more than one UserIdentity entity is found. +// Returns a *NotFoundError when no UserIdentity entities are found. +func (_q *UserIdentityQuery) Only(ctx context.Context) (*UserIdentity, error) { + nodes, err := _q.Limit(2).All(setContextOp(ctx, _q.ctx, ent.OpQueryOnly)) + if err != nil { + return nil, err + } + switch len(nodes) { + case 1: + return nodes[0], nil + case 0: + return nil, &NotFoundError{useridentity.Label} + default: + return nil, &NotSingularError{useridentity.Label} + } +} + +// OnlyX is like Only, but panics if an error occurs. +func (_q *UserIdentityQuery) OnlyX(ctx context.Context) *UserIdentity { + node, err := _q.Only(ctx) + if err != nil { + panic(err) + } + return node +} + +// OnlyID is like Only, but returns the only UserIdentity ID in the query. +// Returns a *NotSingularError when more than one UserIdentity ID is found. +// Returns a *NotFoundError when no entities are found. +func (_q *UserIdentityQuery) OnlyID(ctx context.Context) (id string, err error) { + var ids []string + if ids, err = _q.Limit(2).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryOnlyID)); err != nil { + return + } + switch len(ids) { + case 1: + id = ids[0] + case 0: + err = &NotFoundError{useridentity.Label} + default: + err = &NotSingularError{useridentity.Label} + } + return +} + +// OnlyIDX is like OnlyID, but panics if an error occurs. +func (_q *UserIdentityQuery) OnlyIDX(ctx context.Context) string { + id, err := _q.OnlyID(ctx) + if err != nil { + panic(err) + } + return id +} + +// All executes the query and returns a list of UserIdentities. +func (_q *UserIdentityQuery) All(ctx context.Context) ([]*UserIdentity, error) { + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryAll) + if err := _q.prepareQuery(ctx); err != nil { + return nil, err + } + qr := querierAll[[]*UserIdentity, *UserIdentityQuery]() + return withInterceptors[[]*UserIdentity](ctx, _q, qr, _q.inters) +} + +// AllX is like All, but panics if an error occurs. +func (_q *UserIdentityQuery) AllX(ctx context.Context) []*UserIdentity { + nodes, err := _q.All(ctx) + if err != nil { + panic(err) + } + return nodes +} + +// IDs executes the query and returns a list of UserIdentity IDs. +func (_q *UserIdentityQuery) IDs(ctx context.Context) (ids []string, err error) { + if _q.ctx.Unique == nil && _q.path != nil { + _q.Unique(true) + } + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryIDs) + if err = _q.Select(useridentity.FieldID).Scan(ctx, &ids); err != nil { + return nil, err + } + return ids, nil +} + +// IDsX is like IDs, but panics if an error occurs. +func (_q *UserIdentityQuery) IDsX(ctx context.Context) []string { + ids, err := _q.IDs(ctx) + if err != nil { + panic(err) + } + return ids +} + +// Count returns the count of the given query. +func (_q *UserIdentityQuery) Count(ctx context.Context) (int, error) { + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryCount) + if err := _q.prepareQuery(ctx); err != nil { + return 0, err + } + return withInterceptors[int](ctx, _q, querierCount[*UserIdentityQuery](), _q.inters) +} + +// CountX is like Count, but panics if an error occurs. +func (_q *UserIdentityQuery) CountX(ctx context.Context) int { + count, err := _q.Count(ctx) + if err != nil { + panic(err) + } + return count +} + +// Exist returns true if the query has elements in the graph. +func (_q *UserIdentityQuery) Exist(ctx context.Context) (bool, error) { + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryExist) + switch _, err := _q.FirstID(ctx); { + case IsNotFound(err): + return false, nil + case err != nil: + return false, fmt.Errorf("db: check existence: %w", err) + default: + return true, nil + } +} + +// ExistX is like Exist, but panics if an error occurs. +func (_q *UserIdentityQuery) ExistX(ctx context.Context) bool { + exist, err := _q.Exist(ctx) + if err != nil { + panic(err) + } + return exist +} + +// Clone returns a duplicate of the UserIdentityQuery builder, including all associated steps. It can be +// used to prepare common query builders and use them differently after the clone is made. +func (_q *UserIdentityQuery) Clone() *UserIdentityQuery { + if _q == nil { + return nil + } + return &UserIdentityQuery{ + config: _q.config, + ctx: _q.ctx.Clone(), + order: append([]useridentity.OrderOption{}, _q.order...), + inters: append([]Interceptor{}, _q.inters...), + predicates: append([]predicate.UserIdentity{}, _q.predicates...), + // clone intermediate query. + sql: _q.sql.Clone(), + path: _q.path, + } +} + +// GroupBy is used to group vertices by one or more fields/columns. +// It is often used with aggregate functions, like: count, max, mean, min, sum. +// +// Example: +// +// var v []struct { +// UserID string `json:"user_id,omitempty"` +// Count int `json:"count,omitempty"` +// } +// +// client.UserIdentity.Query(). +// GroupBy(useridentity.FieldUserID). +// Aggregate(db.Count()). +// Scan(ctx, &v) +func (_q *UserIdentityQuery) GroupBy(field string, fields ...string) *UserIdentityGroupBy { + _q.ctx.Fields = append([]string{field}, fields...) + grbuild := &UserIdentityGroupBy{build: _q} + grbuild.flds = &_q.ctx.Fields + grbuild.label = useridentity.Label + grbuild.scan = grbuild.Scan + return grbuild +} + +// Select allows the selection one or more fields/columns for the given query, +// instead of selecting all fields in the entity. +// +// Example: +// +// var v []struct { +// UserID string `json:"user_id,omitempty"` +// } +// +// client.UserIdentity.Query(). +// Select(useridentity.FieldUserID). +// Scan(ctx, &v) +func (_q *UserIdentityQuery) Select(fields ...string) *UserIdentitySelect { + _q.ctx.Fields = append(_q.ctx.Fields, fields...) + sbuild := &UserIdentitySelect{UserIdentityQuery: _q} + sbuild.label = useridentity.Label + sbuild.flds, sbuild.scan = &_q.ctx.Fields, sbuild.Scan + return sbuild +} + +// Aggregate returns a UserIdentitySelect configured with the given aggregations. +func (_q *UserIdentityQuery) Aggregate(fns ...AggregateFunc) *UserIdentitySelect { + return _q.Select().Aggregate(fns...) +} + +func (_q *UserIdentityQuery) prepareQuery(ctx context.Context) error { + for _, inter := range _q.inters { + if inter == nil { + return fmt.Errorf("db: uninitialized interceptor (forgotten import db/runtime?)") + } + if trv, ok := inter.(Traverser); ok { + if err := trv.Traverse(ctx, _q); err != nil { + return err + } + } + } + for _, f := range _q.ctx.Fields { + if !useridentity.ValidColumn(f) { + return &ValidationError{Name: f, err: fmt.Errorf("db: invalid field %q for query", f)} + } + } + if _q.path != nil { + prev, err := _q.path(ctx) + if err != nil { + return err + } + _q.sql = prev + } + return nil +} + +func (_q *UserIdentityQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*UserIdentity, error) { + var ( + nodes = []*UserIdentity{} + _spec = _q.querySpec() + ) + _spec.ScanValues = func(columns []string) ([]any, error) { + return (*UserIdentity).scanValues(nil, columns) + } + _spec.Assign = func(columns []string, values []any) error { + node := &UserIdentity{config: _q.config} + nodes = append(nodes, node) + return node.assignValues(columns, values) + } + for i := range hooks { + hooks[i](ctx, _spec) + } + if err := sqlgraph.QueryNodes(ctx, _q.driver, _spec); err != nil { + return nil, err + } + if len(nodes) == 0 { + return nodes, nil + } + return nodes, nil +} + +func (_q *UserIdentityQuery) sqlCount(ctx context.Context) (int, error) { + _spec := _q.querySpec() + _spec.Node.Columns = _q.ctx.Fields + if len(_q.ctx.Fields) > 0 { + _spec.Unique = _q.ctx.Unique != nil && *_q.ctx.Unique + } + return sqlgraph.CountNodes(ctx, _q.driver, _spec) +} + +func (_q *UserIdentityQuery) querySpec() *sqlgraph.QuerySpec { + _spec := sqlgraph.NewQuerySpec(useridentity.Table, useridentity.Columns, sqlgraph.NewFieldSpec(useridentity.FieldID, field.TypeString)) + _spec.From = _q.sql + if unique := _q.ctx.Unique; unique != nil { + _spec.Unique = *unique + } else if _q.path != nil { + _spec.Unique = true + } + if fields := _q.ctx.Fields; len(fields) > 0 { + _spec.Node.Columns = make([]string, 0, len(fields)) + _spec.Node.Columns = append(_spec.Node.Columns, useridentity.FieldID) + for i := range fields { + if fields[i] != useridentity.FieldID { + _spec.Node.Columns = append(_spec.Node.Columns, fields[i]) + } + } + } + if ps := _q.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if limit := _q.ctx.Limit; limit != nil { + _spec.Limit = *limit + } + if offset := _q.ctx.Offset; offset != nil { + _spec.Offset = *offset + } + if ps := _q.order; len(ps) > 0 { + _spec.Order = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + return _spec +} + +func (_q *UserIdentityQuery) sqlQuery(ctx context.Context) *sql.Selector { + builder := sql.Dialect(_q.driver.Dialect()) + t1 := builder.Table(useridentity.Table) + columns := _q.ctx.Fields + if len(columns) == 0 { + columns = useridentity.Columns + } + selector := builder.Select(t1.Columns(columns...)...).From(t1) + if _q.sql != nil { + selector = _q.sql + selector.Select(selector.Columns(columns...)...) + } + if _q.ctx.Unique != nil && *_q.ctx.Unique { + selector.Distinct() + } + for _, p := range _q.predicates { + p(selector) + } + for _, p := range _q.order { + p(selector) + } + if offset := _q.ctx.Offset; offset != nil { + // limit is mandatory for offset clause. We start + // with default value, and override it below if needed. + selector.Offset(*offset).Limit(math.MaxInt32) + } + if limit := _q.ctx.Limit; limit != nil { + selector.Limit(*limit) + } + return selector +} + +// UserIdentityGroupBy is the group-by builder for UserIdentity entities. +type UserIdentityGroupBy struct { + selector + build *UserIdentityQuery +} + +// Aggregate adds the given aggregation functions to the group-by query. +func (_g *UserIdentityGroupBy) Aggregate(fns ...AggregateFunc) *UserIdentityGroupBy { + _g.fns = append(_g.fns, fns...) + return _g +} + +// Scan applies the selector query and scans the result into the given value. +func (_g *UserIdentityGroupBy) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, _g.build.ctx, ent.OpQueryGroupBy) + if err := _g.build.prepareQuery(ctx); err != nil { + return err + } + return scanWithInterceptors[*UserIdentityQuery, *UserIdentityGroupBy](ctx, _g.build, _g, _g.build.inters, v) +} + +func (_g *UserIdentityGroupBy) sqlScan(ctx context.Context, root *UserIdentityQuery, v any) error { + selector := root.sqlQuery(ctx).Select() + aggregation := make([]string, 0, len(_g.fns)) + for _, fn := range _g.fns { + aggregation = append(aggregation, fn(selector)) + } + if len(selector.SelectedColumns()) == 0 { + columns := make([]string, 0, len(*_g.flds)+len(_g.fns)) + for _, f := range *_g.flds { + columns = append(columns, selector.C(f)) + } + columns = append(columns, aggregation...) + selector.Select(columns...) + } + selector.GroupBy(selector.Columns(*_g.flds...)...) + if err := selector.Err(); err != nil { + return err + } + rows := &sql.Rows{} + query, args := selector.Query() + if err := _g.build.driver.Query(ctx, query, args, rows); err != nil { + return err + } + defer rows.Close() + return sql.ScanSlice(rows, v) +} + +// UserIdentitySelect is the builder for selecting fields of UserIdentity entities. +type UserIdentitySelect struct { + *UserIdentityQuery + selector +} + +// Aggregate adds the given aggregation functions to the selector query. +func (_s *UserIdentitySelect) Aggregate(fns ...AggregateFunc) *UserIdentitySelect { + _s.fns = append(_s.fns, fns...) + return _s +} + +// Scan applies the selector query and scans the result into the given value. +func (_s *UserIdentitySelect) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, _s.ctx, ent.OpQuerySelect) + if err := _s.prepareQuery(ctx); err != nil { + return err + } + return scanWithInterceptors[*UserIdentityQuery, *UserIdentitySelect](ctx, _s.UserIdentityQuery, _s, _s.inters, v) +} + +func (_s *UserIdentitySelect) sqlScan(ctx context.Context, root *UserIdentityQuery, v any) error { + selector := root.sqlQuery(ctx) + aggregation := make([]string, 0, len(_s.fns)) + for _, fn := range _s.fns { + aggregation = append(aggregation, fn(selector)) + } + switch n := len(*_s.selector.flds); { + case n == 0 && len(aggregation) > 0: + selector.Select(aggregation...) + case n != 0 && len(aggregation) > 0: + selector.AppendSelect(aggregation...) + } + rows := &sql.Rows{} + query, args := selector.Query() + if err := _s.driver.Query(ctx, query, args, rows); err != nil { + return err + } + defer rows.Close() + return sql.ScanSlice(rows, v) +} diff --git a/storage/ent/db/useridentity_update.go b/storage/ent/db/useridentity_update.go new file mode 100644 index 00000000..27ee0d3a --- /dev/null +++ b/storage/ent/db/useridentity_update.go @@ -0,0 +1,629 @@ +// Code generated by ent, DO NOT EDIT. + +package db + +import ( + "context" + "errors" + "fmt" + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/dialect/sql/sqljson" + "entgo.io/ent/schema/field" + "github.com/dexidp/dex/storage/ent/db/predicate" + "github.com/dexidp/dex/storage/ent/db/useridentity" +) + +// UserIdentityUpdate is the builder for updating UserIdentity entities. +type UserIdentityUpdate struct { + config + hooks []Hook + mutation *UserIdentityMutation +} + +// Where appends a list predicates to the UserIdentityUpdate builder. +func (_u *UserIdentityUpdate) Where(ps ...predicate.UserIdentity) *UserIdentityUpdate { + _u.mutation.Where(ps...) + return _u +} + +// SetUserID sets the "user_id" field. +func (_u *UserIdentityUpdate) SetUserID(v string) *UserIdentityUpdate { + _u.mutation.SetUserID(v) + return _u +} + +// SetNillableUserID sets the "user_id" field if the given value is not nil. +func (_u *UserIdentityUpdate) SetNillableUserID(v *string) *UserIdentityUpdate { + if v != nil { + _u.SetUserID(*v) + } + return _u +} + +// SetConnectorID sets the "connector_id" field. +func (_u *UserIdentityUpdate) SetConnectorID(v string) *UserIdentityUpdate { + _u.mutation.SetConnectorID(v) + return _u +} + +// SetNillableConnectorID sets the "connector_id" field if the given value is not nil. +func (_u *UserIdentityUpdate) SetNillableConnectorID(v *string) *UserIdentityUpdate { + if v != nil { + _u.SetConnectorID(*v) + } + return _u +} + +// SetClaimsUserID sets the "claims_user_id" field. +func (_u *UserIdentityUpdate) SetClaimsUserID(v string) *UserIdentityUpdate { + _u.mutation.SetClaimsUserID(v) + return _u +} + +// SetNillableClaimsUserID sets the "claims_user_id" field if the given value is not nil. +func (_u *UserIdentityUpdate) SetNillableClaimsUserID(v *string) *UserIdentityUpdate { + if v != nil { + _u.SetClaimsUserID(*v) + } + return _u +} + +// SetClaimsUsername sets the "claims_username" field. +func (_u *UserIdentityUpdate) SetClaimsUsername(v string) *UserIdentityUpdate { + _u.mutation.SetClaimsUsername(v) + return _u +} + +// SetNillableClaimsUsername sets the "claims_username" field if the given value is not nil. +func (_u *UserIdentityUpdate) SetNillableClaimsUsername(v *string) *UserIdentityUpdate { + if v != nil { + _u.SetClaimsUsername(*v) + } + return _u +} + +// SetClaimsPreferredUsername sets the "claims_preferred_username" field. +func (_u *UserIdentityUpdate) SetClaimsPreferredUsername(v string) *UserIdentityUpdate { + _u.mutation.SetClaimsPreferredUsername(v) + return _u +} + +// SetNillableClaimsPreferredUsername sets the "claims_preferred_username" field if the given value is not nil. +func (_u *UserIdentityUpdate) SetNillableClaimsPreferredUsername(v *string) *UserIdentityUpdate { + if v != nil { + _u.SetClaimsPreferredUsername(*v) + } + return _u +} + +// SetClaimsEmail sets the "claims_email" field. +func (_u *UserIdentityUpdate) SetClaimsEmail(v string) *UserIdentityUpdate { + _u.mutation.SetClaimsEmail(v) + return _u +} + +// SetNillableClaimsEmail sets the "claims_email" field if the given value is not nil. +func (_u *UserIdentityUpdate) SetNillableClaimsEmail(v *string) *UserIdentityUpdate { + if v != nil { + _u.SetClaimsEmail(*v) + } + return _u +} + +// SetClaimsEmailVerified sets the "claims_email_verified" field. +func (_u *UserIdentityUpdate) SetClaimsEmailVerified(v bool) *UserIdentityUpdate { + _u.mutation.SetClaimsEmailVerified(v) + return _u +} + +// SetNillableClaimsEmailVerified sets the "claims_email_verified" field if the given value is not nil. +func (_u *UserIdentityUpdate) SetNillableClaimsEmailVerified(v *bool) *UserIdentityUpdate { + if v != nil { + _u.SetClaimsEmailVerified(*v) + } + return _u +} + +// SetClaimsGroups sets the "claims_groups" field. +func (_u *UserIdentityUpdate) SetClaimsGroups(v []string) *UserIdentityUpdate { + _u.mutation.SetClaimsGroups(v) + return _u +} + +// AppendClaimsGroups appends value to the "claims_groups" field. +func (_u *UserIdentityUpdate) AppendClaimsGroups(v []string) *UserIdentityUpdate { + _u.mutation.AppendClaimsGroups(v) + return _u +} + +// ClearClaimsGroups clears the value of the "claims_groups" field. +func (_u *UserIdentityUpdate) ClearClaimsGroups() *UserIdentityUpdate { + _u.mutation.ClearClaimsGroups() + return _u +} + +// SetConsents sets the "consents" field. +func (_u *UserIdentityUpdate) SetConsents(v []byte) *UserIdentityUpdate { + _u.mutation.SetConsents(v) + return _u +} + +// SetCreatedAt sets the "created_at" field. +func (_u *UserIdentityUpdate) SetCreatedAt(v time.Time) *UserIdentityUpdate { + _u.mutation.SetCreatedAt(v) + return _u +} + +// SetNillableCreatedAt sets the "created_at" field if the given value is not nil. +func (_u *UserIdentityUpdate) SetNillableCreatedAt(v *time.Time) *UserIdentityUpdate { + if v != nil { + _u.SetCreatedAt(*v) + } + return _u +} + +// SetLastLogin sets the "last_login" field. +func (_u *UserIdentityUpdate) SetLastLogin(v time.Time) *UserIdentityUpdate { + _u.mutation.SetLastLogin(v) + return _u +} + +// SetNillableLastLogin sets the "last_login" field if the given value is not nil. +func (_u *UserIdentityUpdate) SetNillableLastLogin(v *time.Time) *UserIdentityUpdate { + if v != nil { + _u.SetLastLogin(*v) + } + return _u +} + +// SetBlockedUntil sets the "blocked_until" field. +func (_u *UserIdentityUpdate) SetBlockedUntil(v time.Time) *UserIdentityUpdate { + _u.mutation.SetBlockedUntil(v) + return _u +} + +// SetNillableBlockedUntil sets the "blocked_until" field if the given value is not nil. +func (_u *UserIdentityUpdate) SetNillableBlockedUntil(v *time.Time) *UserIdentityUpdate { + if v != nil { + _u.SetBlockedUntil(*v) + } + return _u +} + +// Mutation returns the UserIdentityMutation object of the builder. +func (_u *UserIdentityUpdate) Mutation() *UserIdentityMutation { + return _u.mutation +} + +// Save executes the query and returns the number of nodes affected by the update operation. +func (_u *UserIdentityUpdate) Save(ctx context.Context) (int, error) { + return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks) +} + +// SaveX is like Save, but panics if an error occurs. +func (_u *UserIdentityUpdate) SaveX(ctx context.Context) int { + affected, err := _u.Save(ctx) + if err != nil { + panic(err) + } + return affected +} + +// Exec executes the query. +func (_u *UserIdentityUpdate) Exec(ctx context.Context) error { + _, err := _u.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_u *UserIdentityUpdate) ExecX(ctx context.Context) { + if err := _u.Exec(ctx); err != nil { + panic(err) + } +} + +// check runs all checks and user-defined validators on the builder. +func (_u *UserIdentityUpdate) check() error { + if v, ok := _u.mutation.UserID(); ok { + if err := useridentity.UserIDValidator(v); err != nil { + return &ValidationError{Name: "user_id", err: fmt.Errorf(`db: validator failed for field "UserIdentity.user_id": %w`, err)} + } + } + if v, ok := _u.mutation.ConnectorID(); ok { + if err := useridentity.ConnectorIDValidator(v); err != nil { + return &ValidationError{Name: "connector_id", err: fmt.Errorf(`db: validator failed for field "UserIdentity.connector_id": %w`, err)} + } + } + return nil +} + +func (_u *UserIdentityUpdate) sqlSave(ctx context.Context) (_node int, err error) { + if err := _u.check(); err != nil { + return _node, err + } + _spec := sqlgraph.NewUpdateSpec(useridentity.Table, useridentity.Columns, sqlgraph.NewFieldSpec(useridentity.FieldID, field.TypeString)) + if ps := _u.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if value, ok := _u.mutation.UserID(); ok { + _spec.SetField(useridentity.FieldUserID, field.TypeString, value) + } + if value, ok := _u.mutation.ConnectorID(); ok { + _spec.SetField(useridentity.FieldConnectorID, field.TypeString, value) + } + if value, ok := _u.mutation.ClaimsUserID(); ok { + _spec.SetField(useridentity.FieldClaimsUserID, field.TypeString, value) + } + if value, ok := _u.mutation.ClaimsUsername(); ok { + _spec.SetField(useridentity.FieldClaimsUsername, field.TypeString, value) + } + if value, ok := _u.mutation.ClaimsPreferredUsername(); ok { + _spec.SetField(useridentity.FieldClaimsPreferredUsername, field.TypeString, value) + } + if value, ok := _u.mutation.ClaimsEmail(); ok { + _spec.SetField(useridentity.FieldClaimsEmail, field.TypeString, value) + } + if value, ok := _u.mutation.ClaimsEmailVerified(); ok { + _spec.SetField(useridentity.FieldClaimsEmailVerified, field.TypeBool, value) + } + if value, ok := _u.mutation.ClaimsGroups(); ok { + _spec.SetField(useridentity.FieldClaimsGroups, field.TypeJSON, value) + } + if value, ok := _u.mutation.AppendedClaimsGroups(); ok { + _spec.AddModifier(func(u *sql.UpdateBuilder) { + sqljson.Append(u, useridentity.FieldClaimsGroups, value) + }) + } + if _u.mutation.ClaimsGroupsCleared() { + _spec.ClearField(useridentity.FieldClaimsGroups, field.TypeJSON) + } + if value, ok := _u.mutation.Consents(); ok { + _spec.SetField(useridentity.FieldConsents, field.TypeBytes, value) + } + if value, ok := _u.mutation.CreatedAt(); ok { + _spec.SetField(useridentity.FieldCreatedAt, field.TypeTime, value) + } + if value, ok := _u.mutation.LastLogin(); ok { + _spec.SetField(useridentity.FieldLastLogin, field.TypeTime, value) + } + if value, ok := _u.mutation.BlockedUntil(); ok { + _spec.SetField(useridentity.FieldBlockedUntil, field.TypeTime, value) + } + if _node, err = sqlgraph.UpdateNodes(ctx, _u.driver, _spec); err != nil { + if _, ok := err.(*sqlgraph.NotFoundError); ok { + err = &NotFoundError{useridentity.Label} + } else if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return 0, err + } + _u.mutation.done = true + return _node, nil +} + +// UserIdentityUpdateOne is the builder for updating a single UserIdentity entity. +type UserIdentityUpdateOne struct { + config + fields []string + hooks []Hook + mutation *UserIdentityMutation +} + +// SetUserID sets the "user_id" field. +func (_u *UserIdentityUpdateOne) SetUserID(v string) *UserIdentityUpdateOne { + _u.mutation.SetUserID(v) + return _u +} + +// SetNillableUserID sets the "user_id" field if the given value is not nil. +func (_u *UserIdentityUpdateOne) SetNillableUserID(v *string) *UserIdentityUpdateOne { + if v != nil { + _u.SetUserID(*v) + } + return _u +} + +// SetConnectorID sets the "connector_id" field. +func (_u *UserIdentityUpdateOne) SetConnectorID(v string) *UserIdentityUpdateOne { + _u.mutation.SetConnectorID(v) + return _u +} + +// SetNillableConnectorID sets the "connector_id" field if the given value is not nil. +func (_u *UserIdentityUpdateOne) SetNillableConnectorID(v *string) *UserIdentityUpdateOne { + if v != nil { + _u.SetConnectorID(*v) + } + return _u +} + +// SetClaimsUserID sets the "claims_user_id" field. +func (_u *UserIdentityUpdateOne) SetClaimsUserID(v string) *UserIdentityUpdateOne { + _u.mutation.SetClaimsUserID(v) + return _u +} + +// SetNillableClaimsUserID sets the "claims_user_id" field if the given value is not nil. +func (_u *UserIdentityUpdateOne) SetNillableClaimsUserID(v *string) *UserIdentityUpdateOne { + if v != nil { + _u.SetClaimsUserID(*v) + } + return _u +} + +// SetClaimsUsername sets the "claims_username" field. +func (_u *UserIdentityUpdateOne) SetClaimsUsername(v string) *UserIdentityUpdateOne { + _u.mutation.SetClaimsUsername(v) + return _u +} + +// SetNillableClaimsUsername sets the "claims_username" field if the given value is not nil. +func (_u *UserIdentityUpdateOne) SetNillableClaimsUsername(v *string) *UserIdentityUpdateOne { + if v != nil { + _u.SetClaimsUsername(*v) + } + return _u +} + +// SetClaimsPreferredUsername sets the "claims_preferred_username" field. +func (_u *UserIdentityUpdateOne) SetClaimsPreferredUsername(v string) *UserIdentityUpdateOne { + _u.mutation.SetClaimsPreferredUsername(v) + return _u +} + +// SetNillableClaimsPreferredUsername sets the "claims_preferred_username" field if the given value is not nil. +func (_u *UserIdentityUpdateOne) SetNillableClaimsPreferredUsername(v *string) *UserIdentityUpdateOne { + if v != nil { + _u.SetClaimsPreferredUsername(*v) + } + return _u +} + +// SetClaimsEmail sets the "claims_email" field. +func (_u *UserIdentityUpdateOne) SetClaimsEmail(v string) *UserIdentityUpdateOne { + _u.mutation.SetClaimsEmail(v) + return _u +} + +// SetNillableClaimsEmail sets the "claims_email" field if the given value is not nil. +func (_u *UserIdentityUpdateOne) SetNillableClaimsEmail(v *string) *UserIdentityUpdateOne { + if v != nil { + _u.SetClaimsEmail(*v) + } + return _u +} + +// SetClaimsEmailVerified sets the "claims_email_verified" field. +func (_u *UserIdentityUpdateOne) SetClaimsEmailVerified(v bool) *UserIdentityUpdateOne { + _u.mutation.SetClaimsEmailVerified(v) + return _u +} + +// SetNillableClaimsEmailVerified sets the "claims_email_verified" field if the given value is not nil. +func (_u *UserIdentityUpdateOne) SetNillableClaimsEmailVerified(v *bool) *UserIdentityUpdateOne { + if v != nil { + _u.SetClaimsEmailVerified(*v) + } + return _u +} + +// SetClaimsGroups sets the "claims_groups" field. +func (_u *UserIdentityUpdateOne) SetClaimsGroups(v []string) *UserIdentityUpdateOne { + _u.mutation.SetClaimsGroups(v) + return _u +} + +// AppendClaimsGroups appends value to the "claims_groups" field. +func (_u *UserIdentityUpdateOne) AppendClaimsGroups(v []string) *UserIdentityUpdateOne { + _u.mutation.AppendClaimsGroups(v) + return _u +} + +// ClearClaimsGroups clears the value of the "claims_groups" field. +func (_u *UserIdentityUpdateOne) ClearClaimsGroups() *UserIdentityUpdateOne { + _u.mutation.ClearClaimsGroups() + return _u +} + +// SetConsents sets the "consents" field. +func (_u *UserIdentityUpdateOne) SetConsents(v []byte) *UserIdentityUpdateOne { + _u.mutation.SetConsents(v) + return _u +} + +// SetCreatedAt sets the "created_at" field. +func (_u *UserIdentityUpdateOne) SetCreatedAt(v time.Time) *UserIdentityUpdateOne { + _u.mutation.SetCreatedAt(v) + return _u +} + +// SetNillableCreatedAt sets the "created_at" field if the given value is not nil. +func (_u *UserIdentityUpdateOne) SetNillableCreatedAt(v *time.Time) *UserIdentityUpdateOne { + if v != nil { + _u.SetCreatedAt(*v) + } + return _u +} + +// SetLastLogin sets the "last_login" field. +func (_u *UserIdentityUpdateOne) SetLastLogin(v time.Time) *UserIdentityUpdateOne { + _u.mutation.SetLastLogin(v) + return _u +} + +// SetNillableLastLogin sets the "last_login" field if the given value is not nil. +func (_u *UserIdentityUpdateOne) SetNillableLastLogin(v *time.Time) *UserIdentityUpdateOne { + if v != nil { + _u.SetLastLogin(*v) + } + return _u +} + +// SetBlockedUntil sets the "blocked_until" field. +func (_u *UserIdentityUpdateOne) SetBlockedUntil(v time.Time) *UserIdentityUpdateOne { + _u.mutation.SetBlockedUntil(v) + return _u +} + +// SetNillableBlockedUntil sets the "blocked_until" field if the given value is not nil. +func (_u *UserIdentityUpdateOne) SetNillableBlockedUntil(v *time.Time) *UserIdentityUpdateOne { + if v != nil { + _u.SetBlockedUntil(*v) + } + return _u +} + +// Mutation returns the UserIdentityMutation object of the builder. +func (_u *UserIdentityUpdateOne) Mutation() *UserIdentityMutation { + return _u.mutation +} + +// Where appends a list predicates to the UserIdentityUpdate builder. +func (_u *UserIdentityUpdateOne) Where(ps ...predicate.UserIdentity) *UserIdentityUpdateOne { + _u.mutation.Where(ps...) + return _u +} + +// Select allows selecting one or more fields (columns) of the returned entity. +// The default is selecting all fields defined in the entity schema. +func (_u *UserIdentityUpdateOne) Select(field string, fields ...string) *UserIdentityUpdateOne { + _u.fields = append([]string{field}, fields...) + return _u +} + +// Save executes the query and returns the updated UserIdentity entity. +func (_u *UserIdentityUpdateOne) Save(ctx context.Context) (*UserIdentity, error) { + return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks) +} + +// SaveX is like Save, but panics if an error occurs. +func (_u *UserIdentityUpdateOne) SaveX(ctx context.Context) *UserIdentity { + node, err := _u.Save(ctx) + if err != nil { + panic(err) + } + return node +} + +// Exec executes the query on the entity. +func (_u *UserIdentityUpdateOne) Exec(ctx context.Context) error { + _, err := _u.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_u *UserIdentityUpdateOne) ExecX(ctx context.Context) { + if err := _u.Exec(ctx); err != nil { + panic(err) + } +} + +// check runs all checks and user-defined validators on the builder. +func (_u *UserIdentityUpdateOne) check() error { + if v, ok := _u.mutation.UserID(); ok { + if err := useridentity.UserIDValidator(v); err != nil { + return &ValidationError{Name: "user_id", err: fmt.Errorf(`db: validator failed for field "UserIdentity.user_id": %w`, err)} + } + } + if v, ok := _u.mutation.ConnectorID(); ok { + if err := useridentity.ConnectorIDValidator(v); err != nil { + return &ValidationError{Name: "connector_id", err: fmt.Errorf(`db: validator failed for field "UserIdentity.connector_id": %w`, err)} + } + } + return nil +} + +func (_u *UserIdentityUpdateOne) sqlSave(ctx context.Context) (_node *UserIdentity, err error) { + if err := _u.check(); err != nil { + return _node, err + } + _spec := sqlgraph.NewUpdateSpec(useridentity.Table, useridentity.Columns, sqlgraph.NewFieldSpec(useridentity.FieldID, field.TypeString)) + id, ok := _u.mutation.ID() + if !ok { + return nil, &ValidationError{Name: "id", err: errors.New(`db: missing "UserIdentity.id" for update`)} + } + _spec.Node.ID.Value = id + if fields := _u.fields; len(fields) > 0 { + _spec.Node.Columns = make([]string, 0, len(fields)) + _spec.Node.Columns = append(_spec.Node.Columns, useridentity.FieldID) + for _, f := range fields { + if !useridentity.ValidColumn(f) { + return nil, &ValidationError{Name: f, err: fmt.Errorf("db: invalid field %q for query", f)} + } + if f != useridentity.FieldID { + _spec.Node.Columns = append(_spec.Node.Columns, f) + } + } + } + if ps := _u.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if value, ok := _u.mutation.UserID(); ok { + _spec.SetField(useridentity.FieldUserID, field.TypeString, value) + } + if value, ok := _u.mutation.ConnectorID(); ok { + _spec.SetField(useridentity.FieldConnectorID, field.TypeString, value) + } + if value, ok := _u.mutation.ClaimsUserID(); ok { + _spec.SetField(useridentity.FieldClaimsUserID, field.TypeString, value) + } + if value, ok := _u.mutation.ClaimsUsername(); ok { + _spec.SetField(useridentity.FieldClaimsUsername, field.TypeString, value) + } + if value, ok := _u.mutation.ClaimsPreferredUsername(); ok { + _spec.SetField(useridentity.FieldClaimsPreferredUsername, field.TypeString, value) + } + if value, ok := _u.mutation.ClaimsEmail(); ok { + _spec.SetField(useridentity.FieldClaimsEmail, field.TypeString, value) + } + if value, ok := _u.mutation.ClaimsEmailVerified(); ok { + _spec.SetField(useridentity.FieldClaimsEmailVerified, field.TypeBool, value) + } + if value, ok := _u.mutation.ClaimsGroups(); ok { + _spec.SetField(useridentity.FieldClaimsGroups, field.TypeJSON, value) + } + if value, ok := _u.mutation.AppendedClaimsGroups(); ok { + _spec.AddModifier(func(u *sql.UpdateBuilder) { + sqljson.Append(u, useridentity.FieldClaimsGroups, value) + }) + } + if _u.mutation.ClaimsGroupsCleared() { + _spec.ClearField(useridentity.FieldClaimsGroups, field.TypeJSON) + } + if value, ok := _u.mutation.Consents(); ok { + _spec.SetField(useridentity.FieldConsents, field.TypeBytes, value) + } + if value, ok := _u.mutation.CreatedAt(); ok { + _spec.SetField(useridentity.FieldCreatedAt, field.TypeTime, value) + } + if value, ok := _u.mutation.LastLogin(); ok { + _spec.SetField(useridentity.FieldLastLogin, field.TypeTime, value) + } + if value, ok := _u.mutation.BlockedUntil(); ok { + _spec.SetField(useridentity.FieldBlockedUntil, field.TypeTime, value) + } + _node = &UserIdentity{config: _u.config} + _spec.Assign = _node.assignValues + _spec.ScanValues = _node.scanValues + if err = sqlgraph.UpdateNode(ctx, _u.driver, _spec); err != nil { + if _, ok := err.(*sqlgraph.NotFoundError); ok { + err = &NotFoundError{useridentity.Label} + } else if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return nil, err + } + _u.mutation.done = true + return _node, nil +} diff --git a/storage/ent/schema/useridentity.go b/storage/ent/schema/useridentity.go new file mode 100644 index 00000000..a4928240 --- /dev/null +++ b/storage/ent/schema/useridentity.go @@ -0,0 +1,56 @@ +package schema + +import ( + "entgo.io/ent" + "entgo.io/ent/schema/field" +) + +// UserIdentity holds the schema definition for the UserIdentity entity. +type UserIdentity struct { + ent.Schema +} + +// Fields of the UserIdentity. +func (UserIdentity) Fields() []ent.Field { + return []ent.Field{ + // Using id field here because it's impossible to create multi-key primary yet + field.Text("id"). + SchemaType(textSchema). + NotEmpty(). + Unique(), + field.Text("user_id"). + SchemaType(textSchema). + NotEmpty(), + field.Text("connector_id"). + SchemaType(textSchema). + NotEmpty(), + field.Text("claims_user_id"). + SchemaType(textSchema). + Default(""), + field.Text("claims_username"). + SchemaType(textSchema). + Default(""), + field.Text("claims_preferred_username"). + SchemaType(textSchema). + Default(""), + field.Text("claims_email"). + SchemaType(textSchema). + Default(""), + field.Bool("claims_email_verified"). + Default(false), + field.JSON("claims_groups", []string{}). + Optional(), + field.Bytes("consents"), + field.Time("created_at"). + SchemaType(timeSchema), + field.Time("last_login"). + SchemaType(timeSchema), + field.Time("blocked_until"). + SchemaType(timeSchema), + } +} + +// Edges of the UserIdentity. +func (UserIdentity) Edges() []ent.Edge { + return []ent.Edge{} +} diff --git a/storage/etcd/etcd.go b/storage/etcd/etcd.go index 8ccf502f..d2248ea2 100644 --- a/storage/etcd/etcd.go +++ b/storage/etcd/etcd.go @@ -24,6 +24,7 @@ const ( keysName = "openid-connect-keys" deviceRequestPrefix = "device_req/" deviceTokenPrefix = "device_token/" + userIdentityPrefix = "user_identity/" // defaultStorageTimeout will be applied to all storage's operations. defaultStorageTimeout = 5 * time.Second @@ -366,6 +367,61 @@ func (c *conn) DeleteOfflineSessions(ctx context.Context, userID string, connID return c.deleteKey(ctx, keySession(userID, connID)) } +func (c *conn) CreateUserIdentity(ctx context.Context, u storage.UserIdentity) error { + return c.txnCreate(ctx, keyUserIdentity(u.UserID, u.ConnectorID), fromStorageUserIdentity(u)) +} + +func (c *conn) GetUserIdentity(ctx context.Context, userID, connectorID string) (u storage.UserIdentity, err error) { + ctx, cancel := context.WithTimeout(ctx, defaultStorageTimeout) + defer cancel() + var ui UserIdentity + if err = c.getKey(ctx, keyUserIdentity(userID, connectorID), &ui); err != nil { + return + } + return toStorageUserIdentity(ui), nil +} + +func (c *conn) UpdateUserIdentity(ctx context.Context, userID, connectorID string, updater func(u storage.UserIdentity) (storage.UserIdentity, error)) error { + ctx, cancel := context.WithTimeout(ctx, defaultStorageTimeout) + defer cancel() + return c.txnUpdate(ctx, keyUserIdentity(userID, connectorID), func(currentValue []byte) ([]byte, error) { + var current UserIdentity + if len(currentValue) > 0 { + if err := json.Unmarshal(currentValue, ¤t); err != nil { + return nil, err + } + } + updated, err := updater(toStorageUserIdentity(current)) + if err != nil { + return nil, err + } + return json.Marshal(fromStorageUserIdentity(updated)) + }) +} + +func (c *conn) DeleteUserIdentity(ctx context.Context, userID, connectorID string) error { + ctx, cancel := context.WithTimeout(ctx, defaultStorageTimeout) + defer cancel() + return c.deleteKey(ctx, keyUserIdentity(userID, connectorID)) +} + +func (c *conn) ListUserIdentities(ctx context.Context) (identities []storage.UserIdentity, err error) { + ctx, cancel := context.WithTimeout(ctx, defaultStorageTimeout) + defer cancel() + res, err := c.db.Get(ctx, userIdentityPrefix, clientv3.WithPrefix()) + if err != nil { + return identities, err + } + for _, v := range res.Kvs { + var ui UserIdentity + if err = json.Unmarshal(v.Value, &ui); err != nil { + return identities, err + } + identities = append(identities, toStorageUserIdentity(ui)) + } + return identities, nil +} + func (c *conn) CreateConnector(ctx context.Context, connector storage.Connector) error { return c.txnCreate(ctx, keyID(connectorPrefix, connector.ID), connector) } @@ -557,6 +613,10 @@ func keySession(userID, connID string) string { return offlineSessionPrefix + strings.ToLower(userID+"|"+connID) } +func keyUserIdentity(userID, connectorID string) string { + return userIdentityPrefix + strings.ToLower(userID+"|"+connectorID) +} + func (c *conn) CreateDeviceRequest(ctx context.Context, d storage.DeviceRequest) error { return c.txnCreate(ctx, keyID(deviceRequestPrefix, d.UserCode), fromStorageDeviceRequest(d)) } diff --git a/storage/etcd/types.go b/storage/etcd/types.go index b3756604..ea4d216c 100644 --- a/storage/etcd/types.go +++ b/storage/etcd/types.go @@ -256,6 +256,46 @@ func toStorageOfflineSessions(o OfflineSessions) storage.OfflineSessions { return s } +// UserIdentity is a mirrored struct from storage with JSON struct tags +type UserIdentity struct { + UserID string `json:"user_id,omitempty"` + ConnectorID string `json:"connector_id,omitempty"` + Claims Claims `json:"claims,omitempty"` + Consents map[string][]string `json:"consents,omitempty"` + CreatedAt time.Time `json:"created_at"` + LastLogin time.Time `json:"last_login"` + BlockedUntil time.Time `json:"blocked_until"` +} + +func fromStorageUserIdentity(u storage.UserIdentity) UserIdentity { + return UserIdentity{ + UserID: u.UserID, + ConnectorID: u.ConnectorID, + Claims: fromStorageClaims(u.Claims), + Consents: u.Consents, + CreatedAt: u.CreatedAt, + LastLogin: u.LastLogin, + BlockedUntil: u.BlockedUntil, + } +} + +func toStorageUserIdentity(u UserIdentity) storage.UserIdentity { + s := storage.UserIdentity{ + UserID: u.UserID, + ConnectorID: u.ConnectorID, + Claims: toStorageClaims(u.Claims), + Consents: u.Consents, + CreatedAt: u.CreatedAt, + LastLogin: u.LastLogin, + BlockedUntil: u.BlockedUntil, + } + if s.Consents == nil { + // Server code assumes this will be non-nil. + s.Consents = make(map[string][]string) + } + return s +} + // DeviceRequest is a mirrored struct from storage with JSON struct tags type DeviceRequest struct { UserCode string `json:"user_code"` diff --git a/storage/kubernetes/storage.go b/storage/kubernetes/storage.go index 028ced7e..fd756845 100644 --- a/storage/kubernetes/storage.go +++ b/storage/kubernetes/storage.go @@ -25,6 +25,7 @@ const ( kindConnector = "Connector" kindDeviceRequest = "DeviceRequest" kindDeviceToken = "DeviceToken" + kindUserIdentity = "UserIdentity" ) const ( @@ -38,6 +39,7 @@ const ( resourceConnector = "connectors" resourceDeviceRequest = "devicerequests" resourceDeviceToken = "devicetokens" + resourceUserIdentity = "useridentities" ) const ( @@ -743,6 +745,70 @@ func (cli *client) UpdateDeviceToken(ctx context.Context, deviceCode string, upd }) } +func (cli *client) CreateUserIdentity(ctx context.Context, u storage.UserIdentity) error { + return cli.post(resourceUserIdentity, cli.fromStorageUserIdentity(u)) +} + +func (cli *client) GetUserIdentity(ctx context.Context, userID, connectorID string) (storage.UserIdentity, error) { + u, err := cli.getUserIdentity(userID, connectorID) + if err != nil { + return storage.UserIdentity{}, err + } + return toStorageUserIdentity(u), nil +} + +func (cli *client) getUserIdentity(userID, connectorID string) (u UserIdentity, err error) { + name := cli.offlineTokenName(userID, connectorID) + if err = cli.get(resourceUserIdentity, name, &u); err != nil { + return UserIdentity{}, err + } + if userID != u.UserID || connectorID != u.ConnectorID { + return UserIdentity{}, fmt.Errorf("get user identity: wrong identity retrieved") + } + return u, nil +} + +func (cli *client) UpdateUserIdentity(ctx context.Context, userID, connectorID string, updater func(old storage.UserIdentity) (storage.UserIdentity, error)) error { + return retryOnConflict(ctx, func() error { + u, err := cli.getUserIdentity(userID, connectorID) + if err != nil { + return err + } + + updated, err := updater(toStorageUserIdentity(u)) + if err != nil { + return err + } + + newUserIdentity := cli.fromStorageUserIdentity(updated) + newUserIdentity.ObjectMeta = u.ObjectMeta + return cli.put(resourceUserIdentity, u.ObjectMeta.Name, newUserIdentity) + }) +} + +func (cli *client) DeleteUserIdentity(ctx context.Context, userID, connectorID string) error { + // Check for hash collision. + u, err := cli.getUserIdentity(userID, connectorID) + if err != nil { + return err + } + return cli.delete(resourceUserIdentity, u.ObjectMeta.Name) +} + +func (cli *client) ListUserIdentities(ctx context.Context) ([]storage.UserIdentity, error) { + var userIdentityList UserIdentityList + if err := cli.list(resourceUserIdentity, &userIdentityList); err != nil { + return nil, fmt.Errorf("failed to list user identities: %v", err) + } + + userIdentities := make([]storage.UserIdentity, len(userIdentityList.UserIdentities)) + for i, u := range userIdentityList.UserIdentities { + userIdentities[i] = toStorageUserIdentity(u) + } + + return userIdentities, nil +} + func isKubernetesAPIConflictError(err error) bool { if httpErr, ok := err.(httpError); ok { if httpErr.StatusCode() == http.StatusConflict { diff --git a/storage/kubernetes/types.go b/storage/kubernetes/types.go index abcd907e..44ab4943 100644 --- a/storage/kubernetes/types.go +++ b/storage/kubernetes/types.go @@ -226,6 +226,23 @@ func customResourceDefinitions(apiVersion string) []k8sapi.CustomResourceDefinit }, }, }, + { + ObjectMeta: k8sapi.ObjectMeta{ + Name: "useridentities.dex.coreos.com", + }, + TypeMeta: crdMeta, + Spec: k8sapi.CustomResourceDefinitionSpec{ + Group: apiGroup, + Version: version, + Versions: versions, + Scope: scope, + Names: k8sapi.CustomResourceDefinitionNames{ + Plural: "useridentities", + Singular: "useridentity", + Kind: "UserIdentity", + }, + }, + }, } } @@ -872,3 +889,62 @@ func toStorageDeviceToken(t DeviceToken) storage.DeviceToken { }, } } + +// UserIdentity is a mirrored struct from storage with JSON struct tags and Kubernetes +// type metadata. +type UserIdentity struct { + k8sapi.TypeMeta `json:",inline"` + k8sapi.ObjectMeta `json:"metadata,omitempty"` + + UserID string `json:"userID,omitempty"` + ConnectorID string `json:"connectorID,omitempty"` + Claims Claims `json:"claims,omitempty"` + Consents map[string][]string `json:"consents,omitempty"` + CreatedAt time.Time `json:"createdAt,omitempty"` + LastLogin time.Time `json:"lastLogin,omitempty"` + BlockedUntil time.Time `json:"blockedUntil,omitempty"` +} + +// UserIdentityList is a list of UserIdentities. +type UserIdentityList struct { + k8sapi.TypeMeta `json:",inline"` + k8sapi.ListMeta `json:"metadata,omitempty"` + UserIdentities []UserIdentity `json:"items"` +} + +func (cli *client) fromStorageUserIdentity(u storage.UserIdentity) UserIdentity { + return UserIdentity{ + TypeMeta: k8sapi.TypeMeta{ + Kind: kindUserIdentity, + APIVersion: cli.apiVersion, + }, + ObjectMeta: k8sapi.ObjectMeta{ + Name: cli.offlineTokenName(u.UserID, u.ConnectorID), + Namespace: cli.namespace, + }, + UserID: u.UserID, + ConnectorID: u.ConnectorID, + Claims: fromStorageClaims(u.Claims), + Consents: u.Consents, + CreatedAt: u.CreatedAt, + LastLogin: u.LastLogin, + BlockedUntil: u.BlockedUntil, + } +} + +func toStorageUserIdentity(u UserIdentity) storage.UserIdentity { + s := storage.UserIdentity{ + UserID: u.UserID, + ConnectorID: u.ConnectorID, + Claims: toStorageClaims(u.Claims), + Consents: u.Consents, + CreatedAt: u.CreatedAt, + LastLogin: u.LastLogin, + BlockedUntil: u.BlockedUntil, + } + if s.Consents == nil { + // Server code assumes this will be non-nil. + s.Consents = make(map[string][]string) + } + return s +} diff --git a/storage/memory/memory.go b/storage/memory/memory.go index eff75e71..ecf3a410 100644 --- a/storage/memory/memory.go +++ b/storage/memory/memory.go @@ -21,7 +21,8 @@ func New(logger *slog.Logger) storage.Storage { refreshTokens: make(map[string]storage.RefreshToken), authReqs: make(map[string]storage.AuthRequest), passwords: make(map[string]storage.Password), - offlineSessions: make(map[offlineSessionID]storage.OfflineSessions), + offlineSessions: make(map[compositeKeyID]storage.OfflineSessions), + userIdentities: make(map[compositeKeyID]storage.UserIdentity), connectors: make(map[string]storage.Connector), deviceRequests: make(map[string]storage.DeviceRequest), deviceTokens: make(map[string]storage.DeviceToken), @@ -48,7 +49,8 @@ type memStorage struct { refreshTokens map[string]storage.RefreshToken authReqs map[string]storage.AuthRequest passwords map[string]storage.Password - offlineSessions map[offlineSessionID]storage.OfflineSessions + offlineSessions map[compositeKeyID]storage.OfflineSessions + userIdentities map[compositeKeyID]storage.UserIdentity connectors map[string]storage.Connector deviceRequests map[string]storage.DeviceRequest deviceTokens map[string]storage.DeviceToken @@ -58,7 +60,7 @@ type memStorage struct { logger *slog.Logger } -type offlineSessionID struct { +type compositeKeyID struct { userID string connID string } @@ -158,7 +160,7 @@ func (s *memStorage) CreatePassword(ctx context.Context, p storage.Password) (er } func (s *memStorage) CreateOfflineSessions(ctx context.Context, o storage.OfflineSessions) (err error) { - id := offlineSessionID{ + id := compositeKeyID{ userID: o.UserID, connID: o.ConnID, } @@ -172,6 +174,78 @@ func (s *memStorage) CreateOfflineSessions(ctx context.Context, o storage.Offlin return } +func (s *memStorage) CreateUserIdentity(ctx context.Context, u storage.UserIdentity) (err error) { + id := compositeKeyID{ + userID: u.UserID, + connID: u.ConnectorID, + } + s.tx(func() { + if _, ok := s.userIdentities[id]; ok { + err = storage.ErrAlreadyExists + } else { + s.userIdentities[id] = u + } + }) + return +} + +func (s *memStorage) GetUserIdentity(ctx context.Context, userID, connectorID string) (u storage.UserIdentity, err error) { + id := compositeKeyID{ + userID: userID, + connID: connectorID, + } + s.tx(func() { + var ok bool + if u, ok = s.userIdentities[id]; !ok { + err = storage.ErrNotFound + return + } + }) + return +} + +func (s *memStorage) UpdateUserIdentity(ctx context.Context, userID, connectorID string, updater func(u storage.UserIdentity) (storage.UserIdentity, error)) (err error) { + id := compositeKeyID{ + userID: userID, + connID: connectorID, + } + s.tx(func() { + r, ok := s.userIdentities[id] + if !ok { + err = storage.ErrNotFound + return + } + if r, err = updater(r); err == nil { + s.userIdentities[id] = r + } + }) + return +} + +func (s *memStorage) DeleteUserIdentity(ctx context.Context, userID, connectorID string) (err error) { + id := compositeKeyID{ + userID: userID, + connID: connectorID, + } + s.tx(func() { + if _, ok := s.userIdentities[id]; !ok { + err = storage.ErrNotFound + return + } + delete(s.userIdentities, id) + }) + return +} + +func (s *memStorage) ListUserIdentities(ctx context.Context) (identities []storage.UserIdentity, err error) { + s.tx(func() { + for _, u := range s.userIdentities { + identities = append(identities, u) + } + }) + return +} + func (s *memStorage) CreateConnector(ctx context.Context, connector storage.Connector) (err error) { s.tx(func() { if _, ok := s.connectors[connector.ID]; ok { @@ -243,7 +317,7 @@ func (s *memStorage) GetAuthRequest(ctx context.Context, id string) (req storage } func (s *memStorage) GetOfflineSessions(ctx context.Context, userID string, connID string) (o storage.OfflineSessions, err error) { - id := offlineSessionID{ + id := compositeKeyID{ userID: userID, connID: connID, } @@ -360,7 +434,7 @@ func (s *memStorage) DeleteAuthRequest(ctx context.Context, id string) (err erro } func (s *memStorage) DeleteOfflineSessions(ctx context.Context, userID string, connID string) (err error) { - id := offlineSessionID{ + id := compositeKeyID{ userID: userID, connID: connID, } @@ -453,7 +527,7 @@ func (s *memStorage) UpdateRefreshToken(ctx context.Context, id string, updater } func (s *memStorage) UpdateOfflineSessions(ctx context.Context, userID string, connID string, updater func(o storage.OfflineSessions) (storage.OfflineSessions, error)) (err error) { - id := offlineSessionID{ + id := compositeKeyID{ userID: userID, connID: connID, } diff --git a/storage/sql/crud.go b/storage/sql/crud.go index 33e69fc0..04c9be3b 100644 --- a/storage/sql/crud.go +++ b/storage/sql/crud.go @@ -775,6 +775,157 @@ func scanOfflineSessions(s scanner) (o storage.OfflineSessions, err error) { return o, nil } +func (c *conn) CreateUserIdentity(ctx context.Context, u storage.UserIdentity) error { + _, err := c.Exec(` + insert into user_identity ( + user_id, connector_id, + claims_user_id, claims_username, claims_preferred_username, + claims_email, claims_email_verified, claims_groups, + consents, + created_at, last_login, blocked_until + ) + values ( + $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12 + ); + `, + u.UserID, u.ConnectorID, + u.Claims.UserID, u.Claims.Username, u.Claims.PreferredUsername, + u.Claims.Email, u.Claims.EmailVerified, encoder(u.Claims.Groups), + encoder(u.Consents), + u.CreatedAt, u.LastLogin, u.BlockedUntil, + ) + if err != nil { + if c.alreadyExistsCheck(err) { + return storage.ErrAlreadyExists + } + return fmt.Errorf("insert user identity: %v", err) + } + return nil +} + +func (c *conn) UpdateUserIdentity(ctx context.Context, userID, connectorID string, updater func(u storage.UserIdentity) (storage.UserIdentity, error)) error { + return c.ExecTx(func(tx *trans) error { + u, err := getUserIdentity(ctx, tx, userID, connectorID) + if err != nil { + return err + } + + newIdentity, err := updater(u) + if err != nil { + return err + } + _, err = tx.Exec(` + update user_identity + set + claims_user_id = $1, + claims_username = $2, + claims_preferred_username = $3, + claims_email = $4, + claims_email_verified = $5, + claims_groups = $6, + consents = $7, + created_at = $8, + last_login = $9, + blocked_until = $10 + where user_id = $11 AND connector_id = $12; + `, + newIdentity.Claims.UserID, newIdentity.Claims.Username, newIdentity.Claims.PreferredUsername, + newIdentity.Claims.Email, newIdentity.Claims.EmailVerified, encoder(newIdentity.Claims.Groups), + encoder(newIdentity.Consents), + newIdentity.CreatedAt, newIdentity.LastLogin, newIdentity.BlockedUntil, + u.UserID, u.ConnectorID, + ) + if err != nil { + return fmt.Errorf("update user identity: %v", err) + } + return nil + }) +} + +func (c *conn) GetUserIdentity(ctx context.Context, userID, connectorID string) (storage.UserIdentity, error) { + return getUserIdentity(ctx, c, userID, connectorID) +} + +func getUserIdentity(ctx context.Context, q querier, userID, connectorID string) (storage.UserIdentity, error) { + return scanUserIdentity(q.QueryRow(` + select + user_id, connector_id, + claims_user_id, claims_username, claims_preferred_username, + claims_email, claims_email_verified, claims_groups, + consents, + created_at, last_login, blocked_until + from user_identity + where user_id = $1 AND connector_id = $2; + `, userID, connectorID)) +} + +func (c *conn) ListUserIdentities(ctx context.Context) ([]storage.UserIdentity, error) { + rows, err := c.Query(` + select + user_id, connector_id, + claims_user_id, claims_username, claims_preferred_username, + claims_email, claims_email_verified, claims_groups, + consents, + created_at, last_login, blocked_until + from user_identity; + `) + if err != nil { + return nil, fmt.Errorf("query: %v", err) + } + defer rows.Close() + + var identities []storage.UserIdentity + for rows.Next() { + u, err := scanUserIdentity(rows) + if err != nil { + return nil, err + } + identities = append(identities, u) + } + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("scan: %v", err) + } + return identities, nil +} + +func scanUserIdentity(s scanner) (u storage.UserIdentity, err error) { + err = s.Scan( + &u.UserID, &u.ConnectorID, + &u.Claims.UserID, &u.Claims.Username, &u.Claims.PreferredUsername, + &u.Claims.Email, &u.Claims.EmailVerified, decoder(&u.Claims.Groups), + decoder(&u.Consents), + &u.CreatedAt, &u.LastLogin, &u.BlockedUntil, + ) + if err != nil { + if err == sql.ErrNoRows { + return u, storage.ErrNotFound + } + return u, fmt.Errorf("select user identity: %v", err) + } + if u.Consents == nil { + u.Consents = make(map[string][]string) + } + return u, nil +} + +func (c *conn) DeleteUserIdentity(ctx context.Context, userID, connectorID string) error { + result, err := c.Exec(`delete from user_identity where user_id = $1 AND connector_id = $2`, userID, connectorID) + if err != nil { + return fmt.Errorf("delete user_identity: user_id = %s, connector_id = %s: %w", userID, connectorID, err) + } + + // For now mandate that the driver implements RowsAffected. If we ever need to support + // a driver that doesn't implement this, we can run this in a transaction with a get beforehand. + n, err := result.RowsAffected() + if err != nil { + return fmt.Errorf("rows affected: %v", err) + } + if n < 1 { + return storage.ErrNotFound + } + return nil +} + func (c *conn) CreateConnector(ctx context.Context, connector storage.Connector) error { grantTypes, err := json.Marshal(connector.GrantTypes) if err != nil { diff --git a/storage/sql/migrate.go b/storage/sql/migrate.go index 440acfd5..053e9e6a 100644 --- a/storage/sql/migrate.go +++ b/storage/sql/migrate.go @@ -389,4 +389,24 @@ var migrations = []migration{ add column allowed_connectors bytea;`, }, }, + { + stmts: []string{ + ` + create table user_identity ( + user_id text not null, + connector_id text not null, + claims_user_id text not null, + claims_username text not null, + claims_preferred_username text not null default '', + claims_email text not null, + claims_email_verified boolean not null, + claims_groups bytea not null, + consents bytea not null, + created_at timestamptz not null, + last_login timestamptz not null, + blocked_until timestamptz not null, + PRIMARY KEY (user_id, connector_id) + );`, + }, + }, } diff --git a/storage/storage.go b/storage/storage.go index e7bddef9..3e332a80 100644 --- a/storage/storage.go +++ b/storage/storage.go @@ -83,6 +83,7 @@ type Storage interface { CreateRefresh(ctx context.Context, r RefreshToken) error CreatePassword(ctx context.Context, p Password) error CreateOfflineSessions(ctx context.Context, s OfflineSessions) error + CreateUserIdentity(ctx context.Context, u UserIdentity) error CreateConnector(ctx context.Context, c Connector) error CreateDeviceRequest(ctx context.Context, d DeviceRequest) error CreateDeviceToken(ctx context.Context, d DeviceToken) error @@ -96,6 +97,7 @@ type Storage interface { GetRefresh(ctx context.Context, id string) (RefreshToken, error) GetPassword(ctx context.Context, email string) (Password, error) GetOfflineSessions(ctx context.Context, userID string, connID string) (OfflineSessions, error) + GetUserIdentity(ctx context.Context, userID, connectorID string) (UserIdentity, error) GetConnector(ctx context.Context, id string) (Connector, error) GetDeviceRequest(ctx context.Context, userCode string) (DeviceRequest, error) GetDeviceToken(ctx context.Context, deviceCode string) (DeviceToken, error) @@ -104,6 +106,7 @@ type Storage interface { ListRefreshTokens(ctx context.Context) ([]RefreshToken, error) ListPasswords(ctx context.Context) ([]Password, error) ListConnectors(ctx context.Context) ([]Connector, error) + ListUserIdentities(ctx context.Context) ([]UserIdentity, error) // Delete methods MUST be atomic. DeleteAuthRequest(ctx context.Context, id string) error @@ -112,6 +115,7 @@ type Storage interface { DeleteRefresh(ctx context.Context, id string) error DeletePassword(ctx context.Context, email string) error DeleteOfflineSessions(ctx context.Context, userID string, connID string) error + DeleteUserIdentity(ctx context.Context, userID, connectorID string) error DeleteConnector(ctx context.Context, id string) error // Update methods take a function for updating an object then performs that update within @@ -134,6 +138,7 @@ type Storage interface { UpdateRefreshToken(ctx context.Context, id string, updater func(r RefreshToken) (RefreshToken, error)) error UpdatePassword(ctx context.Context, email string, updater func(p Password) (Password, error)) error UpdateOfflineSessions(ctx context.Context, userID string, connID string, updater func(s OfflineSessions) (OfflineSessions, error)) error + UpdateUserIdentity(ctx context.Context, userID, connectorID string, updater func(u UserIdentity) (UserIdentity, error)) error UpdateConnector(ctx context.Context, id string, updater func(c Connector) (Connector, error)) error UpdateDeviceToken(ctx context.Context, deviceCode string, updater func(t DeviceToken) (DeviceToken, error)) error @@ -320,6 +325,17 @@ type RefreshTokenRef struct { LastUsed time.Time } +// UserIdentity represents persistent per-user identity data. +type UserIdentity struct { + UserID string + ConnectorID string + Claims Claims + Consents map[string][]string // clientID -> approved scopes + CreatedAt time.Time + LastLogin time.Time + BlockedUntil time.Time +} + // OfflineSessions objects are sessions pertaining to users with refresh tokens. type OfflineSessions struct { // UserID of an end user who has logged into the server.