mirror of https://github.com/dexidp/dex.git
31 changed files with 1878 additions and 4 deletions
@ -0,0 +1,52 @@
|
||||
package client |
||||
|
||||
import ( |
||||
"context" |
||||
|
||||
"github.com/dexidp/dex/storage" |
||||
) |
||||
|
||||
// CreateAuthCode saves provided auth code into the database.
|
||||
func (d *Database) CreateAuthCode(code storage.AuthCode) error { |
||||
_, err := d.client.AuthCode.Create(). |
||||
SetID(code.ID). |
||||
SetClientID(code.ClientID). |
||||
SetScopes(code.Scopes). |
||||
SetRedirectURI(code.RedirectURI). |
||||
SetNonce(code.Nonce). |
||||
SetClaimsUserID(code.Claims.UserID). |
||||
SetClaimsEmail(code.Claims.Email). |
||||
SetClaimsEmailVerified(code.Claims.EmailVerified). |
||||
SetClaimsUsername(code.Claims.Username). |
||||
SetClaimsPreferredUsername(code.Claims.PreferredUsername). |
||||
SetClaimsGroups(code.Claims.Groups). |
||||
SetCodeChallenge(code.PKCE.CodeChallenge). |
||||
SetCodeChallengeMethod(code.PKCE.CodeChallengeMethod). |
||||
// Save utc time into database because ent doesn't support comparing dates with different timezones
|
||||
SetExpiry(code.Expiry.UTC()). |
||||
SetConnectorID(code.ConnectorID). |
||||
SetConnectorData(code.ConnectorData). |
||||
Save(context.TODO()) |
||||
if err != nil { |
||||
return convertDBError("create auth code: %w", err) |
||||
} |
||||
return nil |
||||
} |
||||
|
||||
// GetAuthCode extracts an auth code from the database by id.
|
||||
func (d *Database) GetAuthCode(id string) (storage.AuthCode, error) { |
||||
authCode, err := d.client.AuthCode.Get(context.TODO(), id) |
||||
if err != nil { |
||||
return storage.AuthCode{}, convertDBError("get auth code: %w", err) |
||||
} |
||||
return toStorageAuthCode(authCode), nil |
||||
} |
||||
|
||||
// DeleteAuthCode deletes an auth code from the database by id.
|
||||
func (d *Database) DeleteAuthCode(id string) error { |
||||
err := d.client.AuthCode.DeleteOneID(id).Exec(context.TODO()) |
||||
if err != nil { |
||||
return convertDBError("delete auth code: %w", err) |
||||
} |
||||
return nil |
||||
} |
||||
@ -0,0 +1,107 @@
|
||||
package client |
||||
|
||||
import ( |
||||
"context" |
||||
"fmt" |
||||
|
||||
"github.com/dexidp/dex/storage" |
||||
) |
||||
|
||||
// CreateAuthRequest saves provided auth request into the database.
|
||||
func (d *Database) CreateAuthRequest(authRequest storage.AuthRequest) error { |
||||
_, err := d.client.AuthRequest.Create(). |
||||
SetID(authRequest.ID). |
||||
SetClientID(authRequest.ClientID). |
||||
SetScopes(authRequest.Scopes). |
||||
SetResponseTypes(authRequest.ResponseTypes). |
||||
SetRedirectURI(authRequest.RedirectURI). |
||||
SetState(authRequest.State). |
||||
SetNonce(authRequest.Nonce). |
||||
SetForceApprovalPrompt(authRequest.ForceApprovalPrompt). |
||||
SetLoggedIn(authRequest.LoggedIn). |
||||
SetClaimsUserID(authRequest.Claims.UserID). |
||||
SetClaimsEmail(authRequest.Claims.Email). |
||||
SetClaimsEmailVerified(authRequest.Claims.EmailVerified). |
||||
SetClaimsUsername(authRequest.Claims.Username). |
||||
SetClaimsPreferredUsername(authRequest.Claims.PreferredUsername). |
||||
SetClaimsGroups(authRequest.Claims.Groups). |
||||
SetCodeChallenge(authRequest.PKCE.CodeChallenge). |
||||
SetCodeChallengeMethod(authRequest.PKCE.CodeChallengeMethod). |
||||
// Save utc time into database because ent doesn't support comparing dates with different timezones
|
||||
SetExpiry(authRequest.Expiry.UTC()). |
||||
SetConnectorID(authRequest.ConnectorID). |
||||
SetConnectorData(authRequest.ConnectorData). |
||||
Save(context.TODO()) |
||||
if err != nil { |
||||
return convertDBError("create auth request: %w", err) |
||||
} |
||||
return nil |
||||
} |
||||
|
||||
// GetAuthRequest extracts an auth request from the database by id.
|
||||
func (d *Database) GetAuthRequest(id string) (storage.AuthRequest, error) { |
||||
authRequest, err := d.client.AuthRequest.Get(context.TODO(), id) |
||||
if err != nil { |
||||
return storage.AuthRequest{}, convertDBError("get auth request: %w", err) |
||||
} |
||||
return toStorageAuthRequest(authRequest), nil |
||||
} |
||||
|
||||
// DeleteAuthRequest deletes an auth request from the database by id.
|
||||
func (d *Database) DeleteAuthRequest(id string) error { |
||||
err := d.client.AuthRequest.DeleteOneID(id).Exec(context.TODO()) |
||||
if err != nil { |
||||
return convertDBError("delete auth request: %w", err) |
||||
} |
||||
return nil |
||||
} |
||||
|
||||
// UpdateAuthRequest changes an auth request by id using an updater function and saves it to the database.
|
||||
func (d *Database) UpdateAuthRequest(id string, updater func(old storage.AuthRequest) (storage.AuthRequest, error)) error { |
||||
tx, err := d.client.Tx(context.TODO()) |
||||
if err != nil { |
||||
return fmt.Errorf("update auth request tx: %w", err) |
||||
} |
||||
|
||||
authRequest, err := tx.AuthRequest.Get(context.TODO(), id) |
||||
if err != nil { |
||||
return rollback(tx, "update auth request database: %w", err) |
||||
} |
||||
|
||||
newAuthRequest, err := updater(toStorageAuthRequest(authRequest)) |
||||
if err != nil { |
||||
return rollback(tx, "update auth request updating: %w", err) |
||||
} |
||||
|
||||
_, err = tx.AuthRequest.UpdateOneID(newAuthRequest.ID). |
||||
SetClientID(newAuthRequest.ClientID). |
||||
SetScopes(newAuthRequest.Scopes). |
||||
SetResponseTypes(newAuthRequest.ResponseTypes). |
||||
SetRedirectURI(newAuthRequest.RedirectURI). |
||||
SetState(newAuthRequest.State). |
||||
SetNonce(newAuthRequest.Nonce). |
||||
SetForceApprovalPrompt(newAuthRequest.ForceApprovalPrompt). |
||||
SetLoggedIn(newAuthRequest.LoggedIn). |
||||
SetClaimsUserID(newAuthRequest.Claims.UserID). |
||||
SetClaimsEmail(newAuthRequest.Claims.Email). |
||||
SetClaimsEmailVerified(newAuthRequest.Claims.EmailVerified). |
||||
SetClaimsUsername(newAuthRequest.Claims.Username). |
||||
SetClaimsPreferredUsername(newAuthRequest.Claims.PreferredUsername). |
||||
SetClaimsGroups(newAuthRequest.Claims.Groups). |
||||
SetCodeChallenge(newAuthRequest.PKCE.CodeChallenge). |
||||
SetCodeChallengeMethod(newAuthRequest.PKCE.CodeChallengeMethod). |
||||
// Save utc time into database because ent doesn't support comparing dates with different timezones
|
||||
SetExpiry(newAuthRequest.Expiry.UTC()). |
||||
SetConnectorID(newAuthRequest.ConnectorID). |
||||
SetConnectorData(newAuthRequest.ConnectorData). |
||||
Save(context.TODO()) |
||||
if err != nil { |
||||
return rollback(tx, "update auth request uploading: %w", err) |
||||
} |
||||
|
||||
if err = tx.Commit(); err != nil { |
||||
return rollback(tx, "update auth request commit: %w", err) |
||||
} |
||||
|
||||
return nil |
||||
} |
||||
@ -0,0 +1,92 @@
|
||||
package client |
||||
|
||||
import ( |
||||
"context" |
||||
|
||||
"github.com/dexidp/dex/storage" |
||||
) |
||||
|
||||
// CreateClient saves provided oauth2 client settings into the database.
|
||||
func (d *Database) CreateClient(client storage.Client) error { |
||||
_, err := d.client.OAuth2Client.Create(). |
||||
SetID(client.ID). |
||||
SetName(client.Name). |
||||
SetSecret(client.Secret). |
||||
SetPublic(client.Public). |
||||
SetLogoURL(client.LogoURL). |
||||
SetRedirectUris(client.RedirectURIs). |
||||
SetTrustedPeers(client.TrustedPeers). |
||||
Save(context.TODO()) |
||||
if err != nil { |
||||
return convertDBError("create oauth2 client: %w", err) |
||||
} |
||||
return nil |
||||
} |
||||
|
||||
// ListClients extracts an array of oauth2 clients from the database.
|
||||
func (d *Database) ListClients() ([]storage.Client, error) { |
||||
clients, err := d.client.OAuth2Client.Query().All(context.TODO()) |
||||
if err != nil { |
||||
return nil, convertDBError("list clients: %w", err) |
||||
} |
||||
|
||||
storageClients := make([]storage.Client, 0, len(clients)) |
||||
for _, c := range clients { |
||||
storageClients = append(storageClients, toStorageClient(c)) |
||||
} |
||||
return storageClients, nil |
||||
} |
||||
|
||||
// GetClient extracts an oauth2 client from the database by id.
|
||||
func (d *Database) GetClient(id string) (storage.Client, error) { |
||||
client, err := d.client.OAuth2Client.Get(context.TODO(), id) |
||||
if err != nil { |
||||
return storage.Client{}, convertDBError("get client: %w", err) |
||||
} |
||||
return toStorageClient(client), nil |
||||
} |
||||
|
||||
// DeleteClient deletes an oauth2 client from the database by id.
|
||||
func (d *Database) DeleteClient(id string) error { |
||||
err := d.client.OAuth2Client.DeleteOneID(id).Exec(context.TODO()) |
||||
if err != nil { |
||||
return convertDBError("delete client: %w", err) |
||||
} |
||||
return nil |
||||
} |
||||
|
||||
// UpdateClient changes an oauth2 client by id using an updater function and saves it to the database.
|
||||
func (d *Database) UpdateClient(id string, updater func(old storage.Client) (storage.Client, error)) error { |
||||
tx, err := d.client.Tx(context.TODO()) |
||||
if err != nil { |
||||
return convertDBError("update client tx: %w", err) |
||||
} |
||||
|
||||
client, err := tx.OAuth2Client.Get(context.TODO(), id) |
||||
if err != nil { |
||||
return rollback(tx, "update client database: %w", err) |
||||
} |
||||
|
||||
newClient, err := updater(toStorageClient(client)) |
||||
if err != nil { |
||||
return rollback(tx, "update client updating: %w", err) |
||||
} |
||||
|
||||
_, err = tx.OAuth2Client.UpdateOneID(newClient.ID). |
||||
SetName(newClient.Name). |
||||
SetSecret(newClient.Secret). |
||||
SetPublic(newClient.Public). |
||||
SetLogoURL(newClient.LogoURL). |
||||
SetRedirectUris(newClient.RedirectURIs). |
||||
SetTrustedPeers(newClient.TrustedPeers). |
||||
Save(context.TODO()) |
||||
if err != nil { |
||||
return rollback(tx, "update client uploading: %w", err) |
||||
} |
||||
|
||||
if err = tx.Commit(); err != nil { |
||||
return rollback(tx, "update auth request commit: %w", err) |
||||
} |
||||
|
||||
return nil |
||||
} |
||||
@ -0,0 +1,88 @@
|
||||
package client |
||||
|
||||
import ( |
||||
"context" |
||||
|
||||
"github.com/dexidp/dex/storage" |
||||
) |
||||
|
||||
// CreateConnector saves a connector into the database.
|
||||
func (d *Database) CreateConnector(connector storage.Connector) error { |
||||
_, err := d.client.Connector.Create(). |
||||
SetID(connector.ID). |
||||
SetName(connector.Name). |
||||
SetType(connector.Type). |
||||
SetResourceVersion(connector.ResourceVersion). |
||||
SetConfig(connector.Config). |
||||
Save(context.TODO()) |
||||
if err != nil { |
||||
return convertDBError("create connector: %w", err) |
||||
} |
||||
return nil |
||||
} |
||||
|
||||
// ListConnectors extracts an array of connectors from the database.
|
||||
func (d *Database) ListConnectors() ([]storage.Connector, error) { |
||||
connectors, err := d.client.Connector.Query().All(context.TODO()) |
||||
if err != nil { |
||||
return nil, convertDBError("list connectors: %w", err) |
||||
} |
||||
|
||||
storageConnectors := make([]storage.Connector, 0, len(connectors)) |
||||
for _, c := range connectors { |
||||
storageConnectors = append(storageConnectors, toStorageConnector(c)) |
||||
} |
||||
return storageConnectors, nil |
||||
} |
||||
|
||||
// GetConnector extracts a connector from the database by id.
|
||||
func (d *Database) GetConnector(id string) (storage.Connector, error) { |
||||
connector, err := d.client.Connector.Get(context.TODO(), id) |
||||
if err != nil { |
||||
return storage.Connector{}, convertDBError("get connector: %w", err) |
||||
} |
||||
return toStorageConnector(connector), nil |
||||
} |
||||
|
||||
// DeleteConnector deletes a connector from the database by id.
|
||||
func (d *Database) DeleteConnector(id string) error { |
||||
err := d.client.Connector.DeleteOneID(id).Exec(context.TODO()) |
||||
if err != nil { |
||||
return convertDBError("delete connector: %w", err) |
||||
} |
||||
return nil |
||||
} |
||||
|
||||
// UpdateConnector changes a connector by id using an updater function and saves it to the database.
|
||||
func (d *Database) UpdateConnector(id string, updater func(old storage.Connector) (storage.Connector, error)) error { |
||||
tx, err := d.client.Tx(context.TODO()) |
||||
if err != nil { |
||||
return convertDBError("update connector tx: %w", err) |
||||
} |
||||
|
||||
connector, err := tx.Connector.Get(context.TODO(), id) |
||||
if err != nil { |
||||
return rollback(tx, "update connector database: %w", err) |
||||
} |
||||
|
||||
newConnector, err := updater(toStorageConnector(connector)) |
||||
if err != nil { |
||||
return rollback(tx, "update connector updating: %w", err) |
||||
} |
||||
|
||||
_, err = tx.Connector.UpdateOneID(newConnector.ID). |
||||
SetName(newConnector.Name). |
||||
SetType(newConnector.Type). |
||||
SetResourceVersion(newConnector.ResourceVersion). |
||||
SetConfig(newConnector.Config). |
||||
Save(context.TODO()) |
||||
if err != nil { |
||||
return rollback(tx, "update connector uploading: %w", err) |
||||
} |
||||
|
||||
if err = tx.Commit(); err != nil { |
||||
return rollback(tx, "update connector commit: %w", err) |
||||
} |
||||
|
||||
return nil |
||||
} |
||||
@ -0,0 +1,36 @@
|
||||
package client |
||||
|
||||
import ( |
||||
"context" |
||||
|
||||
"github.com/dexidp/dex/storage" |
||||
"github.com/dexidp/dex/storage/ent/db/devicerequest" |
||||
) |
||||
|
||||
// CreateDeviceRequest saves provided device request into the database.
|
||||
func (d *Database) CreateDeviceRequest(request storage.DeviceRequest) error { |
||||
_, err := d.client.DeviceRequest.Create(). |
||||
SetClientID(request.ClientID). |
||||
SetClientSecret(request.ClientSecret). |
||||
SetScopes(request.Scopes). |
||||
SetUserCode(request.UserCode). |
||||
SetDeviceCode(request.DeviceCode). |
||||
// Save utc time into database because ent doesn't support comparing dates with different timezones
|
||||
SetExpiry(request.Expiry.UTC()). |
||||
Save(context.TODO()) |
||||
if err != nil { |
||||
return convertDBError("create device request: %w", err) |
||||
} |
||||
return nil |
||||
} |
||||
|
||||
// GetDeviceRequest extracts a device request from the database by user code.
|
||||
func (d *Database) GetDeviceRequest(userCode string) (storage.DeviceRequest, error) { |
||||
deviceRequest, err := d.client.DeviceRequest.Query(). |
||||
Where(devicerequest.UserCode(userCode)). |
||||
Only(context.TODO()) |
||||
if err != nil { |
||||
return storage.DeviceRequest{}, convertDBError("get device request: %w", err) |
||||
} |
||||
return toStorageDeviceRequest(deviceRequest), nil |
||||
} |
||||
@ -0,0 +1,76 @@
|
||||
package client |
||||
|
||||
import ( |
||||
"context" |
||||
|
||||
"github.com/dexidp/dex/storage" |
||||
"github.com/dexidp/dex/storage/ent/db/devicetoken" |
||||
) |
||||
|
||||
// CreateDeviceToken saves provided token into the database.
|
||||
func (d *Database) CreateDeviceToken(token storage.DeviceToken) error { |
||||
_, err := d.client.DeviceToken.Create(). |
||||
SetDeviceCode(token.DeviceCode). |
||||
SetToken([]byte(token.Token)). |
||||
SetPollInterval(token.PollIntervalSeconds). |
||||
// Save utc time into database because ent doesn't support comparing dates with different timezones
|
||||
SetExpiry(token.Expiry.UTC()). |
||||
SetLastRequest(token.LastRequestTime.UTC()). |
||||
SetStatus(token.Status). |
||||
Save(context.TODO()) |
||||
if err != nil { |
||||
return convertDBError("create device token: %w", err) |
||||
} |
||||
return nil |
||||
} |
||||
|
||||
// GetDeviceToken extracts a token from the database by device code.
|
||||
func (d *Database) GetDeviceToken(deviceCode string) (storage.DeviceToken, error) { |
||||
deviceToken, err := d.client.DeviceToken.Query(). |
||||
Where(devicetoken.DeviceCode(deviceCode)). |
||||
Only(context.TODO()) |
||||
if err != nil { |
||||
return storage.DeviceToken{}, convertDBError("get device token: %w", err) |
||||
} |
||||
return toStorageDeviceToken(deviceToken), nil |
||||
} |
||||
|
||||
// UpdateDeviceToken changes a token by device code using an updater function and saves it to the database.
|
||||
func (d *Database) UpdateDeviceToken(deviceCode string, updater func(old storage.DeviceToken) (storage.DeviceToken, error)) error { |
||||
tx, err := d.client.Tx(context.TODO()) |
||||
if err != nil { |
||||
return convertDBError("update device token tx: %w", err) |
||||
} |
||||
|
||||
token, err := tx.DeviceToken.Query(). |
||||
Where(devicetoken.DeviceCode(deviceCode)). |
||||
Only(context.TODO()) |
||||
if err != nil { |
||||
return rollback(tx, "update device token database: %w", err) |
||||
} |
||||
|
||||
newToken, err := updater(toStorageDeviceToken(token)) |
||||
if err != nil { |
||||
return rollback(tx, "update device token updating: %w", err) |
||||
} |
||||
|
||||
_, err = tx.DeviceToken.Update(). |
||||
Where(devicetoken.DeviceCode(newToken.DeviceCode)). |
||||
SetDeviceCode(newToken.DeviceCode). |
||||
SetToken([]byte(newToken.Token)). |
||||
SetPollInterval(newToken.PollIntervalSeconds). |
||||
// Save utc time into database because ent doesn't support comparing dates with different timezones
|
||||
SetExpiry(newToken.Expiry.UTC()). |
||||
SetLastRequest(newToken.LastRequestTime.UTC()). |
||||
SetStatus(newToken.Status). |
||||
Save(context.TODO()) |
||||
if err != nil { |
||||
return rollback(tx, "update device token uploading: %w", err) |
||||
} |
||||
|
||||
if err = tx.Commit(); err != nil { |
||||
return rollback(tx, "update device token commit: %w", err) |
||||
} |
||||
|
||||
return nil |
||||
} |
||||
@ -0,0 +1,81 @@
|
||||
package client |
||||
|
||||
import ( |
||||
"context" |
||||
"errors" |
||||
|
||||
"github.com/dexidp/dex/storage" |
||||
"github.com/dexidp/dex/storage/ent/db" |
||||
) |
||||
|
||||
func getKeys(client *db.KeysClient) (storage.Keys, error) { |
||||
rawKeys, err := client.Get(context.TODO(), keysRowID) |
||||
if err != nil { |
||||
return storage.Keys{}, convertDBError("get keys: %w", err) |
||||
} |
||||
|
||||
return toStorageKeys(rawKeys), nil |
||||
} |
||||
|
||||
// GetKeys returns signing keys, public keys and verification keys from the database.
|
||||
func (d *Database) GetKeys() (storage.Keys, error) { |
||||
return getKeys(d.client.Keys) |
||||
} |
||||
|
||||
// UpdateKeys rotates keys using updater function.
|
||||
func (d *Database) UpdateKeys(updater func(old storage.Keys) (storage.Keys, error)) error { |
||||
firstUpdate := false |
||||
|
||||
tx, err := d.client.Tx(context.TODO()) |
||||
if err != nil { |
||||
return convertDBError("update keys tx: %w", err) |
||||
} |
||||
|
||||
storageKeys, err := getKeys(tx.Keys) |
||||
if err != nil { |
||||
if !errors.Is(err, storage.ErrNotFound) { |
||||
return rollback(tx, "update keys get: %w", err) |
||||
} |
||||
firstUpdate = true |
||||
} |
||||
|
||||
newKeys, err := updater(storageKeys) |
||||
if err != nil { |
||||
return rollback(tx, "update keys updating: %w", err) |
||||
} |
||||
|
||||
// ent doesn't have an upsert support yet
|
||||
// https://github.com/facebook/ent/issues/139
|
||||
if firstUpdate { |
||||
_, err = tx.Keys.Create(). |
||||
SetID(keysRowID). |
||||
SetNextRotation(newKeys.NextRotation). |
||||
SetSigningKey(*newKeys.SigningKey). |
||||
SetSigningKeyPub(*newKeys.SigningKeyPub). |
||||
SetVerificationKeys(newKeys.VerificationKeys). |
||||
Save(context.TODO()) |
||||
if err != nil { |
||||
return rollback(tx, "create keys: %w", err) |
||||
} |
||||
if err = tx.Commit(); err != nil { |
||||
return rollback(tx, "update keys commit: %w", err) |
||||
} |
||||
return nil |
||||
} |
||||
|
||||
err = tx.Keys.UpdateOneID(keysRowID). |
||||
SetNextRotation(newKeys.NextRotation.UTC()). |
||||
SetSigningKey(*newKeys.SigningKey). |
||||
SetSigningKeyPub(*newKeys.SigningKeyPub). |
||||
SetVerificationKeys(newKeys.VerificationKeys). |
||||
Exec(context.TODO()) |
||||
if err != nil { |
||||
return rollback(tx, "update keys uploading: %w", err) |
||||
} |
||||
|
||||
if err = tx.Commit(); err != nil { |
||||
return rollback(tx, "update keys commit: %w", err) |
||||
} |
||||
|
||||
return nil |
||||
} |
||||
@ -0,0 +1,95 @@
|
||||
package client |
||||
|
||||
import ( |
||||
"context" |
||||
"hash" |
||||
"time" |
||||
|
||||
"github.com/dexidp/dex/storage" |
||||
"github.com/dexidp/dex/storage/ent/db" |
||||
"github.com/dexidp/dex/storage/ent/db/authcode" |
||||
"github.com/dexidp/dex/storage/ent/db/authrequest" |
||||
"github.com/dexidp/dex/storage/ent/db/devicerequest" |
||||
"github.com/dexidp/dex/storage/ent/db/devicetoken" |
||||
"github.com/dexidp/dex/storage/ent/db/migrate" |
||||
) |
||||
|
||||
var _ storage.Storage = (*Database)(nil) |
||||
|
||||
type Database struct { |
||||
client *db.Client |
||||
hasher func() hash.Hash |
||||
} |
||||
|
||||
// NewDatabase returns new database client with set options.
|
||||
func NewDatabase(opts ...func(*Database)) *Database { |
||||
database := &Database{} |
||||
for _, f := range opts { |
||||
f(database) |
||||
} |
||||
return database |
||||
} |
||||
|
||||
// WithClient sets client option of a Database object.
|
||||
func WithClient(c *db.Client) func(*Database) { |
||||
return func(s *Database) { |
||||
s.client = c |
||||
} |
||||
} |
||||
|
||||
// WithHasher sets client option of a Database object.
|
||||
func WithHasher(h func() hash.Hash) func(*Database) { |
||||
return func(s *Database) { |
||||
s.hasher = h |
||||
} |
||||
} |
||||
|
||||
// Schema exposes migration schema to perform migrations.
|
||||
func (d *Database) Schema() *migrate.Schema { |
||||
return d.client.Schema |
||||
} |
||||
|
||||
// Close calls the corresponding method of the ent database client.
|
||||
func (d *Database) Close() error { |
||||
return d.client.Close() |
||||
} |
||||
|
||||
// GarbageCollect removes expired entities from the database.
|
||||
func (d *Database) GarbageCollect(now time.Time) (storage.GCResult, error) { |
||||
result := storage.GCResult{} |
||||
utcNow := now.UTC() |
||||
|
||||
q, err := d.client.AuthRequest.Delete(). |
||||
Where(authrequest.ExpiryLT(utcNow)). |
||||
Exec(context.TODO()) |
||||
if err != nil { |
||||
return result, convertDBError("gc auth request: %w", err) |
||||
} |
||||
result.AuthRequests = int64(q) |
||||
|
||||
q, err = d.client.AuthCode.Delete(). |
||||
Where(authcode.ExpiryLT(utcNow)). |
||||
Exec(context.TODO()) |
||||
if err != nil { |
||||
return result, convertDBError("gc auth code: %w", err) |
||||
} |
||||
result.AuthCodes = int64(q) |
||||
|
||||
q, err = d.client.DeviceRequest.Delete(). |
||||
Where(devicerequest.ExpiryLT(utcNow)). |
||||
Exec(context.TODO()) |
||||
if err != nil { |
||||
return result, convertDBError("gc device request: %w", err) |
||||
} |
||||
result.DeviceRequests = int64(q) |
||||
|
||||
q, err = d.client.DeviceToken.Delete(). |
||||
Where(devicetoken.ExpiryLT(utcNow)). |
||||
Exec(context.TODO()) |
||||
if err != nil { |
||||
return result, convertDBError("gc device token: %w", err) |
||||
} |
||||
result.DeviceTokens = int64(q) |
||||
|
||||
return result, err |
||||
} |
||||
@ -0,0 +1,93 @@
|
||||
package client |
||||
|
||||
import ( |
||||
"context" |
||||
"encoding/json" |
||||
"fmt" |
||||
|
||||
"github.com/dexidp/dex/storage" |
||||
) |
||||
|
||||
// CreateOfflineSessions saves provided offline session into the database.
|
||||
func (d *Database) CreateOfflineSessions(session storage.OfflineSessions) error { |
||||
encodedRefresh, err := json.Marshal(session.Refresh) |
||||
if err != nil { |
||||
return fmt.Errorf("encode refresh offline session: %w", err) |
||||
} |
||||
|
||||
id := offlineSessionID(session.UserID, session.ConnID, d.hasher) |
||||
_, err = d.client.OfflineSession.Create(). |
||||
SetID(id). |
||||
SetUserID(session.UserID). |
||||
SetConnID(session.ConnID). |
||||
SetConnectorData(session.ConnectorData). |
||||
SetRefresh(encodedRefresh). |
||||
Save(context.TODO()) |
||||
if err != nil { |
||||
return convertDBError("create offline session: %w", err) |
||||
} |
||||
return nil |
||||
} |
||||
|
||||
// GetOfflineSessions extracts an offline session from the database by user id and connector id.
|
||||
func (d *Database) GetOfflineSessions(userID, connID string) (storage.OfflineSessions, error) { |
||||
id := offlineSessionID(userID, connID, d.hasher) |
||||
|
||||
offlineSession, err := d.client.OfflineSession.Get(context.TODO(), id) |
||||
if err != nil { |
||||
return storage.OfflineSessions{}, convertDBError("get offline session: %w", err) |
||||
} |
||||
return toStorageOfflineSession(offlineSession), nil |
||||
} |
||||
|
||||
// DeleteOfflineSessions deletes an offline session from the database by user id and connector id.
|
||||
func (d *Database) DeleteOfflineSessions(userID, connID string) error { |
||||
id := offlineSessionID(userID, connID, d.hasher) |
||||
|
||||
err := d.client.OfflineSession.DeleteOneID(id).Exec(context.TODO()) |
||||
if err != nil { |
||||
return convertDBError("delete offline session: %w", err) |
||||
} |
||||
return nil |
||||
} |
||||
|
||||
// UpdatePassword changes an offline session by user id and connector id using an updater function.
|
||||
func (d *Database) UpdateOfflineSessions(userID string, connID string, updater func(s storage.OfflineSessions) (storage.OfflineSessions, error)) error { |
||||
id := offlineSessionID(userID, connID, d.hasher) |
||||
|
||||
tx, err := d.client.Tx(context.TODO()) |
||||
if err != nil { |
||||
return convertDBError("update offline session tx: %w", err) |
||||
} |
||||
|
||||
offlineSession, err := tx.OfflineSession.Get(context.TODO(), id) |
||||
if err != nil { |
||||
return rollback(tx, "update offline session database: %w", err) |
||||
} |
||||
|
||||
newOfflineSession, err := updater(toStorageOfflineSession(offlineSession)) |
||||
if err != nil { |
||||
return rollback(tx, "update offline session updating: %w", err) |
||||
} |
||||
|
||||
encodedRefresh, err := json.Marshal(newOfflineSession.Refresh) |
||||
if err != nil { |
||||
return rollback(tx, "encode refresh offline session: %w", err) |
||||
} |
||||
|
||||
_, err = tx.OfflineSession.UpdateOneID(id). |
||||
SetUserID(newOfflineSession.UserID). |
||||
SetConnID(newOfflineSession.ConnID). |
||||
SetConnectorData(newOfflineSession.ConnectorData). |
||||
SetRefresh(encodedRefresh). |
||||
Save(context.TODO()) |
||||
if err != nil { |
||||
return rollback(tx, "update offline session uploading: %w", err) |
||||
} |
||||
|
||||
if err = tx.Commit(); err != nil { |
||||
return rollback(tx, "update password commit: %w", err) |
||||
} |
||||
|
||||
return nil |
||||
} |
||||
@ -0,0 +1,100 @@
|
||||
package client |
||||
|
||||
import ( |
||||
"context" |
||||
"strings" |
||||
|
||||
"github.com/dexidp/dex/storage" |
||||
"github.com/dexidp/dex/storage/ent/db/password" |
||||
) |
||||
|
||||
// CreatePassword saves provided password into the database.
|
||||
func (d *Database) CreatePassword(password storage.Password) error { |
||||
_, err := d.client.Password.Create(). |
||||
SetEmail(password.Email). |
||||
SetHash(password.Hash). |
||||
SetUsername(password.Username). |
||||
SetUserID(password.UserID). |
||||
Save(context.TODO()) |
||||
if err != nil { |
||||
return convertDBError("create password: %w", err) |
||||
} |
||||
return nil |
||||
} |
||||
|
||||
// ListPasswords extracts an array of passwords from the database.
|
||||
func (d *Database) ListPasswords() ([]storage.Password, error) { |
||||
passwords, err := d.client.Password.Query().All(context.TODO()) |
||||
if err != nil { |
||||
return nil, convertDBError("list passwords: %w", err) |
||||
} |
||||
|
||||
storagePasswords := make([]storage.Password, 0, len(passwords)) |
||||
for _, p := range passwords { |
||||
storagePasswords = append(storagePasswords, toStoragePassword(p)) |
||||
} |
||||
return storagePasswords, nil |
||||
} |
||||
|
||||
// GetPassword extracts a password from the database by email.
|
||||
func (d *Database) GetPassword(email string) (storage.Password, error) { |
||||
email = strings.ToLower(email) |
||||
passwordFromStorage, err := d.client.Password.Query(). |
||||
Where(password.Email(email)). |
||||
Only(context.TODO()) |
||||
if err != nil { |
||||
return storage.Password{}, convertDBError("get password: %w", err) |
||||
} |
||||
return toStoragePassword(passwordFromStorage), nil |
||||
} |
||||
|
||||
// DeletePassword deletes a password from the database by email.
|
||||
func (d *Database) DeletePassword(email string) error { |
||||
email = strings.ToLower(email) |
||||
_, err := d.client.Password.Delete(). |
||||
Where(password.Email(email)). |
||||
Exec(context.TODO()) |
||||
if err != nil { |
||||
return convertDBError("delete password: %w", err) |
||||
} |
||||
return nil |
||||
} |
||||
|
||||
// UpdatePassword changes a password by email using an updater function and saves it to the database.
|
||||
func (d *Database) UpdatePassword(email string, updater func(old storage.Password) (storage.Password, error)) error { |
||||
email = strings.ToLower(email) |
||||
|
||||
tx, err := d.client.Tx(context.TODO()) |
||||
if err != nil { |
||||
return convertDBError("update connector tx: %w", err) |
||||
} |
||||
|
||||
passwordToUpdate, err := tx.Password.Query(). |
||||
Where(password.Email(email)). |
||||
Only(context.TODO()) |
||||
if err != nil { |
||||
return rollback(tx, "update password database: %w", err) |
||||
} |
||||
|
||||
newPassword, err := updater(toStoragePassword(passwordToUpdate)) |
||||
if err != nil { |
||||
return rollback(tx, "update password updating: %w", err) |
||||
} |
||||
|
||||
_, err = tx.Password.Update(). |
||||
Where(password.Email(newPassword.Email)). |
||||
SetEmail(newPassword.Email). |
||||
SetHash(newPassword.Hash). |
||||
SetUsername(newPassword.Username). |
||||
SetUserID(newPassword.UserID). |
||||
Save(context.TODO()) |
||||
if err != nil { |
||||
return rollback(tx, "update password uploading: %w", err) |
||||
} |
||||
|
||||
if err = tx.Commit(); err != nil { |
||||
return rollback(tx, "update password commit: %w", err) |
||||
} |
||||
|
||||
return nil |
||||
} |
||||
@ -0,0 +1,109 @@
|
||||
package client |
||||
|
||||
import ( |
||||
"context" |
||||
|
||||
"github.com/dexidp/dex/storage" |
||||
) |
||||
|
||||
// CreateRefresh saves provided refresh token into the database.
|
||||
func (d *Database) CreateRefresh(refresh storage.RefreshToken) error { |
||||
_, err := d.client.RefreshToken.Create(). |
||||
SetID(refresh.ID). |
||||
SetClientID(refresh.ClientID). |
||||
SetScopes(refresh.Scopes). |
||||
SetNonce(refresh.Nonce). |
||||
SetClaimsUserID(refresh.Claims.UserID). |
||||
SetClaimsEmail(refresh.Claims.Email). |
||||
SetClaimsEmailVerified(refresh.Claims.EmailVerified). |
||||
SetClaimsUsername(refresh.Claims.Username). |
||||
SetClaimsPreferredUsername(refresh.Claims.PreferredUsername). |
||||
SetClaimsGroups(refresh.Claims.Groups). |
||||
SetConnectorID(refresh.ConnectorID). |
||||
SetConnectorData(refresh.ConnectorData). |
||||
SetToken(refresh.Token). |
||||
// Save utc time into database because ent doesn't support comparing dates with different timezones
|
||||
SetLastUsed(refresh.LastUsed.UTC()). |
||||
SetCreatedAt(refresh.CreatedAt.UTC()). |
||||
Save(context.TODO()) |
||||
if err != nil { |
||||
return convertDBError("create refresh token: %w", err) |
||||
} |
||||
return nil |
||||
} |
||||
|
||||
// ListRefreshTokens extracts an array of refresh tokens from the database.
|
||||
func (d *Database) ListRefreshTokens() ([]storage.RefreshToken, error) { |
||||
refreshTokens, err := d.client.RefreshToken.Query().All(context.TODO()) |
||||
if err != nil { |
||||
return nil, convertDBError("list refresh tokens: %w", err) |
||||
} |
||||
|
||||
storageRefreshTokens := make([]storage.RefreshToken, 0, len(refreshTokens)) |
||||
for _, r := range refreshTokens { |
||||
storageRefreshTokens = append(storageRefreshTokens, toStorageRefreshToken(r)) |
||||
} |
||||
return storageRefreshTokens, nil |
||||
} |
||||
|
||||
// GetRefresh extracts a refresh token from the database by id.
|
||||
func (d *Database) GetRefresh(id string) (storage.RefreshToken, error) { |
||||
refreshToken, err := d.client.RefreshToken.Get(context.TODO(), id) |
||||
if err != nil { |
||||
return storage.RefreshToken{}, convertDBError("get refresh token: %w", err) |
||||
} |
||||
return toStorageRefreshToken(refreshToken), nil |
||||
} |
||||
|
||||
// DeleteRefresh deletes a refresh token from the database by id.
|
||||
func (d *Database) DeleteRefresh(id string) error { |
||||
err := d.client.RefreshToken.DeleteOneID(id).Exec(context.TODO()) |
||||
if err != nil { |
||||
return convertDBError("delete refresh token: %w", err) |
||||
} |
||||
return nil |
||||
} |
||||
|
||||
// UpdateRefreshToken changes a refresh token by id using an updater function and saves it to the database.
|
||||
func (d *Database) UpdateRefreshToken(id string, updater func(old storage.RefreshToken) (storage.RefreshToken, error)) error { |
||||
tx, err := d.client.Tx(context.TODO()) |
||||
if err != nil { |
||||
return convertDBError("update refresh token tx: %w", err) |
||||
} |
||||
|
||||
token, err := tx.RefreshToken.Get(context.TODO(), id) |
||||
if err != nil { |
||||
return rollback(tx, "update refresh token database: %w", err) |
||||
} |
||||
|
||||
newtToken, err := updater(toStorageRefreshToken(token)) |
||||
if err != nil { |
||||
return rollback(tx, "update refresh token updating: %w", err) |
||||
} |
||||
|
||||
_, err = tx.RefreshToken.UpdateOneID(newtToken.ID). |
||||
SetClientID(newtToken.ClientID). |
||||
SetScopes(newtToken.Scopes). |
||||
SetNonce(newtToken.Nonce). |
||||
SetClaimsUserID(newtToken.Claims.UserID). |
||||
SetClaimsEmail(newtToken.Claims.Email). |
||||
SetClaimsEmailVerified(newtToken.Claims.EmailVerified). |
||||
SetClaimsUsername(newtToken.Claims.Username). |
||||
SetClaimsPreferredUsername(newtToken.Claims.PreferredUsername). |
||||
SetClaimsGroups(newtToken.Claims.Groups). |
||||
SetConnectorID(newtToken.ConnectorID). |
||||
SetConnectorData(newtToken.ConnectorData). |
||||
SetToken(newtToken.Token). |
||||
// Save utc time into database because ent doesn't support comparing dates with different timezones
|
||||
SetLastUsed(newtToken.LastUsed.UTC()). |
||||
SetCreatedAt(newtToken.CreatedAt.UTC()). |
||||
Save(context.TODO()) |
||||
if err != nil { |
||||
return rollback(tx, "update refresh token uploading: %w", err) |
||||
} |
||||
|
||||
if err = tx.Commit(); err != nil { |
||||
return rollback(tx, "update refresh token commit: %w", err) |
||||
} |
||||
return nil |
||||
} |
||||
@ -0,0 +1,167 @@
|
||||
package client |
||||
|
||||
import ( |
||||
"encoding/json" |
||||
"strings" |
||||
|
||||
"github.com/dexidp/dex/storage" |
||||
"github.com/dexidp/dex/storage/ent/db" |
||||
) |
||||
|
||||
const keysRowID = "keys" |
||||
|
||||
func toStorageKeys(keys *db.Keys) storage.Keys { |
||||
return storage.Keys{ |
||||
SigningKey: &keys.SigningKey, |
||||
SigningKeyPub: &keys.SigningKeyPub, |
||||
VerificationKeys: keys.VerificationKeys, |
||||
NextRotation: keys.NextRotation, |
||||
} |
||||
} |
||||
|
||||
func toStorageAuthRequest(a *db.AuthRequest) storage.AuthRequest { |
||||
return storage.AuthRequest{ |
||||
ID: a.ID, |
||||
ClientID: a.ClientID, |
||||
ResponseTypes: a.ResponseTypes, |
||||
Scopes: a.Scopes, |
||||
RedirectURI: a.RedirectURI, |
||||
Nonce: a.Nonce, |
||||
State: a.State, |
||||
ForceApprovalPrompt: a.ForceApprovalPrompt, |
||||
LoggedIn: a.LoggedIn, |
||||
ConnectorID: a.ConnectorID, |
||||
ConnectorData: *a.ConnectorData, |
||||
Expiry: a.Expiry, |
||||
Claims: storage.Claims{ |
||||
UserID: a.ClaimsUserID, |
||||
Username: a.ClaimsUsername, |
||||
PreferredUsername: a.ClaimsPreferredUsername, |
||||
Email: a.ClaimsEmail, |
||||
EmailVerified: a.ClaimsEmailVerified, |
||||
Groups: a.ClaimsGroups, |
||||
}, |
||||
PKCE: storage.PKCE{ |
||||
CodeChallenge: a.CodeChallenge, |
||||
CodeChallengeMethod: a.CodeChallengeMethod, |
||||
}, |
||||
} |
||||
} |
||||
|
||||
func toStorageAuthCode(a *db.AuthCode) storage.AuthCode { |
||||
return storage.AuthCode{ |
||||
ID: a.ID, |
||||
ClientID: a.ClientID, |
||||
Scopes: a.Scopes, |
||||
RedirectURI: a.RedirectURI, |
||||
Nonce: a.Nonce, |
||||
ConnectorID: a.ConnectorID, |
||||
ConnectorData: *a.ConnectorData, |
||||
Expiry: a.Expiry, |
||||
Claims: storage.Claims{ |
||||
UserID: a.ClaimsUserID, |
||||
Username: a.ClaimsUsername, |
||||
PreferredUsername: a.ClaimsPreferredUsername, |
||||
Email: a.ClaimsEmail, |
||||
EmailVerified: a.ClaimsEmailVerified, |
||||
Groups: a.ClaimsGroups, |
||||
}, |
||||
PKCE: storage.PKCE{ |
||||
CodeChallenge: a.CodeChallenge, |
||||
CodeChallengeMethod: a.CodeChallengeMethod, |
||||
}, |
||||
} |
||||
} |
||||
|
||||
func toStorageClient(c *db.OAuth2Client) storage.Client { |
||||
return storage.Client{ |
||||
ID: c.ID, |
||||
Secret: c.Secret, |
||||
RedirectURIs: c.RedirectUris, |
||||
TrustedPeers: c.TrustedPeers, |
||||
Public: c.Public, |
||||
Name: c.Name, |
||||
LogoURL: c.LogoURL, |
||||
} |
||||
} |
||||
|
||||
func toStorageConnector(c *db.Connector) storage.Connector { |
||||
return storage.Connector{ |
||||
ID: c.ID, |
||||
Type: c.Type, |
||||
Name: c.Name, |
||||
Config: c.Config, |
||||
} |
||||
} |
||||
|
||||
func toStorageOfflineSession(o *db.OfflineSession) storage.OfflineSessions { |
||||
s := storage.OfflineSessions{ |
||||
UserID: o.UserID, |
||||
ConnID: o.ConnID, |
||||
ConnectorData: *o.ConnectorData, |
||||
} |
||||
|
||||
if o.Refresh != nil { |
||||
if err := json.Unmarshal(o.Refresh, &s.Refresh); err != nil { |
||||
// Correctness of json structure if guaranteed on uploading
|
||||
panic(err) |
||||
} |
||||
} else { |
||||
// Server code assumes this will be non-nil.
|
||||
s.Refresh = make(map[string]*storage.RefreshTokenRef) |
||||
} |
||||
return s |
||||
} |
||||
|
||||
func toStorageRefreshToken(r *db.RefreshToken) storage.RefreshToken { |
||||
return storage.RefreshToken{ |
||||
ID: r.ID, |
||||
Token: r.Token, |
||||
CreatedAt: r.CreatedAt, |
||||
LastUsed: r.LastUsed, |
||||
ClientID: r.ClientID, |
||||
ConnectorID: r.ConnectorID, |
||||
ConnectorData: *r.ConnectorData, |
||||
Scopes: r.Scopes, |
||||
Nonce: r.Nonce, |
||||
Claims: storage.Claims{ |
||||
UserID: r.ClaimsUserID, |
||||
Username: r.ClaimsUsername, |
||||
PreferredUsername: r.ClaimsPreferredUsername, |
||||
Email: r.ClaimsEmail, |
||||
EmailVerified: r.ClaimsEmailVerified, |
||||
Groups: r.ClaimsGroups, |
||||
}, |
||||
} |
||||
} |
||||
|
||||
func toStoragePassword(p *db.Password) storage.Password { |
||||
return storage.Password{ |
||||
Email: p.Email, |
||||
Hash: p.Hash, |
||||
Username: p.Username, |
||||
UserID: p.UserID, |
||||
} |
||||
} |
||||
|
||||
func toStorageDeviceRequest(r *db.DeviceRequest) storage.DeviceRequest { |
||||
return storage.DeviceRequest{ |
||||
UserCode: strings.ToUpper(r.UserCode), |
||||
DeviceCode: r.DeviceCode, |
||||
ClientID: r.ClientID, |
||||
ClientSecret: r.ClientSecret, |
||||
Scopes: r.Scopes, |
||||
Expiry: r.Expiry, |
||||
} |
||||
} |
||||
|
||||
func toStorageDeviceToken(t *db.DeviceToken) storage.DeviceToken { |
||||
return storage.DeviceToken{ |
||||
DeviceCode: t.DeviceCode, |
||||
Status: t.Status, |
||||
Token: string(*t.Token), |
||||
Expiry: t.Expiry, |
||||
LastRequestTime: t.LastRequest, |
||||
PollIntervalSeconds: t.PollInterval, |
||||
} |
||||
} |
||||
@ -0,0 +1,44 @@
|
||||
package client |
||||
|
||||
import ( |
||||
"fmt" |
||||
"hash" |
||||
|
||||
"github.com/pkg/errors" |
||||
|
||||
"github.com/dexidp/dex/storage" |
||||
"github.com/dexidp/dex/storage/ent/db" |
||||
) |
||||
|
||||
func rollback(tx *db.Tx, t string, err error) error { |
||||
rerr := tx.Rollback() |
||||
err = convertDBError(t, err) |
||||
|
||||
if rerr == nil { |
||||
return err |
||||
} |
||||
return errors.Wrapf(err, "rolling back transaction: %v", rerr) |
||||
} |
||||
|
||||
func convertDBError(t string, err error) error { |
||||
if db.IsNotFound(err) { |
||||
return storage.ErrNotFound |
||||
} |
||||
|
||||
if db.IsConstraintError(err) { |
||||
return storage.ErrAlreadyExists |
||||
} |
||||
|
||||
return fmt.Errorf(t, err) |
||||
} |
||||
|
||||
// compose hashed id from user and connection id to use it as primary key
|
||||
// ent doesn't support multi-key primary yet
|
||||
// https://github.com/facebook/ent/issues/400
|
||||
func offlineSessionID(userID string, connID string, hasher func() hash.Hash) string { |
||||
h := hasher() |
||||
|
||||
h.Write([]byte(userID)) |
||||
h.Write([]byte(connID)) |
||||
return fmt.Sprintf("%x", h.Sum(nil)) |
||||
} |
||||
@ -0,0 +1,3 @@
|
||||
package ent |
||||
|
||||
//go:generate go run github.com/facebook/ent/cmd/entc generate ./schema --target ./db
|
||||
@ -0,0 +1,89 @@
|
||||
package schema |
||||
|
||||
import ( |
||||
"github.com/facebook/ent" |
||||
"github.com/facebook/ent/schema/field" |
||||
) |
||||
|
||||
/* Original SQL table: |
||||
create table auth_code |
||||
( |
||||
id text not null primary key, |
||||
client_id text not null, |
||||
scopes blob not null, |
||||
nonce text not null, |
||||
redirect_uri text not null, |
||||
claims_user_id text not null, |
||||
claims_username text not null, |
||||
claims_email text not null, |
||||
claims_email_verified integer not null, |
||||
claims_groups blob not null, |
||||
connector_id text not null, |
||||
connector_data blob, |
||||
expiry timestamp not null, |
||||
claims_preferred_username text default '' not null, |
||||
code_challenge text default '' not null, |
||||
code_challenge_method text default '' not null |
||||
); |
||||
*/ |
||||
|
||||
// AuthCode holds the schema definition for the AuthCode entity.
|
||||
type AuthCode struct { |
||||
ent.Schema |
||||
} |
||||
|
||||
// Fields of the AuthCode.
|
||||
func (AuthCode) Fields() []ent.Field { |
||||
return []ent.Field{ |
||||
field.Text("id"). |
||||
SchemaType(textSchema). |
||||
NotEmpty(). |
||||
Unique(), |
||||
field.Text("client_id"). |
||||
SchemaType(textSchema). |
||||
NotEmpty(), |
||||
field.JSON("scopes", []string{}). |
||||
Optional(), |
||||
field.Text("nonce"). |
||||
SchemaType(textSchema). |
||||
NotEmpty(), |
||||
field.Text("redirect_uri"). |
||||
SchemaType(textSchema). |
||||
NotEmpty(), |
||||
|
||||
field.Text("claims_user_id"). |
||||
SchemaType(textSchema). |
||||
NotEmpty(), |
||||
field.Text("claims_username"). |
||||
SchemaType(textSchema). |
||||
NotEmpty(), |
||||
field.Text("claims_email"). |
||||
SchemaType(textSchema). |
||||
NotEmpty(), |
||||
field.Bool("claims_email_verified"), |
||||
field.JSON("claims_groups", []string{}). |
||||
Optional(), |
||||
field.Text("claims_preferred_username"). |
||||
SchemaType(textSchema). |
||||
Default(""), |
||||
|
||||
field.Text("connector_id"). |
||||
SchemaType(textSchema). |
||||
NotEmpty(), |
||||
field.Bytes("connector_data"). |
||||
Nillable(). |
||||
Optional(), |
||||
field.Time("expiry"), |
||||
field.Text("code_challenge"). |
||||
SchemaType(textSchema). |
||||
Default(""), |
||||
field.Text("code_challenge_method"). |
||||
SchemaType(textSchema). |
||||
Default(""), |
||||
} |
||||
} |
||||
|
||||
// Edges of the AuthCode.
|
||||
func (AuthCode) Edges() []ent.Edge { |
||||
return []ent.Edge{} |
||||
} |
||||
@ -0,0 +1,94 @@
|
||||
package schema |
||||
|
||||
import ( |
||||
"github.com/facebook/ent" |
||||
"github.com/facebook/ent/schema/field" |
||||
) |
||||
|
||||
/* Original SQL table: |
||||
create table auth_request |
||||
( |
||||
id text not null primary key, |
||||
client_id text not null, |
||||
response_types blob not null, |
||||
scopes blob not null, |
||||
redirect_uri text not null, |
||||
nonce text not null, |
||||
state text not null, |
||||
force_approval_prompt integer not null, |
||||
logged_in integer not null, |
||||
claims_user_id text not null, |
||||
claims_username text not null, |
||||
claims_email text not null, |
||||
claims_email_verified integer not null, |
||||
claims_groups blob not null, |
||||
connector_id text not null, |
||||
connector_data blob, |
||||
expiry timestamp not null, |
||||
claims_preferred_username text default '' not null, |
||||
code_challenge text default '' not null, |
||||
code_challenge_method text default '' not null |
||||
); |
||||
*/ |
||||
|
||||
// AuthRequest holds the schema definition for the AuthRequest entity.
|
||||
type AuthRequest struct { |
||||
ent.Schema |
||||
} |
||||
|
||||
// Fields of the AuthRequest.
|
||||
func (AuthRequest) Fields() []ent.Field { |
||||
return []ent.Field{ |
||||
field.Text("id"). |
||||
SchemaType(textSchema). |
||||
NotEmpty(). |
||||
Unique(), |
||||
field.Text("client_id"). |
||||
SchemaType(textSchema), |
||||
field.JSON("scopes", []string{}). |
||||
Optional(), |
||||
field.JSON("response_types", []string{}). |
||||
Optional(), |
||||
field.Text("redirect_uri"). |
||||
SchemaType(textSchema), |
||||
field.Text("nonce"). |
||||
SchemaType(textSchema), |
||||
field.Text("state"). |
||||
SchemaType(textSchema), |
||||
|
||||
field.Bool("force_approval_prompt"), |
||||
field.Bool("logged_in"), |
||||
|
||||
field.Text("claims_user_id"). |
||||
SchemaType(textSchema), |
||||
field.Text("claims_username"). |
||||
SchemaType(textSchema), |
||||
field.Text("claims_email"). |
||||
SchemaType(textSchema), |
||||
field.Bool("claims_email_verified"), |
||||
field.JSON("claims_groups", []string{}). |
||||
Optional(), |
||||
field.Text("claims_preferred_username"). |
||||
SchemaType(textSchema). |
||||
Default(""), |
||||
|
||||
field.Text("connector_id"). |
||||
SchemaType(textSchema), |
||||
field.Bytes("connector_data"). |
||||
Nillable(). |
||||
Optional(), |
||||
field.Time("expiry"), |
||||
|
||||
field.Text("code_challenge"). |
||||
SchemaType(textSchema). |
||||
Default(""), |
||||
field.Text("code_challenge_method"). |
||||
SchemaType(textSchema). |
||||
Default(""), |
||||
} |
||||
} |
||||
|
||||
// Edges of the AuthRequest.
|
||||
func (AuthRequest) Edges() []ent.Edge { |
||||
return []ent.Edge{} |
||||
} |
||||
@ -0,0 +1,53 @@
|
||||
package schema |
||||
|
||||
import ( |
||||
"github.com/facebook/ent" |
||||
"github.com/facebook/ent/schema/field" |
||||
) |
||||
|
||||
/* Original SQL table: |
||||
create table client |
||||
( |
||||
id text not null primary key, |
||||
secret text not null, |
||||
redirect_uris blob not null, |
||||
trusted_peers blob not null, |
||||
public integer not null, |
||||
name text not null, |
||||
logo_url text not null |
||||
); |
||||
*/ |
||||
|
||||
// OAuth2Client holds the schema definition for the Client entity.
|
||||
type OAuth2Client struct { |
||||
ent.Schema |
||||
} |
||||
|
||||
// Fields of the OAuth2Client.
|
||||
func (OAuth2Client) Fields() []ent.Field { |
||||
return []ent.Field{ |
||||
field.Text("id"). |
||||
SchemaType(textSchema). |
||||
NotEmpty(). |
||||
Unique(), |
||||
field.Text("secret"). |
||||
SchemaType(textSchema). |
||||
NotEmpty(), |
||||
field.JSON("redirect_uris", []string{}). |
||||
Optional(), |
||||
field.JSON("trusted_peers", []string{}). |
||||
Optional(), |
||||
field.Bool("public"), |
||||
field.Text("name"). |
||||
SchemaType(textSchema). |
||||
NotEmpty(), |
||||
field.Text("logo_url"). |
||||
SchemaType(textSchema). |
||||
NotEmpty(), |
||||
} |
||||
} |
||||
|
||||
// Edges of the OAuth2Client.
|
||||
func (OAuth2Client) Edges() []ent.Edge { |
||||
return []ent.Edge{} |
||||
} |
||||
@ -0,0 +1,46 @@
|
||||
package schema |
||||
|
||||
import ( |
||||
"github.com/facebook/ent" |
||||
"github.com/facebook/ent/schema/field" |
||||
) |
||||
|
||||
/* Original SQL table: |
||||
create table connector |
||||
( |
||||
id text not null primary key, |
||||
type text not null, |
||||
name text not null, |
||||
resource_version text not null, |
||||
config blob |
||||
); |
||||
*/ |
||||
|
||||
// Connector holds the schema definition for the Client entity.
|
||||
type Connector struct { |
||||
ent.Schema |
||||
} |
||||
|
||||
// Fields of the Connector.
|
||||
func (Connector) Fields() []ent.Field { |
||||
return []ent.Field{ |
||||
field.Text("id"). |
||||
SchemaType(textSchema). |
||||
NotEmpty(). |
||||
Unique(), |
||||
field.Text("type"). |
||||
SchemaType(textSchema). |
||||
NotEmpty(), |
||||
field.Text("name"). |
||||
SchemaType(textSchema). |
||||
NotEmpty(), |
||||
field.Text("resource_version"). |
||||
SchemaType(textSchema), |
||||
field.Bytes("config"), |
||||
} |
||||
} |
||||
|
||||
// Edges of the Connector.
|
||||
func (Connector) Edges() []ent.Edge { |
||||
return []ent.Edge{} |
||||
} |
||||
@ -0,0 +1,50 @@
|
||||
package schema |
||||
|
||||
import ( |
||||
"github.com/facebook/ent" |
||||
"github.com/facebook/ent/schema/field" |
||||
) |
||||
|
||||
/* Original SQL table: |
||||
create table device_request |
||||
( |
||||
user_code text not null primary key, |
||||
device_code text not null, |
||||
client_id text not null, |
||||
client_secret text, |
||||
scopes blob not null, |
||||
expiry timestamp not null |
||||
); |
||||
*/ |
||||
|
||||
// DeviceRequest holds the schema definition for the DeviceRequest entity.
|
||||
type DeviceRequest struct { |
||||
ent.Schema |
||||
} |
||||
|
||||
// Fields of the DeviceRequest.
|
||||
func (DeviceRequest) Fields() []ent.Field { |
||||
return []ent.Field{ |
||||
field.Text("user_code"). |
||||
SchemaType(textSchema). |
||||
NotEmpty(). |
||||
Unique(), |
||||
field.Text("device_code"). |
||||
SchemaType(textSchema). |
||||
NotEmpty(), |
||||
field.Text("client_id"). |
||||
SchemaType(textSchema). |
||||
NotEmpty(), |
||||
field.Text("client_secret"). |
||||
SchemaType(textSchema). |
||||
NotEmpty(), |
||||
field.JSON("scopes", []string{}). |
||||
Optional(), |
||||
field.Time("expiry"), |
||||
} |
||||
} |
||||
|
||||
// Edges of the DeviceRequest.
|
||||
func (DeviceRequest) Edges() []ent.Edge { |
||||
return []ent.Edge{} |
||||
} |
||||
@ -0,0 +1,45 @@
|
||||
package schema |
||||
|
||||
import ( |
||||
"github.com/facebook/ent" |
||||
"github.com/facebook/ent/schema/field" |
||||
) |
||||
|
||||
/* Original SQL table: |
||||
create table device_token |
||||
( |
||||
device_code text not null primary key, |
||||
status text not null, |
||||
token blob, |
||||
expiry timestamp not null, |
||||
last_request timestamp not null, |
||||
poll_interval integer not null |
||||
); |
||||
*/ |
||||
|
||||
// DeviceToken holds the schema definition for the DeviceToken entity.
|
||||
type DeviceToken struct { |
||||
ent.Schema |
||||
} |
||||
|
||||
// Fields of the DeviceToken.
|
||||
func (DeviceToken) Fields() []ent.Field { |
||||
return []ent.Field{ |
||||
field.Text("device_code"). |
||||
SchemaType(textSchema). |
||||
NotEmpty(). |
||||
Unique(), |
||||
field.Text("status"). |
||||
SchemaType(textSchema). |
||||
NotEmpty(), |
||||
field.Bytes("token").Nillable().Optional(), |
||||
field.Time("expiry"), |
||||
field.Time("last_request"), |
||||
field.Int("poll_interval"), |
||||
} |
||||
} |
||||
|
||||
// Edges of the DeviceToken.
|
||||
func (DeviceToken) Edges() []ent.Edge { |
||||
return []ent.Edge{} |
||||
} |
||||
@ -0,0 +1,44 @@
|
||||
package schema |
||||
|
||||
import ( |
||||
"github.com/facebook/ent" |
||||
"github.com/facebook/ent/schema/field" |
||||
"gopkg.in/square/go-jose.v2" |
||||
|
||||
"github.com/dexidp/dex/storage" |
||||
) |
||||
|
||||
/* Original SQL table: |
||||
create table keys |
||||
( |
||||
id text not null primary key, |
||||
verification_keys blob not null, |
||||
signing_key blob not null, |
||||
signing_key_pub blob not null, |
||||
next_rotation timestamp not null |
||||
); |
||||
*/ |
||||
|
||||
// Keys holds the schema definition for the Keys entity.
|
||||
type Keys struct { |
||||
ent.Schema |
||||
} |
||||
|
||||
// Fields of the Keys.
|
||||
func (Keys) Fields() []ent.Field { |
||||
return []ent.Field{ |
||||
field.Text("id"). |
||||
SchemaType(textSchema). |
||||
NotEmpty(). |
||||
Unique(), |
||||
field.JSON("verification_keys", []storage.VerificationKey{}), |
||||
field.JSON("signing_key", jose.JSONWebKey{}), |
||||
field.JSON("signing_key_pub", jose.JSONWebKey{}), |
||||
field.Time("next_rotation"), |
||||
} |
||||
} |
||||
|
||||
// Edges of the Keys.
|
||||
func (Keys) Edges() []ent.Edge { |
||||
return []ent.Edge{} |
||||
} |
||||
@ -0,0 +1,46 @@
|
||||
package schema |
||||
|
||||
import ( |
||||
"github.com/facebook/ent" |
||||
"github.com/facebook/ent/schema/field" |
||||
) |
||||
|
||||
/* Original SQL table: |
||||
create table offline_session |
||||
( |
||||
user_id text not null, |
||||
conn_id text not null, |
||||
refresh blob not null, |
||||
connector_data blob, |
||||
primary key (user_id, conn_id) |
||||
); |
||||
*/ |
||||
|
||||
// OfflineSession holds the schema definition for the OfflineSession entity.
|
||||
type OfflineSession struct { |
||||
ent.Schema |
||||
} |
||||
|
||||
// Fields of the OfflineSession.
|
||||
func (OfflineSession) Fields() []ent.Field { |
||||
return []ent.Field{ |
||||
// Using id field here because it's impossible to create multi-key primary yet
|
||||
field.Text("id"). |
||||
SchemaType(textSchema). |
||||
NotEmpty(). |
||||
Unique(), |
||||
field.Text("user_id"). |
||||
SchemaType(textSchema). |
||||
NotEmpty(), |
||||
field.Text("conn_id"). |
||||
SchemaType(textSchema). |
||||
NotEmpty(), |
||||
field.Bytes("refresh"), |
||||
field.Bytes("connector_data").Nillable().Optional(), |
||||
} |
||||
} |
||||
|
||||
// Edges of the OfflineSession.
|
||||
func (OfflineSession) Edges() []ent.Edge { |
||||
return []ent.Edge{} |
||||
} |
||||
@ -0,0 +1,44 @@
|
||||
package schema |
||||
|
||||
import ( |
||||
"github.com/facebook/ent" |
||||
"github.com/facebook/ent/schema/field" |
||||
) |
||||
|
||||
/* Original SQL table: |
||||
create table password |
||||
( |
||||
email text not null primary key, |
||||
hash blob not null, |
||||
username text not null, |
||||
user_id text not null |
||||
); |
||||
*/ |
||||
|
||||
// Password holds the schema definition for the Password entity.
|
||||
type Password struct { |
||||
ent.Schema |
||||
} |
||||
|
||||
// Fields of the Password.
|
||||
func (Password) Fields() []ent.Field { |
||||
return []ent.Field{ |
||||
field.Text("email"). |
||||
SchemaType(textSchema). |
||||
StorageKey("email"). // use email as ID field to make querying easier
|
||||
NotEmpty(). |
||||
Unique(), |
||||
field.Bytes("hash"), |
||||
field.Text("username"). |
||||
SchemaType(textSchema). |
||||
NotEmpty(), |
||||
field.Text("user_id"). |
||||
SchemaType(textSchema). |
||||
NotEmpty(), |
||||
} |
||||
} |
||||
|
||||
// Edges of the Password.
|
||||
func (Password) Edges() []ent.Edge { |
||||
return []ent.Edge{} |
||||
} |
||||
@ -0,0 +1,89 @@
|
||||
package schema |
||||
|
||||
import ( |
||||
"time" |
||||
|
||||
"github.com/facebook/ent" |
||||
"github.com/facebook/ent/schema/field" |
||||
) |
||||
|
||||
/* Original SQL table: |
||||
create table refresh_token |
||||
( |
||||
id text not null primary key, |
||||
client_id text not null, |
||||
scopes blob not null, |
||||
nonce text not null, |
||||
claims_user_id text not null, |
||||
claims_username text not null, |
||||
claims_email text not null, |
||||
claims_email_verified integer not null, |
||||
claims_groups blob not null, |
||||
connector_id text not null, |
||||
connector_data blob, |
||||
token text default '' not null, |
||||
created_at timestamp default '0001-01-01 00:00:00 UTC' not null, |
||||
last_used timestamp default '0001-01-01 00:00:00 UTC' not null, |
||||
claims_preferred_username text default '' not null |
||||
); |
||||
*/ |
||||
|
||||
// RefreshToken holds the schema definition for the RefreshToken entity.
|
||||
type RefreshToken struct { |
||||
ent.Schema |
||||
} |
||||
|
||||
// Fields of the RefreshToken.
|
||||
func (RefreshToken) Fields() []ent.Field { |
||||
return []ent.Field{ |
||||
field.Text("id"). |
||||
SchemaType(textSchema). |
||||
NotEmpty(). |
||||
Unique(), |
||||
field.Text("client_id"). |
||||
SchemaType(textSchema). |
||||
NotEmpty(), |
||||
field.JSON("scopes", []string{}). |
||||
Optional(), |
||||
field.Text("nonce"). |
||||
SchemaType(textSchema). |
||||
NotEmpty(), |
||||
|
||||
field.Text("claims_user_id"). |
||||
SchemaType(textSchema). |
||||
NotEmpty(), |
||||
field.Text("claims_username"). |
||||
SchemaType(textSchema). |
||||
NotEmpty(), |
||||
field.Text("claims_email"). |
||||
SchemaType(textSchema). |
||||
NotEmpty(), |
||||
field.Bool("claims_email_verified"), |
||||
field.JSON("claims_groups", []string{}). |
||||
Optional(), |
||||
field.Text("claims_preferred_username"). |
||||
SchemaType(textSchema). |
||||
Default(""), |
||||
|
||||
field.Text("connector_id"). |
||||
SchemaType(textSchema). |
||||
NotEmpty(), |
||||
field.Bytes("connector_data"). |
||||
Nillable(). |
||||
Optional(), |
||||
|
||||
field.Text("token"). |
||||
SchemaType(textSchema). |
||||
Default(""), |
||||
|
||||
field.Time("created_at"). |
||||
Default(time.Now), |
||||
field.Time("last_used"). |
||||
Default(time.Now), |
||||
} |
||||
} |
||||
|
||||
// Edges of the RefreshToken.
|
||||
func (RefreshToken) Edges() []ent.Edge { |
||||
return []ent.Edge{} |
||||
} |
||||
@ -0,0 +1,9 @@
|
||||
package schema |
||||
|
||||
import ( |
||||
"github.com/facebook/ent/dialect" |
||||
) |
||||
|
||||
var textSchema = map[string]string{ |
||||
dialect.SQLite: "text", |
||||
} |
||||
@ -0,0 +1,65 @@
|
||||
package ent |
||||
|
||||
import ( |
||||
"context" |
||||
"crypto/sha256" |
||||
"strings" |
||||
|
||||
"github.com/facebook/ent/dialect/sql" |
||||
|
||||
// Register sqlite driver.
|
||||
_ "github.com/mattn/go-sqlite3" |
||||
|
||||
"github.com/dexidp/dex/pkg/log" |
||||
"github.com/dexidp/dex/storage" |
||||
"github.com/dexidp/dex/storage/ent/client" |
||||
"github.com/dexidp/dex/storage/ent/db" |
||||
) |
||||
|
||||
// SQLite3 options for creating an SQL db.
|
||||
type SQLite3 struct { |
||||
File string `json:"file"` |
||||
} |
||||
|
||||
// Open always returns a new in sqlite3 storage.
|
||||
func (s *SQLite3) Open(logger log.Logger) (storage.Storage, error) { |
||||
logger.Debug("experimental ent-based storage driver is enabled") |
||||
|
||||
// Implicitly set foreign_keys pragma to "on" because it is required by ent
|
||||
s.File = addFK(s.File) |
||||
|
||||
drv, err := sql.Open("sqlite3", s.File) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
pool := drv.DB() |
||||
if s.File == ":memory:" { |
||||
// sqlite3 uses file locks to coordinate concurrent access. In memory
|
||||
// doesn't support this, so limit the number of connections to 1.
|
||||
pool.SetMaxOpenConns(1) |
||||
} |
||||
|
||||
databaseClient := client.NewDatabase( |
||||
client.WithClient(db.NewClient(db.Driver(drv))), |
||||
client.WithHasher(sha256.New), |
||||
) |
||||
|
||||
if err := databaseClient.Schema().Create(context.TODO()); err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
return databaseClient, nil |
||||
} |
||||
|
||||
func addFK(dsn string) string { |
||||
if strings.Contains(dsn, "_fk") { |
||||
return dsn |
||||
} |
||||
|
||||
delim := "?" |
||||
if strings.Contains(dsn, "?") { |
||||
delim = "&" |
||||
} |
||||
return dsn + delim + "_fk=1" |
||||
} |
||||
@ -0,0 +1,31 @@
|
||||
package ent |
||||
|
||||
import ( |
||||
"os" |
||||
"testing" |
||||
|
||||
"github.com/sirupsen/logrus" |
||||
|
||||
"github.com/dexidp/dex/storage" |
||||
"github.com/dexidp/dex/storage/conformance" |
||||
) |
||||
|
||||
func newStorage() storage.Storage { |
||||
logger := &logrus.Logger{ |
||||
Out: os.Stderr, |
||||
Formatter: &logrus.TextFormatter{DisableColors: true}, |
||||
Level: logrus.DebugLevel, |
||||
} |
||||
|
||||
cfg := SQLite3{File: ":memory:"} |
||||
s, err := cfg.Open(logger) |
||||
if err != nil { |
||||
panic(err) |
||||
} |
||||
return s |
||||
} |
||||
|
||||
func TestSQLite3(t *testing.T) { |
||||
conformance.RunTests(t, newStorage) |
||||
conformance.RunTransactionTests(t, newStorage) |
||||
} |
||||
Loading…
Reference in new issue