Browse Source

feat: implement AuthSession CRUD operations (#4646)

Signed-off-by: maksim.nabokikh <max.nabokih@gmail.com>
master
Maksim Nabokikh 8 hours ago committed by GitHub
parent
commit
6b9ce00e11
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 100
      storage/conformance/conformance.go
  2. 108
      storage/ent/client/authsession.go
  3. 22
      storage/ent/client/types.go
  4. 150
      storage/ent/db/authsession.go
  5. 83
      storage/ent/db/authsession/authsession.go
  6. 355
      storage/ent/db/authsession/where.go
  7. 282
      storage/ent/db/authsession_create.go
  8. 88
      storage/ent/db/authsession_delete.go
  9. 527
      storage/ent/db/authsession_query.go
  10. 330
      storage/ent/db/authsession_update.go
  11. 155
      storage/ent/db/client.go
  12. 2
      storage/ent/db/ent.go
  13. 12
      storage/ent/db/hook/hook.go
  14. 16
      storage/ent/db/migrate/schema.go
  15. 550
      storage/ent/db/mutation.go
  16. 3
      storage/ent/db/predicate/predicate.go
  17. 15
      storage/ent/db/runtime.go
  18. 3
      storage/ent/db/tx.go
  19. 37
      storage/ent/schema/authsession.go
  20. 60
      storage/etcd/etcd.go
  21. 36
      storage/etcd/types.go
  22. 50
      storage/kubernetes/storage.go
  23. 69
      storage/kubernetes/types.go
  24. 57
      storage/memory/memory.go
  25. 130
      storage/sql/crud.go
  26. 13
      storage/sql/migrate.go
  27. 25
      storage/storage.go

100
storage/conformance/conformance.go

@ -52,6 +52,7 @@ func RunTests(t *testing.T, newStorage func(t *testing.T) storage.Storage) {
{"DeviceRequestCRUD", testDeviceRequestCRUD},
{"DeviceTokenCRUD", testDeviceTokenCRUD},
{"UserIdentityCRUD", testUserIdentityCRUD},
{"AuthSessionCRUD", testAuthSessionCRUD},
})
}
@ -1166,3 +1167,102 @@ func testUserIdentityCRUD(t *testing.T, s storage.Storage) {
_, err = s.GetUserIdentity(ctx, u1.UserID, u1.ConnectorID)
mustBeErrNotFound(t, "user identity", err)
}
func testAuthSessionCRUD(t *testing.T, s storage.Storage) {
ctx := t.Context()
now := time.Now().UTC().Round(time.Millisecond)
session := storage.AuthSession{
ID: storage.NewID(),
ClientStates: map[string]*storage.ClientAuthState{
"client1": {
UserID: "user1",
ConnectorID: "conn1",
Active: true,
ExpiresAt: now.Add(24 * time.Hour),
LastActivity: now,
LastTokenIssuedAt: now,
},
},
CreatedAt: now,
LastActivity: now,
IPAddress: "192.168.1.1",
UserAgent: "TestBrowser/1.0",
}
// Create.
if err := s.CreateAuthSession(ctx, session); err != nil {
t.Fatalf("create auth session: %v", err)
}
// Duplicate create should return ErrAlreadyExists.
err := s.CreateAuthSession(ctx, session)
mustBeErrAlreadyExists(t, "auth session", err)
// Get and compare.
got, err := s.GetAuthSession(ctx, session.ID)
if err != nil {
t.Fatalf("get auth session: %v", err)
}
got.CreatedAt = got.CreatedAt.UTC().Round(time.Millisecond)
got.LastActivity = got.LastActivity.UTC().Round(time.Millisecond)
for _, cs := range got.ClientStates {
cs.ExpiresAt = cs.ExpiresAt.UTC().Round(time.Millisecond)
cs.LastActivity = cs.LastActivity.UTC().Round(time.Millisecond)
cs.LastTokenIssuedAt = cs.LastTokenIssuedAt.UTC().Round(time.Millisecond)
}
if diff := pretty.Compare(session, got); diff != "" {
t.Errorf("auth session retrieved from storage did not match: %s", diff)
}
// Update: add a new client state.
newNow := now.Add(time.Minute)
if err := s.UpdateAuthSession(ctx, session.ID, func(old storage.AuthSession) (storage.AuthSession, error) {
old.ClientStates["client2"] = &storage.ClientAuthState{
UserID: "user2",
ConnectorID: "conn2",
Active: true,
ExpiresAt: newNow.Add(24 * time.Hour),
LastActivity: newNow,
}
old.LastActivity = newNow
return old, nil
}); err != nil {
t.Fatalf("update auth session: %v", err)
}
// Get and verify update.
got, err = s.GetAuthSession(ctx, session.ID)
if err != nil {
t.Fatalf("get auth session after update: %v", err)
}
if len(got.ClientStates) != 2 {
t.Fatalf("expected 2 client states, got %d", len(got.ClientStates))
}
if got.ClientStates["client2"] == nil {
t.Fatal("expected client2 state to exist")
}
if got.ClientStates["client2"].UserID != "user2" {
t.Errorf("expected client2 user_id to be user2, got %s", got.ClientStates["client2"].UserID)
}
// List and verify.
sessions, err := s.ListAuthSessions(ctx)
if err != nil {
t.Fatalf("list auth sessions: %v", err)
}
if len(sessions) != 1 {
t.Fatalf("expected 1 auth session, got %d", len(sessions))
}
// Delete.
if err := s.DeleteAuthSession(ctx, session.ID); err != nil {
t.Fatalf("delete auth session: %v", err)
}
// Get deleted should return ErrNotFound.
_, err = s.GetAuthSession(ctx, session.ID)
mustBeErrNotFound(t, "auth session", err)
}

108
storage/ent/client/authsession.go

@ -0,0 +1,108 @@
package client
import (
"context"
"encoding/json"
"fmt"
"github.com/dexidp/dex/storage"
)
// CreateAuthSession saves provided auth session into the database.
func (d *Database) CreateAuthSession(ctx context.Context, session storage.AuthSession) error {
if session.ClientStates == nil {
session.ClientStates = make(map[string]*storage.ClientAuthState)
}
encodedStates, err := json.Marshal(session.ClientStates)
if err != nil {
return fmt.Errorf("encode client states auth session: %w", err)
}
_, err = d.client.AuthSession.Create().
SetID(session.ID).
SetClientStates(encodedStates).
SetCreatedAt(session.CreatedAt).
SetLastActivity(session.LastActivity).
SetIPAddress(session.IPAddress).
SetUserAgent(session.UserAgent).
Save(ctx)
if err != nil {
return convertDBError("create auth session: %w", err)
}
return nil
}
// GetAuthSession extracts an auth session from the database by session ID.
func (d *Database) GetAuthSession(ctx context.Context, sessionID string) (storage.AuthSession, error) {
authSession, err := d.client.AuthSession.Get(ctx, sessionID)
if err != nil {
return storage.AuthSession{}, convertDBError("get auth session: %w", err)
}
return toStorageAuthSession(authSession), nil
}
// ListAuthSessions extracts all auth sessions from the database.
func (d *Database) ListAuthSessions(ctx context.Context) ([]storage.AuthSession, error) {
authSessions, err := d.client.AuthSession.Query().All(ctx)
if err != nil {
return nil, convertDBError("list auth sessions: %w", err)
}
storageAuthSessions := make([]storage.AuthSession, 0, len(authSessions))
for _, s := range authSessions {
storageAuthSessions = append(storageAuthSessions, toStorageAuthSession(s))
}
return storageAuthSessions, nil
}
// DeleteAuthSession deletes an auth session from the database by session ID.
func (d *Database) DeleteAuthSession(ctx context.Context, sessionID string) error {
err := d.client.AuthSession.DeleteOneID(sessionID).Exec(ctx)
if err != nil {
return convertDBError("delete auth session: %w", err)
}
return nil
}
// UpdateAuthSession changes an auth session using an updater function.
func (d *Database) UpdateAuthSession(ctx context.Context, sessionID string, updater func(s storage.AuthSession) (storage.AuthSession, error)) error {
tx, err := d.BeginTx(ctx)
if err != nil {
return convertDBError("update auth session tx: %w", err)
}
authSession, err := tx.AuthSession.Get(ctx, sessionID)
if err != nil {
return rollback(tx, "update auth session database: %w", err)
}
newSession, err := updater(toStorageAuthSession(authSession))
if err != nil {
return rollback(tx, "update auth session updating: %w", err)
}
if newSession.ClientStates == nil {
newSession.ClientStates = make(map[string]*storage.ClientAuthState)
}
encodedStates, err := json.Marshal(newSession.ClientStates)
if err != nil {
return rollback(tx, "encode client states auth session: %w", err)
}
_, err = tx.AuthSession.UpdateOneID(sessionID).
SetClientStates(encodedStates).
SetLastActivity(newSession.LastActivity).
SetIPAddress(newSession.IPAddress).
SetUserAgent(newSession.UserAgent).
Save(ctx)
if err != nil {
return rollback(tx, "update auth session updating: %w", err)
}
if err = tx.Commit(); err != nil {
return rollback(tx, "update auth session commit: %w", err)
}
return nil
}

22
storage/ent/client/types.go

@ -196,6 +196,28 @@ func toStorageUserIdentity(u *db.UserIdentity) storage.UserIdentity {
return s
}
func toStorageAuthSession(s *db.AuthSession) storage.AuthSession {
result := storage.AuthSession{
ID: s.ID,
CreatedAt: s.CreatedAt,
LastActivity: s.LastActivity,
IPAddress: s.IPAddress,
UserAgent: s.UserAgent,
}
if s.ClientStates != nil {
if err := json.Unmarshal(s.ClientStates, &result.ClientStates); err != nil {
panic(err)
}
if result.ClientStates == nil {
result.ClientStates = make(map[string]*storage.ClientAuthState)
}
} else {
result.ClientStates = make(map[string]*storage.ClientAuthState)
}
return result
}
func toStorageDeviceToken(t *db.DeviceToken) storage.DeviceToken {
return storage.DeviceToken{
DeviceCode: t.DeviceCode,

150
storage/ent/db/authsession.go

@ -0,0 +1,150 @@
// Code generated by ent, DO NOT EDIT.
package db
import (
"fmt"
"strings"
"time"
"entgo.io/ent"
"entgo.io/ent/dialect/sql"
"github.com/dexidp/dex/storage/ent/db/authsession"
)
// AuthSession is the model entity for the AuthSession schema.
type AuthSession struct {
config `json:"-"`
// ID of the ent.
ID string `json:"id,omitempty"`
// ClientStates holds the value of the "client_states" field.
ClientStates []byte `json:"client_states,omitempty"`
// CreatedAt holds the value of the "created_at" field.
CreatedAt time.Time `json:"created_at,omitempty"`
// LastActivity holds the value of the "last_activity" field.
LastActivity time.Time `json:"last_activity,omitempty"`
// IPAddress holds the value of the "ip_address" field.
IPAddress string `json:"ip_address,omitempty"`
// UserAgent holds the value of the "user_agent" field.
UserAgent string `json:"user_agent,omitempty"`
selectValues sql.SelectValues
}
// scanValues returns the types for scanning values from sql.Rows.
func (*AuthSession) scanValues(columns []string) ([]any, error) {
values := make([]any, len(columns))
for i := range columns {
switch columns[i] {
case authsession.FieldClientStates:
values[i] = new([]byte)
case authsession.FieldID, authsession.FieldIPAddress, authsession.FieldUserAgent:
values[i] = new(sql.NullString)
case authsession.FieldCreatedAt, authsession.FieldLastActivity:
values[i] = new(sql.NullTime)
default:
values[i] = new(sql.UnknownType)
}
}
return values, nil
}
// assignValues assigns the values that were returned from sql.Rows (after scanning)
// to the AuthSession fields.
func (_m *AuthSession) assignValues(columns []string, values []any) error {
if m, n := len(values), len(columns); m < n {
return fmt.Errorf("mismatch number of scan values: %d != %d", m, n)
}
for i := range columns {
switch columns[i] {
case authsession.FieldID:
if value, ok := values[i].(*sql.NullString); !ok {
return fmt.Errorf("unexpected type %T for field id", values[i])
} else if value.Valid {
_m.ID = value.String
}
case authsession.FieldClientStates:
if value, ok := values[i].(*[]byte); !ok {
return fmt.Errorf("unexpected type %T for field client_states", values[i])
} else if value != nil {
_m.ClientStates = *value
}
case authsession.FieldCreatedAt:
if value, ok := values[i].(*sql.NullTime); !ok {
return fmt.Errorf("unexpected type %T for field created_at", values[i])
} else if value.Valid {
_m.CreatedAt = value.Time
}
case authsession.FieldLastActivity:
if value, ok := values[i].(*sql.NullTime); !ok {
return fmt.Errorf("unexpected type %T for field last_activity", values[i])
} else if value.Valid {
_m.LastActivity = value.Time
}
case authsession.FieldIPAddress:
if value, ok := values[i].(*sql.NullString); !ok {
return fmt.Errorf("unexpected type %T for field ip_address", values[i])
} else if value.Valid {
_m.IPAddress = value.String
}
case authsession.FieldUserAgent:
if value, ok := values[i].(*sql.NullString); !ok {
return fmt.Errorf("unexpected type %T for field user_agent", values[i])
} else if value.Valid {
_m.UserAgent = value.String
}
default:
_m.selectValues.Set(columns[i], values[i])
}
}
return nil
}
// Value returns the ent.Value that was dynamically selected and assigned to the AuthSession.
// This includes values selected through modifiers, order, etc.
func (_m *AuthSession) Value(name string) (ent.Value, error) {
return _m.selectValues.Get(name)
}
// Update returns a builder for updating this AuthSession.
// Note that you need to call AuthSession.Unwrap() before calling this method if this AuthSession
// was returned from a transaction, and the transaction was committed or rolled back.
func (_m *AuthSession) Update() *AuthSessionUpdateOne {
return NewAuthSessionClient(_m.config).UpdateOne(_m)
}
// Unwrap unwraps the AuthSession entity that was returned from a transaction after it was closed,
// so that all future queries will be executed through the driver which created the transaction.
func (_m *AuthSession) Unwrap() *AuthSession {
_tx, ok := _m.config.driver.(*txDriver)
if !ok {
panic("db: AuthSession is not a transactional entity")
}
_m.config.driver = _tx.drv
return _m
}
// String implements the fmt.Stringer.
func (_m *AuthSession) String() string {
var builder strings.Builder
builder.WriteString("AuthSession(")
builder.WriteString(fmt.Sprintf("id=%v, ", _m.ID))
builder.WriteString("client_states=")
builder.WriteString(fmt.Sprintf("%v", _m.ClientStates))
builder.WriteString(", ")
builder.WriteString("created_at=")
builder.WriteString(_m.CreatedAt.Format(time.ANSIC))
builder.WriteString(", ")
builder.WriteString("last_activity=")
builder.WriteString(_m.LastActivity.Format(time.ANSIC))
builder.WriteString(", ")
builder.WriteString("ip_address=")
builder.WriteString(_m.IPAddress)
builder.WriteString(", ")
builder.WriteString("user_agent=")
builder.WriteString(_m.UserAgent)
builder.WriteByte(')')
return builder.String()
}
// AuthSessions is a parsable slice of AuthSession.
type AuthSessions []*AuthSession

83
storage/ent/db/authsession/authsession.go

@ -0,0 +1,83 @@
// Code generated by ent, DO NOT EDIT.
package authsession
import (
"entgo.io/ent/dialect/sql"
)
const (
// Label holds the string label denoting the authsession type in the database.
Label = "auth_session"
// FieldID holds the string denoting the id field in the database.
FieldID = "id"
// FieldClientStates holds the string denoting the client_states field in the database.
FieldClientStates = "client_states"
// FieldCreatedAt holds the string denoting the created_at field in the database.
FieldCreatedAt = "created_at"
// FieldLastActivity holds the string denoting the last_activity field in the database.
FieldLastActivity = "last_activity"
// FieldIPAddress holds the string denoting the ip_address field in the database.
FieldIPAddress = "ip_address"
// FieldUserAgent holds the string denoting the user_agent field in the database.
FieldUserAgent = "user_agent"
// Table holds the table name of the authsession in the database.
Table = "auth_sessions"
)
// Columns holds all SQL columns for authsession fields.
var Columns = []string{
FieldID,
FieldClientStates,
FieldCreatedAt,
FieldLastActivity,
FieldIPAddress,
FieldUserAgent,
}
// ValidColumn reports if the column name is valid (part of the table columns).
func ValidColumn(column string) bool {
for i := range Columns {
if column == Columns[i] {
return true
}
}
return false
}
var (
// DefaultIPAddress holds the default value on creation for the "ip_address" field.
DefaultIPAddress string
// DefaultUserAgent holds the default value on creation for the "user_agent" field.
DefaultUserAgent string
// IDValidator is a validator for the "id" field. It is called by the builders before save.
IDValidator func(string) error
)
// OrderOption defines the ordering options for the AuthSession queries.
type OrderOption func(*sql.Selector)
// ByID orders the results by the id field.
func ByID(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldID, opts...).ToFunc()
}
// ByCreatedAt orders the results by the created_at field.
func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldCreatedAt, opts...).ToFunc()
}
// ByLastActivity orders the results by the last_activity field.
func ByLastActivity(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldLastActivity, opts...).ToFunc()
}
// ByIPAddress orders the results by the ip_address field.
func ByIPAddress(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldIPAddress, opts...).ToFunc()
}
// ByUserAgent orders the results by the user_agent field.
func ByUserAgent(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldUserAgent, opts...).ToFunc()
}

