From 6b9ce00e119edfe47cfcf1709cf5640f7b9c75d9 Mon Sep 17 00:00:00 2001 From: Maksim Nabokikh Date: Mon, 16 Mar 2026 17:06:53 +0100 Subject: [PATCH] feat: implement AuthSession CRUD operations (#4646) Signed-off-by: maksim.nabokikh --- storage/conformance/conformance.go | 100 ++++ storage/ent/client/authsession.go | 108 +++++ storage/ent/client/types.go | 22 + storage/ent/db/authsession.go | 150 ++++++ storage/ent/db/authsession/authsession.go | 83 ++++ storage/ent/db/authsession/where.go | 355 ++++++++++++++ storage/ent/db/authsession_create.go | 282 +++++++++++ storage/ent/db/authsession_delete.go | 88 ++++ storage/ent/db/authsession_query.go | 527 +++++++++++++++++++++ storage/ent/db/authsession_update.go | 330 +++++++++++++ storage/ent/db/client.go | 155 +++++- storage/ent/db/ent.go | 2 + storage/ent/db/hook/hook.go | 12 + storage/ent/db/migrate/schema.go | 16 + storage/ent/db/mutation.go | 550 ++++++++++++++++++++++ storage/ent/db/predicate/predicate.go | 3 + storage/ent/db/runtime.go | 15 + storage/ent/db/tx.go | 3 + storage/ent/schema/authsession.go | 37 ++ storage/etcd/etcd.go | 60 +++ storage/etcd/types.go | 36 ++ storage/kubernetes/storage.go | 50 ++ storage/kubernetes/types.go | 69 +++ storage/memory/memory.go | 57 +++ storage/sql/crud.go | 130 +++++ storage/sql/migrate.go | 13 + storage/storage.go | 25 + 27 files changed, 3272 insertions(+), 6 deletions(-) create mode 100644 storage/ent/client/authsession.go create mode 100644 storage/ent/db/authsession.go create mode 100644 storage/ent/db/authsession/authsession.go create mode 100644 storage/ent/db/authsession/where.go create mode 100644 storage/ent/db/authsession_create.go create mode 100644 storage/ent/db/authsession_delete.go create mode 100644 storage/ent/db/authsession_query.go create mode 100644 storage/ent/db/authsession_update.go create mode 100644 storage/ent/schema/authsession.go diff --git a/storage/conformance/conformance.go b/storage/conformance/conformance.go index ec3cae11..94b23745 100644 --- a/storage/conformance/conformance.go +++ b/storage/conformance/conformance.go @@ -52,6 +52,7 @@ func RunTests(t *testing.T, newStorage func(t *testing.T) storage.Storage) { {"DeviceRequestCRUD", testDeviceRequestCRUD}, {"DeviceTokenCRUD", testDeviceTokenCRUD}, {"UserIdentityCRUD", testUserIdentityCRUD}, + {"AuthSessionCRUD", testAuthSessionCRUD}, }) } @@ -1166,3 +1167,102 @@ func testUserIdentityCRUD(t *testing.T, s storage.Storage) { _, err = s.GetUserIdentity(ctx, u1.UserID, u1.ConnectorID) mustBeErrNotFound(t, "user identity", err) } + +func testAuthSessionCRUD(t *testing.T, s storage.Storage) { + ctx := t.Context() + + now := time.Now().UTC().Round(time.Millisecond) + + session := storage.AuthSession{ + ID: storage.NewID(), + ClientStates: map[string]*storage.ClientAuthState{ + "client1": { + UserID: "user1", + ConnectorID: "conn1", + Active: true, + ExpiresAt: now.Add(24 * time.Hour), + LastActivity: now, + LastTokenIssuedAt: now, + }, + }, + CreatedAt: now, + LastActivity: now, + IPAddress: "192.168.1.1", + UserAgent: "TestBrowser/1.0", + } + + // Create. + if err := s.CreateAuthSession(ctx, session); err != nil { + t.Fatalf("create auth session: %v", err) + } + + // Duplicate create should return ErrAlreadyExists. + err := s.CreateAuthSession(ctx, session) + mustBeErrAlreadyExists(t, "auth session", err) + + // Get and compare. + got, err := s.GetAuthSession(ctx, session.ID) + if err != nil { + t.Fatalf("get auth session: %v", err) + } + + got.CreatedAt = got.CreatedAt.UTC().Round(time.Millisecond) + got.LastActivity = got.LastActivity.UTC().Round(time.Millisecond) + for _, cs := range got.ClientStates { + cs.ExpiresAt = cs.ExpiresAt.UTC().Round(time.Millisecond) + cs.LastActivity = cs.LastActivity.UTC().Round(time.Millisecond) + cs.LastTokenIssuedAt = cs.LastTokenIssuedAt.UTC().Round(time.Millisecond) + } + if diff := pretty.Compare(session, got); diff != "" { + t.Errorf("auth session retrieved from storage did not match: %s", diff) + } + + // Update: add a new client state. + newNow := now.Add(time.Minute) + if err := s.UpdateAuthSession(ctx, session.ID, func(old storage.AuthSession) (storage.AuthSession, error) { + old.ClientStates["client2"] = &storage.ClientAuthState{ + UserID: "user2", + ConnectorID: "conn2", + Active: true, + ExpiresAt: newNow.Add(24 * time.Hour), + LastActivity: newNow, + } + old.LastActivity = newNow + return old, nil + }); err != nil { + t.Fatalf("update auth session: %v", err) + } + + // Get and verify update. + got, err = s.GetAuthSession(ctx, session.ID) + if err != nil { + t.Fatalf("get auth session after update: %v", err) + } + if len(got.ClientStates) != 2 { + t.Fatalf("expected 2 client states, got %d", len(got.ClientStates)) + } + if got.ClientStates["client2"] == nil { + t.Fatal("expected client2 state to exist") + } + if got.ClientStates["client2"].UserID != "user2" { + t.Errorf("expected client2 user_id to be user2, got %s", got.ClientStates["client2"].UserID) + } + + // List and verify. + sessions, err := s.ListAuthSessions(ctx) + if err != nil { + t.Fatalf("list auth sessions: %v", err) + } + if len(sessions) != 1 { + t.Fatalf("expected 1 auth session, got %d", len(sessions)) + } + + // Delete. + if err := s.DeleteAuthSession(ctx, session.ID); err != nil { + t.Fatalf("delete auth session: %v", err) + } + + // Get deleted should return ErrNotFound. + _, err = s.GetAuthSession(ctx, session.ID) + mustBeErrNotFound(t, "auth session", err) +} diff --git a/storage/ent/client/authsession.go b/storage/ent/client/authsession.go new file mode 100644 index 00000000..439bdbe3 --- /dev/null +++ b/storage/ent/client/authsession.go @@ -0,0 +1,108 @@ +package client + +import ( + "context" + "encoding/json" + "fmt" + + "github.com/dexidp/dex/storage" +) + +// CreateAuthSession saves provided auth session into the database. +func (d *Database) CreateAuthSession(ctx context.Context, session storage.AuthSession) error { + if session.ClientStates == nil { + session.ClientStates = make(map[string]*storage.ClientAuthState) + } + encodedStates, err := json.Marshal(session.ClientStates) + if err != nil { + return fmt.Errorf("encode client states auth session: %w", err) + } + + _, err = d.client.AuthSession.Create(). + SetID(session.ID). + SetClientStates(encodedStates). + SetCreatedAt(session.CreatedAt). + SetLastActivity(session.LastActivity). + SetIPAddress(session.IPAddress). + SetUserAgent(session.UserAgent). + Save(ctx) + if err != nil { + return convertDBError("create auth session: %w", err) + } + return nil +} + +// GetAuthSession extracts an auth session from the database by session ID. +func (d *Database) GetAuthSession(ctx context.Context, sessionID string) (storage.AuthSession, error) { + authSession, err := d.client.AuthSession.Get(ctx, sessionID) + if err != nil { + return storage.AuthSession{}, convertDBError("get auth session: %w", err) + } + return toStorageAuthSession(authSession), nil +} + +// ListAuthSessions extracts all auth sessions from the database. +func (d *Database) ListAuthSessions(ctx context.Context) ([]storage.AuthSession, error) { + authSessions, err := d.client.AuthSession.Query().All(ctx) + if err != nil { + return nil, convertDBError("list auth sessions: %w", err) + } + + storageAuthSessions := make([]storage.AuthSession, 0, len(authSessions)) + for _, s := range authSessions { + storageAuthSessions = append(storageAuthSessions, toStorageAuthSession(s)) + } + return storageAuthSessions, nil +} + +// DeleteAuthSession deletes an auth session from the database by session ID. +func (d *Database) DeleteAuthSession(ctx context.Context, sessionID string) error { + err := d.client.AuthSession.DeleteOneID(sessionID).Exec(ctx) + if err != nil { + return convertDBError("delete auth session: %w", err) + } + return nil +} + +// UpdateAuthSession changes an auth session using an updater function. +func (d *Database) UpdateAuthSession(ctx context.Context, sessionID string, updater func(s storage.AuthSession) (storage.AuthSession, error)) error { + tx, err := d.BeginTx(ctx) + if err != nil { + return convertDBError("update auth session tx: %w", err) + } + + authSession, err := tx.AuthSession.Get(ctx, sessionID) + if err != nil { + return rollback(tx, "update auth session database: %w", err) + } + + newSession, err := updater(toStorageAuthSession(authSession)) + if err != nil { + return rollback(tx, "update auth session updating: %w", err) + } + + if newSession.ClientStates == nil { + newSession.ClientStates = make(map[string]*storage.ClientAuthState) + } + + encodedStates, err := json.Marshal(newSession.ClientStates) + if err != nil { + return rollback(tx, "encode client states auth session: %w", err) + } + + _, err = tx.AuthSession.UpdateOneID(sessionID). + SetClientStates(encodedStates). + SetLastActivity(newSession.LastActivity). + SetIPAddress(newSession.IPAddress). + SetUserAgent(newSession.UserAgent). + Save(ctx) + if err != nil { + return rollback(tx, "update auth session updating: %w", err) + } + + if err = tx.Commit(); err != nil { + return rollback(tx, "update auth session commit: %w", err) + } + + return nil +} diff --git a/storage/ent/client/types.go b/storage/ent/client/types.go index 535ab5f8..f8e99c4a 100644 --- a/storage/ent/client/types.go +++ b/storage/ent/client/types.go @@ -196,6 +196,28 @@ func toStorageUserIdentity(u *db.UserIdentity) storage.UserIdentity { return s } +func toStorageAuthSession(s *db.AuthSession) storage.AuthSession { + result := storage.AuthSession{ + ID: s.ID, + CreatedAt: s.CreatedAt, + LastActivity: s.LastActivity, + IPAddress: s.IPAddress, + UserAgent: s.UserAgent, + } + + if s.ClientStates != nil { + if err := json.Unmarshal(s.ClientStates, &result.ClientStates); err != nil { + panic(err) + } + if result.ClientStates == nil { + result.ClientStates = make(map[string]*storage.ClientAuthState) + } + } else { + result.ClientStates = make(map[string]*storage.ClientAuthState) + } + return result +} + func toStorageDeviceToken(t *db.DeviceToken) storage.DeviceToken { return storage.DeviceToken{ DeviceCode: t.DeviceCode, diff --git a/storage/ent/db/authsession.go b/storage/ent/db/authsession.go new file mode 100644 index 00000000..b81479c7 --- /dev/null +++ b/storage/ent/db/authsession.go @@ -0,0 +1,150 @@ +// Code generated by ent, DO NOT EDIT. + +package db + +import ( + "fmt" + "strings" + "time" + + "entgo.io/ent" + "entgo.io/ent/dialect/sql" + "github.com/dexidp/dex/storage/ent/db/authsession" +) + +// AuthSession is the model entity for the AuthSession schema. +type AuthSession struct { + config `json:"-"` + // ID of the ent. + ID string `json:"id,omitempty"` + // ClientStates holds the value of the "client_states" field. + ClientStates []byte `json:"client_states,omitempty"` + // CreatedAt holds the value of the "created_at" field. + CreatedAt time.Time `json:"created_at,omitempty"` + // LastActivity holds the value of the "last_activity" field. + LastActivity time.Time `json:"last_activity,omitempty"` + // IPAddress holds the value of the "ip_address" field. + IPAddress string `json:"ip_address,omitempty"` + // UserAgent holds the value of the "user_agent" field. + UserAgent string `json:"user_agent,omitempty"` + selectValues sql.SelectValues +} + +// scanValues returns the types for scanning values from sql.Rows. +func (*AuthSession) scanValues(columns []string) ([]any, error) { + values := make([]any, len(columns)) + for i := range columns { + switch columns[i] { + case authsession.FieldClientStates: + values[i] = new([]byte) + case authsession.FieldID, authsession.FieldIPAddress, authsession.FieldUserAgent: + values[i] = new(sql.NullString) + case authsession.FieldCreatedAt, authsession.FieldLastActivity: + values[i] = new(sql.NullTime) + default: + values[i] = new(sql.UnknownType) + } + } + return values, nil +} + +// assignValues assigns the values that were returned from sql.Rows (after scanning) +// to the AuthSession fields. +func (_m *AuthSession) assignValues(columns []string, values []any) error { + if m, n := len(values), len(columns); m < n { + return fmt.Errorf("mismatch number of scan values: %d != %d", m, n) + } + for i := range columns { + switch columns[i] { + case authsession.FieldID: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field id", values[i]) + } else if value.Valid { + _m.ID = value.String + } + case authsession.FieldClientStates: + if value, ok := values[i].(*[]byte); !ok { + return fmt.Errorf("unexpected type %T for field client_states", values[i]) + } else if value != nil { + _m.ClientStates = *value + } + case authsession.FieldCreatedAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field created_at", values[i]) + } else if value.Valid { + _m.CreatedAt = value.Time + } + case authsession.FieldLastActivity: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field last_activity", values[i]) + } else if value.Valid { + _m.LastActivity = value.Time + } + case authsession.FieldIPAddress: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field ip_address", values[i]) + } else if value.Valid { + _m.IPAddress = value.String + } + case authsession.FieldUserAgent: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field user_agent", values[i]) + } else if value.Valid { + _m.UserAgent = value.String + } + default: + _m.selectValues.Set(columns[i], values[i]) + } + } + return nil +} + +// Value returns the ent.Value that was dynamically selected and assigned to the AuthSession. +// This includes values selected through modifiers, order, etc. +func (_m *AuthSession) Value(name string) (ent.Value, error) { + return _m.selectValues.Get(name) +} + +// Update returns a builder for updating this AuthSession. +// Note that you need to call AuthSession.Unwrap() before calling this method if this AuthSession +// was returned from a transaction, and the transaction was committed or rolled back. +func (_m *AuthSession) Update() *AuthSessionUpdateOne { + return NewAuthSessionClient(_m.config).UpdateOne(_m) +} + +// Unwrap unwraps the AuthSession entity that was returned from a transaction after it was closed, +// so that all future queries will be executed through the driver which created the transaction. +func (_m *AuthSession) Unwrap() *AuthSession { + _tx, ok := _m.config.driver.(*txDriver) + if !ok { + panic("db: AuthSession is not a transactional entity") + } + _m.config.driver = _tx.drv + return _m +} + +// String implements the fmt.Stringer. +func (_m *AuthSession) String() string { + var builder strings.Builder + builder.WriteString("AuthSession(") + builder.WriteString(fmt.Sprintf("id=%v, ", _m.ID)) + builder.WriteString("client_states=") + builder.WriteString(fmt.Sprintf("%v", _m.ClientStates)) + builder.WriteString(", ") + builder.WriteString("created_at=") + builder.WriteString(_m.CreatedAt.Format(time.ANSIC)) + builder.WriteString(", ") + builder.WriteString("last_activity=") + builder.WriteString(_m.LastActivity.Format(time.ANSIC)) + builder.WriteString(", ") + builder.WriteString("ip_address=") + builder.WriteString(_m.IPAddress) + builder.WriteString(", ") + builder.WriteString("user_agent=") + builder.WriteString(_m.UserAgent) + builder.WriteByte(')') + return builder.String() +} + +// AuthSessions is a parsable slice of AuthSession. +type AuthSessions []*AuthSession diff --git a/storage/ent/db/authsession/authsession.go b/storage/ent/db/authsession/authsession.go new file mode 100644 index 00000000..e2548f90 --- /dev/null +++ b/storage/ent/db/authsession/authsession.go @@ -0,0 +1,83 @@ +// Code generated by ent, DO NOT EDIT. + +package authsession + +import ( + "entgo.io/ent/dialect/sql" +) + +const ( + // Label holds the string label denoting the authsession type in the database. + Label = "auth_session" + // FieldID holds the string denoting the id field in the database. + FieldID = "id" + // FieldClientStates holds the string denoting the client_states field in the database. + FieldClientStates = "client_states" + // FieldCreatedAt holds the string denoting the created_at field in the database. + FieldCreatedAt = "created_at" + // FieldLastActivity holds the string denoting the last_activity field in the database. + FieldLastActivity = "last_activity" + // FieldIPAddress holds the string denoting the ip_address field in the database. + FieldIPAddress = "ip_address" + // FieldUserAgent holds the string denoting the user_agent field in the database. + FieldUserAgent = "user_agent" + // Table holds the table name of the authsession in the database. + Table = "auth_sessions" +) + +// Columns holds all SQL columns for authsession fields. +var Columns = []string{ + FieldID, + FieldClientStates, + FieldCreatedAt, + FieldLastActivity, + FieldIPAddress, + FieldUserAgent, +} + +// ValidColumn reports if the column name is valid (part of the table columns). +func ValidColumn(column string) bool { + for i := range Columns { + if column == Columns[i] { + return true + } + } + return false +} + +var ( + // DefaultIPAddress holds the default value on creation for the "ip_address" field. + DefaultIPAddress string + // DefaultUserAgent holds the default value on creation for the "user_agent" field. + DefaultUserAgent string + // IDValidator is a validator for the "id" field. It is called by the builders before save. + IDValidator func(string) error +) + +// OrderOption defines the ordering options for the AuthSession queries. +type OrderOption func(*sql.Selector) + +// ByID orders the results by the id field. +func ByID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldID, opts...).ToFunc() +} + +// ByCreatedAt orders the results by the created_at field. +func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldCreatedAt, opts...).ToFunc() +} + +// ByLastActivity orders the results by the last_activity field. +func ByLastActivity(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldLastActivity, opts...).ToFunc() +} + +// ByIPAddress orders the results by the ip_address field. +func ByIPAddress(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldIPAddress, opts...).ToFunc() +} + +// ByUserAgent orders the results by the user_agent field. +func ByUserAgent(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldUserAgent, opts...).ToFunc() +} diff --git a/storage/ent/db/authsession/where.go b/storage/ent/db/authsession/where.go new file mode 100644 index 00000000..a4f52894 --- /dev/null +++ b/storage/ent/db/authsession/where.go @@ -0,0 +1,355 @@ +// Code generated by ent, DO NOT EDIT. + +package authsession + +import ( + "time" + + "entgo.io/ent/dialect/sql" + "github.com/dexidp/dex/storage/ent/db/predicate" +) + +// ID filters vertices based on their ID field. +func ID(id string) predicate.AuthSession { + return predicate.AuthSession(sql.FieldEQ(FieldID, id)) +} + +// IDEQ applies the EQ predicate on the ID field. +func IDEQ(id string) predicate.AuthSession { + return predicate.AuthSession(sql.FieldEQ(FieldID, id)) +} + +// IDNEQ applies the NEQ predicate on the ID field. +func IDNEQ(id string) predicate.AuthSession { + return predicate.AuthSession(sql.FieldNEQ(FieldID, id)) +} + +// IDIn applies the In predicate on the ID field. +func IDIn(ids ...string) predicate.AuthSession { + return predicate.AuthSession(sql.FieldIn(FieldID, ids...)) +} + +// IDNotIn applies the NotIn predicate on the ID field. +func IDNotIn(ids ...string) predicate.AuthSession { + return predicate.AuthSession(sql.FieldNotIn(FieldID, ids...)) +} + +// IDGT applies the GT predicate on the ID field. +func IDGT(id string) predicate.AuthSession { + return predicate.AuthSession(sql.FieldGT(FieldID, id)) +} + +// IDGTE applies the GTE predicate on the ID field. +func IDGTE(id string) predicate.AuthSession { + return predicate.AuthSession(sql.FieldGTE(FieldID, id)) +} + +// IDLT applies the LT predicate on the ID field. +func IDLT(id string) predicate.AuthSession { + return predicate.AuthSession(sql.FieldLT(FieldID, id)) +} + +// IDLTE applies the LTE predicate on the ID field. +func IDLTE(id string) predicate.AuthSession { + return predicate.AuthSession(sql.FieldLTE(FieldID, id)) +} + +// IDEqualFold applies the EqualFold predicate on the ID field. +func IDEqualFold(id string) predicate.AuthSession { + return predicate.AuthSession(sql.FieldEqualFold(FieldID, id)) +} + +// IDContainsFold applies the ContainsFold predicate on the ID field. +func IDContainsFold(id string) predicate.AuthSession { + return predicate.AuthSession(sql.FieldContainsFold(FieldID, id)) +} + +// ClientStates applies equality check predicate on the "client_states" field. It's identical to ClientStatesEQ. +func ClientStates(v []byte) predicate.AuthSession { + return predicate.AuthSession(sql.FieldEQ(FieldClientStates, v)) +} + +// CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ. +func CreatedAt(v time.Time) predicate.AuthSession { + return predicate.AuthSession(sql.FieldEQ(FieldCreatedAt, v)) +} + +// LastActivity applies equality check predicate on the "last_activity" field. It's identical to LastActivityEQ. +func LastActivity(v time.Time) predicate.AuthSession { + return predicate.AuthSession(sql.FieldEQ(FieldLastActivity, v)) +} + +// IPAddress applies equality check predicate on the "ip_address" field. It's identical to IPAddressEQ. +func IPAddress(v string) predicate.AuthSession { + return predicate.AuthSession(sql.FieldEQ(FieldIPAddress, v)) +} + +// UserAgent applies equality check predicate on the "user_agent" field. It's identical to UserAgentEQ. +func UserAgent(v string) predicate.AuthSession { + return predicate.AuthSession(sql.FieldEQ(FieldUserAgent, v)) +} + +// ClientStatesEQ applies the EQ predicate on the "client_states" field. +func ClientStatesEQ(v []byte) predicate.AuthSession { + return predicate.AuthSession(sql.FieldEQ(FieldClientStates, v)) +} + +// ClientStatesNEQ applies the NEQ predicate on the "client_states" field. +func ClientStatesNEQ(v []byte) predicate.AuthSession { + return predicate.AuthSession(sql.FieldNEQ(FieldClientStates, v)) +} + +// ClientStatesIn applies the In predicate on the "client_states" field. +func ClientStatesIn(vs ...[]byte) predicate.AuthSession { + return predicate.AuthSession(sql.FieldIn(FieldClientStates, vs...)) +} + +// ClientStatesNotIn applies the NotIn predicate on the "client_states" field. +func ClientStatesNotIn(vs ...[]byte) predicate.AuthSession { + return predicate.AuthSession(sql.FieldNotIn(FieldClientStates, vs...)) +} + +// ClientStatesGT applies the GT predicate on the "client_states" field. +func ClientStatesGT(v []byte) predicate.AuthSession { + return predicate.AuthSession(sql.FieldGT(FieldClientStates, v)) +} + +// ClientStatesGTE applies the GTE predicate on the "client_states" field. +func ClientStatesGTE(v []byte) predicate.AuthSession { + return predicate.AuthSession(sql.FieldGTE(FieldClientStates, v)) +} + +// ClientStatesLT applies the LT predicate on the "client_states" field. +func ClientStatesLT(v []byte) predicate.AuthSession { + return predicate.AuthSession(sql.FieldLT(FieldClientStates, v)) +} + +// ClientStatesLTE applies the LTE predicate on the "client_states" field. +func ClientStatesLTE(v []byte) predicate.AuthSession { + return predicate.AuthSession(sql.FieldLTE(FieldClientStates, v)) +} + +// CreatedAtEQ applies the EQ predicate on the "created_at" field. +func CreatedAtEQ(v time.Time) predicate.AuthSession { + return predicate.AuthSession(sql.FieldEQ(FieldCreatedAt, v)) +} + +// CreatedAtNEQ applies the NEQ predicate on the "created_at" field. +func CreatedAtNEQ(v time.Time) predicate.AuthSession { + return predicate.AuthSession(sql.FieldNEQ(FieldCreatedAt, v)) +} + +// CreatedAtIn applies the In predicate on the "created_at" field. +func CreatedAtIn(vs ...time.Time) predicate.AuthSession { + return predicate.AuthSession(sql.FieldIn(FieldCreatedAt, vs...)) +} + +// CreatedAtNotIn applies the NotIn predicate on the "created_at" field. +func CreatedAtNotIn(vs ...time.Time) predicate.AuthSession { + return predicate.AuthSession(sql.FieldNotIn(FieldCreatedAt, vs...)) +} + +// CreatedAtGT applies the GT predicate on the "created_at" field. +func CreatedAtGT(v time.Time) predicate.AuthSession { + return predicate.AuthSession(sql.FieldGT(FieldCreatedAt, v)) +} + +// CreatedAtGTE applies the GTE predicate on the "created_at" field. +func CreatedAtGTE(v time.Time) predicate.AuthSession { + return predicate.AuthSession(sql.FieldGTE(FieldCreatedAt, v)) +} + +// CreatedAtLT applies the LT predicate on the "created_at" field. +func CreatedAtLT(v time.Time) predicate.AuthSession { + return predicate.AuthSession(sql.FieldLT(FieldCreatedAt, v)) +} + +// CreatedAtLTE applies the LTE predicate on the "created_at" field. +func CreatedAtLTE(v time.Time) predicate.AuthSession { + return predicate.AuthSession(sql.FieldLTE(FieldCreatedAt, v)) +} + +// LastActivityEQ applies the EQ predicate on the "last_activity" field. +func LastActivityEQ(v time.Time) predicate.AuthSession { + return predicate.AuthSession(sql.FieldEQ(FieldLastActivity, v)) +} + +// LastActivityNEQ applies the NEQ predicate on the "last_activity" field. +func LastActivityNEQ(v time.Time) predicate.AuthSession { + return predicate.AuthSession(sql.FieldNEQ(FieldLastActivity, v)) +} + +// LastActivityIn applies the In predicate on the "last_activity" field. +func LastActivityIn(vs ...time.Time) predicate.AuthSession { + return predicate.AuthSession(sql.FieldIn(FieldLastActivity, vs...)) +} + +// LastActivityNotIn applies the NotIn predicate on the "last_activity" field. +func LastActivityNotIn(vs ...time.Time) predicate.AuthSession { + return predicate.AuthSession(sql.FieldNotIn(FieldLastActivity, vs...)) +} + +// LastActivityGT applies the GT predicate on the "last_activity" field. +func LastActivityGT(v time.Time) predicate.AuthSession { + return predicate.AuthSession(sql.FieldGT(FieldLastActivity, v)) +} + +// LastActivityGTE applies the GTE predicate on the "last_activity" field. +func LastActivityGTE(v time.Time) predicate.AuthSession { + return predicate.AuthSession(sql.FieldGTE(FieldLastActivity, v)) +} + +// LastActivityLT applies the LT predicate on the "last_activity" field. +func LastActivityLT(v time.Time) predicate.AuthSession { + return predicate.AuthSession(sql.FieldLT(FieldLastActivity, v)) +} + +// LastActivityLTE applies the LTE predicate on the "last_activity" field. +func LastActivityLTE(v time.Time) predicate.AuthSession { + return predicate.AuthSession(sql.FieldLTE(FieldLastActivity, v)) +} + +// IPAddressEQ applies the EQ predicate on the "ip_address" field. +func IPAddressEQ(v string) predicate.AuthSession { + return predicate.AuthSession(sql.FieldEQ(FieldIPAddress, v)) +} + +// IPAddressNEQ applies the NEQ predicate on the "ip_address" field. +func IPAddressNEQ(v string) predicate.AuthSession { + return predicate.AuthSession(sql.FieldNEQ(FieldIPAddress, v)) +} + +// IPAddressIn applies the In predicate on the "ip_address" field. +func IPAddressIn(vs ...string) predicate.AuthSession { + return predicate.AuthSession(sql.FieldIn(FieldIPAddress, vs...)) +} + +// IPAddressNotIn applies the NotIn predicate on the "ip_address" field. +func IPAddressNotIn(vs ...string) predicate.AuthSession { + return predicate.AuthSession(sql.FieldNotIn(FieldIPAddress, vs...)) +} + +// IPAddressGT applies the GT predicate on the "ip_address" field. +func IPAddressGT(v string) predicate.AuthSession { + return predicate.AuthSession(sql.FieldGT(FieldIPAddress, v)) +} + +// IPAddressGTE applies the GTE predicate on the "ip_address" field. +func IPAddressGTE(v string) predicate.AuthSession { + return predicate.AuthSession(sql.FieldGTE(FieldIPAddress, v)) +} + +// IPAddressLT applies the LT predicate on the "ip_address" field. +func IPAddressLT(v string) predicate.AuthSession { + return predicate.AuthSession(sql.FieldLT(FieldIPAddress, v)) +} + +// IPAddressLTE applies the LTE predicate on the "ip_address" field. +func IPAddressLTE(v string) predicate.AuthSession { + return predicate.AuthSession(sql.FieldLTE(FieldIPAddress, v)) +} + +// IPAddressContains applies the Contains predicate on the "ip_address" field. +func IPAddressContains(v string) predicate.AuthSession { + return predicate.AuthSession(sql.FieldContains(FieldIPAddress, v)) +} + +// IPAddressHasPrefix applies the HasPrefix predicate on the "ip_address" field. +func IPAddressHasPrefix(v string) predicate.AuthSession { + return predicate.AuthSession(sql.FieldHasPrefix(FieldIPAddress, v)) +} + +// IPAddressHasSuffix applies the HasSuffix predicate on the "ip_address" field. +func IPAddressHasSuffix(v string) predicate.AuthSession { + return predicate.AuthSession(sql.FieldHasSuffix(FieldIPAddress, v)) +} + +// IPAddressEqualFold applies the EqualFold predicate on the "ip_address" field. +func IPAddressEqualFold(v string) predicate.AuthSession { + return predicate.AuthSession(sql.FieldEqualFold(FieldIPAddress, v)) +} + +// IPAddressContainsFold applies the ContainsFold predicate on the "ip_address" field. +func IPAddressContainsFold(v string) predicate.AuthSession { + return predicate.AuthSession(sql.FieldContainsFold(FieldIPAddress, v)) +} + +// UserAgentEQ applies the EQ predicate on the "user_agent" field. +func UserAgentEQ(v string) predicate.AuthSession { + return predicate.AuthSession(sql.FieldEQ(FieldUserAgent, v)) +} + +// UserAgentNEQ applies the NEQ predicate on the "user_agent" field. +func UserAgentNEQ(v string) predicate.AuthSession { + return predicate.AuthSession(sql.FieldNEQ(FieldUserAgent, v)) +} + +// UserAgentIn applies the In predicate on the "user_agent" field. +func UserAgentIn(vs ...string) predicate.AuthSession { + return predicate.AuthSession(sql.FieldIn(FieldUserAgent, vs...)) +} + +// UserAgentNotIn applies the NotIn predicate on the "user_agent" field. +func UserAgentNotIn(vs ...string) predicate.AuthSession { + return predicate.AuthSession(sql.FieldNotIn(FieldUserAgent, vs...)) +} + +// UserAgentGT applies the GT predicate on the "user_agent" field. +func UserAgentGT(v string) predicate.AuthSession { + return predicate.AuthSession(sql.FieldGT(FieldUserAgent, v)) +} + +// UserAgentGTE applies the GTE predicate on the "user_agent" field. +func UserAgentGTE(v string) predicate.AuthSession { + return predicate.AuthSession(sql.FieldGTE(FieldUserAgent, v)) +} + +// UserAgentLT applies the LT predicate on the "user_agent" field. +func UserAgentLT(v string) predicate.AuthSession { + return predicate.AuthSession(sql.FieldLT(FieldUserAgent, v)) +} + +// UserAgentLTE applies the LTE predicate on the "user_agent" field. +func UserAgentLTE(v string) predicate.AuthSession { + return predicate.AuthSession(sql.FieldLTE(FieldUserAgent, v)) +} + +// UserAgentContains applies the Contains predicate on the "user_agent" field. +func UserAgentContains(v string) predicate.AuthSession { + return predicate.AuthSession(sql.FieldContains(FieldUserAgent, v)) +} + +// UserAgentHasPrefix applies the HasPrefix predicate on the "user_agent" field. +func UserAgentHasPrefix(v string) predicate.AuthSession { + return predicate.AuthSession(sql.FieldHasPrefix(FieldUserAgent, v)) +} + +// UserAgentHasSuffix applies the HasSuffix predicate on the "user_agent" field. +func UserAgentHasSuffix(v string) predicate.AuthSession { + return predicate.AuthSession(sql.FieldHasSuffix(FieldUserAgent, v)) +} + +// UserAgentEqualFold applies the EqualFold predicate on the "user_agent" field. +func UserAgentEqualFold(v string) predicate.AuthSession { + return predicate.AuthSession(sql.FieldEqualFold(FieldUserAgent, v)) +} + +// UserAgentContainsFold applies the ContainsFold predicate on the "user_agent" field. +func UserAgentContainsFold(v string) predicate.AuthSession { + return predicate.AuthSession(sql.FieldContainsFold(FieldUserAgent, v)) +} + +// And groups predicates with the AND operator between them. +func And(predicates ...predicate.AuthSession) predicate.AuthSession { + return predicate.AuthSession(sql.AndPredicates(predicates...)) +} + +// Or groups predicates with the OR operator between them. +func Or(predicates ...predicate.AuthSession) predicate.AuthSession { + return predicate.AuthSession(sql.OrPredicates(predicates...)) +} + +// Not applies the not operator on the given predicate. +func Not(p predicate.AuthSession) predicate.AuthSession { + return predicate.AuthSession(sql.NotPredicates(p)) +} diff --git a/storage/ent/db/authsession_create.go b/storage/ent/db/authsession_create.go new file mode 100644 index 00000000..a680d675 --- /dev/null +++ b/storage/ent/db/authsession_create.go @@ -0,0 +1,282 @@ +// Code generated by ent, DO NOT EDIT. + +package db + +import ( + "context" + "errors" + "fmt" + "time" + + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/dexidp/dex/storage/ent/db/authsession" +) + +// AuthSessionCreate is the builder for creating a AuthSession entity. +type AuthSessionCreate struct { + config + mutation *AuthSessionMutation + hooks []Hook +} + +// SetClientStates sets the "client_states" field. +func (_c *AuthSessionCreate) SetClientStates(v []byte) *AuthSessionCreate { + _c.mutation.SetClientStates(v) + return _c +} + +// SetCreatedAt sets the "created_at" field. +func (_c *AuthSessionCreate) SetCreatedAt(v time.Time) *AuthSessionCreate { + _c.mutation.SetCreatedAt(v) + return _c +} + +// SetLastActivity sets the "last_activity" field. +func (_c *AuthSessionCreate) SetLastActivity(v time.Time) *AuthSessionCreate { + _c.mutation.SetLastActivity(v) + return _c +} + +// SetIPAddress sets the "ip_address" field. +func (_c *AuthSessionCreate) SetIPAddress(v string) *AuthSessionCreate { + _c.mutation.SetIPAddress(v) + return _c +} + +// SetNillableIPAddress sets the "ip_address" field if the given value is not nil. +func (_c *AuthSessionCreate) SetNillableIPAddress(v *string) *AuthSessionCreate { + if v != nil { + _c.SetIPAddress(*v) + } + return _c +} + +// SetUserAgent sets the "user_agent" field. +func (_c *AuthSessionCreate) SetUserAgent(v string) *AuthSessionCreate { + _c.mutation.SetUserAgent(v) + return _c +} + +// SetNillableUserAgent sets the "user_agent" field if the given value is not nil. +func (_c *AuthSessionCreate) SetNillableUserAgent(v *string) *AuthSessionCreate { + if v != nil { + _c.SetUserAgent(*v) + } + return _c +} + +// SetID sets the "id" field. +func (_c *AuthSessionCreate) SetID(v string) *AuthSessionCreate { + _c.mutation.SetID(v) + return _c +} + +// Mutation returns the AuthSessionMutation object of the builder. +func (_c *AuthSessionCreate) Mutation() *AuthSessionMutation { + return _c.mutation +} + +// Save creates the AuthSession in the database. +func (_c *AuthSessionCreate) Save(ctx context.Context) (*AuthSession, error) { + _c.defaults() + return withHooks(ctx, _c.sqlSave, _c.mutation, _c.hooks) +} + +// SaveX calls Save and panics if Save returns an error. +func (_c *AuthSessionCreate) SaveX(ctx context.Context) *AuthSession { + v, err := _c.Save(ctx) + if err != nil { + panic(err) + } + return v +} + +// Exec executes the query. +func (_c *AuthSessionCreate) Exec(ctx context.Context) error { + _, err := _c.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_c *AuthSessionCreate) ExecX(ctx context.Context) { + if err := _c.Exec(ctx); err != nil { + panic(err) + } +} + +// defaults sets the default values of the builder before save. +func (_c *AuthSessionCreate) defaults() { + if _, ok := _c.mutation.IPAddress(); !ok { + v := authsession.DefaultIPAddress + _c.mutation.SetIPAddress(v) + } + if _, ok := _c.mutation.UserAgent(); !ok { + v := authsession.DefaultUserAgent + _c.mutation.SetUserAgent(v) + } +} + +// check runs all checks and user-defined validators on the builder. +func (_c *AuthSessionCreate) check() error { + if _, ok := _c.mutation.ClientStates(); !ok { + return &ValidationError{Name: "client_states", err: errors.New(`db: missing required field "AuthSession.client_states"`)} + } + if _, ok := _c.mutation.CreatedAt(); !ok { + return &ValidationError{Name: "created_at", err: errors.New(`db: missing required field "AuthSession.created_at"`)} + } + if _, ok := _c.mutation.LastActivity(); !ok { + return &ValidationError{Name: "last_activity", err: errors.New(`db: missing required field "AuthSession.last_activity"`)} + } + if _, ok := _c.mutation.IPAddress(); !ok { + return &ValidationError{Name: "ip_address", err: errors.New(`db: missing required field "AuthSession.ip_address"`)} + } + if _, ok := _c.mutation.UserAgent(); !ok { + return &ValidationError{Name: "user_agent", err: errors.New(`db: missing required field "AuthSession.user_agent"`)} + } + if v, ok := _c.mutation.ID(); ok { + if err := authsession.IDValidator(v); err != nil { + return &ValidationError{Name: "id", err: fmt.Errorf(`db: validator failed for field "AuthSession.id": %w`, err)} + } + } + return nil +} + +func (_c *AuthSessionCreate) sqlSave(ctx context.Context) (*AuthSession, error) { + if err := _c.check(); err != nil { + return nil, err + } + _node, _spec := _c.createSpec() + if err := sqlgraph.CreateNode(ctx, _c.driver, _spec); err != nil { + if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return nil, err + } + if _spec.ID.Value != nil { + if id, ok := _spec.ID.Value.(string); ok { + _node.ID = id + } else { + return nil, fmt.Errorf("unexpected AuthSession.ID type: %T", _spec.ID.Value) + } + } + _c.mutation.id = &_node.ID + _c.mutation.done = true + return _node, nil +} + +func (_c *AuthSessionCreate) createSpec() (*AuthSession, *sqlgraph.CreateSpec) { + var ( + _node = &AuthSession{config: _c.config} + _spec = sqlgraph.NewCreateSpec(authsession.Table, sqlgraph.NewFieldSpec(authsession.FieldID, field.TypeString)) + ) + if id, ok := _c.mutation.ID(); ok { + _node.ID = id + _spec.ID.Value = id + } + if value, ok := _c.mutation.ClientStates(); ok { + _spec.SetField(authsession.FieldClientStates, field.TypeBytes, value) + _node.ClientStates = value + } + if value, ok := _c.mutation.CreatedAt(); ok { + _spec.SetField(authsession.FieldCreatedAt, field.TypeTime, value) + _node.CreatedAt = value + } + if value, ok := _c.mutation.LastActivity(); ok { + _spec.SetField(authsession.FieldLastActivity, field.TypeTime, value) + _node.LastActivity = value + } + if value, ok := _c.mutation.IPAddress(); ok { + _spec.SetField(authsession.FieldIPAddress, field.TypeString, value) + _node.IPAddress = value + } + if value, ok := _c.mutation.UserAgent(); ok { + _spec.SetField(authsession.FieldUserAgent, field.TypeString, value) + _node.UserAgent = value + } + return _node, _spec +} + +// AuthSessionCreateBulk is the builder for creating many AuthSession entities in bulk. +type AuthSessionCreateBulk struct { + config + err error + builders []*AuthSessionCreate +} + +// Save creates the AuthSession entities in the database. +func (_c *AuthSessionCreateBulk) Save(ctx context.Context) ([]*AuthSession, error) { + if _c.err != nil { + return nil, _c.err + } + specs := make([]*sqlgraph.CreateSpec, len(_c.builders)) + nodes := make([]*AuthSession, len(_c.builders)) + mutators := make([]Mutator, len(_c.builders)) + for i := range _c.builders { + func(i int, root context.Context) { + builder := _c.builders[i] + builder.defaults() + var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { + mutation, ok := m.(*AuthSessionMutation) + if !ok { + return nil, fmt.Errorf("unexpected mutation type %T", m) + } + if err := builder.check(); err != nil { + return nil, err + } + builder.mutation = mutation + var err error + nodes[i], specs[i] = builder.createSpec() + if i < len(mutators)-1 { + _, err = mutators[i+1].Mutate(root, _c.builders[i+1].mutation) + } else { + spec := &sqlgraph.BatchCreateSpec{Nodes: specs} + // Invoke the actual operation on the latest mutation in the chain. + if err = sqlgraph.BatchCreate(ctx, _c.driver, spec); err != nil { + if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + } + } + if err != nil { + return nil, err + } + mutation.id = &nodes[i].ID + mutation.done = true + return nodes[i], nil + }) + for i := len(builder.hooks) - 1; i >= 0; i-- { + mut = builder.hooks[i](mut) + } + mutators[i] = mut + }(i, ctx) + } + if len(mutators) > 0 { + if _, err := mutators[0].Mutate(ctx, _c.builders[0].mutation); err != nil { + return nil, err + } + } + return nodes, nil +} + +// SaveX is like Save, but panics if an error occurs. +func (_c *AuthSessionCreateBulk) SaveX(ctx context.Context) []*AuthSession { + v, err := _c.Save(ctx) + if err != nil { + panic(err) + } + return v +} + +// Exec executes the query. +func (_c *AuthSessionCreateBulk) Exec(ctx context.Context) error { + _, err := _c.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_c *AuthSessionCreateBulk) ExecX(ctx context.Context) { + if err := _c.Exec(ctx); err != nil { + panic(err) + } +} diff --git a/storage/ent/db/authsession_delete.go b/storage/ent/db/authsession_delete.go new file mode 100644 index 00000000..63116fb4 --- /dev/null +++ b/storage/ent/db/authsession_delete.go @@ -0,0 +1,88 @@ +// Code generated by ent, DO NOT EDIT. + +package db + +import ( + "context" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/dexidp/dex/storage/ent/db/authsession" + "github.com/dexidp/dex/storage/ent/db/predicate" +) + +// AuthSessionDelete is the builder for deleting a AuthSession entity. +type AuthSessionDelete struct { + config + hooks []Hook + mutation *AuthSessionMutation +} + +// Where appends a list predicates to the AuthSessionDelete builder. +func (_d *AuthSessionDelete) Where(ps ...predicate.AuthSession) *AuthSessionDelete { + _d.mutation.Where(ps...) + return _d +} + +// Exec executes the deletion query and returns how many vertices were deleted. +func (_d *AuthSessionDelete) Exec(ctx context.Context) (int, error) { + return withHooks(ctx, _d.sqlExec, _d.mutation, _d.hooks) +} + +// ExecX is like Exec, but panics if an error occurs. +func (_d *AuthSessionDelete) ExecX(ctx context.Context) int { + n, err := _d.Exec(ctx) + if err != nil { + panic(err) + } + return n +} + +func (_d *AuthSessionDelete) sqlExec(ctx context.Context) (int, error) { + _spec := sqlgraph.NewDeleteSpec(authsession.Table, sqlgraph.NewFieldSpec(authsession.FieldID, field.TypeString)) + if ps := _d.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + affected, err := sqlgraph.DeleteNodes(ctx, _d.driver, _spec) + if err != nil && sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + _d.mutation.done = true + return affected, err +} + +// AuthSessionDeleteOne is the builder for deleting a single AuthSession entity. +type AuthSessionDeleteOne struct { + _d *AuthSessionDelete +} + +// Where appends a list predicates to the AuthSessionDelete builder. +func (_d *AuthSessionDeleteOne) Where(ps ...predicate.AuthSession) *AuthSessionDeleteOne { + _d._d.mutation.Where(ps...) + return _d +} + +// Exec executes the deletion query. +func (_d *AuthSessionDeleteOne) Exec(ctx context.Context) error { + n, err := _d._d.Exec(ctx) + switch { + case err != nil: + return err + case n == 0: + return &NotFoundError{authsession.Label} + default: + return nil + } +} + +// ExecX is like Exec, but panics if an error occurs. +func (_d *AuthSessionDeleteOne) ExecX(ctx context.Context) { + if err := _d.Exec(ctx); err != nil { + panic(err) + } +} diff --git a/storage/ent/db/authsession_query.go b/storage/ent/db/authsession_query.go new file mode 100644 index 00000000..dc3528f9 --- /dev/null +++ b/storage/ent/db/authsession_query.go @@ -0,0 +1,527 @@ +// Code generated by ent, DO NOT EDIT. + +package db + +import ( + "context" + "fmt" + "math" + + "entgo.io/ent" + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/dexidp/dex/storage/ent/db/authsession" + "github.com/dexidp/dex/storage/ent/db/predicate" +) + +// AuthSessionQuery is the builder for querying AuthSession entities. +type AuthSessionQuery struct { + config + ctx *QueryContext + order []authsession.OrderOption + inters []Interceptor + predicates []predicate.AuthSession + // intermediate query (i.e. traversal path). + sql *sql.Selector + path func(context.Context) (*sql.Selector, error) +} + +// Where adds a new predicate for the AuthSessionQuery builder. +func (_q *AuthSessionQuery) Where(ps ...predicate.AuthSession) *AuthSessionQuery { + _q.predicates = append(_q.predicates, ps...) + return _q +} + +// Limit the number of records to be returned by this query. +func (_q *AuthSessionQuery) Limit(limit int) *AuthSessionQuery { + _q.ctx.Limit = &limit + return _q +} + +// Offset to start from. +func (_q *AuthSessionQuery) Offset(offset int) *AuthSessionQuery { + _q.ctx.Offset = &offset + return _q +} + +// Unique configures the query builder to filter duplicate records on query. +// By default, unique is set to true, and can be disabled using this method. +func (_q *AuthSessionQuery) Unique(unique bool) *AuthSessionQuery { + _q.ctx.Unique = &unique + return _q +} + +// Order specifies how the records should be ordered. +func (_q *AuthSessionQuery) Order(o ...authsession.OrderOption) *AuthSessionQuery { + _q.order = append(_q.order, o...) + return _q +} + +// First returns the first AuthSession entity from the query. +// Returns a *NotFoundError when no AuthSession was found. +func (_q *AuthSessionQuery) First(ctx context.Context) (*AuthSession, error) { + nodes, err := _q.Limit(1).All(setContextOp(ctx, _q.ctx, ent.OpQueryFirst)) + if err != nil { + return nil, err + } + if len(nodes) == 0 { + return nil, &NotFoundError{authsession.Label} + } + return nodes[0], nil +} + +// FirstX is like First, but panics if an error occurs. +func (_q *AuthSessionQuery) FirstX(ctx context.Context) *AuthSession { + node, err := _q.First(ctx) + if err != nil && !IsNotFound(err) { + panic(err) + } + return node +} + +// FirstID returns the first AuthSession ID from the query. +// Returns a *NotFoundError when no AuthSession ID was found. +func (_q *AuthSessionQuery) FirstID(ctx context.Context) (id string, err error) { + var ids []string + if ids, err = _q.Limit(1).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryFirstID)); err != nil { + return + } + if len(ids) == 0 { + err = &NotFoundError{authsession.Label} + return + } + return ids[0], nil +} + +// FirstIDX is like FirstID, but panics if an error occurs. +func (_q *AuthSessionQuery) FirstIDX(ctx context.Context) string { + id, err := _q.FirstID(ctx) + if err != nil && !IsNotFound(err) { + panic(err) + } + return id +} + +// Only returns a single AuthSession entity found by the query, ensuring it only returns one. +// Returns a *NotSingularError when more than one AuthSession entity is found. +// Returns a *NotFoundError when no AuthSession entities are found. +func (_q *AuthSessionQuery) Only(ctx context.Context) (*AuthSession, error) { + nodes, err := _q.Limit(2).All(setContextOp(ctx, _q.ctx, ent.OpQueryOnly)) + if err != nil { + return nil, err + } + switch len(nodes) { + case 1: + return nodes[0], nil + case 0: + return nil, &NotFoundError{authsession.Label} + default: + return nil, &NotSingularError{authsession.Label} + } +} + +// OnlyX is like Only, but panics if an error occurs. +func (_q *AuthSessionQuery) OnlyX(ctx context.Context) *AuthSession { + node, err := _q.Only(ctx) + if err != nil { + panic(err) + } + return node +} + +// OnlyID is like Only, but returns the only AuthSession ID in the query. +// Returns a *NotSingularError when more than one AuthSession ID is found. +// Returns a *NotFoundError when no entities are found. +func (_q *AuthSessionQuery) OnlyID(ctx context.Context) (id string, err error) { + var ids []string + if ids, err = _q.Limit(2).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryOnlyID)); err != nil { + return + } + switch len(ids) { + case 1: + id = ids[0] + case 0: + err = &NotFoundError{authsession.Label} + default: + err = &NotSingularError{authsession.Label} + } + return +} + +// OnlyIDX is like OnlyID, but panics if an error occurs. +func (_q *AuthSessionQuery) OnlyIDX(ctx context.Context) string { + id, err := _q.OnlyID(ctx) + if err != nil { + panic(err) + } + return id +} + +// All executes the query and returns a list of AuthSessions. +func (_q *AuthSessionQuery) All(ctx context.Context) ([]*AuthSession, error) { + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryAll) + if err := _q.prepareQuery(ctx); err != nil { + return nil, err + } + qr := querierAll[[]*AuthSession, *AuthSessionQuery]() + return withInterceptors[[]*AuthSession](ctx, _q, qr, _q.inters) +} + +// AllX is like All, but panics if an error occurs. +func (_q *AuthSessionQuery) AllX(ctx context.Context) []*AuthSession { + nodes, err := _q.All(ctx) + if err != nil { + panic(err) + } + return nodes +} + +// IDs executes the query and returns a list of AuthSession IDs. +func (_q *AuthSessionQuery) IDs(ctx context.Context) (ids []string, err error) { + if _q.ctx.Unique == nil && _q.path != nil { + _q.Unique(true) + } + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryIDs) + if err = _q.Select(authsession.FieldID).Scan(ctx, &ids); err != nil { + return nil, err + } + return ids, nil +} + +// IDsX is like IDs, but panics if an error occurs. +func (_q *AuthSessionQuery) IDsX(ctx context.Context) []string { + ids, err := _q.IDs(ctx) + if err != nil { + panic(err) + } + return ids +} + +// Count returns the count of the given query. +func (_q *AuthSessionQuery) Count(ctx context.Context) (int, error) { + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryCount) + if err := _q.prepareQuery(ctx); err != nil { + return 0, err + } + return withInterceptors[int](ctx, _q, querierCount[*AuthSessionQuery](), _q.inters) +} + +// CountX is like Count, but panics if an error occurs. +func (_q *AuthSessionQuery) CountX(ctx context.Context) int { + count, err := _q.Count(ctx) + if err != nil { + panic(err) + } + return count +} + +// Exist returns true if the query has elements in the graph. +func (_q *AuthSessionQuery) Exist(ctx context.Context) (bool, error) { + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryExist) + switch _, err := _q.FirstID(ctx); { + case IsNotFound(err): + return false, nil + case err != nil: + return false, fmt.Errorf("db: check existence: %w", err) + default: + return true, nil + } +} + +// ExistX is like Exist, but panics if an error occurs. +func (_q *AuthSessionQuery) ExistX(ctx context.Context) bool { + exist, err := _q.Exist(ctx) + if err != nil { + panic(err) + } + return exist +} + +// Clone returns a duplicate of the AuthSessionQuery builder, including all associated steps. It can be +// used to prepare common query builders and use them differently after the clone is made. +func (_q *AuthSessionQuery) Clone() *AuthSessionQuery { + if _q == nil { + return nil + } + return &AuthSessionQuery{ + config: _q.config, + ctx: _q.ctx.Clone(), + order: append([]authsession.OrderOption{}, _q.order...), + inters: append([]Interceptor{}, _q.inters...), + predicates: append([]predicate.AuthSession{}, _q.predicates...), + // clone intermediate query. + sql: _q.sql.Clone(), + path: _q.path, + } +} + +// GroupBy is used to group vertices by one or more fields/columns. +// It is often used with aggregate functions, like: count, max, mean, min, sum. +// +// Example: +// +// var v []struct { +// ClientStates []byte `json:"client_states,omitempty"` +// Count int `json:"count,omitempty"` +// } +// +// client.AuthSession.Query(). +// GroupBy(authsession.FieldClientStates). +// Aggregate(db.Count()). +// Scan(ctx, &v) +func (_q *AuthSessionQuery) GroupBy(field string, fields ...string) *AuthSessionGroupBy { + _q.ctx.Fields = append([]string{field}, fields...) + grbuild := &AuthSessionGroupBy{build: _q} + grbuild.flds = &_q.ctx.Fields + grbuild.label = authsession.Label + grbuild.scan = grbuild.Scan + return grbuild +} + +// Select allows the selection one or more fields/columns for the given query, +// instead of selecting all fields in the entity. +// +// Example: +// +// var v []struct { +// ClientStates []byte `json:"client_states,omitempty"` +// } +// +// client.AuthSession.Query(). +// Select(authsession.FieldClientStates). +// Scan(ctx, &v) +func (_q *AuthSessionQuery) Select(fields ...string) *AuthSessionSelect { + _q.ctx.Fields = append(_q.ctx.Fields, fields...) + sbuild := &AuthSessionSelect{AuthSessionQuery: _q} + sbuild.label = authsession.Label + sbuild.flds, sbuild.scan = &_q.ctx.Fields, sbuild.Scan + return sbuild +} + +// Aggregate returns a AuthSessionSelect configured with the given aggregations. +func (_q *AuthSessionQuery) Aggregate(fns ...AggregateFunc) *AuthSessionSelect { + return _q.Select().Aggregate(fns...) +} + +func (_q *AuthSessionQuery) prepareQuery(ctx context.Context) error { + for _, inter := range _q.inters { + if inter == nil { + return fmt.Errorf("db: uninitialized interceptor (forgotten import db/runtime?)") + } + if trv, ok := inter.(Traverser); ok { + if err := trv.Traverse(ctx, _q); err != nil { + return err + } + } + } + for _, f := range _q.ctx.Fields { + if !authsession.ValidColumn(f) { + return &ValidationError{Name: f, err: fmt.Errorf("db: invalid field %q for query", f)} + } + } + if _q.path != nil { + prev, err := _q.path(ctx) + if err != nil { + return err + } + _q.sql = prev + } + return nil +} + +func (_q *AuthSessionQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*AuthSession, error) { + var ( + nodes = []*AuthSession{} + _spec = _q.querySpec() + ) + _spec.ScanValues = func(columns []string) ([]any, error) { + return (*AuthSession).scanValues(nil, columns) + } + _spec.Assign = func(columns []string, values []any) error { + node := &AuthSession{config: _q.config} + nodes = append(nodes, node) + return node.assignValues(columns, values) + } + for i := range hooks { + hooks[i](ctx, _spec) + } + if err := sqlgraph.QueryNodes(ctx, _q.driver, _spec); err != nil { + return nil, err + } + if len(nodes) == 0 { + return nodes, nil + } + return nodes, nil +} + +func (_q *AuthSessionQuery) sqlCount(ctx context.Context) (int, error) { + _spec := _q.querySpec() + _spec.Node.Columns = _q.ctx.Fields + if len(_q.ctx.Fields) > 0 { + _spec.Unique = _q.ctx.Unique != nil && *_q.ctx.Unique + } + return sqlgraph.CountNodes(ctx, _q.driver, _spec) +} + +func (_q *AuthSessionQuery) querySpec() *sqlgraph.QuerySpec { + _spec := sqlgraph.NewQuerySpec(authsession.Table, authsession.Columns, sqlgraph.NewFieldSpec(authsession.FieldID, field.TypeString)) + _spec.From = _q.sql + if unique := _q.ctx.Unique; unique != nil { + _spec.Unique = *unique + } else if _q.path != nil { + _spec.Unique = true + } + if fields := _q.ctx.Fields; len(fields) > 0 { + _spec.Node.Columns = make([]string, 0, len(fields)) + _spec.Node.Columns = append(_spec.Node.Columns, authsession.FieldID) + for i := range fields { + if fields[i] != authsession.FieldID { + _spec.Node.Columns = append(_spec.Node.Columns, fields[i]) + } + } + } + if ps := _q.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if limit := _q.ctx.Limit; limit != nil { + _spec.Limit = *limit + } + if offset := _q.ctx.Offset; offset != nil { + _spec.Offset = *offset + } + if ps := _q.order; len(ps) > 0 { + _spec.Order = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + return _spec +} + +func (_q *AuthSessionQuery) sqlQuery(ctx context.Context) *sql.Selector { + builder := sql.Dialect(_q.driver.Dialect()) + t1 := builder.Table(authsession.Table) + columns := _q.ctx.Fields + if len(columns) == 0 { + columns = authsession.Columns + } + selector := builder.Select(t1.Columns(columns...)...).From(t1) + if _q.sql != nil { + selector = _q.sql + selector.Select(selector.Columns(columns...)...) + } + if _q.ctx.Unique != nil && *_q.ctx.Unique { + selector.Distinct() + } + for _, p := range _q.predicates { + p(selector) + } + for _, p := range _q.order { + p(selector) + } + if offset := _q.ctx.Offset; offset != nil { + // limit is mandatory for offset clause. We start + // with default value, and override it below if needed. + selector.Offset(*offset).Limit(math.MaxInt32) + } + if limit := _q.ctx.Limit; limit != nil { + selector.Limit(*limit) + } + return selector +} + +// AuthSessionGroupBy is the group-by builder for AuthSession entities. +type AuthSessionGroupBy struct { + selector + build *AuthSessionQuery +} + +// Aggregate adds the given aggregation functions to the group-by query. +func (_g *AuthSessionGroupBy) Aggregate(fns ...AggregateFunc) *AuthSessionGroupBy { + _g.fns = append(_g.fns, fns...) + return _g +} + +// Scan applies the selector query and scans the result into the given value. +func (_g *AuthSessionGroupBy) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, _g.build.ctx, ent.OpQueryGroupBy) + if err := _g.build.prepareQuery(ctx); err != nil { + return err + } + return scanWithInterceptors[*AuthSessionQuery, *AuthSessionGroupBy](ctx, _g.build, _g, _g.build.inters, v) +} + +func (_g *AuthSessionGroupBy) sqlScan(ctx context.Context, root *AuthSessionQuery, v any) error { + selector := root.sqlQuery(ctx).Select() + aggregation := make([]string, 0, len(_g.fns)) + for _, fn := range _g.fns { + aggregation = append(aggregation, fn(selector)) + } + if len(selector.SelectedColumns()) == 0 { + columns := make([]string, 0, len(*_g.flds)+len(_g.fns)) + for _, f := range *_g.flds { + columns = append(columns, selector.C(f)) + } + columns = append(columns, aggregation...) + selector.Select(columns...) + } + selector.GroupBy(selector.Columns(*_g.flds...)...) + if err := selector.Err(); err != nil { + return err + } + rows := &sql.Rows{} + query, args := selector.Query() + if err := _g.build.driver.Query(ctx, query, args, rows); err != nil { + return err + } + defer rows.Close() + return sql.ScanSlice(rows, v) +} + +// AuthSessionSelect is the builder for selecting fields of AuthSession entities. +type AuthSessionSelect struct { + *AuthSessionQuery + selector +} + +// Aggregate adds the given aggregation functions to the selector query. +func (_s *AuthSessionSelect) Aggregate(fns ...AggregateFunc) *AuthSessionSelect { + _s.fns = append(_s.fns, fns...) + return _s +} + +// Scan applies the selector query and scans the result into the given value. +func (_s *AuthSessionSelect) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, _s.ctx, ent.OpQuerySelect) + if err := _s.prepareQuery(ctx); err != nil { + return err + } + return scanWithInterceptors[*AuthSessionQuery, *AuthSessionSelect](ctx, _s.AuthSessionQuery, _s, _s.inters, v) +} + +func (_s *AuthSessionSelect) sqlScan(ctx context.Context, root *AuthSessionQuery, v any) error { + selector := root.sqlQuery(ctx) + aggregation := make([]string, 0, len(_s.fns)) + for _, fn := range _s.fns { + aggregation = append(aggregation, fn(selector)) + } + switch n := len(*_s.selector.flds); { + case n == 0 && len(aggregation) > 0: + selector.Select(aggregation...) + case n != 0 && len(aggregation) > 0: + selector.AppendSelect(aggregation...) + } + rows := &sql.Rows{} + query, args := selector.Query() + if err := _s.driver.Query(ctx, query, args, rows); err != nil { + return err + } + defer rows.Close() + return sql.ScanSlice(rows, v) +} diff --git a/storage/ent/db/authsession_update.go b/storage/ent/db/authsession_update.go new file mode 100644 index 00000000..e91999bd --- /dev/null +++ b/storage/ent/db/authsession_update.go @@ -0,0 +1,330 @@ +// Code generated by ent, DO NOT EDIT. + +package db + +import ( + "context" + "errors" + "fmt" + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/dexidp/dex/storage/ent/db/authsession" + "github.com/dexidp/dex/storage/ent/db/predicate" +) + +// AuthSessionUpdate is the builder for updating AuthSession entities. +type AuthSessionUpdate struct { + config + hooks []Hook + mutation *AuthSessionMutation +} + +// Where appends a list predicates to the AuthSessionUpdate builder. +func (_u *AuthSessionUpdate) Where(ps ...predicate.AuthSession) *AuthSessionUpdate { + _u.mutation.Where(ps...) + return _u +} + +// SetClientStates sets the "client_states" field. +func (_u *AuthSessionUpdate) SetClientStates(v []byte) *AuthSessionUpdate { + _u.mutation.SetClientStates(v) + return _u +} + +// SetCreatedAt sets the "created_at" field. +func (_u *AuthSessionUpdate) SetCreatedAt(v time.Time) *AuthSessionUpdate { + _u.mutation.SetCreatedAt(v) + return _u +} + +// SetNillableCreatedAt sets the "created_at" field if the given value is not nil. +func (_u *AuthSessionUpdate) SetNillableCreatedAt(v *time.Time) *AuthSessionUpdate { + if v != nil { + _u.SetCreatedAt(*v) + } + return _u +} + +// SetLastActivity sets the "last_activity" field. +func (_u *AuthSessionUpdate) SetLastActivity(v time.Time) *AuthSessionUpdate { + _u.mutation.SetLastActivity(v) + return _u +} + +// SetNillableLastActivity sets the "last_activity" field if the given value is not nil. +func (_u *AuthSessionUpdate) SetNillableLastActivity(v *time.Time) *AuthSessionUpdate { + if v != nil { + _u.SetLastActivity(*v) + } + return _u +} + +// SetIPAddress sets the "ip_address" field. +func (_u *AuthSessionUpdate) SetIPAddress(v string) *AuthSessionUpdate { + _u.mutation.SetIPAddress(v) + return _u +} + +// SetNillableIPAddress sets the "ip_address" field if the given value is not nil. +func (_u *AuthSessionUpdate) SetNillableIPAddress(v *string) *AuthSessionUpdate { + if v != nil { + _u.SetIPAddress(*v) + } + return _u +} + +// SetUserAgent sets the "user_agent" field. +func (_u *AuthSessionUpdate) SetUserAgent(v string) *AuthSessionUpdate { + _u.mutation.SetUserAgent(v) + return _u +} + +// SetNillableUserAgent sets the "user_agent" field if the given value is not nil. +func (_u *AuthSessionUpdate) SetNillableUserAgent(v *string) *AuthSessionUpdate { + if v != nil { + _u.SetUserAgent(*v) + } + return _u +} + +// Mutation returns the AuthSessionMutation object of the builder. +func (_u *AuthSessionUpdate) Mutation() *AuthSessionMutation { + return _u.mutation +} + +// Save executes the query and returns the number of nodes affected by the update operation. +func (_u *AuthSessionUpdate) Save(ctx context.Context) (int, error) { + return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks) +} + +// SaveX is like Save, but panics if an error occurs. +func (_u *AuthSessionUpdate) SaveX(ctx context.Context) int { + affected, err := _u.Save(ctx) + if err != nil { + panic(err) + } + return affected +} + +// Exec executes the query. +func (_u *AuthSessionUpdate) Exec(ctx context.Context) error { + _, err := _u.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_u *AuthSessionUpdate) ExecX(ctx context.Context) { + if err := _u.Exec(ctx); err != nil { + panic(err) + } +} + +func (_u *AuthSessionUpdate) sqlSave(ctx context.Context) (_node int, err error) { + _spec := sqlgraph.NewUpdateSpec(authsession.Table, authsession.Columns, sqlgraph.NewFieldSpec(authsession.FieldID, field.TypeString)) + if ps := _u.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if value, ok := _u.mutation.ClientStates(); ok { + _spec.SetField(authsession.FieldClientStates, field.TypeBytes, value) + } + if value, ok := _u.mutation.CreatedAt(); ok { + _spec.SetField(authsession.FieldCreatedAt, field.TypeTime, value) + } + if value, ok := _u.mutation.LastActivity(); ok { + _spec.SetField(authsession.FieldLastActivity, field.TypeTime, value) + } + if value, ok := _u.mutation.IPAddress(); ok { + _spec.SetField(authsession.FieldIPAddress, field.TypeString, value) + } + if value, ok := _u.mutation.UserAgent(); ok { + _spec.SetField(authsession.FieldUserAgent, field.TypeString, value) + } + if _node, err = sqlgraph.UpdateNodes(ctx, _u.driver, _spec); err != nil { + if _, ok := err.(*sqlgraph.NotFoundError); ok { + err = &NotFoundError{authsession.Label} + } else if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return 0, err + } + _u.mutation.done = true + return _node, nil +} + +// AuthSessionUpdateOne is the builder for updating a single AuthSession entity. +type AuthSessionUpdateOne struct { + config + fields []string + hooks []Hook + mutation *AuthSessionMutation +} + +// SetClientStates sets the "client_states" field. +func (_u *AuthSessionUpdateOne) SetClientStates(v []byte) *AuthSessionUpdateOne { + _u.mutation.SetClientStates(v) + return _u +} + +// SetCreatedAt sets the "created_at" field. +func (_u *AuthSessionUpdateOne) SetCreatedAt(v time.Time) *AuthSessionUpdateOne { + _u.mutation.SetCreatedAt(v) + return _u +} + +// SetNillableCreatedAt sets the "created_at" field if the given value is not nil. +func (_u *AuthSessionUpdateOne) SetNillableCreatedAt(v *time.Time) *AuthSessionUpdateOne { + if v != nil { + _u.SetCreatedAt(*v) + } + return _u +} + +// SetLastActivity sets the "last_activity" field. +func (_u *AuthSessionUpdateOne) SetLastActivity(v time.Time) *AuthSessionUpdateOne { + _u.mutation.SetLastActivity(v) + return _u +} + +// SetNillableLastActivity sets the "last_activity" field if the given value is not nil. +func (_u *AuthSessionUpdateOne) SetNillableLastActivity(v *time.Time) *AuthSessionUpdateOne { + if v != nil { + _u.SetLastActivity(*v) + } + return _u +} + +// SetIPAddress sets the "ip_address" field. +func (_u *AuthSessionUpdateOne) SetIPAddress(v string) *AuthSessionUpdateOne { + _u.mutation.SetIPAddress(v) + return _u +} + +// SetNillableIPAddress sets the "ip_address" field if the given value is not nil. +func (_u *AuthSessionUpdateOne) SetNillableIPAddress(v *string) *AuthSessionUpdateOne { + if v != nil { + _u.SetIPAddress(*v) + } + return _u +} + +// SetUserAgent sets the "user_agent" field. +func (_u *AuthSessionUpdateOne) SetUserAgent(v string) *AuthSessionUpdateOne { + _u.mutation.SetUserAgent(v) + return _u +} + +// SetNillableUserAgent sets the "user_agent" field if the given value is not nil. +func (_u *AuthSessionUpdateOne) SetNillableUserAgent(v *string) *AuthSessionUpdateOne { + if v != nil { + _u.SetUserAgent(*v) + } + return _u +} + +// Mutation returns the AuthSessionMutation object of the builder. +func (_u *AuthSessionUpdateOne) Mutation() *AuthSessionMutation { + return _u.mutation +} + +// Where appends a list predicates to the AuthSessionUpdate builder. +func (_u *AuthSessionUpdateOne) Where(ps ...predicate.AuthSession) *AuthSessionUpdateOne { + _u.mutation.Where(ps...) + return _u +} + +// Select allows selecting one or more fields (columns) of the returned entity. +// The default is selecting all fields defined in the entity schema. +func (_u *AuthSessionUpdateOne) Select(field string, fields ...string) *AuthSessionUpdateOne { + _u.fields = append([]string{field}, fields...) + return _u +} + +// Save executes the query and returns the updated AuthSession entity. +func (_u *AuthSessionUpdateOne) Save(ctx context.Context) (*AuthSession, error) { + return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks) +} + +// SaveX is like Save, but panics if an error occurs. +func (_u *AuthSessionUpdateOne) SaveX(ctx context.Context) *AuthSession { + node, err := _u.Save(ctx) + if err != nil { + panic(err) + } + return node +} + +// Exec executes the query on the entity. +func (_u *AuthSessionUpdateOne) Exec(ctx context.Context) error { + _, err := _u.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_u *AuthSessionUpdateOne) ExecX(ctx context.Context) { + if err := _u.Exec(ctx); err != nil { + panic(err) + } +} + +func (_u *AuthSessionUpdateOne) sqlSave(ctx context.Context) (_node *AuthSession, err error) { + _spec := sqlgraph.NewUpdateSpec(authsession.Table, authsession.Columns, sqlgraph.NewFieldSpec(authsession.FieldID, field.TypeString)) + id, ok := _u.mutation.ID() + if !ok { + return nil, &ValidationError{Name: "id", err: errors.New(`db: missing "AuthSession.id" for update`)} + } + _spec.Node.ID.Value = id + if fields := _u.fields; len(fields) > 0 { + _spec.Node.Columns = make([]string, 0, len(fields)) + _spec.Node.Columns = append(_spec.Node.Columns, authsession.FieldID) + for _, f := range fields { + if !authsession.ValidColumn(f) { + return nil, &ValidationError{Name: f, err: fmt.Errorf("db: invalid field %q for query", f)} + } + if f != authsession.FieldID { + _spec.Node.Columns = append(_spec.Node.Columns, f) + } + } + } + if ps := _u.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if value, ok := _u.mutation.ClientStates(); ok { + _spec.SetField(authsession.FieldClientStates, field.TypeBytes, value) + } + if value, ok := _u.mutation.CreatedAt(); ok { + _spec.SetField(authsession.FieldCreatedAt, field.TypeTime, value) + } + if value, ok := _u.mutation.LastActivity(); ok { + _spec.SetField(authsession.FieldLastActivity, field.TypeTime, value) + } + if value, ok := _u.mutation.IPAddress(); ok { + _spec.SetField(authsession.FieldIPAddress, field.TypeString, value) + } + if value, ok := _u.mutation.UserAgent(); ok { + _spec.SetField(authsession.FieldUserAgent, field.TypeString, value) + } + _node = &AuthSession{config: _u.config} + _spec.Assign = _node.assignValues + _spec.ScanValues = _node.scanValues + if err = sqlgraph.UpdateNode(ctx, _u.driver, _spec); err != nil { + if _, ok := err.(*sqlgraph.NotFoundError); ok { + err = &NotFoundError{authsession.Label} + } else if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return nil, err + } + _u.mutation.done = true + return _node, nil +} diff --git a/storage/ent/db/client.go b/storage/ent/db/client.go index efb6da6b..ffd70add 100644 --- a/storage/ent/db/client.go +++ b/storage/ent/db/client.go @@ -16,6 +16,7 @@ import ( "entgo.io/ent/dialect/sql" "github.com/dexidp/dex/storage/ent/db/authcode" "github.com/dexidp/dex/storage/ent/db/authrequest" + "github.com/dexidp/dex/storage/ent/db/authsession" "github.com/dexidp/dex/storage/ent/db/connector" "github.com/dexidp/dex/storage/ent/db/devicerequest" "github.com/dexidp/dex/storage/ent/db/devicetoken" @@ -36,6 +37,8 @@ type Client struct { AuthCode *AuthCodeClient // AuthRequest is the client for interacting with the AuthRequest builders. AuthRequest *AuthRequestClient + // AuthSession is the client for interacting with the AuthSession builders. + AuthSession *AuthSessionClient // Connector is the client for interacting with the Connector builders. Connector *ConnectorClient // DeviceRequest is the client for interacting with the DeviceRequest builders. @@ -67,6 +70,7 @@ func (c *Client) init() { c.Schema = migrate.NewSchema(c.driver) c.AuthCode = NewAuthCodeClient(c.config) c.AuthRequest = NewAuthRequestClient(c.config) + c.AuthSession = NewAuthSessionClient(c.config) c.Connector = NewConnectorClient(c.config) c.DeviceRequest = NewDeviceRequestClient(c.config) c.DeviceToken = NewDeviceTokenClient(c.config) @@ -170,6 +174,7 @@ func (c *Client) Tx(ctx context.Context) (*Tx, error) { config: cfg, AuthCode: NewAuthCodeClient(cfg), AuthRequest: NewAuthRequestClient(cfg), + AuthSession: NewAuthSessionClient(cfg), Connector: NewConnectorClient(cfg), DeviceRequest: NewDeviceRequestClient(cfg), DeviceToken: NewDeviceTokenClient(cfg), @@ -200,6 +205,7 @@ func (c *Client) BeginTx(ctx context.Context, opts *sql.TxOptions) (*Tx, error) config: cfg, AuthCode: NewAuthCodeClient(cfg), AuthRequest: NewAuthRequestClient(cfg), + AuthSession: NewAuthSessionClient(cfg), Connector: NewConnectorClient(cfg), DeviceRequest: NewDeviceRequestClient(cfg), DeviceToken: NewDeviceTokenClient(cfg), @@ -238,8 +244,9 @@ func (c *Client) Close() error { // In order to add hooks to a specific client, call: `client.Node.Use(...)`. func (c *Client) Use(hooks ...Hook) { for _, n := range []interface{ Use(...Hook) }{ - c.AuthCode, c.AuthRequest, c.Connector, c.DeviceRequest, c.DeviceToken, c.Keys, - c.OAuth2Client, c.OfflineSession, c.Password, c.RefreshToken, c.UserIdentity, + c.AuthCode, c.AuthRequest, c.AuthSession, c.Connector, c.DeviceRequest, + c.DeviceToken, c.Keys, c.OAuth2Client, c.OfflineSession, c.Password, + c.RefreshToken, c.UserIdentity, } { n.Use(hooks...) } @@ -249,8 +256,9 @@ func (c *Client) Use(hooks ...Hook) { // In order to add interceptors to a specific client, call: `client.Node.Intercept(...)`. func (c *Client) Intercept(interceptors ...Interceptor) { for _, n := range []interface{ Intercept(...Interceptor) }{ - c.AuthCode, c.AuthRequest, c.Connector, c.DeviceRequest, c.DeviceToken, c.Keys, - c.OAuth2Client, c.OfflineSession, c.Password, c.RefreshToken, c.UserIdentity, + c.AuthCode, c.AuthRequest, c.AuthSession, c.Connector, c.DeviceRequest, + c.DeviceToken, c.Keys, c.OAuth2Client, c.OfflineSession, c.Password, + c.RefreshToken, c.UserIdentity, } { n.Intercept(interceptors...) } @@ -263,6 +271,8 @@ func (c *Client) Mutate(ctx context.Context, m Mutation) (Value, error) { return c.AuthCode.mutate(ctx, m) case *AuthRequestMutation: return c.AuthRequest.mutate(ctx, m) + case *AuthSessionMutation: + return c.AuthSession.mutate(ctx, m) case *ConnectorMutation: return c.Connector.mutate(ctx, m) case *DeviceRequestMutation: @@ -552,6 +562,139 @@ func (c *AuthRequestClient) mutate(ctx context.Context, m *AuthRequestMutation) } } +// AuthSessionClient is a client for the AuthSession schema. +type AuthSessionClient struct { + config +} + +// NewAuthSessionClient returns a client for the AuthSession from the given config. +func NewAuthSessionClient(c config) *AuthSessionClient { + return &AuthSessionClient{config: c} +} + +// Use adds a list of mutation hooks to the hooks stack. +// A call to `Use(f, g, h)` equals to `authsession.Hooks(f(g(h())))`. +func (c *AuthSessionClient) Use(hooks ...Hook) { + c.hooks.AuthSession = append(c.hooks.AuthSession, hooks...) +} + +// Intercept adds a list of query interceptors to the interceptors stack. +// A call to `Intercept(f, g, h)` equals to `authsession.Intercept(f(g(h())))`. +func (c *AuthSessionClient) Intercept(interceptors ...Interceptor) { + c.inters.AuthSession = append(c.inters.AuthSession, interceptors...) +} + +// Create returns a builder for creating a AuthSession entity. +func (c *AuthSessionClient) Create() *AuthSessionCreate { + mutation := newAuthSessionMutation(c.config, OpCreate) + return &AuthSessionCreate{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// CreateBulk returns a builder for creating a bulk of AuthSession entities. +func (c *AuthSessionClient) CreateBulk(builders ...*AuthSessionCreate) *AuthSessionCreateBulk { + return &AuthSessionCreateBulk{config: c.config, builders: builders} +} + +// MapCreateBulk creates a bulk creation builder from the given slice. For each item in the slice, the function creates +// a builder and applies setFunc on it. +func (c *AuthSessionClient) MapCreateBulk(slice any, setFunc func(*AuthSessionCreate, int)) *AuthSessionCreateBulk { + rv := reflect.ValueOf(slice) + if rv.Kind() != reflect.Slice { + return &AuthSessionCreateBulk{err: fmt.Errorf("calling to AuthSessionClient.MapCreateBulk with wrong type %T, need slice", slice)} + } + builders := make([]*AuthSessionCreate, rv.Len()) + for i := 0; i < rv.Len(); i++ { + builders[i] = c.Create() + setFunc(builders[i], i) + } + return &AuthSessionCreateBulk{config: c.config, builders: builders} +} + +// Update returns an update builder for AuthSession. +func (c *AuthSessionClient) Update() *AuthSessionUpdate { + mutation := newAuthSessionMutation(c.config, OpUpdate) + return &AuthSessionUpdate{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// UpdateOne returns an update builder for the given entity. +func (c *AuthSessionClient) UpdateOne(_m *AuthSession) *AuthSessionUpdateOne { + mutation := newAuthSessionMutation(c.config, OpUpdateOne, withAuthSession(_m)) + return &AuthSessionUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// UpdateOneID returns an update builder for the given id. +func (c *AuthSessionClient) UpdateOneID(id string) *AuthSessionUpdateOne { + mutation := newAuthSessionMutation(c.config, OpUpdateOne, withAuthSessionID(id)) + return &AuthSessionUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// Delete returns a delete builder for AuthSession. +func (c *AuthSessionClient) Delete() *AuthSessionDelete { + mutation := newAuthSessionMutation(c.config, OpDelete) + return &AuthSessionDelete{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// DeleteOne returns a builder for deleting the given entity. +func (c *AuthSessionClient) DeleteOne(_m *AuthSession) *AuthSessionDeleteOne { + return c.DeleteOneID(_m.ID) +} + +// DeleteOneID returns a builder for deleting the given entity by its id. +func (c *AuthSessionClient) DeleteOneID(id string) *AuthSessionDeleteOne { + builder := c.Delete().Where(authsession.ID(id)) + builder.mutation.id = &id + builder.mutation.op = OpDeleteOne + return &AuthSessionDeleteOne{builder} +} + +// Query returns a query builder for AuthSession. +func (c *AuthSessionClient) Query() *AuthSessionQuery { + return &AuthSessionQuery{ + config: c.config, + ctx: &QueryContext{Type: TypeAuthSession}, + inters: c.Interceptors(), + } +} + +// Get returns a AuthSession entity by its id. +func (c *AuthSessionClient) Get(ctx context.Context, id string) (*AuthSession, error) { + return c.Query().Where(authsession.ID(id)).Only(ctx) +} + +// GetX is like Get, but panics if an error occurs. +func (c *AuthSessionClient) GetX(ctx context.Context, id string) *AuthSession { + obj, err := c.Get(ctx, id) + if err != nil { + panic(err) + } + return obj +} + +// Hooks returns the client hooks. +func (c *AuthSessionClient) Hooks() []Hook { + return c.hooks.AuthSession +} + +// Interceptors returns the client interceptors. +func (c *AuthSessionClient) Interceptors() []Interceptor { + return c.inters.AuthSession +} + +func (c *AuthSessionClient) mutate(ctx context.Context, m *AuthSessionMutation) (Value, error) { + switch m.Op() { + case OpCreate: + return (&AuthSessionCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdate: + return (&AuthSessionUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdateOne: + return (&AuthSessionUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpDelete, OpDeleteOne: + return (&AuthSessionDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx) + default: + return nil, fmt.Errorf("db: unknown AuthSession mutation op: %q", m.Op()) + } +} + // ConnectorClient is a client for the Connector schema. type ConnectorClient struct { config @@ -1752,11 +1895,11 @@ func (c *UserIdentityClient) mutate(ctx context.Context, m *UserIdentityMutation // hooks and interceptors per client, for fast access. type ( hooks struct { - AuthCode, AuthRequest, Connector, DeviceRequest, DeviceToken, Keys, + AuthCode, AuthRequest, AuthSession, Connector, DeviceRequest, DeviceToken, Keys, OAuth2Client, OfflineSession, Password, RefreshToken, UserIdentity []ent.Hook } inters struct { - AuthCode, AuthRequest, Connector, DeviceRequest, DeviceToken, Keys, + AuthCode, AuthRequest, AuthSession, Connector, DeviceRequest, DeviceToken, Keys, OAuth2Client, OfflineSession, Password, RefreshToken, UserIdentity []ent.Interceptor } diff --git a/storage/ent/db/ent.go b/storage/ent/db/ent.go index fe965a79..f73e0090 100644 --- a/storage/ent/db/ent.go +++ b/storage/ent/db/ent.go @@ -14,6 +14,7 @@ import ( "entgo.io/ent/dialect/sql/sqlgraph" "github.com/dexidp/dex/storage/ent/db/authcode" "github.com/dexidp/dex/storage/ent/db/authrequest" + "github.com/dexidp/dex/storage/ent/db/authsession" "github.com/dexidp/dex/storage/ent/db/connector" "github.com/dexidp/dex/storage/ent/db/devicerequest" "github.com/dexidp/dex/storage/ent/db/devicetoken" @@ -85,6 +86,7 @@ func checkColumn(t, c string) error { columnCheck = sql.NewColumnCheck(map[string]func(string) bool{ authcode.Table: authcode.ValidColumn, authrequest.Table: authrequest.ValidColumn, + authsession.Table: authsession.ValidColumn, connector.Table: connector.ValidColumn, devicerequest.Table: devicerequest.ValidColumn, devicetoken.Table: devicetoken.ValidColumn, diff --git a/storage/ent/db/hook/hook.go b/storage/ent/db/hook/hook.go index 008dc7ee..1c8780d7 100644 --- a/storage/ent/db/hook/hook.go +++ b/storage/ent/db/hook/hook.go @@ -33,6 +33,18 @@ func (f AuthRequestFunc) Mutate(ctx context.Context, m db.Mutation) (db.Value, e return nil, fmt.Errorf("unexpected mutation type %T. expect *db.AuthRequestMutation", m) } +// The AuthSessionFunc type is an adapter to allow the use of ordinary +// function as AuthSession mutator. +type AuthSessionFunc func(context.Context, *db.AuthSessionMutation) (db.Value, error) + +// Mutate calls f(ctx, m). +func (f AuthSessionFunc) Mutate(ctx context.Context, m db.Mutation) (db.Value, error) { + if mv, ok := m.(*db.AuthSessionMutation); ok { + return f(ctx, mv) + } + return nil, fmt.Errorf("unexpected mutation type %T. expect *db.AuthSessionMutation", m) +} + // The ConnectorFunc type is an adapter to allow the use of ordinary // function as Connector mutator. type ConnectorFunc func(context.Context, *db.ConnectorMutation) (db.Value, error) diff --git a/storage/ent/db/migrate/schema.go b/storage/ent/db/migrate/schema.go index 1b9b61c9..786598c0 100644 --- a/storage/ent/db/migrate/schema.go +++ b/storage/ent/db/migrate/schema.go @@ -63,6 +63,21 @@ var ( Columns: AuthRequestsColumns, PrimaryKey: []*schema.Column{AuthRequestsColumns[0]}, } + // AuthSessionsColumns holds the columns for the "auth_sessions" table. + AuthSessionsColumns = []*schema.Column{ + {Name: "id", Type: field.TypeString, Unique: true, Size: 2147483647, SchemaType: map[string]string{"mysql": "varchar(384)", "postgres": "text", "sqlite3": "text"}}, + {Name: "client_states", Type: field.TypeBytes}, + {Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"mysql": "datetime(3)", "postgres": "timestamptz", "sqlite3": "timestamp"}}, + {Name: "last_activity", Type: field.TypeTime, SchemaType: map[string]string{"mysql": "datetime(3)", "postgres": "timestamptz", "sqlite3": "timestamp"}}, + {Name: "ip_address", Type: field.TypeString, Size: 2147483647, Default: "", SchemaType: map[string]string{"mysql": "varchar(384)", "postgres": "text", "sqlite3": "text"}}, + {Name: "user_agent", Type: field.TypeString, Size: 2147483647, Default: "", SchemaType: map[string]string{"mysql": "varchar(384)", "postgres": "text", "sqlite3": "text"}}, + } + // AuthSessionsTable holds the schema information for the "auth_sessions" table. + AuthSessionsTable = &schema.Table{ + Name: "auth_sessions", + Columns: AuthSessionsColumns, + PrimaryKey: []*schema.Column{AuthSessionsColumns[0]}, + } // ConnectorsColumns holds the columns for the "connectors" table. ConnectorsColumns = []*schema.Column{ {Name: "id", Type: field.TypeString, Unique: true, Size: 100, SchemaType: map[string]string{"mysql": "varchar(384)", "postgres": "text", "sqlite3": "text"}}, @@ -226,6 +241,7 @@ var ( Tables = []*schema.Table{ AuthCodesTable, AuthRequestsTable, + AuthSessionsTable, ConnectorsTable, DeviceRequestsTable, DeviceTokensTable, diff --git a/storage/ent/db/mutation.go b/storage/ent/db/mutation.go index 8625471a..748022c9 100644 --- a/storage/ent/db/mutation.go +++ b/storage/ent/db/mutation.go @@ -14,6 +14,7 @@ import ( "github.com/dexidp/dex/storage" "github.com/dexidp/dex/storage/ent/db/authcode" "github.com/dexidp/dex/storage/ent/db/authrequest" + "github.com/dexidp/dex/storage/ent/db/authsession" "github.com/dexidp/dex/storage/ent/db/connector" "github.com/dexidp/dex/storage/ent/db/devicerequest" "github.com/dexidp/dex/storage/ent/db/devicetoken" @@ -38,6 +39,7 @@ const ( // Node types. TypeAuthCode = "AuthCode" TypeAuthRequest = "AuthRequest" + TypeAuthSession = "AuthSession" TypeConnector = "Connector" TypeDeviceRequest = "DeviceRequest" TypeDeviceToken = "DeviceToken" @@ -2719,6 +2721,554 @@ func (m *AuthRequestMutation) ResetEdge(name string) error { return fmt.Errorf("unknown AuthRequest edge %s", name) } +// AuthSessionMutation represents an operation that mutates the AuthSession nodes in the graph. +type AuthSessionMutation struct { + config + op Op + typ string + id *string + client_states *[]byte + created_at *time.Time + last_activity *time.Time + ip_address *string + user_agent *string + clearedFields map[string]struct{} + done bool + oldValue func(context.Context) (*AuthSession, error) + predicates []predicate.AuthSession +} + +var _ ent.Mutation = (*AuthSessionMutation)(nil) + +// authsessionOption allows management of the mutation configuration using functional options. +type authsessionOption func(*AuthSessionMutation) + +// newAuthSessionMutation creates new mutation for the AuthSession entity. +func newAuthSessionMutation(c config, op Op, opts ...authsessionOption) *AuthSessionMutation { + m := &AuthSessionMutation{ + config: c, + op: op, + typ: TypeAuthSession, + clearedFields: make(map[string]struct{}), + } + for _, opt := range opts { + opt(m) + } + return m +} + +// withAuthSessionID sets the ID field of the mutation. +func withAuthSessionID(id string) authsessionOption { + return func(m *AuthSessionMutation) { + var ( + err error + once sync.Once + value *AuthSession + ) + m.oldValue = func(ctx context.Context) (*AuthSession, error) { + once.Do(func() { + if m.done { + err = errors.New("querying old values post mutation is not allowed") + } else { + value, err = m.Client().AuthSession.Get(ctx, id) + } + }) + return value, err + } + m.id = &id + } +} + +// withAuthSession sets the old AuthSession of the mutation. +func withAuthSession(node *AuthSession) authsessionOption { + return func(m *AuthSessionMutation) { + m.oldValue = func(context.Context) (*AuthSession, error) { + return node, nil + } + m.id = &node.ID + } +} + +// Client returns a new `ent.Client` from the mutation. If the mutation was +// executed in a transaction (ent.Tx), a transactional client is returned. +func (m AuthSessionMutation) Client() *Client { + client := &Client{config: m.config} + client.init() + return client +} + +// Tx returns an `ent.Tx` for mutations that were executed in transactions; +// it returns an error otherwise. +func (m AuthSessionMutation) Tx() (*Tx, error) { + if _, ok := m.driver.(*txDriver); !ok { + return nil, errors.New("db: mutation is not running in a transaction") + } + tx := &Tx{config: m.config} + tx.init() + return tx, nil +} + +// SetID sets the value of the id field. Note that this +// operation is only accepted on creation of AuthSession entities. +func (m *AuthSessionMutation) SetID(id string) { + m.id = &id +} + +// ID returns the ID value in the mutation. Note that the ID is only available +// if it was provided to the builder or after it was returned from the database. +func (m *AuthSessionMutation) ID() (id string, exists bool) { + if m.id == nil { + return + } + return *m.id, true +} + +// IDs queries the database and returns the entity ids that match the mutation's predicate. +// That means, if the mutation is applied within a transaction with an isolation level such +// as sql.LevelSerializable, the returned ids match the ids of the rows that will be updated +// or updated by the mutation. +func (m *AuthSessionMutation) IDs(ctx context.Context) ([]string, error) { + switch { + case m.op.Is(OpUpdateOne | OpDeleteOne): + id, exists := m.ID() + if exists { + return []string{id}, nil + } + fallthrough + case m.op.Is(OpUpdate | OpDelete): + return m.Client().AuthSession.Query().Where(m.predicates...).IDs(ctx) + default: + return nil, fmt.Errorf("IDs is not allowed on %s operations", m.op) + } +} + +// SetClientStates sets the "client_states" field. +func (m *AuthSessionMutation) SetClientStates(b []byte) { + m.client_states = &b +} + +// ClientStates returns the value of the "client_states" field in the mutation. +func (m *AuthSessionMutation) ClientStates() (r []byte, exists bool) { + v := m.client_states + if v == nil { + return + } + return *v, true +} + +// OldClientStates returns the old "client_states" field's value of the AuthSession entity. +// If the AuthSession object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *AuthSessionMutation) OldClientStates(ctx context.Context) (v []byte, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldClientStates is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldClientStates requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldClientStates: %w", err) + } + return oldValue.ClientStates, nil +} + +// ResetClientStates resets all changes to the "client_states" field. +func (m *AuthSessionMutation) ResetClientStates() { + m.client_states = nil +} + +// SetCreatedAt sets the "created_at" field. +func (m *AuthSessionMutation) SetCreatedAt(t time.Time) { + m.created_at = &t +} + +// CreatedAt returns the value of the "created_at" field in the mutation. +func (m *AuthSessionMutation) CreatedAt() (r time.Time, exists bool) { + v := m.created_at + if v == nil { + return + } + return *v, true +} + +// OldCreatedAt returns the old "created_at" field's value of the AuthSession entity. +// If the AuthSession object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *AuthSessionMutation) OldCreatedAt(ctx context.Context) (v time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldCreatedAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldCreatedAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldCreatedAt: %w", err) + } + return oldValue.CreatedAt, nil +} + +// ResetCreatedAt resets all changes to the "created_at" field. +func (m *AuthSessionMutation) ResetCreatedAt() { + m.created_at = nil +} + +// SetLastActivity sets the "last_activity" field. +func (m *AuthSessionMutation) SetLastActivity(t time.Time) { + m.last_activity = &t +} + +// LastActivity returns the value of the "last_activity" field in the mutation. +func (m *AuthSessionMutation) LastActivity() (r time.Time, exists bool) { + v := m.last_activity + if v == nil { + return + } + return *v, true +} + +// OldLastActivity returns the old "last_activity" field's value of the AuthSession entity. +// If the AuthSession object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *AuthSessionMutation) OldLastActivity(ctx context.Context) (v time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldLastActivity is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldLastActivity requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldLastActivity: %w", err) + } + return oldValue.LastActivity, nil +} + +// ResetLastActivity resets all changes to the "last_activity" field. +func (m *AuthSessionMutation) ResetLastActivity() { + m.last_activity = nil +} + +// SetIPAddress sets the "ip_address" field. +func (m *AuthSessionMutation) SetIPAddress(s string) { + m.ip_address = &s +} + +// IPAddress returns the value of the "ip_address" field in the mutation. +func (m *AuthSessionMutation) IPAddress() (r string, exists bool) { + v := m.ip_address + if v == nil { + return + } + return *v, true +} + +// OldIPAddress returns the old "ip_address" field's value of the AuthSession entity. +// If the AuthSession object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *AuthSessionMutation) OldIPAddress(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldIPAddress is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldIPAddress requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldIPAddress: %w", err) + } + return oldValue.IPAddress, nil +} + +// ResetIPAddress resets all changes to the "ip_address" field. +func (m *AuthSessionMutation) ResetIPAddress() { + m.ip_address = nil +} + +// SetUserAgent sets the "user_agent" field. +func (m *AuthSessionMutation) SetUserAgent(s string) { + m.user_agent = &s +} + +// UserAgent returns the value of the "user_agent" field in the mutation. +func (m *AuthSessionMutation) UserAgent() (r string, exists bool) { + v := m.user_agent + if v == nil { + return + } + return *v, true +} + +// OldUserAgent returns the old "user_agent" field's value of the AuthSession entity. +// If the AuthSession object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *AuthSessionMutation) OldUserAgent(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldUserAgent is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldUserAgent requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldUserAgent: %w", err) + } + return oldValue.UserAgent, nil +} + +// ResetUserAgent resets all changes to the "user_agent" field. +func (m *AuthSessionMutation) ResetUserAgent() { + m.user_agent = nil +} + +// Where appends a list predicates to the AuthSessionMutation builder. +func (m *AuthSessionMutation) Where(ps ...predicate.AuthSession) { + m.predicates = append(m.predicates, ps...) +} + +// WhereP appends storage-level predicates to the AuthSessionMutation builder. Using this method, +// users can use type-assertion to append predicates that do not depend on any generated package. +func (m *AuthSessionMutation) WhereP(ps ...func(*sql.Selector)) { + p := make([]predicate.AuthSession, len(ps)) + for i := range ps { + p[i] = ps[i] + } + m.Where(p...) +} + +// Op returns the operation name. +func (m *AuthSessionMutation) Op() Op { + return m.op +} + +// SetOp allows setting the mutation operation. +func (m *AuthSessionMutation) SetOp(op Op) { + m.op = op +} + +// Type returns the node type of this mutation (AuthSession). +func (m *AuthSessionMutation) Type() string { + return m.typ +} + +// Fields returns all fields that were changed during this mutation. Note that in +// order to get all numeric fields that were incremented/decremented, call +// AddedFields(). +func (m *AuthSessionMutation) Fields() []string { + fields := make([]string, 0, 5) + if m.client_states != nil { + fields = append(fields, authsession.FieldClientStates) + } + if m.created_at != nil { + fields = append(fields, authsession.FieldCreatedAt) + } + if m.last_activity != nil { + fields = append(fields, authsession.FieldLastActivity) + } + if m.ip_address != nil { + fields = append(fields, authsession.FieldIPAddress) + } + if m.user_agent != nil { + fields = append(fields, authsession.FieldUserAgent) + } + return fields +} + +// Field returns the value of a field with the given name. The second boolean +// return value indicates that this field was not set, or was not defined in the +// schema. +func (m *AuthSessionMutation) Field(name string) (ent.Value, bool) { + switch name { + case authsession.FieldClientStates: + return m.ClientStates() + case authsession.FieldCreatedAt: + return m.CreatedAt() + case authsession.FieldLastActivity: + return m.LastActivity() + case authsession.FieldIPAddress: + return m.IPAddress() + case authsession.FieldUserAgent: + return m.UserAgent() + } + return nil, false +} + +// OldField returns the old value of the field from the database. An error is +// returned if the mutation operation is not UpdateOne, or the query to the +// database failed. +func (m *AuthSessionMutation) OldField(ctx context.Context, name string) (ent.Value, error) { + switch name { + case authsession.FieldClientStates: + return m.OldClientStates(ctx) + case authsession.FieldCreatedAt: + return m.OldCreatedAt(ctx) + case authsession.FieldLastActivity: + return m.OldLastActivity(ctx) + case authsession.FieldIPAddress: + return m.OldIPAddress(ctx) + case authsession.FieldUserAgent: + return m.OldUserAgent(ctx) + } + return nil, fmt.Errorf("unknown AuthSession field %s", name) +} + +// SetField sets the value of a field with the given name. It returns an error if +// the field is not defined in the schema, or if the type mismatched the field +// type. +func (m *AuthSessionMutation) SetField(name string, value ent.Value) error { + switch name { + case authsession.FieldClientStates: + v, ok := value.([]byte) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetClientStates(v) + return nil + case authsession.FieldCreatedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetCreatedAt(v) + return nil + case authsession.FieldLastActivity: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetLastActivity(v) + return nil + case authsession.FieldIPAddress: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetIPAddress(v) + return nil + case authsession.FieldUserAgent: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetUserAgent(v) + return nil + } + return fmt.Errorf("unknown AuthSession field %s", name) +} + +// AddedFields returns all numeric fields that were incremented/decremented during +// this mutation. +func (m *AuthSessionMutation) AddedFields() []string { + return nil +} + +// AddedField returns the numeric value that was incremented/decremented on a field +// with the given name. The second boolean return value indicates that this field +// was not set, or was not defined in the schema. +func (m *AuthSessionMutation) AddedField(name string) (ent.Value, bool) { + return nil, false +} + +// AddField adds the value to the field with the given name. It returns an error if +// the field is not defined in the schema, or if the type mismatched the field +// type. +func (m *AuthSessionMutation) AddField(name string, value ent.Value) error { + switch name { + } + return fmt.Errorf("unknown AuthSession numeric field %s", name) +} + +// ClearedFields returns all nullable fields that were cleared during this +// mutation. +func (m *AuthSessionMutation) ClearedFields() []string { + return nil +} + +// FieldCleared returns a boolean indicating if a field with the given name was +// cleared in this mutation. +func (m *AuthSessionMutation) FieldCleared(name string) bool { + _, ok := m.clearedFields[name] + return ok +} + +// ClearField clears the value of the field with the given name. It returns an +// error if the field is not defined in the schema. +func (m *AuthSessionMutation) ClearField(name string) error { + return fmt.Errorf("unknown AuthSession nullable field %s", name) +} + +// ResetField resets all changes in the mutation for the field with the given name. +// It returns an error if the field is not defined in the schema. +func (m *AuthSessionMutation) ResetField(name string) error { + switch name { + case authsession.FieldClientStates: + m.ResetClientStates() + return nil + case authsession.FieldCreatedAt: + m.ResetCreatedAt() + return nil + case authsession.FieldLastActivity: + m.ResetLastActivity() + return nil + case authsession.FieldIPAddress: + m.ResetIPAddress() + return nil + case authsession.FieldUserAgent: + m.ResetUserAgent() + return nil + } + return fmt.Errorf("unknown AuthSession field %s", name) +} + +// AddedEdges returns all edge names that were set/added in this mutation. +func (m *AuthSessionMutation) AddedEdges() []string { + edges := make([]string, 0, 0) + return edges +} + +// AddedIDs returns all IDs (to other nodes) that were added for the given edge +// name in this mutation. +func (m *AuthSessionMutation) AddedIDs(name string) []ent.Value { + return nil +} + +// RemovedEdges returns all edge names that were removed in this mutation. +func (m *AuthSessionMutation) RemovedEdges() []string { + edges := make([]string, 0, 0) + return edges +} + +// RemovedIDs returns all IDs (to other nodes) that were removed for the edge with +// the given name in this mutation. +func (m *AuthSessionMutation) RemovedIDs(name string) []ent.Value { + return nil +} + +// ClearedEdges returns all edge names that were cleared in this mutation. +func (m *AuthSessionMutation) ClearedEdges() []string { + edges := make([]string, 0, 0) + return edges +} + +// EdgeCleared returns a boolean which indicates if the edge with the given name +// was cleared in this mutation. +func (m *AuthSessionMutation) EdgeCleared(name string) bool { + return false +} + +// ClearEdge clears the value of the edge with the given name. It returns an error +// if that edge is not defined in the schema. +func (m *AuthSessionMutation) ClearEdge(name string) error { + return fmt.Errorf("unknown AuthSession unique edge %s", name) +} + +// ResetEdge resets all changes to the edge with the given name in this mutation. +// It returns an error if the edge is not defined in the schema. +func (m *AuthSessionMutation) ResetEdge(name string) error { + return fmt.Errorf("unknown AuthSession edge %s", name) +} + // ConnectorMutation represents an operation that mutates the Connector nodes in the graph. type ConnectorMutation struct { config diff --git a/storage/ent/db/predicate/predicate.go b/storage/ent/db/predicate/predicate.go index 90ddffcb..9e977a2a 100644 --- a/storage/ent/db/predicate/predicate.go +++ b/storage/ent/db/predicate/predicate.go @@ -12,6 +12,9 @@ type AuthCode func(*sql.Selector) // AuthRequest is the predicate function for authrequest builders. type AuthRequest func(*sql.Selector) +// AuthSession is the predicate function for authsession builders. +type AuthSession func(*sql.Selector) + // Connector is the predicate function for connector builders. type Connector func(*sql.Selector) diff --git a/storage/ent/db/runtime.go b/storage/ent/db/runtime.go index b6f34aef..98c12ecc 100644 --- a/storage/ent/db/runtime.go +++ b/storage/ent/db/runtime.go @@ -7,6 +7,7 @@ import ( "github.com/dexidp/dex/storage/ent/db/authcode" "github.com/dexidp/dex/storage/ent/db/authrequest" + "github.com/dexidp/dex/storage/ent/db/authsession" "github.com/dexidp/dex/storage/ent/db/connector" "github.com/dexidp/dex/storage/ent/db/devicerequest" "github.com/dexidp/dex/storage/ent/db/devicetoken" @@ -87,6 +88,20 @@ func init() { authrequestDescID := authrequestFields[0].Descriptor() // authrequest.IDValidator is a validator for the "id" field. It is called by the builders before save. authrequest.IDValidator = authrequestDescID.Validators[0].(func(string) error) + authsessionFields := schema.AuthSession{}.Fields() + _ = authsessionFields + // authsessionDescIPAddress is the schema descriptor for ip_address field. + authsessionDescIPAddress := authsessionFields[4].Descriptor() + // authsession.DefaultIPAddress holds the default value on creation for the ip_address field. + authsession.DefaultIPAddress = authsessionDescIPAddress.Default.(string) + // authsessionDescUserAgent is the schema descriptor for user_agent field. + authsessionDescUserAgent := authsessionFields[5].Descriptor() + // authsession.DefaultUserAgent holds the default value on creation for the user_agent field. + authsession.DefaultUserAgent = authsessionDescUserAgent.Default.(string) + // authsessionDescID is the schema descriptor for id field. + authsessionDescID := authsessionFields[0].Descriptor() + // authsession.IDValidator is a validator for the "id" field. It is called by the builders before save. + authsession.IDValidator = authsessionDescID.Validators[0].(func(string) error) connectorFields := schema.Connector{}.Fields() _ = connectorFields // connectorDescType is the schema descriptor for type field. diff --git a/storage/ent/db/tx.go b/storage/ent/db/tx.go index 77be8a8f..94f27935 100644 --- a/storage/ent/db/tx.go +++ b/storage/ent/db/tx.go @@ -16,6 +16,8 @@ type Tx struct { AuthCode *AuthCodeClient // AuthRequest is the client for interacting with the AuthRequest builders. AuthRequest *AuthRequestClient + // AuthSession is the client for interacting with the AuthSession builders. + AuthSession *AuthSessionClient // Connector is the client for interacting with the Connector builders. Connector *ConnectorClient // DeviceRequest is the client for interacting with the DeviceRequest builders. @@ -167,6 +169,7 @@ func (tx *Tx) Client() *Client { func (tx *Tx) init() { tx.AuthCode = NewAuthCodeClient(tx.config) tx.AuthRequest = NewAuthRequestClient(tx.config) + tx.AuthSession = NewAuthSessionClient(tx.config) tx.Connector = NewConnectorClient(tx.config) tx.DeviceRequest = NewDeviceRequestClient(tx.config) tx.DeviceToken = NewDeviceTokenClient(tx.config) diff --git a/storage/ent/schema/authsession.go b/storage/ent/schema/authsession.go new file mode 100644 index 00000000..f0e57110 --- /dev/null +++ b/storage/ent/schema/authsession.go @@ -0,0 +1,37 @@ +package schema + +import ( + "entgo.io/ent" + "entgo.io/ent/schema/field" +) + +// AuthSession holds the schema definition for the AuthSession entity. +type AuthSession struct { + ent.Schema +} + +// Fields of the AuthSession. +func (AuthSession) Fields() []ent.Field { + return []ent.Field{ + field.Text("id"). + SchemaType(textSchema). + NotEmpty(). + Unique(), + field.Bytes("client_states"), + field.Time("created_at"). + SchemaType(timeSchema), + field.Time("last_activity"). + SchemaType(timeSchema), + field.Text("ip_address"). + SchemaType(textSchema). + Default(""), + field.Text("user_agent"). + SchemaType(textSchema). + Default(""), + } +} + +// Edges of the AuthSession. +func (AuthSession) Edges() []ent.Edge { + return []ent.Edge{} +} diff --git a/storage/etcd/etcd.go b/storage/etcd/etcd.go index d2248ea2..c05f5631 100644 --- a/storage/etcd/etcd.go +++ b/storage/etcd/etcd.go @@ -25,6 +25,7 @@ const ( deviceRequestPrefix = "device_req/" deviceTokenPrefix = "device_token/" userIdentityPrefix = "user_identity/" + authSessionPrefix = "auth_session/" // defaultStorageTimeout will be applied to all storage's operations. defaultStorageTimeout = 5 * time.Second @@ -422,6 +423,61 @@ func (c *conn) ListUserIdentities(ctx context.Context) (identities []storage.Use return identities, nil } +func (c *conn) CreateAuthSession(ctx context.Context, s storage.AuthSession) error { + return c.txnCreate(ctx, keyAuthSession(s.ID), fromStorageAuthSession(s)) +} + +func (c *conn) GetAuthSession(ctx context.Context, sessionID string) (storage.AuthSession, error) { + ctx, cancel := context.WithTimeout(ctx, defaultStorageTimeout) + defer cancel() + var s AuthSession + if err := c.getKey(ctx, keyAuthSession(sessionID), &s); err != nil { + return storage.AuthSession{}, err + } + return toStorageAuthSession(s), nil +} + +func (c *conn) UpdateAuthSession(ctx context.Context, sessionID string, updater func(s storage.AuthSession) (storage.AuthSession, error)) error { + ctx, cancel := context.WithTimeout(ctx, defaultStorageTimeout) + defer cancel() + return c.txnUpdate(ctx, keyAuthSession(sessionID), func(currentValue []byte) ([]byte, error) { + var current AuthSession + if len(currentValue) > 0 { + if err := json.Unmarshal(currentValue, ¤t); err != nil { + return nil, err + } + } + updated, err := updater(toStorageAuthSession(current)) + if err != nil { + return nil, err + } + return json.Marshal(fromStorageAuthSession(updated)) + }) +} + +func (c *conn) ListAuthSessions(ctx context.Context) (sessions []storage.AuthSession, err error) { + ctx, cancel := context.WithTimeout(ctx, defaultStorageTimeout) + defer cancel() + res, err := c.db.Get(ctx, authSessionPrefix, clientv3.WithPrefix()) + if err != nil { + return sessions, err + } + for _, v := range res.Kvs { + var s AuthSession + if err = json.Unmarshal(v.Value, &s); err != nil { + return sessions, err + } + sessions = append(sessions, toStorageAuthSession(s)) + } + return sessions, nil +} + +func (c *conn) DeleteAuthSession(ctx context.Context, sessionID string) error { + ctx, cancel := context.WithTimeout(ctx, defaultStorageTimeout) + defer cancel() + return c.deleteKey(ctx, keyAuthSession(sessionID)) +} + func (c *conn) CreateConnector(ctx context.Context, connector storage.Connector) error { return c.txnCreate(ctx, keyID(connectorPrefix, connector.ID), connector) } @@ -617,6 +673,10 @@ func keyUserIdentity(userID, connectorID string) string { return userIdentityPrefix + strings.ToLower(userID+"|"+connectorID) } +func keyAuthSession(sessionID string) string { + return strings.ToLower(authSessionPrefix + sessionID) +} + func (c *conn) CreateDeviceRequest(ctx context.Context, d storage.DeviceRequest) error { return c.txnCreate(ctx, keyID(deviceRequestPrefix, d.UserCode), fromStorageDeviceRequest(d)) } diff --git a/storage/etcd/types.go b/storage/etcd/types.go index ea4d216c..3624de32 100644 --- a/storage/etcd/types.go +++ b/storage/etcd/types.go @@ -296,6 +296,42 @@ func toStorageUserIdentity(u UserIdentity) storage.UserIdentity { return s } +// AuthSession is a mirrored struct from storage with JSON struct tags. +type AuthSession struct { + ID string `json:"id,omitempty"` + ClientStates map[string]*storage.ClientAuthState `json:"client_states,omitempty"` + CreatedAt time.Time `json:"created_at"` + LastActivity time.Time `json:"last_activity"` + IPAddress string `json:"ip_address,omitempty"` + UserAgent string `json:"user_agent,omitempty"` +} + +func fromStorageAuthSession(s storage.AuthSession) AuthSession { + return AuthSession{ + ID: s.ID, + ClientStates: s.ClientStates, + CreatedAt: s.CreatedAt, + LastActivity: s.LastActivity, + IPAddress: s.IPAddress, + UserAgent: s.UserAgent, + } +} + +func toStorageAuthSession(s AuthSession) storage.AuthSession { + result := storage.AuthSession{ + ID: s.ID, + ClientStates: s.ClientStates, + CreatedAt: s.CreatedAt, + LastActivity: s.LastActivity, + IPAddress: s.IPAddress, + UserAgent: s.UserAgent, + } + if result.ClientStates == nil { + result.ClientStates = make(map[string]*storage.ClientAuthState) + } + return result +} + // DeviceRequest is a mirrored struct from storage with JSON struct tags type DeviceRequest struct { UserCode string `json:"user_code"` diff --git a/storage/kubernetes/storage.go b/storage/kubernetes/storage.go index fd756845..55ea7455 100644 --- a/storage/kubernetes/storage.go +++ b/storage/kubernetes/storage.go @@ -26,6 +26,7 @@ const ( kindDeviceRequest = "DeviceRequest" kindDeviceToken = "DeviceToken" kindUserIdentity = "UserIdentity" + kindAuthSession = "AuthSession" ) const ( @@ -40,6 +41,7 @@ const ( resourceDeviceRequest = "devicerequests" resourceDeviceToken = "devicetokens" resourceUserIdentity = "useridentities" + resourceAuthSession = "authsessions" ) const ( @@ -809,6 +811,54 @@ func (cli *client) ListUserIdentities(ctx context.Context) ([]storage.UserIdenti return userIdentities, nil } +func (cli *client) CreateAuthSession(ctx context.Context, s storage.AuthSession) error { + return cli.post(resourceAuthSession, cli.fromStorageAuthSession(s)) +} + +func (cli *client) GetAuthSession(ctx context.Context, sessionID string) (storage.AuthSession, error) { + var s AuthSession + if err := cli.get(resourceAuthSession, sessionID, &s); err != nil { + return storage.AuthSession{}, err + } + return toStorageAuthSession(s), nil +} + +func (cli *client) UpdateAuthSession(ctx context.Context, sessionID string, updater func(old storage.AuthSession) (storage.AuthSession, error)) error { + return retryOnConflict(ctx, func() error { + var s AuthSession + if err := cli.get(resourceAuthSession, sessionID, &s); err != nil { + return err + } + + updated, err := updater(toStorageAuthSession(s)) + if err != nil { + return err + } + + newSession := cli.fromStorageAuthSession(updated) + newSession.ObjectMeta = s.ObjectMeta + return cli.put(resourceAuthSession, sessionID, newSession) + }) +} + +func (cli *client) ListAuthSessions(ctx context.Context) ([]storage.AuthSession, error) { + var authSessionList AuthSessionList + if err := cli.list(resourceAuthSession, &authSessionList); err != nil { + return nil, fmt.Errorf("failed to list auth sessions: %v", err) + } + + sessions := make([]storage.AuthSession, len(authSessionList.AuthSessions)) + for i, s := range authSessionList.AuthSessions { + sessions[i] = toStorageAuthSession(s) + } + + return sessions, nil +} + +func (cli *client) DeleteAuthSession(ctx context.Context, sessionID string) error { + return cli.delete(resourceAuthSession, sessionID) +} + func isKubernetesAPIConflictError(err error) bool { if httpErr, ok := err.(httpError); ok { if httpErr.StatusCode() == http.StatusConflict { diff --git a/storage/kubernetes/types.go b/storage/kubernetes/types.go index 44ab4943..473f59cc 100644 --- a/storage/kubernetes/types.go +++ b/storage/kubernetes/types.go @@ -243,6 +243,23 @@ func customResourceDefinitions(apiVersion string) []k8sapi.CustomResourceDefinit }, }, }, + { + ObjectMeta: k8sapi.ObjectMeta{ + Name: "authsessions.dex.coreos.com", + }, + TypeMeta: crdMeta, + Spec: k8sapi.CustomResourceDefinitionSpec{ + Group: apiGroup, + Version: version, + Versions: versions, + Scope: scope, + Names: k8sapi.CustomResourceDefinitionNames{ + Plural: "authsessions", + Singular: "authsession", + Kind: "AuthSession", + }, + }, + }, } } @@ -948,3 +965,55 @@ func toStorageUserIdentity(u UserIdentity) storage.UserIdentity { } return s } + +// AuthSession is a Kubernetes representation of a storage AuthSession. +type AuthSession struct { + k8sapi.TypeMeta `json:",inline"` + k8sapi.ObjectMeta `json:"metadata,omitempty"` + + ClientStates map[string]*storage.ClientAuthState `json:"clientStates,omitempty"` + CreatedAt time.Time `json:"createdAt,omitempty"` + LastActivity time.Time `json:"lastActivity,omitempty"` + IPAddress string `json:"ipAddress,omitempty"` + UserAgent string `json:"userAgent,omitempty"` +} + +// AuthSessionList is a list of AuthSessions. +type AuthSessionList struct { + k8sapi.TypeMeta `json:",inline"` + k8sapi.ListMeta `json:"metadata,omitempty"` + AuthSessions []AuthSession `json:"items"` +} + +func (cli *client) fromStorageAuthSession(s storage.AuthSession) AuthSession { + return AuthSession{ + TypeMeta: k8sapi.TypeMeta{ + Kind: kindAuthSession, + APIVersion: cli.apiVersion, + }, + ObjectMeta: k8sapi.ObjectMeta{ + Name: s.ID, + Namespace: cli.namespace, + }, + ClientStates: s.ClientStates, + CreatedAt: s.CreatedAt, + LastActivity: s.LastActivity, + IPAddress: s.IPAddress, + UserAgent: s.UserAgent, + } +} + +func toStorageAuthSession(s AuthSession) storage.AuthSession { + result := storage.AuthSession{ + ID: s.ObjectMeta.Name, + ClientStates: s.ClientStates, + CreatedAt: s.CreatedAt, + LastActivity: s.LastActivity, + IPAddress: s.IPAddress, + UserAgent: s.UserAgent, + } + if result.ClientStates == nil { + result.ClientStates = make(map[string]*storage.ClientAuthState) + } + return result +} diff --git a/storage/memory/memory.go b/storage/memory/memory.go index ecf3a410..483ed246 100644 --- a/storage/memory/memory.go +++ b/storage/memory/memory.go @@ -23,6 +23,7 @@ func New(logger *slog.Logger) storage.Storage { passwords: make(map[string]storage.Password), offlineSessions: make(map[compositeKeyID]storage.OfflineSessions), userIdentities: make(map[compositeKeyID]storage.UserIdentity), + authSessions: make(map[string]storage.AuthSession), connectors: make(map[string]storage.Connector), deviceRequests: make(map[string]storage.DeviceRequest), deviceTokens: make(map[string]storage.DeviceToken), @@ -51,6 +52,7 @@ type memStorage struct { passwords map[string]storage.Password offlineSessions map[compositeKeyID]storage.OfflineSessions userIdentities map[compositeKeyID]storage.UserIdentity + authSessions map[string]storage.AuthSession connectors map[string]storage.Connector deviceRequests map[string]storage.DeviceRequest deviceTokens map[string]storage.DeviceToken @@ -246,6 +248,61 @@ func (s *memStorage) ListUserIdentities(ctx context.Context) (identities []stora return } +func (s *memStorage) ListAuthSessions(ctx context.Context) (sessions []storage.AuthSession, err error) { + s.tx(func() { + for _, session := range s.authSessions { + sessions = append(sessions, session) + } + }) + return +} + +func (s *memStorage) CreateAuthSession(ctx context.Context, session storage.AuthSession) (err error) { + s.tx(func() { + if _, ok := s.authSessions[session.ID]; ok { + err = storage.ErrAlreadyExists + } else { + s.authSessions[session.ID] = session + } + }) + return +} + +func (s *memStorage) GetAuthSession(ctx context.Context, sessionID string) (session storage.AuthSession, err error) { + s.tx(func() { + var ok bool + if session, ok = s.authSessions[sessionID]; !ok { + err = storage.ErrNotFound + } + }) + return +} + +func (s *memStorage) UpdateAuthSession(ctx context.Context, sessionID string, updater func(s storage.AuthSession) (storage.AuthSession, error)) (err error) { + s.tx(func() { + r, ok := s.authSessions[sessionID] + if !ok { + err = storage.ErrNotFound + return + } + if r, err = updater(r); err == nil { + s.authSessions[sessionID] = r + } + }) + return +} + +func (s *memStorage) DeleteAuthSession(ctx context.Context, sessionID string) (err error) { + s.tx(func() { + if _, ok := s.authSessions[sessionID]; !ok { + err = storage.ErrNotFound + return + } + delete(s.authSessions, sessionID) + }) + return +} + func (s *memStorage) CreateConnector(ctx context.Context, connector storage.Connector) (err error) { s.tx(func() { if _, ok := s.connectors[connector.ID]; ok { diff --git a/storage/sql/crud.go b/storage/sql/crud.go index 04c9be3b..ab11713a 100644 --- a/storage/sql/crud.go +++ b/storage/sql/crud.go @@ -926,6 +926,136 @@ func (c *conn) DeleteUserIdentity(ctx context.Context, userID, connectorID strin return nil } +func (c *conn) CreateAuthSession(ctx context.Context, s storage.AuthSession) error { + _, err := c.Exec(` + insert into auth_session ( + id, client_states, + created_at, last_activity, + ip_address, user_agent + ) + values ($1, $2, $3, $4, $5, $6); + `, + s.ID, encoder(s.ClientStates), + s.CreatedAt, s.LastActivity, + s.IPAddress, s.UserAgent, + ) + if err != nil { + if c.alreadyExistsCheck(err) { + return storage.ErrAlreadyExists + } + return fmt.Errorf("insert auth session: %v", err) + } + return nil +} + +func (c *conn) UpdateAuthSession(ctx context.Context, sessionID string, updater func(s storage.AuthSession) (storage.AuthSession, error)) error { + return c.ExecTx(func(tx *trans) error { + s, err := getAuthSession(ctx, tx, sessionID) + if err != nil { + return err + } + + newSession, err := updater(s) + if err != nil { + return err + } + _, err = tx.Exec(` + update auth_session + set + client_states = $1, + last_activity = $2, + ip_address = $3, + user_agent = $4 + where id = $5; + `, + encoder(newSession.ClientStates), + newSession.LastActivity, + newSession.IPAddress, newSession.UserAgent, + sessionID, + ) + if err != nil { + return fmt.Errorf("update auth session: %v", err) + } + return nil + }) +} + +func (c *conn) GetAuthSession(ctx context.Context, sessionID string) (storage.AuthSession, error) { + return getAuthSession(ctx, c, sessionID) +} + +func getAuthSession(ctx context.Context, q querier, sessionID string) (storage.AuthSession, error) { + return scanAuthSession(q.QueryRow(` + select + id, client_states, + created_at, last_activity, + ip_address, user_agent + from auth_session + where id = $1; + `, sessionID)) +} + +func scanAuthSession(s scanner) (session storage.AuthSession, err error) { + err = s.Scan( + &session.ID, decoder(&session.ClientStates), + &session.CreatedAt, &session.LastActivity, + &session.IPAddress, &session.UserAgent, + ) + if err != nil { + if err == sql.ErrNoRows { + return session, storage.ErrNotFound + } + return session, fmt.Errorf("select auth session: %v", err) + } + if session.ClientStates == nil { + session.ClientStates = make(map[string]*storage.ClientAuthState) + } + return session, nil +} + +func (c *conn) ListAuthSessions(ctx context.Context) ([]storage.AuthSession, error) { + rows, err := c.Query(` + select + id, client_states, + created_at, last_activity, + ip_address, user_agent + from auth_session; + `) + if err != nil { + return nil, fmt.Errorf("query: %v", err) + } + defer rows.Close() + + var sessions []storage.AuthSession + for rows.Next() { + s, err := scanAuthSession(rows) + if err != nil { + return nil, err + } + sessions = append(sessions, s) + } + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("scan: %v", err) + } + return sessions, nil +} + +func (c *conn) DeleteAuthSession(ctx context.Context, sessionID string) error { + result, err := c.Exec(`delete from auth_session where id = $1`, sessionID) + if err != nil { + return fmt.Errorf("delete auth_session: id = %s: %w", sessionID, err) + } + + n, err := result.RowsAffected() + if err != nil { + return fmt.Errorf("rows affected: %v", err) + } + if n < 1 { + return storage.ErrNotFound + } + return nil +} + func (c *conn) CreateConnector(ctx context.Context, connector storage.Connector) error { grantTypes, err := json.Marshal(connector.GrantTypes) if err != nil { diff --git a/storage/sql/migrate.go b/storage/sql/migrate.go index 053e9e6a..7561d146 100644 --- a/storage/sql/migrate.go +++ b/storage/sql/migrate.go @@ -409,4 +409,17 @@ var migrations = []migration{ );`, }, }, + { + stmts: []string{ + ` + create table auth_session ( + id text not null primary key, + client_states bytea not null, + created_at timestamptz not null, + last_activity timestamptz not null, + ip_address text not null default '', + user_agent text not null default '' + );`, + }, + }, } diff --git a/storage/storage.go b/storage/storage.go index 3e332a80..963c7c67 100644 --- a/storage/storage.go +++ b/storage/storage.go @@ -84,6 +84,7 @@ type Storage interface { CreatePassword(ctx context.Context, p Password) error CreateOfflineSessions(ctx context.Context, s OfflineSessions) error CreateUserIdentity(ctx context.Context, u UserIdentity) error + CreateAuthSession(ctx context.Context, s AuthSession) error CreateConnector(ctx context.Context, c Connector) error CreateDeviceRequest(ctx context.Context, d DeviceRequest) error CreateDeviceToken(ctx context.Context, d DeviceToken) error @@ -98,6 +99,7 @@ type Storage interface { GetPassword(ctx context.Context, email string) (Password, error) GetOfflineSessions(ctx context.Context, userID string, connID string) (OfflineSessions, error) GetUserIdentity(ctx context.Context, userID, connectorID string) (UserIdentity, error) + GetAuthSession(ctx context.Context, sessionID string) (AuthSession, error) GetConnector(ctx context.Context, id string) (Connector, error) GetDeviceRequest(ctx context.Context, userCode string) (DeviceRequest, error) GetDeviceToken(ctx context.Context, deviceCode string) (DeviceToken, error) @@ -107,6 +109,7 @@ type Storage interface { ListPasswords(ctx context.Context) ([]Password, error) ListConnectors(ctx context.Context) ([]Connector, error) ListUserIdentities(ctx context.Context) ([]UserIdentity, error) + ListAuthSessions(ctx context.Context) ([]AuthSession, error) // Delete methods MUST be atomic. DeleteAuthRequest(ctx context.Context, id string) error @@ -116,6 +119,7 @@ type Storage interface { DeletePassword(ctx context.Context, email string) error DeleteOfflineSessions(ctx context.Context, userID string, connID string) error DeleteUserIdentity(ctx context.Context, userID, connectorID string) error + DeleteAuthSession(ctx context.Context, sessionID string) error DeleteConnector(ctx context.Context, id string) error // Update methods take a function for updating an object then performs that update within @@ -139,6 +143,7 @@ type Storage interface { UpdatePassword(ctx context.Context, email string, updater func(p Password) (Password, error)) error UpdateOfflineSessions(ctx context.Context, userID string, connID string, updater func(s OfflineSessions) (OfflineSessions, error)) error UpdateUserIdentity(ctx context.Context, userID, connectorID string, updater func(u UserIdentity) (UserIdentity, error)) error + UpdateAuthSession(ctx context.Context, sessionID string, updater func(s AuthSession) (AuthSession, error)) error UpdateConnector(ctx context.Context, id string, updater func(c Connector) (Connector, error)) error UpdateDeviceToken(ctx context.Context, deviceCode string, updater func(t DeviceToken) (DeviceToken, error)) error @@ -336,6 +341,26 @@ type UserIdentity struct { BlockedUntil time.Time } +// ClientAuthState represents the authentication state for a specific client within a session. +type ClientAuthState struct { + UserID string + ConnectorID string + Active bool + ExpiresAt time.Time + LastActivity time.Time + LastTokenIssuedAt time.Time +} + +// AuthSession represents a browser-bound authentication session. +type AuthSession struct { + ID string + ClientStates map[string]*ClientAuthState // clientID -> auth state + CreatedAt time.Time + LastActivity time.Time + IPAddress string + UserAgent string +} + // OfflineSessions objects are sessions pertaining to users with refresh tokens. type OfflineSessions struct { // UserID of an end user who has logged into the server.