mirror of https://github.com/dexidp/dex.git
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
108 lines
3.3 KiB
108 lines
3.3 KiB
package client |
|
|
|
import ( |
|
"context" |
|
"encoding/json" |
|
"fmt" |
|
|
|
"github.com/dexidp/dex/storage" |
|
) |
|
|
|
// CreateAuthSession saves provided auth session into the database. |
|
func (d *Database) CreateAuthSession(ctx context.Context, session storage.AuthSession) error { |
|
if session.ClientStates == nil { |
|
session.ClientStates = make(map[string]*storage.ClientAuthState) |
|
} |
|
encodedStates, err := json.Marshal(session.ClientStates) |
|
if err != nil { |
|
return fmt.Errorf("encode client states auth session: %w", err) |
|
} |
|
|
|
_, err = d.client.AuthSession.Create(). |
|
SetID(session.ID). |
|
SetClientStates(encodedStates). |
|
SetCreatedAt(session.CreatedAt). |
|
SetLastActivity(session.LastActivity). |
|
SetIPAddress(session.IPAddress). |
|
SetUserAgent(session.UserAgent). |
|
Save(ctx) |
|
if err != nil { |
|
return convertDBError("create auth session: %w", err) |
|
} |
|
return nil |
|
} |
|
|
|
// GetAuthSession extracts an auth session from the database by session ID. |
|
func (d *Database) GetAuthSession(ctx context.Context, sessionID string) (storage.AuthSession, error) { |
|
authSession, err := d.client.AuthSession.Get(ctx, sessionID) |
|
if err != nil { |
|
return storage.AuthSession{}, convertDBError("get auth session: %w", err) |
|
} |
|
return toStorageAuthSession(authSession), nil |
|
} |
|
|
|
// ListAuthSessions extracts all auth sessions from the database. |
|
func (d *Database) ListAuthSessions(ctx context.Context) ([]storage.AuthSession, error) { |
|
authSessions, err := d.client.AuthSession.Query().All(ctx) |
|
if err != nil { |
|
return nil, convertDBError("list auth sessions: %w", err) |
|
} |
|
|
|
storageAuthSessions := make([]storage.AuthSession, 0, len(authSessions)) |
|
for _, s := range authSessions { |
|
storageAuthSessions = append(storageAuthSessions, toStorageAuthSession(s)) |
|
} |
|
return storageAuthSessions, nil |
|
} |
|
|
|
// DeleteAuthSession deletes an auth session from the database by session ID. |
|
func (d *Database) DeleteAuthSession(ctx context.Context, sessionID string) error { |
|
err := d.client.AuthSession.DeleteOneID(sessionID).Exec(ctx) |
|
if err != nil { |
|
return convertDBError("delete auth session: %w", err) |
|
} |
|
return nil |
|
} |
|
|
|
// UpdateAuthSession changes an auth session using an updater function. |
|
func (d *Database) UpdateAuthSession(ctx context.Context, sessionID string, updater func(s storage.AuthSession) (storage.AuthSession, error)) error { |
|
tx, err := d.BeginTx(ctx) |
|
if err != nil { |
|
return convertDBError("update auth session tx: %w", err) |
|
} |
|
|
|
authSession, err := tx.AuthSession.Get(ctx, sessionID) |
|
if err != nil { |
|
return rollback(tx, "update auth session database: %w", err) |
|
} |
|
|
|
newSession, err := updater(toStorageAuthSession(authSession)) |
|
if err != nil { |
|
return rollback(tx, "update auth session updating: %w", err) |
|
} |
|
|
|
if newSession.ClientStates == nil { |
|
newSession.ClientStates = make(map[string]*storage.ClientAuthState) |
|
} |
|
|
|
encodedStates, err := json.Marshal(newSession.ClientStates) |
|
if err != nil { |
|
return rollback(tx, "encode client states auth session: %w", err) |
|
} |
|
|
|
_, err = tx.AuthSession.UpdateOneID(sessionID). |
|
SetClientStates(encodedStates). |
|
SetLastActivity(newSession.LastActivity). |
|
SetIPAddress(newSession.IPAddress). |
|
SetUserAgent(newSession.UserAgent). |
|
Save(ctx) |
|
if err != nil { |
|
return rollback(tx, "update auth session updating: %w", err) |
|
} |
|
|
|
if err = tx.Commit(); err != nil { |
|
return rollback(tx, "update auth session commit: %w", err) |
|
} |
|
|
|
return nil |
|
}
|
|
|