355
storage/ent/db/authsession/where.go

@ -0,0 +1,355 @@
// Code generated by ent, DO NOT EDIT.
package authsession
import (
"time"
"entgo.io/ent/dialect/sql"
"github.com/dexidp/dex/storage/ent/db/predicate"
)
// ID filters vertices based on their ID field.
func ID(id string) predicate.AuthSession {
return predicate.AuthSession(sql.FieldEQ(FieldID, id))
}
// IDEQ applies the EQ predicate on the ID field.
func IDEQ(id string) predicate.AuthSession {
return predicate.AuthSession(sql.FieldEQ(FieldID, id))
}
// IDNEQ applies the NEQ predicate on the ID field.
func IDNEQ(id string) predicate.AuthSession {
return predicate.AuthSession(sql.FieldNEQ(FieldID, id))
}
// IDIn applies the In predicate on the ID field.
func IDIn(ids ...string) predicate.AuthSession {
return predicate.AuthSession(sql.FieldIn(FieldID, ids...))
}
// IDNotIn applies the NotIn predicate on the ID field.
func IDNotIn(ids ...string) predicate.AuthSession {
return predicate.AuthSession(sql.FieldNotIn(FieldID, ids...))
}
// IDGT applies the GT predicate on the ID field.
func IDGT(id string) predicate.AuthSession {
return predicate.AuthSession(sql.FieldGT(FieldID, id))
}
// IDGTE applies the GTE predicate on the ID field.
func IDGTE(id string) predicate.AuthSession {
return predicate.AuthSession(sql.FieldGTE(FieldID, id))
}
// IDLT applies the LT predicate on the ID field.
func IDLT(id string) predicate.AuthSession {
return predicate.AuthSession(sql.FieldLT(FieldID, id))
}
// IDLTE applies the LTE predicate on the ID field.
func IDLTE(id string) predicate.AuthSession {
return predicate.AuthSession(sql.FieldLTE(FieldID, id))
}
// IDEqualFold applies the EqualFold predicate on the ID field.
func IDEqualFold(id string) predicate.AuthSession {
return predicate.AuthSession(sql.FieldEqualFold(FieldID, id))
}
// IDContainsFold applies the ContainsFold predicate on the ID field.
func IDContainsFold(id string) predicate.AuthSession {
return predicate.AuthSession(sql.FieldContainsFold(FieldID, id))
}
// ClientStates applies equality check predicate on the "client_states" field. It's identical to ClientStatesEQ.
func ClientStates(v []byte) predicate.AuthSession {
return predicate.AuthSession(sql.FieldEQ(FieldClientStates, v))
}
// CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ.
func CreatedAt(v time.Time) predicate.AuthSession {
return predicate.AuthSession(sql.FieldEQ(FieldCreatedAt, v))
}
// LastActivity applies equality check predicate on the "last_activity" field. It's identical to LastActivityEQ.
func LastActivity(v time.Time) predicate.AuthSession {
return predicate.AuthSession(sql.FieldEQ(FieldLastActivity, v))
}
// IPAddress applies equality check predicate on the "ip_address" field. It's identical to IPAddressEQ.
func IPAddress(v string) predicate.AuthSession {
return predicate.AuthSession(sql.FieldEQ(FieldIPAddress, v))
}
// UserAgent applies equality check predicate on the "user_agent" field. It's identical to UserAgentEQ.
func UserAgent(v string) predicate.AuthSession {
return predicate.AuthSession(sql.FieldEQ(FieldUserAgent, v))
}
// ClientStatesEQ applies the EQ predicate on the "client_states" field.
func ClientStatesEQ(v []byte) predicate.AuthSession {
return predicate.AuthSession(sql.FieldEQ(FieldClientStates, v))
}
// ClientStatesNEQ applies the NEQ predicate on the "client_states" field.
func ClientStatesNEQ(v []byte) predicate.AuthSession {
return predicate.AuthSession(sql.FieldNEQ(FieldClientStates, v))
}
// ClientStatesIn applies the In predicate on the "client_states" field.
func ClientStatesIn(vs ...[]byte) predicate.AuthSession {
return predicate.AuthSession(sql.FieldIn(FieldClientStates, vs...))
}
// ClientStatesNotIn applies the NotIn predicate on the "client_states" field.
func ClientStatesNotIn(vs ...[]byte) predicate.AuthSession {
return predicate.AuthSession(sql.FieldNotIn(FieldClientStates, vs...))
}
// ClientStatesGT applies the GT predicate on the "client_states" field.
func ClientStatesGT(v []byte) predicate.AuthSession {
return predicate.AuthSession(sql.FieldGT(FieldClientStates, v))
}
// ClientStatesGTE applies the GTE predicate on the "client_states" field.
func ClientStatesGTE(v []byte) predicate.AuthSession {
return predicate.AuthSession(sql.FieldGTE(FieldClientStates, v))
}
// ClientStatesLT applies the LT predicate on the "client_states" field.
func ClientStatesLT(v []byte) predicate.AuthSession {
return predicate.AuthSession(sql.FieldLT(FieldClientStates, v))
}
// ClientStatesLTE applies the LTE predicate on the "client_states" field.
func ClientStatesLTE(v []byte) predicate.AuthSession {
return predicate.AuthSession(sql.FieldLTE(FieldClientStates, v))
}
// CreatedAtEQ applies the EQ predicate on the "created_at" field.
func CreatedAtEQ(v time.Time) predicate.AuthSession {
return predicate.AuthSession(sql.FieldEQ(FieldCreatedAt, v))
}
// CreatedAtNEQ applies the NEQ predicate on the "created_at" field.
func CreatedAtNEQ(v time.Time) predicate.AuthSession {
return predicate.AuthSession(sql.FieldNEQ(FieldCreatedAt, v))
}
// CreatedAtIn applies the In predicate on the "created_at" field.
func CreatedAtIn(vs ...time.Time) predicate.AuthSession {
return predicate.AuthSession(sql.FieldIn(FieldCreatedAt, vs...))
}
// CreatedAtNotIn applies the NotIn predicate on the "created_at" field.
func CreatedAtNotIn(vs ...time.Time) predicate.AuthSession {
return predicate.AuthSession(sql.FieldNotIn(FieldCreatedAt, vs...))
}
// CreatedAtGT applies the GT predicate on the "created_at" field.
func CreatedAtGT(v time.Time) predicate.AuthSession {
return predicate.AuthSession(sql.FieldGT(FieldCreatedAt, v))
}
// CreatedAtGTE applies the GTE predicate on the "created_at" field.
func CreatedAtGTE(v time.Time) predicate.AuthSession {
return predicate.AuthSession(sql.FieldGTE(FieldCreatedAt, v))
}
// CreatedAtLT applies the LT predicate on the "created_at" field.
func CreatedAtLT(v time.Time) predicate.AuthSession {
return predicate.AuthSession(sql.FieldLT(FieldCreatedAt, v))
}
// CreatedAtLTE applies the LTE predicate on the "created_at" field.
func CreatedAtLTE(v time.Time) predicate.AuthSession {
return predicate.AuthSession(sql.FieldLTE(FieldCreatedAt, v))
}
// LastActivityEQ applies the EQ predicate on the "last_activity" field.
func LastActivityEQ(v time.Time) predicate.AuthSession {
return predicate.AuthSession(sql.FieldEQ(FieldLastActivity, v))
}
// LastActivityNEQ applies the NEQ predicate on the "last_activity" field.
func LastActivityNEQ(v time.Time) predicate.AuthSession {
return predicate.AuthSession(sql.FieldNEQ(FieldLastActivity, v))
}
// LastActivityIn applies the In predicate on the "last_activity" field.
func LastActivityIn(vs ...time.Time) predicate.AuthSession {
return predicate.AuthSession(sql.FieldIn(FieldLastActivity, vs...))
}
// LastActivityNotIn applies the NotIn predicate on the "last_activity" field.
func LastActivityNotIn(vs ...time.Time) predicate.AuthSession {
return predicate.AuthSession(sql.FieldNotIn(FieldLastActivity, vs...))
}
// LastActivityGT applies the GT predicate on the "last_activity" field.
func LastActivityGT(v time.Time) predicate.AuthSession {
return predicate.AuthSession(sql.FieldGT(FieldLastActivity, v))
}
// LastActivityGTE applies the GTE predicate on the "last_activity" field.
func LastActivityGTE(v time.Time) predicate.AuthSession {
return predicate.AuthSession(sql.FieldGTE(FieldLastActivity, v))
}
// LastActivityLT applies the LT predicate on the "last_activity" field.
func LastActivityLT(v time.Time) predicate.AuthSession {
return predicate.AuthSession(sql.FieldLT(FieldLastActivity, v))
}
// LastActivityLTE applies the LTE predicate on the "last_activity" field.
func LastActivityLTE(v time.Time) predicate.AuthSession {
return predicate.AuthSession(sql.FieldLTE(FieldLastActivity, v))
}
// IPAddressEQ applies the EQ predicate on the "ip_address" field.
func IPAddressEQ(v string) predicate.AuthSession {
return predicate.AuthSession(sql.FieldEQ(FieldIPAddress, v))
}
// IPAddressNEQ applies the NEQ predicate on the "ip_address" field.
func IPAddressNEQ(v string) predicate.AuthSession {
return predicate.AuthSession(sql.FieldNEQ(FieldIPAddress, v))
}
// IPAddressIn applies the In predicate on the "ip_address" field.
func IPAddressIn(vs ...string) predicate.AuthSession {
return predicate.AuthSession(sql.FieldIn(FieldIPAddress, vs...))
}
// IPAddressNotIn applies the NotIn predicate on the "ip_address" field.
func IPAddressNotIn(vs ...string) predicate.AuthSession {
return predicate.AuthSession(sql.FieldNotIn(FieldIPAddress, vs...))
}
// IPAddressGT applies the GT predicate on the "ip_address" field.
func IPAddressGT(v string) predicate.AuthSession {
return predicate.AuthSession(sql.FieldGT(FieldIPAddress, v))
}
// IPAddressGTE applies the GTE predicate on the "ip_address" field.
func IPAddressGTE(v string) predicate.AuthSession {
return predicate.AuthSession(sql.FieldGTE(FieldIPAddress, v))
}
// IPAddressLT applies the LT predicate on the "ip_address" field.
func IPAddressLT(v string) predicate.AuthSession {
return predicate.AuthSession(sql.FieldLT(FieldIPAddress, v))
}
// IPAddressLTE applies the LTE predicate on the "ip_address" field.
func IPAddressLTE(v string) predicate.AuthSession {
return predicate.AuthSession(sql.FieldLTE(FieldIPAddress, v))
}
// IPAddressContains applies the Contains predicate on the "ip_address" field.
func IPAddressContains(v string) predicate.AuthSession {
return predicate.AuthSession(sql.FieldContains(FieldIPAddress, v))
}
// IPAddressHasPrefix applies the HasPrefix predicate on the "ip_address" field.
func IPAddressHasPrefix(v string) predicate.AuthSession {
return predicate.AuthSession(sql.FieldHasPrefix(FieldIPAddress, v))
}
// IPAddressHasSuffix applies the HasSuffix predicate on the "ip_address" field.
func IPAddressHasSuffix(v string) predicate.AuthSession {
return predicate.AuthSession(sql.FieldHasSuffix(FieldIPAddress, v))
}
// IPAddressEqualFold applies the EqualFold predicate on the "ip_address" field.
func IPAddressEqualFold(v string) predicate.AuthSession {
return predicate.AuthSession(sql.FieldEqualFold(FieldIPAddress, v))
}
// IPAddressContainsFold applies the ContainsFold predicate on the "ip_address" field.
func IPAddressContainsFold(v string) predicate.AuthSession {
return predicate.AuthSession(sql.FieldContainsFold(FieldIPAddress, v))
}
// UserAgentEQ applies the EQ predicate on the "user_agent" field.
func UserAgentEQ(v string) predicate.AuthSession {
return predicate.AuthSession(sql.FieldEQ(FieldUserAgent, v))
}
// UserAgentNEQ applies the NEQ predicate on the "user_agent" field.
func UserAgentNEQ(v string) predicate.AuthSession {
return predicate.AuthSession(sql.FieldNEQ(FieldUserAgent, v))
}
// UserAgentIn applies the In predicate on the "user_agent" field.
func UserAgentIn(vs ...string) predicate.AuthSession {
return predicate.AuthSession(sql.FieldIn(FieldUserAgent, vs...))
}
// UserAgentNotIn applies the NotIn predicate on the "user_agent" field.
func UserAgentNotIn(vs ...string) predicate.AuthSession {
return predicate.AuthSession(sql.FieldNotIn(FieldUserAgent, vs...))
}
// UserAgentGT applies the GT predicate on the "user_agent" field.
func UserAgentGT(v string) predicate.AuthSession {
return predicate.AuthSession(sql.FieldGT(FieldUserAgent, v))
}
// UserAgentGTE applies the GTE predicate on the "user_agent" field.
func UserAgentGTE(v string) predicate.AuthSession {
return predicate.AuthSession(sql.FieldGTE(FieldUserAgent, v))
}
// UserAgentLT applies the LT predicate on the "user_agent" field.
func UserAgentLT(v string) predicate.AuthSession {
return predicate.AuthSession(sql.FieldLT(FieldUserAgent, v))
}
// UserAgentLTE applies the LTE predicate on the "user_agent" field.
func UserAgentLTE(v string) predicate.AuthSession {
return predicate.AuthSession(sql.FieldLTE(FieldUserAgent, v))
}
// UserAgentContains applies the Contains predicate on the "user_agent" field.
func UserAgentContains(v string) predicate.AuthSession {
return predicate.AuthSession(sql.FieldContains(FieldUserAgent, v))
}
// UserAgentHasPrefix applies the HasPrefix predicate on the "user_agent" field.
func UserAgentHasPrefix(v string) predicate.AuthSession {
return predicate.AuthSession(sql.FieldHasPrefix(FieldUserAgent, v))
}
// UserAgentHasSuffix applies the HasSuffix predicate on the "user_agent" field.
func UserAgentHasSuffix(v string) predicate.AuthSession {
return predicate.AuthSession(sql.FieldHasSuffix(FieldUserAgent, v))
}
// UserAgentEqualFold applies the EqualFold predicate on the "user_agent" field.
func UserAgentEqualFold(v string) predicate.AuthSession {
return predicate.AuthSession(sql.FieldEqualFold(FieldUserAgent, v))
}
// UserAgentContainsFold applies the ContainsFold predicate on the "user_agent" field.
func UserAgentContainsFold(v string) predicate.AuthSession {
return predicate.AuthSession(sql.FieldContainsFold(FieldUserAgent, v))
}
// And groups predicates with the AND operator between them.
func And(predicates ...predicate.AuthSession) predicate.AuthSession {
return predicate.AuthSession(sql.AndPredicates(predicates...))
}
// Or groups predicates with the OR operator between them.
func Or(predicates ...predicate.AuthSession) predicate.AuthSession {
return predicate.AuthSession(sql.OrPredicates(predicates...))
}
// Not applies the not operator on the given predicate.
func Not(p predicate.AuthSession) predicate.AuthSession {
return predicate.AuthSession(sql.NotPredicates(p))
}

282
storage/ent/db/authsession_create.go

@ -0,0 +1,282 @@
// Code generated by ent, DO NOT EDIT.
package db
import (
"context"
"errors"
"fmt"
"time"
"entgo.io/ent/dialect/sql/sqlgraph"
"entgo.io/ent/schema/field"
"github.com/dexidp/dex/storage/ent/db/authsession"
)
// AuthSessionCreate is the builder for creating a AuthSession entity.
type AuthSessionCreate struct {
config
mutation *AuthSessionMutation
hooks []Hook
}
// SetClientStates sets the "client_states" field.
func (_c *AuthSessionCreate) SetClientStates(v []byte) *AuthSessionCreate {
_c.mutation.SetClientStates(v)
return _c
}
// SetCreatedAt sets the "created_at" field.
func (_c *AuthSessionCreate) SetCreatedAt(v time.Time) *AuthSessionCreate {
_c.mutation.SetCreatedAt(v)
return _c
}
// SetLastActivity sets the "last_activity" field.
func (_c *AuthSessionCreate) SetLastActivity(v time.Time) *AuthSessionCreate {
_c.mutation.SetLastActivity(v)
return _c
}
// SetIPAddress sets the "ip_address" field.
func (_c *AuthSessionCreate) SetIPAddress(v string) *AuthSessionCreate {
_c.mutation.SetIPAddress(v)
return _c
}
// SetNillableIPAddress sets the "ip_address" field if the given value is not nil.
func (_c *AuthSessionCreate) SetNillableIPAddress(v *string) *AuthSessionCreate {
if v != nil {
_c.SetIPAddress(*v)
}
return _c
}
// SetUserAgent sets the "user_agent" field.
func (_c *AuthSessionCreate) SetUserAgent(v string) *AuthSessionCreate {
_c.mutation.SetUserAgent(v)
return _c
}
// SetNillableUserAgent sets the "user_agent" field if the given value is not nil.
func (_c *AuthSessionCreate) SetNillableUserAgent(v *string) *AuthSessionCreate {
if v != nil {
_c.SetUserAgent(*v)
}
return _c
}
// SetID sets the "id" field.
func (_c *AuthSessionCreate) SetID(v string) *AuthSessionCreate {
_c.mutation.SetID(v)
return _c
}
// Mutation returns the AuthSessionMutation object of the builder.
func (_c *AuthSessionCreate) Mutation() *AuthSessionMutation {
return _c.mutation
}
// Save creates the AuthSession in the database.
func (_c *AuthSessionCreate) Save(ctx context.Context) (*AuthSession, error) {
_c.defaults()
return withHooks(ctx, _c.sqlSave, _c.mutation, _c.hooks)
}
// SaveX calls Save and panics if Save returns an error.
func (_c *AuthSessionCreate) SaveX(ctx context.Context) *AuthSession {
v, err := _c.Save(ctx)
if err != nil {
panic(err)
}
return v
}
// Exec executes the query.
func (_c *AuthSessionCreate) Exec(ctx context.Context) error {
_, err := _c.Save(ctx)
return err
}
// ExecX is like Exec, but panics if an error occurs.
func (_c *AuthSessionCreate) ExecX(ctx context.Context) {
if err := _c.Exec(ctx); err != nil {
panic(err)
}
}
// defaults sets the default values of the builder before save.
func (_c *AuthSessionCreate) defaults() {
if _, ok := _c.mutation.IPAddress(); !ok {
v := authsession.DefaultIPAddress
_c.mutation.SetIPAddress(v)
}
if _, ok := _c.mutation.UserAgent(); !ok {
v := authsession.DefaultUserAgent
_c.mutation.SetUserAgent(v)
}
}
// check runs all checks and user-defined validators on the builder.
func (_c *AuthSessionCreate) check() error {
if _, ok := _c.mutation.ClientStates(); !ok {
return &ValidationError{Name: "client_states", err: errors.New(`db: missing required field "AuthSession.client_states"`)}
}
if _, ok := _c.mutation.CreatedAt(); !ok {
return &ValidationError{Name: "created_at", err: errors.New(`db: missing required field "AuthSession.created_at"`)}
}
if _, ok := _c.mutation.LastActivity(); !ok {
return &ValidationError{Name: "last_activity", err: errors.New(`db: missing required field "AuthSession.last_activity"`)}
}
if _, ok := _c.mutation.IPAddress(); !ok {
return &ValidationError{Name: "ip_address", err: errors.New(`db: missing required field "AuthSession.ip_address"`)}
}
if _, ok := _c.mutation.UserAgent(); !ok {
return &ValidationError{Name: "user_agent", err: errors.New(`db: missing required field "AuthSession.user_agent"`)}
}
if v, ok := _c.mutation.ID(); ok {
if err := authsession.IDValidator(v); err != nil {
return &ValidationError{Name: "id", err: fmt.Errorf(`db: validator failed for field "AuthSession.id": %w`, err)}
}
}
return nil
}
func (_c *AuthSessionCreate) sqlSave(ctx context.Context) (*AuthSession, error) {
if err := _c.check(); err != nil {
return nil, err
}
_node, _spec := _c.createSpec()
if err := sqlgraph.CreateNode(ctx, _c.driver, _spec); err != nil {
if sqlgraph.IsConstraintError(err) {
err = &ConstraintError{msg: err.Error(), wrap: err}
}
return nil, err
}
if _spec.ID.Value != nil {
if id, ok := _spec.ID.Value.(string); ok {
_node.ID = id
} else {
return nil, fmt.Errorf("unexpected AuthSession.ID type: %T", _spec.ID.Value)
}
}
_c.mutation.id = &_node.ID
_c.mutation.done = true
return _node, nil
}
func (_c *AuthSessionCreate) createSpec() (*AuthSession, *sqlgraph.CreateSpec) {
var (
_node = &AuthSession{config: _c.config}
_spec = sqlgraph.NewCreateSpec(authsession.Table, sqlgraph.NewFieldSpec(authsession.FieldID, field.TypeString))
)
if id, ok := _c.mutation.ID(); ok {
_node.ID = id
_spec.ID.Value = id
}
if value, ok := _c.mutation.ClientStates(); ok {
_spec.SetField(authsession.FieldClientStates, field.TypeBytes, value)
_node.ClientStates = value
}
if value, ok := _c.mutation.CreatedAt(); ok {
_spec.SetField(authsession.FieldCreatedAt, field.TypeTime, value)
_node.CreatedAt = value
}
if value, ok := _c.mutation.LastActivity(); ok {
_spec.SetField(authsession.FieldLastActivity, field.TypeTime, value)
_node.LastActivity = value
}
if value, ok := _c.mutation.IPAddress(); ok {
_spec.SetField(authsession.FieldIPAddress, field.TypeString, value)
_node.IPAddress = value
}
if value, ok := _c.mutation.UserAgent(); ok {
_spec.SetField(authsession.FieldUserAgent, field.TypeString, value)
_node.UserAgent = value
}
return _node, _spec
}
// AuthSessionCreateBulk is the builder for creating many AuthSession entities in bulk.
type AuthSessionCreateBulk struct {
config
err error
builders []*AuthSessionCreate
}
// Save creates the AuthSession entities in the database.
func (_c *AuthSessionCreateBulk) Save(ctx context.Context) ([]*AuthSession, error) {
if _c.err != nil {
return nil, _c.err
}
specs := make([]*sqlgraph.CreateSpec, len(_c.builders))
nodes := make([]*AuthSession, len(_c.builders))
mutators := make([]Mutator, len(_c.builders))
for i := range _c.builders {
func(i int, root context.Context) {
builder := _c.builders[i]
builder.defaults()
var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) {
mutation, ok := m.(*AuthSessionMutation)
if !ok {
return nil, fmt.Errorf("unexpected mutation type %T", m)
}
if err := builder.check(); err != nil {
return nil, err
}
builder.mutation = mutation
var err error
nodes[i], specs[i] = builder.createSpec()
if i < len(mutators)-1 {
_, err = mutators[i+1].Mutate(root, _c.builders[i+1].mutation)
} else {
spec := &sqlgraph.BatchCreateSpec{Nodes: specs}
// Invoke the actual operation on the latest mutation in the chain.
if err = sqlgraph.BatchCreate(ctx, _c.driver, spec); err != nil {
if sqlgraph.IsConstraintError(err) {
err = &ConstraintError{msg: err.Error(), wrap: err}
}
}
}
if err != nil {
return nil, err
}
mutation.id = &nodes[i].ID
mutation.done = true
return nodes[i], nil
})
for i := len(builder.hooks) - 1; i >= 0; i-- {
mut = builder.hooks[i](mut)
}
mutators[i] = mut
}(i, ctx)
}
if len(mutators) > 0 {
if _, err := mutators[0].Mutate(ctx, _c.builders[0].mutation); err != nil {
return nil, err
}
}
return nodes, nil
}
// SaveX is like Save, but panics if an error occurs.
func (_c *AuthSessionCreateBulk) SaveX(ctx context.Context) []*AuthSession {
v, err := _c.Save(ctx)
if err != nil {
panic(err)
}
return v
}
// Exec executes the query.
func (_c *AuthSessionCreateBulk) Exec(ctx context.Context) error {
_, err := _c.Save(ctx)
return err
}
// ExecX is like Exec, but panics if an error occurs.
func (_c *AuthSessionCreateBulk) ExecX(ctx context.Context) {
if err := _c.Exec(ctx); err != nil {
panic(err)
}
}

88
storage/ent/db/authsession_delete.go

@ -0,0 +1,88 @@
// Code generated by ent, DO NOT EDIT.
package db
import (
"context"
"entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph"
"entgo.io/ent/schema/field"
"github.com/dexidp/dex/storage/ent/db/authsession"
"github.com/dexidp/dex/storage/ent/db/predicate"
)
// AuthSessionDelete is the builder for deleting a AuthSession entity.
type AuthSessionDelete struct {
config
hooks []Hook
mutation *AuthSessionMutation
}
// Where appends a list predicates to the AuthSessionDelete builder.
func (_d *AuthSessionDelete) Where(ps ...predicate.AuthSession) *AuthSessionDelete {
_d.mutation.Where(ps...)
return _d
}
// Exec executes the deletion query and returns how many vertices were deleted.
func (_d *AuthSessionDelete) Exec(ctx context.Context) (int, error) {
return withHooks(ctx, _d.sqlExec, _d.mutation, _d.hooks)
}
// ExecX is like Exec, but panics if an error occurs.
func (_d *AuthSessionDelete) ExecX(ctx context.Context) int {
n, err := _d.Exec(ctx)
if err != nil {
panic(err)
}
return n
}
func (_d *AuthSessionDelete) sqlExec(ctx context.Context) (int, error) {
_spec := sqlgraph.NewDeleteSpec(authsession.Table, sqlgraph.NewFieldSpec(authsession.FieldID, field.TypeString))
if ps := _d.mutation.predicates; len(ps) > 0 {
_spec.Predicate = func(selector *sql.Selector) {
for i := range ps {
ps[i](selector)
}
}
}
affected, err := sqlgraph.DeleteNodes(ctx, _d.driver, _spec)
if err != nil && sqlgraph.IsConstraintError(err) {
err = &ConstraintError{msg: err.Error(), wrap: err}
}
_d.mutation.done = true
return affected, err
}
// AuthSessionDeleteOne is the builder for deleting a single AuthSession entity.
type AuthSessionDeleteOne struct {
_d *AuthSessionDelete
}
// Where appends a list predicates to the AuthSessionDelete builder.
func (_d *AuthSessionDeleteOne) Where(ps ...predicate.AuthSession) *AuthSessionDeleteOne {
_d._d.mutation.Where(ps...)
return _d
}
// Exec executes the deletion query.
func (_d *AuthSessionDeleteOne) Exec(ctx context.Context) error {
n, err := _d._d.Exec(ctx)
switch {
case err != nil:
return err
case n == 0:
return &NotFoundError{authsession.Label}
default:
return nil
}
}
// ExecX is like Exec, but panics if an error occurs.
func (_d *AuthSessionDeleteOne) ExecX(ctx context.Context) {
if err := _d.Exec(ctx); err != nil {
panic(err)
}
}

527
storage/ent/db/authsession_query.go

@ -0,0 +1,527 @@
// Code generated by ent, DO NOT EDIT.
package db
import (
"context"
"fmt"
"math"
"entgo.io/ent"
"entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph"
"entgo.io/ent/schema/field"
"github.com/dexidp/dex/storage/ent/db/authsession"
"github.com/dexidp/dex/storage/ent/db/predicate"
)
// AuthSessionQuery is the builder for querying AuthSession entities.
type AuthSessionQuery struct {
config
ctx *QueryContext
order []authsession.OrderOption
inters []Interceptor
predicates []predicate.AuthSession
// intermediate query (i.e. traversal path).
sql *sql.Selector
path func(context.Context) (*sql.Selector, error)
}
// Where adds a new predicate for the AuthSessionQuery builder.
func (_q *AuthSessionQuery) Where(ps ...predicate.AuthSession) *AuthSessionQuery {
_q.predicates = append(_q.predicates, ps...)
return _q
}
// Limit the number of records to be returned by this query.
func (_q *AuthSessionQuery) Limit(limit int) *AuthSessionQuery {
_q.ctx.Limit = &limit
return _q
}
// Offset to start from.
func (_q *AuthSessionQuery) Offset(offset int) *AuthSessionQuery {
_q.ctx.Offset = &offset
return _q
}
// Unique configures the query builder to filter duplicate records on query.
// By default, unique is set to true, and can be disabled using this method.
func (_q *AuthSessionQuery) Unique(unique bool) *AuthSessionQuery {
_q.ctx.Unique = &unique
return _q
}
// Order specifies how the records should be ordered.
func (_q *AuthSessionQuery) Order(o ...authsession.OrderOption) *AuthSessionQuery {
_q.order = append(_q.order, o...)
return _q
}
// First returns the first AuthSession entity from the query.
// Returns a *NotFoundError when no AuthSession was found.
func (_q *AuthSessionQuery) First(ctx context.Context) (*AuthSession, error) {
nodes, err := _q.Limit(1).All(setContextOp(ctx, _q.ctx, ent.OpQueryFirst))
if err != nil {
return nil, err
}
if len(nodes) == 0 {
return nil, &NotFoundError{authsession.Label}
}
return nodes[0], nil
}
// FirstX is like First, but panics if an error occurs.
func (_q *AuthSessionQuery) FirstX(ctx context.Context) *AuthSession {
node, err := _q.First(ctx)
if err != nil && !IsNotFound(err) {
panic(err)
}
return node
}
// FirstID returns the first AuthSession ID from the query.
// Returns a *NotFoundError when no AuthSession ID was found.
func (_q *AuthSessionQuery) FirstID(ctx context.Context) (id string, err error) {
var ids []string
if ids, err = _q.Limit(1).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryFirstID)); err != nil {
return
}
if len(ids) == 0 {
err = &NotFoundError{authsession.Label}
return
}
return ids[0], nil
}
// FirstIDX is like FirstID, but panics if an error occurs.
func (_q *AuthSessionQuery) FirstIDX(ctx context.Context) string {
id, err := _q.FirstID(ctx)
if err != nil && !IsNotFound(err) {
panic(err)
}
return id
}
// Only returns a single AuthSession entity found by the query, ensuring it only returns one.
// Returns a *NotSingularError when more than one AuthSession entity is found.
// Returns a *NotFoundError when no AuthSession entities are found.
func (_q *AuthSessionQuery) Only(ctx context.Context) (*AuthSession, error) {
nodes, err := _q.Limit(2).All(setContextOp(ctx, _q.ctx, ent.OpQueryOnly))
if err != nil {
return nil, err
}
switch len(nodes) {
case 1:
return nodes[0], nil
case 0:
return nil, &NotFoundError{authsession.Label}
default:
return nil, &NotSingularError{authsession.Label}
}
}
// OnlyX is like Only, but panics if an error occurs.
func (_q *AuthSessionQuery) OnlyX(ctx context.Context) *AuthSession {
node, err := _q.Only(ctx)
if err != nil {
panic(err)
}
return node
}
// OnlyID is like Only, but returns the only AuthSession ID in the query.
// Returns a *NotSingularError when more than one AuthSession ID is found.
// Returns a *NotFoundError when no entities are found.
func (_q *AuthSessionQuery) OnlyID(ctx context.Context) (id string, err error) {
var ids []string
if ids, err = _q.Limit(2).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryOnlyID)); err != nil {
return
}
switch len(ids) {
case 1:
id = ids[0]
case 0:
err = &NotFoundError{authsession.Label}
default:
err = &NotSingularError{authsession.Label}
}
return
}
// OnlyIDX is like OnlyID, but panics if an error occurs.
func (_q *AuthSessionQuery) OnlyIDX(ctx context.Context) string {
id, err := _q.OnlyID(ctx)
if err != nil {
panic(err)
}
return id
}
// All executes the query and returns a list of AuthSessions.
func (_q *AuthSessionQuery) All(ctx context.Context) ([]*AuthSession, error) {
ctx = setContextOp(ctx, _q.ctx, ent.OpQueryAll)
if err := _q.prepareQuery(ctx); err != nil {
return nil, err
}
qr := querierAll[[]*AuthSession, *AuthSessionQuery]()
return withInterceptors[[]*AuthSession](ctx, _q, qr, _q.inters)
}
// AllX is like All, but panics if an error occurs.
func (_q *AuthSessionQuery) AllX(ctx context.Context) []*AuthSession {
nodes, err := _q.All(ctx)
if err != nil {
panic(err)
}
return nodes
}
// IDs executes the query and returns a list of AuthSession IDs.
func (_q *AuthSessionQuery) IDs(ctx context.Context) (ids []string, err error) {
if _q.ctx.Unique == nil && _q.path != nil {
_q.Unique(true)
}
ctx = setContextOp(ctx, _q.ctx, ent.OpQueryIDs)
if err = _q.Select(authsession.FieldID).Scan(ctx, &ids); err != nil {
return nil, err
}
return ids, nil
}
// IDsX is like IDs, but panics if an error occurs.
func (_q *AuthSessionQuery) IDsX(ctx context.Context) []string {
ids, err := _q.IDs(ctx)
if err != nil {
panic(err)
}
return ids
}
// Count returns the count of the given query.
func (_q *AuthSessionQuery) Count(ctx context.Context) (int, error) {
ctx = setContextOp(ctx, _q.ctx, ent.OpQueryCount)
if err := _q.prepareQuery(ctx); err != nil {
return 0, err
}
return withInterceptors[int](ctx, _q, querierCount[*AuthSessionQuery](), _q.inters)
}
// CountX is like Count, but panics if an error occurs.
func (_q *AuthSessionQuery) CountX(ctx context.Context) int {
count, err := _q.Count(ctx)
if err != nil {
panic(err)
}
return count
}
// Exist returns true if the query has elements in the graph.
func (_q *AuthSessionQuery) Exist(ctx context.Context) (bool, error) {
ctx = setContextOp(ctx, _q.ctx, ent.OpQueryExist)
switch _, err := _q.FirstID(ctx); {
case IsNotFound(err):
return false, nil
case err != nil:
return false, fmt.Errorf("db: check existence: %w", err)
default:
return true, nil
}
}
// ExistX is like Exist, but panics if an error occurs.
func (_q *AuthSessionQuery) ExistX(ctx context.Context) bool {
exist, err := _q.Exist(ctx)
if err != nil {
panic(err)
}
return exist
}
// Clone returns a duplicate of the AuthSessionQuery builder, including all associated steps. It can be
// used to prepare common query builders and use them differently after the clone is made.
func (_q *AuthSessionQuery) Clone() *AuthSessionQuery {
if _q == nil {
return nil
}
return &AuthSessionQuery{
config: _q.config,
ctx: _q.ctx.Clone(),
order: append([]authsession.OrderOption{}, _q.order...),
inters: append([]Interceptor{}, _q.inters...),
predicates: append([]predicate.AuthSession{}, _q.predicates...),
// clone intermediate query.
sql: _q.sql.Clone(),
path: _q.path,
}
}
// GroupBy is used to group vertices by one or more fields/columns.
// It is often used with aggregate functions, like: count, max, mean, min, sum.
//
// Example:
//
// var v []struct {
// ClientStates []byte `json:"client_states,omitempty"`
// Count int `json:"count,omitempty"`
// }
//
// client.AuthSession.Query().
// GroupBy(authsession.FieldClientStates).
// Aggregate(db.Count()).
// Scan(ctx, &v)
func (_q *AuthSessionQuery) GroupBy(field string, fields ...string) *AuthSessionGroupBy {
_q.ctx.Fields = append([]string{field}, fields...)
grbuild := &AuthSessionGroupBy{build: _q}
grbuild.flds = &_q.ctx.Fields
grbuild.label = authsession.Label
grbuild.scan = grbuild.Scan
return grbuild
}
// Select allows the selection one or more fields/columns for the given query,
// instead of selecting all fields in the entity.
//
// Example:
//
// var v []struct {
// ClientStates []byte `json:"client_states,omitempty"`
// }
//
// client.AuthSession.Query().
// Select(authsession.FieldClientStates).
// Scan(ctx, &v)
func (_q *AuthSessionQuery) Select(fields ...string) *AuthSessionSelect {
_q.ctx.Fields = append(_q.ctx.Fields, fields...)
sbuild := &AuthSessionSelect{AuthSessionQuery: _q}
sbuild.label = authsession.Label
sbuild.flds, sbuild.scan = &_q.ctx.Fields, sbuild.Scan
return sbuild
}
// Aggregate returns a AuthSessionSelect configured with the given aggregations.
func (_q *AuthSessionQuery) Aggregate(fns ...AggregateFunc) *AuthSessionSelect {
return _q.Select().Aggregate(fns...)
}
func (_q *AuthSessionQuery) prepareQuery(ctx context.Context) error {
for _, inter := range _q.inters {
if inter == nil {
return fmt.Errorf("db: uninitialized interceptor (forgotten import db/runtime?)")
}
if trv, ok := inter.(Traverser); ok {
if err := trv.Traverse(ctx, _q); err != nil {
return err
}
}
}
for _, f := range _q.ctx.Fields {
if !authsession.ValidColumn(f) {
return &ValidationError{Name: f, err: fmt.Errorf("db: invalid field %q for query", f)}
}
}
if _q.path != nil {
prev, err := _q.path(ctx)
if err != nil {
return err
}
_q.sql = prev
}
return nil
}
func (_q *AuthSessionQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*AuthSession, error) {
var (
nodes = []*AuthSession{}
_spec = _q.querySpec()
)
_spec.ScanValues = func(columns []string) ([]any, error) {
return (*AuthSession).scanValues(nil, columns)
}
_spec.Assign = func(columns []string, values []any) error {
node := &AuthSession{config: _q.config}
nodes = append(nodes, node)
return node.assignValues(columns, values)
}
for i := range hooks {
hooks[i](ctx, _spec)
}
if err := sqlgraph.QueryNodes(ctx, _q.driver, _spec); err != nil {
return nil, err
}
if len(nodes) == 0 {
return nodes, nil
}
return nodes, nil
}
func (_q *AuthSessionQuery) sqlCount(ctx context.Context) (int, error) {
_spec := _q.querySpec()
_spec.Node.Columns = _q.ctx.Fields
if len(_q.ctx.Fields) > 0 {
_spec.Unique = _q.ctx.Unique != nil && *_q.ctx.Unique
}
return sqlgraph.CountNodes(ctx, _q.driver, _spec)
}
func (_q *AuthSessionQuery) querySpec() *sqlgraph.QuerySpec {
_spec := sqlgraph.NewQuerySpec(authsession.Table, authsession.Columns, sqlgraph.NewFieldSpec(authsession.FieldID, field.TypeString))
_spec.From = _q.sql
if unique := _q.ctx.Unique; unique != nil {
_spec.Unique = *unique
} else if _q.path != nil {
_spec.Unique = true
}
if fields := _q.ctx.Fields; len(fields) > 0 {
_spec.Node.Columns = make([]string, 0, len(fields))
_spec.Node.Columns = append(_spec.Node.Columns, authsession.FieldID)
for i := range fields {
if fields[i] != authsession.FieldID {
_spec.Node.Columns = append(_spec.Node.Columns, fields[i])
}
}
}
if ps := _q.predicates; len(ps) > 0 {
_spec.Predicate = func(selector *sql.Selector) {
for i := range ps {
ps[i](selector)
}
}
}
if limit := _q.ctx.Limit; limit != nil {
_spec.Limit = *limit
}
if offset := _q.ctx.Offset; offset != nil {
_spec.Offset = *offset
}
if ps := _q.order; len(ps) > 0 {
_spec.Order = func(selector *sql.Selector) {
for i := range ps {
ps[i](selector)
}
}
}
return _spec
}
func (_q *AuthSessionQuery) sqlQuery(ctx context.Context) *sql.Selector {
builder := sql.Dialect(_q.driver.Dialect())
t1 := builder.Table(authsession.Table)
columns := _q.ctx.Fields
if len(columns) == 0 {
columns = authsession.Columns
}
selector := builder.Select(t1.Columns(columns...)...).From(t1)
if _q.sql != nil {
selector = _q.sql
selector.Select(selector.Columns(columns...)...)
}
if _q.ctx.Unique != nil && *_q.ctx.Unique {
selector.Distinct()
}
for _, p := range _q.predicates {
p(selector)
}
for _, p := range _q.order {
p(selector)
}
if offset := _q.ctx.Offset; offset != nil {
// limit is mandatory for offset clause. We start
// with default value, and override it below if needed.
selector.Offset(*offset).Limit(math.MaxInt32)
}
if limit := _q.ctx.Limit; limit != nil {
selector.Limit(*limit)
}
return selector
}
// AuthSessionGroupBy is the group-by builder for AuthSession entities.
type AuthSessionGroupBy struct {
selector
build *AuthSessionQuery
}
// Aggregate adds the given aggregation functions to the group-by query.
func (_g *AuthSessionGroupBy) Aggregate(fns ...AggregateFunc) *AuthSessionGroupBy {
_g.fns = append(_g.fns, fns...)
return _g
}
// Scan applies the selector query and scans the result into the given value.
func (_g *AuthSessionGroupBy) Scan(ctx context.Context, v any) error {
ctx = setContextOp(ctx, _g.build.ctx, ent.OpQueryGroupBy)
if err := _g.build.prepareQuery(ctx); err != nil {
return err
}
return scanWithInterceptors[*AuthSessionQuery, *AuthSessionGroupBy](ctx, _g.build, _g, _g.build.inters, v)
}
func (_g *AuthSessionGroupBy) sqlScan(ctx context.Context, root *AuthSessionQuery, v any) error {
selector := root.sqlQuery(ctx).Select()
aggregation := make([]string, 0, len(_g.fns))
for _, fn := range _g.fns {
aggregation = append(aggregation, fn(selector))
}
if len(selector.SelectedColumns()) == 0 {
columns := make([]string, 0, len(*_g.flds)+len(_g.fns))
for _, f := range *_g.flds {
columns = append(columns, selector.C(f))
}
columns = append(columns, aggregation...)
selector.Select(columns...)
}
selector.GroupBy(selector.Columns(*_g.flds...)...)
if err := selector.Err(); err != nil {
return err
}
rows := &sql.Rows{}
query, args := selector.Query()
if err := _g.build.driver.Query(ctx, query, args, rows); err != nil {
return err
}
defer rows.Close()
return sql.ScanSlice(rows, v)
}
// AuthSessionSelect is the builder for selecting fields of AuthSession entities.
type AuthSessionSelect struct {
*AuthSessionQuery
selector
}
// Aggregate adds the given aggregation functions to the selector query.
func (_s *AuthSessionSelect) Aggregate(fns ...AggregateFunc) *AuthSessionSelect {
_s.fns = append(_s.fns, fns...)
return _s
}
// Scan applies the selector query and scans the result into the given value.
func (_s *AuthSessionSelect) Scan(ctx context.Context, v any) error {
ctx = setContextOp(ctx, _s.ctx, ent.OpQuerySelect)
if err := _s.prepareQuery(ctx); err != nil {
return err
}
return scanWithInterceptors[*AuthSessionQuery, *AuthSessionSelect](ctx, _s.AuthSessionQuery, _s, _s.inters, v)
}
func (_s *AuthSessionSelect) sqlScan(ctx context.Context, root *AuthSessionQuery, v any) error {
selector := root.sqlQuery(ctx)
aggregation := make([]string, 0, len(_s.fns))
for _, fn := range _s.fns {
aggregation = append(aggregation, fn(selector))
}
switch n := len(*_s.selector.flds); {
case n == 0 && len(aggregation) > 0:
selector.Select(aggregation...)
case n != 0 && len(aggregation) > 0:
selector.AppendSelect(aggregation...)
}
rows := &sql.Rows{}
query, args := selector.Query()
if err := _s.driver.Query(ctx, query, args, rows); err != nil {
return err
}
defer rows.Close()
return sql.ScanSlice(rows, v)
}

330
storage/ent/db/authsession_update.go

@ -0,0 +1,330 @@
// Code generated by ent, DO NOT EDIT.
package db
import (
"context"
"errors"
"fmt"
"time"
"entgo.io/ent/dialect/sql"
"entgo.io/ent/dialect/sql/sqlgraph"
"entgo.io/ent/schema/field"
"github.com/dexidp/dex/storage/ent/db/authsession"
"github.com/dexidp/dex/storage/ent/db/predicate"
)
// AuthSessionUpdate is the builder for updating AuthSession entities.
type AuthSessionUpdate struct {
config
hooks []Hook
mutation *AuthSessionMutation
}
// Where appends a list predicates to the AuthSessionUpdate builder.
func (_u *AuthSessionUpdate) Where(ps ...predicate.AuthSession) *AuthSessionUpdate {
_u.mutation.Where(ps...)
return _u
}
// SetClientStates sets the "client_states" field.
func (_u *AuthSessionUpdate) SetClientStates(v []byte) *AuthSessionUpdate {
_u.mutation.SetClientStates(v)
return _u
}
// SetCreatedAt sets the "created_at" field.
func (_u *AuthSessionUpdate) SetCreatedAt(v time.Time) *AuthSessionUpdate {
_u.mutation.SetCreatedAt(v)
return _u
}
// SetNillableCreatedAt sets the "created_at" field if the given value is not nil.
func (_u *AuthSessionUpdate) SetNillableCreatedAt(v *time.Time) *AuthSessionUpdate {
if v != nil {
_u.SetCreatedAt(*v)
}
return _u
}
// SetLastActivity sets the "last_activity" field.
func (_u *AuthSessionUpdate) SetLastActivity(v time.Time) *AuthSessionUpdate {
_u.mutation.SetLastActivity(v)
return _u
}
// SetNillableLastActivity sets the "last_activity" field if the given value is not nil.
func (_u *AuthSessionUpdate) SetNillableLastActivity(v *time.Time) *AuthSessionUpdate {
if v != nil {
_u.SetLastActivity(*v)
}
return _u
}
// SetIPAddress sets the "ip_address" field.
func (_u *AuthSessionUpdate) SetIPAddress(v string) *AuthSessionUpdate {
_u.mutation.SetIPAddress(v)
return _u
}
// SetNillableIPAddress sets the "ip_address" field if the given value is not nil.
func (_u *AuthSessionUpdate) SetNillableIPAddress(v *string) *AuthSessionUpdate {
if v != nil {
_u.SetIPAddress(*v)
}
return _u
}
// SetUserAgent sets the "user_agent" field.
func (_u *AuthSessionUpdate) SetUserAgent(v string) *AuthSessionUpdate {
_u.mutation.SetUserAgent(v)
return _u
}
// SetNillableUserAgent sets the "user_agent" field if the given value is not nil.
func (_u *AuthSessionUpdate) SetNillableUserAgent(v *string) *AuthSessionUpdate {
if v != nil {
_u.SetUserAgent(*v)
}
return _u
}
// Mutation returns the AuthSessionMutation object of the builder.
func (_u *AuthSessionUpdate) Mutation() *AuthSessionMutation {
return _u.mutation
}
// Save executes the query and returns the number of nodes affected by the update operation.
func (_u *AuthSessionUpdate) Save(ctx context.Context) (int, error) {
return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks)
}
// SaveX is like Save, but panics if an error occurs.
func (_u *AuthSessionUpdate) SaveX(ctx context.Context) int {
affected, err := _u.Save(ctx)
if err != nil {
panic(err)
}
return affected
}
// Exec executes the query.
func (_u *AuthSessionUpdate) Exec(ctx context.Context) error {
_, err := _u.Save(ctx)
return err
}
// ExecX is like Exec, but panics if an error occurs.
func (_u *AuthSessionUpdate) ExecX(ctx context.Context) {
if err := _u.Exec(ctx); err != nil {
panic(err)
}
}
func (_u *AuthSessionUpdate) sqlSave(ctx context.Context) (_node int, err error) {
_spec := sqlgraph.NewUpdateSpec(authsession.Table, authsession.Columns, sqlgraph.NewFieldSpec(authsession.FieldID, field.TypeString))
if ps := _u.mutation.predicates; len(ps) > 0 {
_spec.Predicate = func(selector *sql.Selector) {
for i := range ps {
ps[i](selector)
}
}
}
if value, ok := _u.mutation.ClientStates(); ok {
_spec.SetField(authsession.FieldClientStates, field.TypeBytes, value)
}
if value, ok := _u.mutation.CreatedAt(); ok {
_spec.SetField(authsession.FieldCreatedAt, field.TypeTime, value)
}
if value, ok := _u.mutation.LastActivity(); ok {
_spec.SetField(authsession.FieldLastActivity, field.TypeTime, value)
}
if value, ok := _u.mutation.IPAddress(); ok {
_spec.SetField(authsession.FieldIPAddress, field.TypeString, value)
}
if value, ok := _u.mutation.UserAgent(); ok {
_spec.SetField(authsession.FieldUserAgent, field.TypeString, value)
}
if _node, err = sqlgraph.UpdateNodes(ctx, _u.driver, _spec); err != nil {
if _, ok := err.(*sqlgraph.NotFoundError); ok {
err = &NotFoundError{authsession.Label}
} else if sqlgraph.IsConstraintError(err) {
err = &ConstraintError{msg: err.Error(), wrap: err}
}
return 0, err
}
_u.mutation.done = true
return _node, nil
}
// AuthSessionUpdateOne is the builder for updating a single AuthSession entity.
type AuthSessionUpdateOne struct {
config
fields []string
hooks []Hook
mutation *AuthSessionMutation
}
// SetClientStates sets the "client_states" field.
func (_u *AuthSessionUpdateOne) SetClientStates(v []byte) *AuthSessionUpdateOne {
_u.mutation.SetClientStates(v)
return _u
}
// SetCreatedAt sets the "created_at" field.
func (_u *AuthSessionUpdateOne) SetCreatedAt(v time.Time) *AuthSessionUpdateOne {
_u.mutation.SetCreatedAt(v)
return _u
}
// SetNillableCreatedAt sets the "created_at" field if the given value is not nil.
func (_u *AuthSessionUpdateOne) SetNillableCreatedAt(v *time.Time) *AuthSessionUpdateOne {
if v != nil {
_u.SetCreatedAt(*v)
}
return _u
}
// SetLastActivity sets the "last_activity" field.
func (_u *AuthSessionUpdateOne) SetLastActivity(v time.Time) *AuthSessionUpdateOne {
_u.mutation.SetLastActivity(v)
return _u
}
// SetNillableLastActivity sets the "last_activity" field if the given value is not nil.
func (_u *AuthSessionUpdateOne) SetNillableLastActivity(v *time.Time) *AuthSessionUpdateOne {
if v != nil {
_u.SetLastActivity(*v)
}
return _u
}
// SetIPAddress sets the "ip_address" field.
func (_u *AuthSessionUpdateOne) SetIPAddress(v string) *AuthSessionUpdateOne {
_u.mutation.SetIPAddress(v)
return _u
}
// SetNillableIPAddress sets the "ip_address" field if the given value is not nil.
func (_u *AuthSessionUpdateOne) SetNillableIPAddress(v *string) *AuthSessionUpdateOne {
if v != nil {
_u.SetIPAddress(*v)
}
return _u
}
// SetUserAgent sets the "user_agent" field.
func (_u *AuthSessionUpdateOne) SetUserAgent(v string) *AuthSessionUpdateOne {
_u.mutation.SetUserAgent(v)
return _u
}
// SetNillableUserAgent sets the "user_agent" field if the given value is not nil.
func (_u *AuthSessionUpdateOne) SetNillableUserAgent(v *string) *AuthSessionUpdateOne {
if v != nil {
_u.SetUserAgent(*v)
}
return _u
}
// Mutation returns the AuthSessionMutation object of the builder.
func (_u *AuthSessionUpdateOne) Mutation() *AuthSessionMutation {
return _u.mutation
}
// Where appends a list predicates to the AuthSessionUpdate builder.
func (_u *AuthSessionUpdateOne) Where(ps ...predicate.AuthSession) *AuthSessionUpdateOne {
_u.mutation.Where(ps...)
return _u
}
// Select allows selecting one or more fields (columns) of the returned entity.
// The default is selecting all fields defined in the entity schema.
func (_u *AuthSessionUpdateOne) Select(field string, fields ...string) *AuthSessionUpdateOne {
_u.fields = append([]string{field}, fields...)
return _u
}
// Save executes the query and returns the updated AuthSession entity.
func (_u *AuthSessionUpdateOne) Save(ctx context.Context) (*AuthSession, error) {
return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks)
}
// SaveX is like Save, but panics if an error occurs.
func (_u *AuthSessionUpdateOne) SaveX(ctx context.Context) *AuthSession {
node, err := _u.Save(ctx)
if err != nil {
panic(err)
}
return node
}
// Exec executes the query on the entity.
func (_u *AuthSessionUpdateOne) Exec(ctx context.Context) error {
_, err := _u.Save(ctx)
return err
}
// ExecX is like Exec, but panics if an error occurs.
func (_u *AuthSessionUpdateOne) ExecX(ctx context.Context) {
if err := _u.Exec(ctx); err != nil {
panic(err)
}
}
func (_u *AuthSessionUpdateOne) sqlSave(ctx context.Context) (_node *AuthSession, err error) {
_spec := sqlgraph.NewUpdateSpec(authsession.Table, authsession.Columns, sqlgraph.NewFieldSpec(authsession.FieldID, field.TypeString))
id, ok := _u.mutation.ID()
if !ok {
return nil, &ValidationError{Name: "id", err: errors.New(`db: missing "AuthSession.id" for update`)}
}
_spec.Node.ID.Value = id
if fields := _u.fields; len(fields) > 0 {
_spec.Node.Columns = make([]string, 0, len(fields))
_spec.Node.Columns = append(_spec.Node.Columns, authsession.FieldID)
for _, f := range fields {
if !authsession.ValidColumn(f) {
return nil, &ValidationError{Name: f, err: fmt.Errorf("db: invalid field %q for query", f)}
}
if f != authsession.FieldID {
_spec.Node.Columns = append(_spec.Node.Columns, f)
}
}
}
if ps := _u.mutation.predicates; len(ps) > 0 {
_spec.Predicate = func(selector *sql.Selector) {
for i := range ps {
ps[i](selector)
}
}
}
if value, ok := _u.mutation.ClientStates(); ok {
_spec.SetField(authsession.FieldClientStates, field.TypeBytes, value)
}
if value, ok := _u.mutation.CreatedAt(); ok {
_spec.SetField(authsession.FieldCreatedAt, field.TypeTime, value)
}
if value, ok := _u.mutation.LastActivity(); ok {
_spec.SetField(authsession.FieldLastActivity, field.TypeTime, value)
}
if value, ok := _u.mutation.IPAddress(); ok {
_spec.SetField(authsession.FieldIPAddress, field.TypeString, value)
}
if value, ok := _u.mutation.UserAgent(); ok {
_spec.SetField(authsession.FieldUserAgent, field.TypeString, value)
}
_node = &AuthSession{config: _u.config}
_spec.Assign = _node.assignValues
_spec.ScanValues = _node.scanValues
if err = sqlgraph.UpdateNode(ctx, _u.driver, _spec); err != nil {
if _, ok := err.(*sqlgraph.NotFoundError); ok {
err = &NotFoundError{authsession.Label}
} else if sqlgraph.IsConstraintError(err) {
err = &ConstraintError{msg: err.Error(), wrap: err}
}
return nil, err
}
_u.mutation.done = true
return _node, nil
}

155
storage/ent/db/client.go

@ -16,6 +16,7 @@ import (
"entgo.io/ent/dialect/sql"
"github.com/dexidp/dex/storage/ent/db/authcode"
"github.com/dexidp/dex/storage/ent/db/authrequest"
"github.com/dexidp/dex/storage/ent/db/authsession"
"github.com/dexidp/dex/storage/ent/db/connector"
"github.com/dexidp/dex/storage/ent/db/devicerequest"
"github.com/dexidp/dex/storage/ent/db/devicetoken"
@ -36,6 +37,8 @@ type Client struct {
AuthCode *AuthCodeClient
// AuthRequest is the client for interacting with the AuthRequest builders.
AuthRequest *AuthRequestClient
// AuthSession is the client for interacting with the AuthSession builders.
AuthSession *AuthSessionClient
// Connector is the client for interacting with the Connector builders.
Connector *ConnectorClient
// DeviceRequest is the client for interacting with the DeviceRequest builders.
@ -67,6 +70,7 @@ func (c *Client) init() {
c.Schema = migrate.NewSchema(c.driver)
c.AuthCode = NewAuthCodeClient(c.config)
c.AuthRequest = NewAuthRequestClient(c.config)
c.AuthSession = NewAuthSessionClient(c.config)
c.Connector = NewConnectorClient(c.config)
c.DeviceRequest = NewDeviceRequestClient(c.config)
c.DeviceToken = NewDeviceTokenClient(c.config)
@ -170,6 +174,7 @@ func (c *Client) Tx(ctx context.Context) (*Tx, error) {
config: cfg,
AuthCode: NewAuthCodeClient(cfg),
AuthRequest: NewAuthRequestClient(cfg),
AuthSession: NewAuthSessionClient(cfg),
Connector: NewConnectorClient(cfg),
DeviceRequest: NewDeviceRequestClient(cfg),
DeviceToken: NewDeviceTokenClient(cfg),
@ -200,6 +205,7 @@ func (c *Client) BeginTx(ctx context.Context, opts *sql.TxOptions) (*Tx, error)
config: cfg,
AuthCode: NewAuthCodeClient(cfg),
AuthRequest: NewAuthRequestClient(cfg),
AuthSession: NewAuthSessionClient(cfg),
Connector: NewConnectorClient(cfg),
DeviceRequest: NewDeviceRequestClient(cfg),
DeviceToken: NewDeviceTokenClient(cfg),
@ -238,8 +244,9 @@ func (c *Client) Close() error {
// In order to add hooks to a specific client, call: `client.Node.Use(...)`.
func (c *Client) Use(hooks ...Hook) {
for _, n := range []interface{ Use(...Hook) }{
c.AuthCode, c.AuthRequest, c.Connector, c.DeviceRequest, c.DeviceToken, c.Keys,
c.OAuth2Client, c.OfflineSession, c.Password, c.RefreshToken, c.UserIdentity,
c.AuthCode, c.AuthRequest, c.AuthSession, c.Connector, c.DeviceRequest,
c.DeviceToken, c.Keys, c.OAuth2Client, c.OfflineSession, c.Password,
c.RefreshToken, c.UserIdentity,
} {
n.Use(hooks...)
}
@ -249,8 +256,9 @@ func (c *Client) Use(hooks ...Hook) {
// In order to add interceptors to a specific client, call: `client.Node.Intercept(...)`.
func (c *Client) Intercept(interceptors ...Interceptor) {
for _, n := range []interface{ Intercept(...Interceptor) }{
c.AuthCode, c.AuthRequest, c.Connector, c.DeviceRequest, c.DeviceToken, c.Keys,
c.OAuth2Client, c.OfflineSession, c.Password, c.RefreshToken, c.UserIdentity,
c.AuthCode, c.AuthRequest, c.AuthSession, c.Connector, c.DeviceRequest,
c.DeviceToken, c.Keys, c.OAuth2Client, c.OfflineSession, c.Password,
c.RefreshToken, c.UserIdentity,
} {
n.Intercept(interceptors...)
}
@ -263,6 +271,8 @@ func (c *Client) Mutate(ctx context.Context, m Mutation) (Value, error) {
return c.AuthCode.mutate(ctx, m)
case *AuthRequestMutation:
return c.AuthRequest.mutate(ctx, m)
case *AuthSessionMutation:
return c.AuthSession.mutate(ctx, m)
case *ConnectorMutation:
return c.Connector.mutate(ctx, m)
case *DeviceRequestMutation:
@ -552,6 +562,139 @@ func (c *AuthRequestClient) mutate(ctx context.Context, m *AuthRequestMutation)
}
}
// AuthSessionClient is a client for the AuthSession schema.
type AuthSessionClient struct {
config
}
// NewAuthSessionClient returns a client for the AuthSession from the given config.
func NewAuthSessionClient(c config) *AuthSessionClient {
return &AuthSessionClient{config: c}
}
// Use adds a list of mutation hooks to the hooks stack.
// A call to `Use(f, g, h)` equals to `authsession.Hooks(f(g(h())))`.
func (c *AuthSessionClient) Use(hooks ...Hook) {
c.hooks.AuthSession = append(c.hooks.AuthSession, hooks...)
}
// Intercept adds a list of query interceptors to the interceptors stack.
// A call to `Intercept(f, g, h)` equals to `authsession.Intercept(f(g(h())))`.
func (c *AuthSessionClient) Intercept(interceptors ...Interceptor) {
c.inters.AuthSession = append(c.inters.AuthSession, interceptors...)
}
// Create returns a builder for creating a AuthSession entity.
func (c *AuthSessionClient) Create() *AuthSessionCreate {
mutation := newAuthSessionMutation(c.config, OpCreate)
return &AuthSessionCreate{config: c.config, hooks: c.Hooks(), mutation: mutation}
}
// CreateBulk returns a builder for creating a bulk of AuthSession entities.
func (c *AuthSessionClient) CreateBulk(builders ...*AuthSessionCreate) *AuthSessionCreateBulk {
return &AuthSessionCreateBulk{config: c.config, builders: builders}
}
// MapCreateBulk creates a bulk creation builder from the given slice. For each item in the slice, the function creates
// a builder and applies setFunc on it.
func (c *AuthSessionClient) MapCreateBulk(slice any, setFunc func(*AuthSessionCreate, int)) *AuthSessionCreateBulk {
rv := reflect.ValueOf(slice)
if rv.Kind() != reflect.Slice {
return &AuthSessionCreateBulk{err: fmt.Errorf("calling to AuthSessionClient.MapCreateBulk with wrong type %T, need slice", slice)}
}
builders := make([]*AuthSessionCreate, rv.Len())
for i := 0; i < rv.Len(); i++ {
builders[i] = c.Create()
setFunc(builders[i], i)
}
return &AuthSessionCreateBulk{config: c.config, builders: builders}
}
// Update returns an update builder for AuthSession.
func (c *AuthSessionClient) Update() *AuthSessionUpdate {
mutation := newAuthSessionMutation(c.config, OpUpdate)
return &AuthSessionUpdate{config: c.config, hooks: c.Hooks(), mutation: mutation}
}
// UpdateOne returns an update builder for the given entity.
func (c *AuthSessionClient) UpdateOne(_m *AuthSession) *AuthSessionUpdateOne {
mutation := newAuthSessionMutation(c.config, OpUpdateOne, withAuthSession(_m))
return &AuthSessionUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation}
}
// UpdateOneID returns an update builder for the given id.
func (c *AuthSessionClient) UpdateOneID(id string) *AuthSessionUpdateOne {
mutation := newAuthSessionMutation(c.config, OpUpdateOne, withAuthSessionID(id))
return &AuthSessionUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation}
}
// Delete returns a delete builder for AuthSession.
func (c *AuthSessionClient) Delete() *AuthSessionDelete {
mutation := newAuthSessionMutation(c.config, OpDelete)
return &AuthSessionDelete{config: c.config, hooks: c.Hooks(), mutation: mutation}
}
// DeleteOne returns a builder for deleting the given entity.
func (c *AuthSessionClient) DeleteOne(_m *AuthSession) *AuthSessionDeleteOne {
return c.DeleteOneID(_m.ID)
}
// DeleteOneID returns a builder for deleting the given entity by its id.
func (c *AuthSessionClient) DeleteOneID(id string) *AuthSessionDeleteOne {
builder := c.Delete().Where(authsession.ID(id))
builder.mutation.id = &id
builder.mutation.op = OpDeleteOne
return &AuthSessionDeleteOne{builder}
}
// Query returns a query builder for AuthSession.
func (c *AuthSessionClient) Query() *AuthSessionQuery {
return &AuthSessionQuery{
config: c.config,
ctx: &QueryContext{Type: TypeAuthSession},
inters: c.Interceptors(),
}
}
// Get returns a AuthSession entity by its id.
func (c *AuthSessionClient) Get(ctx context.Context, id string) (*AuthSession, error) {
return c.Query().Where(authsession.ID(id)).Only(ctx)
}
// GetX is like Get, but panics if an error occurs.
func (c *AuthSessionClient) GetX(ctx context.Context, id string) *AuthSession {
obj, err := c.Get(ctx, id)
if err != nil {
panic(err)
}
return obj
}
// Hooks returns the client hooks.
func (c *AuthSessionClient) Hooks() []Hook {
return c.hooks.AuthSession
}
// Interceptors returns the client interceptors.
func (c *AuthSessionClient) Interceptors() []Interceptor {
return c.inters.AuthSession
}
func (c *AuthSessionClient) mutate(ctx context.Context, m *AuthSessionMutation) (Value, error) {
switch m.Op() {
case OpCreate:
return (&AuthSessionCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
case OpUpdate:
return (&AuthSessionUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
case OpUpdateOne:
return (&AuthSessionUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
case OpDelete, OpDeleteOne:
return (&AuthSessionDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx)
default:
return nil, fmt.Errorf("db: unknown AuthSession mutation op: %q", m.Op())
}
}
// ConnectorClient is a client for the Connector schema.
type ConnectorClient struct {
config
@ -1752,11 +1895,11 @@ func (c *UserIdentityClient) mutate(ctx context.Context, m *UserIdentityMutation
// hooks and interceptors per client, for fast access.
type (
hooks struct {
AuthCode, AuthRequest, Connector, DeviceRequest, DeviceToken, Keys,
AuthCode, AuthRequest, AuthSession, Connector, DeviceRequest, DeviceToken, Keys,
OAuth2Client, OfflineSession, Password, RefreshToken, UserIdentity []ent.Hook
}
inters struct {
AuthCode, AuthRequest, Connector, DeviceRequest, DeviceToken, Keys,
AuthCode, AuthRequest, AuthSession, Connector, DeviceRequest, DeviceToken, Keys,
OAuth2Client, OfflineSession, Password, RefreshToken,
UserIdentity []ent.Interceptor
}

2
storage/ent/db/ent.go

@ -14,6 +14,7 @@ import (
"entgo.io/ent/dialect/sql/sqlgraph"
"github.com/dexidp/dex/storage/ent/db/authcode"
"github.com/dexidp/dex/storage/ent/db/authrequest"
"github.com/dexidp/dex/storage/ent/db/authsession"
"github.com/dexidp/dex/storage/ent/db/connector"
"github.com/dexidp/dex/storage/ent/db/devicerequest"
"github.com/dexidp/dex/storage/ent/db/devicetoken"
@ -85,6 +86,7 @@ func checkColumn(t, c string) error {
columnCheck = sql.NewColumnCheck(map[string]func(string) bool{
authcode.Table: authcode.ValidColumn,
authrequest.Table: authrequest.ValidColumn,
authsession.Table: authsession.ValidColumn,
connector.Table: connector.ValidColumn,
devicerequest.Table: devicerequest.ValidColumn,
devicetoken.Table: devicetoken.ValidColumn,

12
storage/ent/db/hook/hook.go

@ -33,6 +33,18 @@ func (f AuthRequestFunc) Mutate(ctx context.Context, m db.Mutation) (db.Value, e
return nil, fmt.Errorf("unexpected mutation type %T. expect *db.AuthRequestMutation", m)
}
// The AuthSessionFunc type is an adapter to allow the use of ordinary
// function as AuthSession mutator.
type AuthSessionFunc func(context.Context, *db.AuthSessionMutation) (db.Value, error)
// Mutate calls f(ctx, m).
func (f AuthSessionFunc) Mutate(ctx context.Context, m db.Mutation) (db.Value, error) {
if mv, ok := m.(*db.AuthSessionMutation); ok {
return f(ctx, mv)
}
return nil, fmt.Errorf("unexpected mutation type %T. expect *db.AuthSessionMutation", m)
}
// The ConnectorFunc type is an adapter to allow the use of ordinary
// function as Connector mutator.
type ConnectorFunc func(context.Context, *db.ConnectorMutation) (db.Value, error)

16
storage/ent/db/migrate/schema.go

@ -63,6 +63,21 @@ var (
Columns: AuthRequestsColumns,
PrimaryKey: []*schema.Column{AuthRequestsColumns[0]},
}
// AuthSessionsColumns holds the columns for the "auth_sessions" table.
AuthSessionsColumns = []*schema.Column{
{Name: "id", Type: field.TypeString, Unique: true, Size: 2147483647, SchemaType: map[string]string{"mysql": "varchar(384)", "postgres": "text", "sqlite3": "text"}},
{Name: "client_states", Type: field.TypeBytes},
{Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"mysql": "datetime(3)", "postgres": "timestamptz", "sqlite3": "timestamp"}},
{Name: "last_activity", Type: field.TypeTime, SchemaType: map[string]string{"mysql": "datetime(3)", "postgres": "timestamptz", "sqlite3": "timestamp"}},
{Name: "ip_address", Type: field.TypeString, Size: 2147483647, Default: "", SchemaType: map[string]string{"mysql": "varchar(384)", "postgres": "text", "sqlite3": "text"}},
{Name: "user_agent", Type: field.TypeString, Size: 2147483647, Default: "", SchemaType: map[string]string{"mysql": "varchar(384)", "postgres": "text", "sqlite3": "text"}},
}
// AuthSessionsTable holds the schema information for the "auth_sessions" table.
AuthSessionsTable = &schema.Table{
Name: "auth_sessions",
Columns: AuthSessionsColumns,
PrimaryKey: []*schema.Column{AuthSessionsColumns[0]},
}
// ConnectorsColumns holds the columns for the "connectors" table.
ConnectorsColumns = []*schema.Column{
{Name: "id", Type: field.TypeString, Unique: true, Size: 100, SchemaType: map[string]string{"mysql": "varchar(384)", "postgres": "text", "sqlite3": "text"}},
@ -226,6 +241,7 @@ var (
Tables = []*schema.Table{
AuthCodesTable,
AuthRequestsTable,
AuthSessionsTable,
ConnectorsTable,
DeviceRequestsTable,
DeviceTokensTable,

550
storage/ent/db/mutation.go

@ -14,6 +14,7 @@ import (
"github.com/dexidp/dex/storage"
"github.com/dexidp/dex/storage/ent/db/authcode"
"github.com/dexidp/dex/storage/ent/db/authrequest"
"github.com/dexidp/dex/storage/ent/db/authsession"
"github.com/dexidp/dex/storage/ent/db/connector"
"github.com/dexidp/dex/storage/ent/db/devicerequest"
"github.com/dexidp/dex/storage/ent/db/devicetoken"
@ -38,6 +39,7 @@ const (
// Node types.
TypeAuthCode = "AuthCode"
TypeAuthRequest = "AuthRequest"
TypeAuthSession = "AuthSession"
TypeConnector = "Connector"
TypeDeviceRequest = "DeviceRequest"
TypeDeviceToken = "DeviceToken"
@ -2719,6 +2721,554 @@ func (m *AuthRequestMutation) ResetEdge(name string) error {
return fmt.Errorf("unknown AuthRequest edge %s", name)
}
// AuthSessionMutation represents an operation that mutates the AuthSession nodes in the graph.
type AuthSessionMutation struct {
config
op Op
typ string
id *string
client_states *[]byte
created_at *time.Time
last_activity *time.Time
ip_address *string
user_agent *string
clearedFields map[string]struct{}
done bool
oldValue func(context.Context) (*AuthSession, error)
predicates []predicate.AuthSession
}
var _ ent.Mutation = (*AuthSessionMutation)(nil)
// authsessionOption allows management of the mutation configuration using functional options.
type authsessionOption func(*AuthSessionMutation)
// newAuthSessionMutation creates new mutation for the AuthSession entity.
func newAuthSessionMutation(c config, op Op, opts ...authsessionOption) *AuthSessionMutation {
m := &AuthSessionMutation{
config: c,
op: op,
typ: TypeAuthSession,
clearedFields: make(map[string]struct{}),
}
for _, opt := range opts {
opt(m)
}
return m
}
// withAuthSessionID sets the ID field of the mutation.
func withAuthSessionID(id string) authsessionOption {
return func(m *AuthSessionMutation) {
var (
err error
once sync.Once
value *AuthSession
)
m.oldValue = func(ctx context.Context) (*AuthSession, error) {
once.Do(func() {
if m.done {
err = errors.New("querying old values post mutation is not allowed")
} else {
value, err = m.Client().AuthSession.Get(ctx, id)
}
})
return value, err
}
m.id = &id
}
}
// withAuthSession sets the old AuthSession of the mutation.
func withAuthSession(node *AuthSession) authsessionOption {
return func(m *AuthSessionMutation) {
m.oldValue = func(context.Context) (*AuthSession, error) {
return node, nil
}
m.id = &node.ID
}
}
// Client returns a new `ent.Client` from the mutation. If the mutation was
// executed in a transaction (ent.Tx), a transactional client is returned.
func (m AuthSessionMutation) Client() *Client {
client := &Client{config: m.config}
client.init()
return client
}
// Tx returns an `ent.Tx` for mutations that were executed in transactions;
// it returns an error otherwise.
func (m AuthSessionMutation) Tx() (*Tx, error) {
if _, ok := m.driver.(*txDriver); !ok {
return nil, errors.New("db: mutation is not running in a transaction")
}
tx := &Tx{config: m.config}
tx.init()
return tx, nil
}
// SetID sets the value of the id field. Note that this
// operation is only accepted on creation of AuthSession entities.
func (m *AuthSessionMutation) SetID(id string) {
m.id = &id
}
// ID returns the ID value in the mutation. Note that the ID is only available
// if it was provided to the builder or after it was returned from the database.
func (m *AuthSessionMutation) ID() (id string, exists bool) {
if m.id == nil {
return
}
return *m.id, true
}
// IDs queries the database and returns the entity ids that match the mutation's predicate.
// That means, if the mutation is applied within a transaction with an isolation level such
// as sql.LevelSerializable, the returned ids match the ids of the rows that will be updated
// or updated by the mutation.
func (m *AuthSessionMutation) IDs(ctx context.Context) ([]string, error) {
switch {
case m.op.Is(OpUpdateOne | OpDeleteOne):
id, exists := m.ID()
if exists {
return []string{id}, nil
}
fallthrough
case m.op.Is(OpUpdate | OpDelete):
return m.Client().AuthSession.Query().Where(m.predicates...).IDs(ctx)
default:
return nil, fmt.Errorf("IDs is not allowed on %s operations", m.op)
}
}
// SetClientStates sets the "client_states" field.
func (m *AuthSessionMutation) SetClientStates(b []byte) {
m.client_states = &b
}
// ClientStates returns the value of the "client_states" field in the mutation.
func (m *AuthSessionMutation) ClientStates() (r []byte, exists bool) {
v := m.client_states
if v == nil {
return
}
return *v, true
}
// OldClientStates returns the old "client_states" field's value of the AuthSession entity.
// If the AuthSession object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
func (m *AuthSessionMutation) OldClientStates(ctx context.Context) (v []byte, err error) {
if !m.op.Is(OpUpdateOne) {
return v, errors.New("OldClientStates is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
return v, errors.New("OldClientStates requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
return v, fmt.Errorf("querying old value for OldClientStates: %w", err)
}
return oldValue.ClientStates, nil
}
// ResetClientStates resets all changes to the "client_states" field.
func (m *AuthSessionMutation) ResetClientStates() {
m.client_states = nil
}
// SetCreatedAt sets the "created_at" field.
func (m *AuthSessionMutation) SetCreatedAt(t time.Time) {
m.created_at = &t
}
// CreatedAt returns the value of the "created_at" field in the mutation.
func (m *AuthSessionMutation) CreatedAt() (r time.Time, exists bool) {
v := m.created_at
if v == nil {
return
}
return *v, true
}
// OldCreatedAt returns the old "created_at" field's value of the AuthSession entity.
// If the AuthSession object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
func (m *AuthSessionMutation) OldCreatedAt(ctx context.Context) (v time.Time, err error) {
if !m.op.Is(OpUpdateOne) {
return v, errors.New("OldCreatedAt is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
return v, errors.New("OldCreatedAt requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
return v, fmt.Errorf("querying old value for OldCreatedAt: %w", err)
}
return oldValue.CreatedAt, nil
}
// ResetCreatedAt resets all changes to the "created_at" field.
func (m *AuthSessionMutation) ResetCreatedAt() {
m.created_at = nil
}
// SetLastActivity sets the "last_activity" field.
func (m *AuthSessionMutation) SetLastActivity(t time.Time) {
m.last_activity = &t
}
// LastActivity returns the value of the "last_activity" field in the mutation.
func (m *AuthSessionMutation) LastActivity() (r time.Time, exists bool) {
v := m.last_activity
if v == nil {
return
}
return *v, true
}
// OldLastActivity returns the old "last_activity" field's value of the AuthSession entity.
// If the AuthSession object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
func (m *AuthSessionMutation) OldLastActivity(ctx context.Context) (v time.Time, err error) {
if !m.op.Is(OpUpdateOne) {
return v, errors.New("OldLastActivity is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
return v, errors.New("OldLastActivity requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
return v, fmt.Errorf("querying old value for OldLastActivity: %w", err)
}
return oldValue.LastActivity, nil
}
// ResetLastActivity resets all changes to the "last_activity" field.
func (m *AuthSessionMutation) ResetLastActivity() {
m.last_activity = nil
}
// SetIPAddress sets the "ip_address" field.
func (m *AuthSessionMutation) SetIPAddress(s string) {
m.ip_address = &s
}
// IPAddress returns the value of the "ip_address" field in the mutation.
func (m *AuthSessionMutation) IPAddress() (r string, exists bool) {
v := m.ip_address
if v == nil {
return
}
return *v, true
}
// OldIPAddress returns the old "ip_address" field's value of the AuthSession entity.
// If the AuthSession object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
func (m *AuthSessionMutation) OldIPAddress(ctx context.Context) (v string, err error) {
if !m.op.Is(OpUpdateOne) {
return v, errors.New("OldIPAddress is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
return v, errors.New("OldIPAddress requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
return v, fmt.Errorf("querying old value for OldIPAddress: %w", err)
}
return oldValue.IPAddress, nil
}
// ResetIPAddress resets all changes to the "ip_address" field.
func (m *AuthSessionMutation) ResetIPAddress() {
m.ip_address = nil
}
// SetUserAgent sets the "user_agent" field.
func (m *AuthSessionMutation) SetUserAgent(s string) {
m.user_agent = &s
}
// UserAgent returns the value of the "user_agent" field in the mutation.
func (m *AuthSessionMutation) UserAgent() (r string, exists bool) {
v := m.user_agent
if v == nil {
return
}
return *v, true
}
// OldUserAgent returns the old "user_agent" field's value of the AuthSession entity.
// If the AuthSession object wasn't provided to the builder, the object is fetched from the database.
// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
func (m *AuthSessionMutation) OldUserAgent(ctx context.Context) (v string, err error) {
if !m.op.Is(OpUpdateOne) {
return v, errors.New("OldUserAgent is only allowed on UpdateOne operations")
}
if m.id == nil || m.oldValue == nil {
return v, errors.New("OldUserAgent requires an ID field in the mutation")
}
oldValue, err := m.oldValue(ctx)
if err != nil {
return v, fmt.Errorf("querying old value for OldUserAgent: %w", err)
}
return oldValue.UserAgent, nil
}
// ResetUserAgent resets all changes to the "user_agent" field.
func (m *AuthSessionMutation) ResetUserAgent() {
m.user_agent = nil
}
// Where appends a list predicates to the AuthSessionMutation builder.
func (m *AuthSessionMutation) Where(ps ...predicate.AuthSession) {
m.predicates = append(m.predicates, ps...)
}
// WhereP appends storage-level predicates to the AuthSessionMutation builder. Using this method,
// users can use type-assertion to append predicates that do not depend on any generated package.
func (m *AuthSessionMutation) WhereP(ps ...func(*sql.Selector)) {
p := make([]predicate.AuthSession, len(ps))
for i := range ps {
p[i] = ps[i]
}
m.Where(p...)
}
// Op returns the operation name.
func (m *AuthSessionMutation) Op() Op {
return m.op
}
// SetOp allows setting the mutation operation.
func (m *AuthSessionMutation) SetOp(op Op) {
m.op = op
}
// Type returns the node type of this mutation (AuthSession).
func (m *AuthSessionMutation) Type() string {
return m.typ
}
// Fields returns all fields that were changed during this mutation. Note that in
// order to get all numeric fields that were incremented/decremented, call
// AddedFields().
func (m *AuthSessionMutation) Fields() []string {
fields := make([]string, 0, 5)
if m.client_states != nil {
fields = append(fields, authsession.FieldClientStates)
}
if m.created_at != nil {
fields = append(fields, authsession.FieldCreatedAt)
}
if m.last_activity != nil {
fields = append(fields, authsession.FieldLastActivity)
}
if m.ip_address != nil {
fields = append(fields, authsession.FieldIPAddress)
}
if m.user_agent != nil {
fields = append(fields, authsession.FieldUserAgent)
}
return fields
}
// Field returns the value of a field with the given name. The second boolean
// return value indicates that this field was not set, or was not defined in the
// schema.
func (m *AuthSessionMutation) Field(name string) (ent.Value, bool) {
switch name {
case authsession.FieldClientStates:
return m.ClientStates()
case authsession.FieldCreatedAt:
return m.CreatedAt()
case authsession.FieldLastActivity:
return m.LastActivity()
case authsession.FieldIPAddress:
return m.IPAddress()
case authsession.FieldUserAgent:
return m.UserAgent()
}
return nil, false
}
// OldField returns the old value of the field from the database. An error is
// returned if the mutation operation is not UpdateOne, or the query to the
// database failed.
func (m *AuthSessionMutation) OldField(ctx context.Context, name string) (ent.Value, error) {
switch name {
case authsession.FieldClientStates:
return m.OldClientStates(ctx)
case authsession.FieldCreatedAt:
return m.OldCreatedAt(ctx)
case authsession.FieldLastActivity:
return m.OldLastActivity(ctx)
case authsession.FieldIPAddress:
return m.OldIPAddress(ctx)
case authsession.FieldUserAgent:
return m.OldUserAgent(ctx)
}
return nil, fmt.Errorf("unknown AuthSession field %s", name)
}
// SetField sets the value of a field with the given name. It returns an error if
// the field is not defined in the schema, or if the type mismatched the field
// type.
func (m *AuthSessionMutation) SetField(name string, value ent.Value) error {
switch name {
case authsession.FieldClientStates:
v, ok := value.([]byte)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
m.SetClientStates(v)
return nil
case authsession.FieldCreatedAt:
v, ok := value.(time.Time)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
m.SetCreatedAt(v)
return nil
case authsession.FieldLastActivity:
v, ok := value.(time.Time)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
m.SetLastActivity(v)
return nil
case authsession.FieldIPAddress:
v, ok := value.(string)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
m.SetIPAddress(v)
return nil
case authsession.FieldUserAgent:
v, ok := value.(string)
if !ok {
return fmt.Errorf("unexpected type %T for field %s", value, name)
}
m.SetUserAgent(v)
return nil
}
return fmt.Errorf("unknown AuthSession field %s", name)
}
// AddedFields returns all numeric fields that were incremented/decremented during
// this mutation.
func (m *AuthSessionMutation) AddedFields() []string {
return nil
}
// AddedField returns the numeric value that was incremented/decremented on a field
// with the given name. The second boolean return value indicates that this field
// was not set, or was not defined in the schema.
func (m *AuthSessionMutation) AddedField(name string) (ent.Value, bool) {
return nil, false
}
// AddField adds the value to the field with the given name. It returns an error if
// the field is not defined in the schema, or if the type mismatched the field
// type.
func (m *AuthSessionMutation) AddField(name string, value ent.Value) error {
switch name {
}
return fmt.Errorf("unknown AuthSession numeric field %s", name)
}
// ClearedFields returns all nullable fields that were cleared during this
// mutation.
func (m *AuthSessionMutation) ClearedFields() []string {
return nil
}
// FieldCleared returns a boolean indicating if a field with the given name was
// cleared in this mutation.
func (m *AuthSessionMutation) FieldCleared(name string) bool {
_, ok := m.clearedFields[name]
return ok
}
// ClearField clears the value of the field with the given name. It returns an
// error if the field is not defined in the schema.
func (m *AuthSessionMutation) ClearField(name string) error {
return fmt.Errorf("unknown AuthSession nullable field %s", name)
}
// ResetField resets all changes in the mutation for the field with the given name.
// It returns an error if the field is not defined in the schema.
func (m *AuthSessionMutation) ResetField(name string) error {
switch name {
case authsession.FieldClientStates:
m.ResetClientStates()
return nil
case authsession.FieldCreatedAt:
m.ResetCreatedAt()
return nil
case authsession.FieldLastActivity:
m.ResetLastActivity()
return nil
case authsession.FieldIPAddress:
m.ResetIPAddress()
return nil
case authsession.FieldUserAgent:
m.ResetUserAgent()
return nil
}
return fmt.Errorf("unknown AuthSession field %s", name)
}
// AddedEdges returns all edge names that were set/added in this mutation.
func (m *AuthSessionMutation) AddedEdges() []string {
edges := make([]string, 0, 0)
return edges
}
// AddedIDs returns all IDs (to other nodes) that were added for the given edge
// name in this mutation.
func (m *AuthSessionMutation) AddedIDs(name string) []ent.Value {
return nil
}
// RemovedEdges returns all edge names that were removed in this mutation.
func (m *AuthSessionMutation) RemovedEdges() []string {
edges := make([]string, 0, 0)
return edges
}
// RemovedIDs returns all IDs (to other nodes) that were removed for the edge with
// the given name in this mutation.
func (m *AuthSessionMutation) RemovedIDs(name string) []ent.Value {
return nil
}
// ClearedEdges returns all edge names that were cleared in this mutation.
func (m *AuthSessionMutation) ClearedEdges() []string {
edges := make([]string, 0, 0)
return edges
}
// EdgeCleared returns a boolean which indicates if the edge with the given name
// was cleared in this mutation.
func (m *AuthSessionMutation) EdgeCleared(name string) bool {
return false
}
// ClearEdge clears the value of the edge with the given name. It returns an error
// if that edge is not defined in the schema.
func (m *AuthSessionMutation) ClearEdge(name string) error {
return fmt.Errorf("unknown AuthSession unique edge %s", name)
}
// ResetEdge resets all changes to the edge with the given name in this mutation.
// It returns an error if the edge is not defined in the schema.
func (m *AuthSessionMutation) ResetEdge(name string) error {
return fmt.Errorf("unknown AuthSession edge %s", name)
}
// ConnectorMutation represents an operation that mutates the Connector nodes in the graph.
type ConnectorMutation struct {
config

3
storage/ent/db/predicate/predicate.go

@ -12,6 +12,9 @@ type AuthCode func(*sql.Selector)
// AuthRequest is the predicate function for authrequest builders.
type AuthRequest func(*sql.Selector)
// AuthSession is the predicate function for authsession builders.
type AuthSession func(*sql.Selector)
// Connector is the predicate function for connector builders.
type Connector func(*sql.Selector)

15
storage/ent/db/runtime.go

@ -7,6 +7,7 @@ import (
"github.com/dexidp/dex/storage/ent/db/authcode"
"github.com/dexidp/dex/storage/ent/db/authrequest"
"github.com/dexidp/dex/storage/ent/db/authsession"
"github.com/dexidp/dex/storage/ent/db/connector"
"github.com/dexidp/dex/storage/ent/db/devicerequest"
"github.com/dexidp/dex/storage/ent/db/devicetoken"
@ -87,6 +88,20 @@ func init() {
authrequestDescID := authrequestFields[0].Descriptor()
// authrequest.IDValidator is a validator for the "id" field. It is called by the builders before save.
authrequest.IDValidator = authrequestDescID.Validators[0].(func(string) error)
authsessionFields := schema.AuthSession{}.Fields()
_ = authsessionFields
// authsessionDescIPAddress is the schema descriptor for ip_address field.
authsessionDescIPAddress := authsessionFields[4].Descriptor()
// authsession.DefaultIPAddress holds the default value on creation for the ip_address field.
authsession.DefaultIPAddress = authsessionDescIPAddress.Default.(string)
// authsessionDescUserAgent is the schema descriptor for user_agent field.
authsessionDescUserAgent := authsessionFields[5].Descriptor()
// authsession.DefaultUserAgent holds the default value on creation for the user_agent field.
authsession.DefaultUserAgent = authsessionDescUserAgent.Default.(string)
// authsessionDescID is the schema descriptor for id field.
authsessionDescID := authsessionFields[0].Descriptor()
// authsession.IDValidator is a validator for the "id" field. It is called by the builders before save.
authsession.IDValidator = authsessionDescID.Validators[0].(func(string) error)
connectorFields := schema.Connector{}.Fields()
_ = connectorFields
// connectorDescType is the schema descriptor for type field.

3
storage/ent/db/tx.go

@ -16,6 +16,8 @@ type Tx struct {
AuthCode *AuthCodeClient
// AuthRequest is the client for interacting with the AuthRequest builders.
AuthRequest *AuthRequestClient
// AuthSession is the client for interacting with the AuthSession builders.
AuthSession *AuthSessionClient
// Connector is the client for interacting with the Connector builders.
Connector *ConnectorClient
// DeviceRequest is the client for interacting with the DeviceRequest builders.
@ -167,6 +169,7 @@ func (tx *Tx) Client() *Client {
func (tx *Tx) init() {
tx.AuthCode = NewAuthCodeClient(tx.config)
tx.AuthRequest = NewAuthRequestClient(tx.config)
tx.AuthSession = NewAuthSessionClient(tx.config)
tx.Connector = NewConnectorClient(tx.config)
tx.DeviceRequest = NewDeviceRequestClient(tx.config)
tx.DeviceToken = NewDeviceTokenClient(tx.config)

37
storage/ent/schema/authsession.go

@ -0,0 +1,37 @@
package schema
import (
"entgo.io/ent"
"entgo.io/ent/schema/field"
)
// AuthSession holds the schema definition for the AuthSession entity.
type AuthSession struct {
ent.Schema
}
// Fields of the AuthSession.
func (AuthSession) Fields() []ent.Field {
return []ent.Field{
field.Text("id").
SchemaType(textSchema).
NotEmpty().
Unique(),
field.Bytes("client_states"),
field.Time("created_at").
SchemaType(timeSchema),
field.Time("last_activity").
SchemaType(timeSchema),
field.Text("ip_address").
SchemaType(textSchema).
Default(""),
field.Text("user_agent").
SchemaType(textSchema).
Default(""),
}
}
// Edges of the AuthSession.
func (AuthSession) Edges() []ent.Edge {
return []ent.Edge{}
}

60
storage/etcd/etcd.go

@ -25,6 +25,7 @@ const (
deviceRequestPrefix = "device_req/"
deviceTokenPrefix = "device_token/"
userIdentityPrefix = "user_identity/"
authSessionPrefix = "auth_session/"
// defaultStorageTimeout will be applied to all storage's operations.
defaultStorageTimeout = 5 * time.Second
@ -422,6 +423,61 @@ func (c *conn) ListUserIdentities(ctx context.Context) (identities []storage.Use
return identities, nil
}
func (c *conn) CreateAuthSession(ctx context.Context, s storage.AuthSession) error {
return c.txnCreate(ctx, keyAuthSession(s.ID), fromStorageAuthSession(s))
}
func (c *conn) GetAuthSession(ctx context.Context, sessionID string) (storage.AuthSession, error) {
ctx, cancel := context.WithTimeout(ctx, defaultStorageTimeout)
defer cancel()
var s AuthSession
if err := c.getKey(ctx, keyAuthSession(sessionID), &s); err != nil {
return storage.AuthSession{}, err
}
return toStorageAuthSession(s), nil
}
func (c *conn) UpdateAuthSession(ctx context.Context, sessionID string, updater func(s storage.AuthSession) (storage.AuthSession, error)) error {
ctx, cancel := context.WithTimeout(ctx, defaultStorageTimeout)
defer cancel()
return c.txnUpdate(ctx, keyAuthSession(sessionID), func(currentValue []byte) ([]byte, error) {
var current AuthSession
if len(currentValue) > 0 {
if err := json.Unmarshal(currentValue, &current); err != nil {
return nil, err
}
}
updated, err := updater(toStorageAuthSession(current))
if err != nil {
return nil, err
}
return json.Marshal(fromStorageAuthSession(updated))
})
}
func (c *conn) ListAuthSessions(ctx context.Context) (sessions []storage.AuthSession, err error) {
ctx, cancel := context.WithTimeout(ctx, defaultStorageTimeout)
defer cancel()
res, err := c.db.Get(ctx, authSessionPrefix, clientv3.WithPrefix())
if err != nil {
return sessions, err
}
for _, v := range res.Kvs {
var s AuthSession
if err = json.Unmarshal(v.Value, &s); err != nil {
return sessions, err
}
sessions = append(sessions, toStorageAuthSession(s))
}
return sessions, nil
}
func (c *conn) DeleteAuthSession(ctx context.Context, sessionID string) error {
ctx, cancel := context.WithTimeout(ctx, defaultStorageTimeout)
defer cancel()
return c.deleteKey(ctx, keyAuthSession(sessionID))
}
func (c *conn) CreateConnector(ctx context.Context, connector storage.Connector) error {
return c.txnCreate(ctx, keyID(connectorPrefix, connector.ID), connector)
}
@ -617,6 +673,10 @@ func keyUserIdentity(userID, connectorID string) string {
return userIdentityPrefix + strings.ToLower(userID+"|"+connectorID)
}
func keyAuthSession(sessionID string) string {
return strings.ToLower(authSessionPrefix + sessionID)
}
func (c *conn) CreateDeviceRequest(ctx context.Context, d storage.DeviceRequest) error {
return c.txnCreate(ctx, keyID(deviceRequestPrefix, d.UserCode), fromStorageDeviceRequest(d))
}

36
storage/etcd/types.go

@ -296,6 +296,42 @@ func toStorageUserIdentity(u UserIdentity) storage.UserIdentity {
return s
}
// AuthSession is a mirrored struct from storage with JSON struct tags.
type AuthSession struct {
ID string `json:"id,omitempty"`
ClientStates map[string]*storage.ClientAuthState `json:"client_states,omitempty"`
CreatedAt time.Time `json:"created_at"`
LastActivity time.Time `json:"last_activity"`
IPAddress string `json:"ip_address,omitempty"`
UserAgent string `json:"user_agent,omitempty"`
}
func fromStorageAuthSession(s storage.AuthSession) AuthSession {
return AuthSession{
ID: s.ID,
ClientStates: s.ClientStates,
CreatedAt: s.CreatedAt,
LastActivity: s.LastActivity,
IPAddress: s.IPAddress,
UserAgent: s.UserAgent,
}
}
func toStorageAuthSession(s AuthSession) storage.AuthSession {
result := storage.AuthSession{
ID: s.ID,
ClientStates: s.ClientStates,
CreatedAt: s.CreatedAt,
LastActivity: s.LastActivity,
IPAddress: s.IPAddress,
UserAgent: s.UserAgent,
}
if result.ClientStates == nil {
result.ClientStates = make(map[string]*storage.ClientAuthState)
}
return result
}
// DeviceRequest is a mirrored struct from storage with JSON struct tags
type DeviceRequest struct {
UserCode string `json:"user_code"`

50
storage/kubernetes/storage.go

@ -26,6 +26,7 @@ const (
kindDeviceRequest = "DeviceRequest"
kindDeviceToken = "DeviceToken"
kindUserIdentity = "UserIdentity"
kindAuthSession = "AuthSession"
)
const (
@ -40,6 +41,7 @@ const (
resourceDeviceRequest = "devicerequests"
resourceDeviceToken = "devicetokens"
resourceUserIdentity = "useridentities"
resourceAuthSession = "authsessions"
)
const (
@ -809,6 +811,54 @@ func (cli *client) ListUserIdentities(ctx context.Context) ([]storage.UserIdenti
return userIdentities, nil
}
func (cli *client) CreateAuthSession(ctx context.Context, s storage.AuthSession) error {
return cli.post(resourceAuthSession, cli.fromStorageAuthSession(s))
}
func (cli *client) GetAuthSession(ctx context.Context, sessionID string) (storage.AuthSession, error) {
var s AuthSession
if err := cli.get(resourceAuthSession, sessionID, &s); err != nil {
return storage.AuthSession{}, err
}
return toStorageAuthSession(s), nil
}
func (cli *client) UpdateAuthSession(ctx context.Context, sessionID string, updater func(old storage.AuthSession) (storage.AuthSession, error)) error {
return retryOnConflict(ctx, func() error {
var s AuthSession
if err := cli.get(resourceAuthSession, sessionID, &s); err != nil {
return err
}
updated, err := updater(toStorageAuthSession(s))
if err != nil {
return err
}
newSession := cli.fromStorageAuthSession(updated)
newSession.ObjectMeta = s.ObjectMeta
return cli.put(resourceAuthSession, sessionID, newSession)
})
}
func (cli *client) ListAuthSessions(ctx context.Context) ([]storage.AuthSession, error) {
var authSessionList AuthSessionList
if err := cli.list(resourceAuthSession, &authSessionList); err != nil {
return nil, fmt.Errorf("failed to list auth sessions: %v", err)
}
sessions := make([]storage.AuthSession, len(authSessionList.AuthSessions))
for i, s := range authSessionList.AuthSessions {
sessions[i] = toStorageAuthSession(s)
}
return sessions, nil
}
func (cli *client) DeleteAuthSession(ctx context.Context, sessionID string) error {
return cli.delete(resourceAuthSession, sessionID)
}
func isKubernetesAPIConflictError(err error) bool {
if httpErr, ok := err.(httpError); ok {
if httpErr.StatusCode() == http.StatusConflict {

69
storage/kubernetes/types.go

@ -243,6 +243,23 @@ func customResourceDefinitions(apiVersion string) []k8sapi.CustomResourceDefinit
},
},
},
{
ObjectMeta: k8sapi.ObjectMeta{
Name: "authsessions.dex.coreos.com",
},
TypeMeta: crdMeta,
Spec: k8sapi.CustomResourceDefinitionSpec{
Group: apiGroup,
Version: version,
Versions: versions,
Scope: scope,
Names: k8sapi.CustomResourceDefinitionNames{
Plural: "authsessions",
Singular: "authsession",
Kind: "AuthSession",
},
},
},
}
}
@ -948,3 +965,55 @@ func toStorageUserIdentity(u UserIdentity) storage.UserIdentity {
}
return s
}
// AuthSession is a Kubernetes representation of a storage AuthSession.
type AuthSession struct {
k8sapi.TypeMeta `json:",inline"`
k8sapi.ObjectMeta `json:"metadata,omitempty"`
ClientStates map[string]*storage.ClientAuthState `json:"clientStates,omitempty"`
CreatedAt time.Time `json:"createdAt,omitempty"`
LastActivity time.Time `json:"lastActivity,omitempty"`
IPAddress string `json:"ipAddress,omitempty"`
UserAgent string `json:"userAgent,omitempty"`
}
// AuthSessionList is a list of AuthSessions.
type AuthSessionList struct {
k8sapi.TypeMeta `json:",inline"`
k8sapi.ListMeta `json:"metadata,omitempty"`
AuthSessions []AuthSession `json:"items"`
}
func (cli *client) fromStorageAuthSession(s storage.AuthSession) AuthSession {
return AuthSession{
TypeMeta: k8sapi.TypeMeta{
Kind: kindAuthSession,
APIVersion: cli.apiVersion,
},
ObjectMeta: k8sapi.ObjectMeta{
Name: s.ID,
Namespace: cli.namespace,
},
ClientStates: s.ClientStates,
CreatedAt: s.CreatedAt,
LastActivity: s.LastActivity,
IPAddress: s.IPAddress,
UserAgent: s.UserAgent,
}
}
func toStorageAuthSession(s AuthSession) storage.AuthSession {
result := storage.AuthSession{
ID: s.ObjectMeta.Name,
ClientStates: s.ClientStates,
CreatedAt: s.CreatedAt,
LastActivity: s.LastActivity,
IPAddress: s.IPAddress,
UserAgent: s.UserAgent,
}
if result.ClientStates == nil {
result.ClientStates = make(map[string]*storage.ClientAuthState)
}
return result
}

57
storage/memory/memory.go

@ -23,6 +23,7 @@ func New(logger *slog.Logger) storage.Storage {
passwords: make(map[string]storage.Password),
offlineSessions: make(map[compositeKeyID]storage.OfflineSessions),
userIdentities: make(map[compositeKeyID]storage.UserIdentity),
authSessions: make(map[string]storage.AuthSession),
connectors: make(map[string]storage.Connector),
deviceRequests: make(map[string]storage.DeviceRequest),
deviceTokens: make(map[string]storage.DeviceToken),
@ -51,6 +52,7 @@ type memStorage struct {
passwords map[string]storage.Password
offlineSessions map[compositeKeyID]storage.OfflineSessions
userIdentities map[compositeKeyID]storage.UserIdentity
authSessions map[string]storage.AuthSession
connectors map[string]storage.Connector
deviceRequests map[string]storage.DeviceRequest
deviceTokens map[string]storage.DeviceToken
@ -246,6 +248,61 @@ func (s *memStorage) ListUserIdentities(ctx context.Context) (identities []stora
return
}
func (s *memStorage) ListAuthSessions(ctx context.Context) (sessions []storage.AuthSession, err error) {
s.tx(func() {
for _, session := range s.authSessions {
sessions = append(sessions, session)
}
})
return
}
func (s *memStorage) CreateAuthSession(ctx context.Context, session storage.AuthSession) (err error) {
s.tx(func() {
if _, ok := s.authSessions[session.ID]; ok {
err = storage.ErrAlreadyExists
} else {
s.authSessions[session.ID] = session
}
})
return
}
func (s *memStorage) GetAuthSession(ctx context.Context, sessionID string) (session storage.AuthSession, err error) {
s.tx(func() {
var ok bool
if session, ok = s.authSessions[sessionID]; !ok {
err = storage.ErrNotFound
}
})
return
}
func (s *memStorage) UpdateAuthSession(ctx context.Context, sessionID string, updater func(s storage.AuthSession) (storage.AuthSession, error)) (err error) {
s.tx(func() {
r, ok := s.authSessions[sessionID]
if !ok {
err = storage.ErrNotFound
return
}
if r, err = updater(r); err == nil {
s.authSessions[sessionID] = r
}
})
return
}
func (s *memStorage) DeleteAuthSession(ctx context.Context, sessionID string) (err error) {
s.tx(func() {
if _, ok := s.authSessions[sessionID]; !ok {
err = storage.ErrNotFound
return
}
delete(s.authSessions, sessionID)
})
return
}
func (s *memStorage) CreateConnector(ctx context.Context, connector storage.Connector) (err error) {
s.tx(func() {
if _, ok := s.connectors[connector.ID]; ok {

130
storage/sql/crud.go

@ -926,6 +926,136 @@ func (c *conn) DeleteUserIdentity(ctx context.Context, userID, connectorID strin
return nil
}
func (c *conn) CreateAuthSession(ctx context.Context, s storage.AuthSession) error {
_, err := c.Exec(`
insert into auth_session (
id, client_states,
created_at, last_activity,
ip_address, user_agent
)
values ($1, $2, $3, $4, $5, $6);
`,
s.ID, encoder(s.ClientStates),
s.CreatedAt, s.LastActivity,
s.IPAddress, s.UserAgent,
)
if err != nil {
if c.alreadyExistsCheck(err) {
return storage.ErrAlreadyExists
}
return fmt.Errorf("insert auth session: %v", err)
}
return nil
}
func (c *conn) UpdateAuthSession(ctx context.Context, sessionID string, updater func(s storage.AuthSession) (storage.AuthSession, error)) error {
return c.ExecTx(func(tx *trans) error {
s, err := getAuthSession(ctx, tx, sessionID)
if err != nil {
return err
}
newSession, err := updater(s)
if err != nil {
return err
}
_, err = tx.Exec(`
update auth_session
set
client_states = $1,
last_activity = $2,
ip_address = $3,
user_agent = $4
where id = $5;
`,
encoder(newSession.ClientStates),
newSession.LastActivity,
newSession.IPAddress, newSession.UserAgent,
sessionID,
)
if err != nil {
return fmt.Errorf("update auth session: %v", err)
}
return nil
})
}
func (c *conn) GetAuthSession(ctx context.Context, sessionID string) (storage.AuthSession, error) {
return getAuthSession(ctx, c, sessionID)
}
func getAuthSession(ctx context.Context, q querier, sessionID string) (storage.AuthSession, error) {
return scanAuthSession(q.QueryRow(`
select
id, client_states,
created_at, last_activity,
ip_address, user_agent
from auth_session
where id = $1;
`, sessionID))
}
func scanAuthSession(s scanner) (session storage.AuthSession, err error) {
err = s.Scan(
&session.ID, decoder(&session.ClientStates),
&session.CreatedAt, &session.LastActivity,
&session.IPAddress, &session.UserAgent,
)
if err != nil {
if err == sql.ErrNoRows {
return session, storage.ErrNotFound
}
return session, fmt.Errorf("select auth session: %v", err)
}
if session.ClientStates == nil {
session.ClientStates = make(map[string]*storage.ClientAuthState)
}
return session, nil
}
func (c *conn) ListAuthSessions(ctx context.Context) ([]storage.AuthSession, error) {
rows, err := c.Query(`
select
id, client_states,
created_at, last_activity,
ip_address, user_agent
from auth_session;
`)
if err != nil {
return nil, fmt.Errorf("query: %v", err)
}
defer rows.Close()
var sessions []storage.AuthSession
for rows.Next() {
s, err := scanAuthSession(rows)
if err != nil {
return nil, err
}
sessions = append(sessions, s)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("scan: %v", err)
}
return sessions, nil
}
func (c *conn) DeleteAuthSession(ctx context.Context, sessionID string) error {
result, err := c.Exec(`delete from auth_session where id = $1`, sessionID)
if err != nil {
return fmt.Errorf("delete auth_session: id = %s: %w", sessionID, err)
}
n, err := result.RowsAffected()
if err != nil {
return fmt.Errorf("rows affected: %v", err)
}
if n < 1 {
return storage.ErrNotFound
}
return nil
}
func (c *conn) CreateConnector(ctx context.Context, connector storage.Connector) error {
grantTypes, err := json.Marshal(connector.GrantTypes)
if err != nil {

13
storage/sql/migrate.go

@ -409,4 +409,17 @@ var migrations = []migration{
);`,
},
},
{
stmts: []string{
`
create table auth_session (
id text not null primary key,
client_states bytea not null,
created_at timestamptz not null,
last_activity timestamptz not null,
ip_address text not null default '',
user_agent text not null default ''
);`,
},
},
}

25
storage/storage.go

@ -84,6 +84,7 @@ type Storage interface {
CreatePassword(ctx context.Context, p Password) error
CreateOfflineSessions(ctx context.Context, s OfflineSessions) error
CreateUserIdentity(ctx context.Context, u UserIdentity) error
CreateAuthSession(ctx context.Context, s AuthSession) error
CreateConnector(ctx context.Context, c Connector) error
CreateDeviceRequest(ctx context.Context, d DeviceRequest) error
CreateDeviceToken(ctx context.Context, d DeviceToken) error
@ -98,6 +99,7 @@ type Storage interface {
GetPassword(ctx context.Context, email string) (Password, error)
GetOfflineSessions(ctx context.Context, userID string, connID string) (OfflineSessions, error)
GetUserIdentity(ctx context.Context, userID, connectorID string) (UserIdentity, error)
GetAuthSession(ctx context.Context, sessionID string) (AuthSession, error)
GetConnector(ctx context.Context, id string) (Connector, error)
GetDeviceRequest(ctx context.Context, userCode string) (DeviceRequest, error)
GetDeviceToken(ctx context.Context, deviceCode string) (DeviceToken, error)
@ -107,6 +109,7 @@ type Storage interface {
ListPasswords(ctx context.Context) ([]Password, error)
ListConnectors(ctx context.Context) ([]Connector, error)
ListUserIdentities(ctx context.Context) ([]UserIdentity, error)
ListAuthSessions(ctx context.Context) ([]AuthSession, error)
// Delete methods MUST be atomic.
DeleteAuthRequest(ctx context.Context, id string) error
@ -116,6 +119,7 @@ type Storage interface {
DeletePassword(ctx context.Context, email string) error
DeleteOfflineSessions(ctx context.Context, userID string, connID string) error
DeleteUserIdentity(ctx context.Context, userID, connectorID string) error
DeleteAuthSession(ctx context.Context, sessionID string) error
DeleteConnector(ctx context.Context, id string) error
// Update methods take a function for updating an object then performs that update within
@ -139,6 +143,7 @@ type Storage interface {
UpdatePassword(ctx context.Context, email string, updater func(p Password) (Password, error)) error
UpdateOfflineSessions(ctx context.Context, userID string, connID string, updater func(s OfflineSessions) (OfflineSessions, error)) error
UpdateUserIdentity(ctx context.Context, userID, connectorID string, updater func(u UserIdentity) (UserIdentity, error)) error
UpdateAuthSession(ctx context.Context, sessionID string, updater func(s AuthSession) (AuthSession, error)) error
UpdateConnector(ctx context.Context, id string, updater func(c Connector) (Connector, error)) error
UpdateDeviceToken(ctx context.Context, deviceCode string, updater func(t DeviceToken) (DeviceToken, error)) error
@ -336,6 +341,26 @@ type UserIdentity struct {
BlockedUntil time.Time
}
// ClientAuthState represents the authentication state for a specific client within a session.
type ClientAuthState struct {
UserID string
ConnectorID string
Active bool
ExpiresAt time.Time
LastActivity time.Time
LastTokenIssuedAt time.Time
}
// AuthSession represents a browser-bound authentication session.
type AuthSession struct {
ID string
ClientStates map[string]*ClientAuthState // clientID -> auth state
CreatedAt time.Time
LastActivity time.Time
IPAddress string
UserAgent string
}
// OfflineSessions objects are sessions pertaining to users with refresh tokens.
type OfflineSessions struct {
// UserID of an end user who has logged into the server.

Loading…
Cancel
Save