mirror of https://github.com/dexidp/dex.git
Browse Source
This patch adds etcd storage implementation. This should be useful in environments where - we dont want to depends on a separate, hard to maintain SQL cluster - we dont want to incur the overhead of talking to kubernetes apiservers - kubernetes is not available yet, or if kubernetes depends on dex to perform authentication and the operator would like to remove any circular dependency if possible.pull/1108/head
6 changed files with 1058 additions and 0 deletions
@ -0,0 +1,92 @@
|
||||
package etcd |
||||
|
||||
import ( |
||||
"time" |
||||
|
||||
"github.com/coreos/dex/storage" |
||||
"github.com/coreos/etcd/clientv3" |
||||
"github.com/coreos/etcd/clientv3/namespace" |
||||
"github.com/coreos/etcd/pkg/transport" |
||||
"github.com/sirupsen/logrus" |
||||
) |
||||
|
||||
var ( |
||||
defaultDialTimeout = 2 * time.Second |
||||
) |
||||
|
||||
// SSL represents SSL options for etcd databases.
|
||||
type SSL struct { |
||||
ServerName string |
||||
CAFile string |
||||
KeyFile string |
||||
CertFile string |
||||
} |
||||
|
||||
// Etcd options for connecting to etcd databases.
|
||||
// If you are using a shared etcd cluster for storage, it might be useful to
|
||||
// configure an etcd namespace either via Namespace field or using `etcd grpc-proxy
|
||||
// --namespace=<prefix>`
|
||||
type Etcd struct { |
||||
Endpoints []string |
||||
Namespace string |
||||
Username string |
||||
Password string |
||||
SSL SSL |
||||
} |
||||
|
||||
// Open creates a new storage implementation backed by Etcd
|
||||
func (p *Etcd) Open(logger logrus.FieldLogger) (storage.Storage, error) { |
||||
return p.open(logger) |
||||
} |
||||
|
||||
func (p *Etcd) open(logger logrus.FieldLogger) (*conn, error) { |
||||
cfg := clientv3.Config{ |
||||
Endpoints: p.Endpoints, |
||||
DialTimeout: defaultDialTimeout * time.Second, |
||||
Username: p.Username, |
||||
Password: p.Password, |
||||
} |
||||
|
||||
var cfgtls *transport.TLSInfo |
||||
tlsinfo := transport.TLSInfo{} |
||||
if p.SSL.CertFile != "" { |
||||
tlsinfo.CertFile = p.SSL.CertFile |
||||
cfgtls = &tlsinfo |
||||
} |
||||
|
||||
if p.SSL.KeyFile != "" { |
||||
tlsinfo.KeyFile = p.SSL.KeyFile |
||||
cfgtls = &tlsinfo |
||||
} |
||||
|
||||
if p.SSL.CAFile != "" { |
||||
tlsinfo.CAFile = p.SSL.CAFile |
||||
cfgtls = &tlsinfo |
||||
} |
||||
|
||||
if p.SSL.ServerName != "" { |
||||
tlsinfo.ServerName = p.SSL.ServerName |
||||
cfgtls = &tlsinfo |
||||
} |
||||
|
||||
if cfgtls != nil { |
||||
clientTLS, err := cfgtls.ClientConfig() |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
cfg.TLS = clientTLS |
||||
} |
||||
|
||||
db, err := clientv3.New(cfg) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
if len(p.Namespace) > 0 { |
||||
db.KV = namespace.NewKV(db.KV, p.Namespace) |
||||
} |
||||
c := &conn{ |
||||
db: db, |
||||
logger: logger, |
||||
} |
||||
return c, nil |
||||
} |
||||
@ -0,0 +1,532 @@
|
||||
package etcd |
||||
|
||||
import ( |
||||
"context" |
||||
"encoding/json" |
||||
"fmt" |
||||
"strings" |
||||
"time" |
||||
|
||||
"github.com/coreos/dex/storage" |
||||
"github.com/coreos/etcd/clientv3" |
||||
"github.com/sirupsen/logrus" |
||||
) |
||||
|
||||
const ( |
||||
clientPrefix = "client/" |
||||
authCodePrefix = "auth_code/" |
||||
refreshTokenPrefix = "refresh_token/" |
||||
authRequestPrefix = "auth_req/" |
||||
passwordPrefix = "password/" |
||||
offlineSessionPrefix = "offline_session/" |
||||
connectorPrefix = "connector/" |
||||
keysName = "openid-connect-keys" |
||||
|
||||
// defaultStorageTimeout will be applied to all storage's operations.
|
||||
defaultStorageTimeout = 5 * time.Second |
||||
) |
||||
|
||||
type conn struct { |
||||
db *clientv3.Client |
||||
logger logrus.FieldLogger |
||||
} |
||||
|
||||
func (c *conn) Close() error { |
||||
return c.db.Close() |
||||
} |
||||
|
||||
func (c *conn) GarbageCollect(now time.Time) (result storage.GCResult, err error) { |
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout) |
||||
defer cancel() |
||||
authRequests, err := c.listAuthRequests(ctx) |
||||
if err != nil { |
||||
return result, err |
||||
} |
||||
|
||||
var delErr error |
||||
for _, authRequest := range authRequests { |
||||
if now.After(authRequest.Expiry) { |
||||
if err := c.deleteKey(ctx, keyID(authRequestPrefix, authRequest.ID)); err != nil { |
||||
c.logger.Errorf("failed to delete auth request: %v", err) |
||||
delErr = fmt.Errorf("failed to delete auth request: %v", err) |
||||
} |
||||
result.AuthRequests++ |
||||
} |
||||
} |
||||
if delErr != nil { |
||||
return result, delErr |
||||
} |
||||
|
||||
authCodes, err := c.listAuthCodes(ctx) |
||||
if err != nil { |
||||
return result, err |
||||
} |
||||
|
||||
for _, authCode := range authCodes { |
||||
if now.After(authCode.Expiry) { |
||||
if err := c.deleteKey(ctx, keyID(authCodePrefix, authCode.ID)); err != nil { |
||||
c.logger.Errorf("failed to delete auth code %v", err) |
||||
delErr = fmt.Errorf("failed to delete auth code: %v", err) |
||||
} |
||||
result.AuthCodes++ |
||||
} |
||||
} |
||||
return result, delErr |
||||
} |
||||
|
||||
func (c *conn) CreateAuthRequest(a storage.AuthRequest) error { |
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout) |
||||
defer cancel() |
||||
return c.txnCreate(ctx, keyID(authRequestPrefix, a.ID), fromStorageAuthRequest(a)) |
||||
} |
||||
|
||||
func (c *conn) GetAuthRequest(id string) (a storage.AuthRequest, err error) { |
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout) |
||||
defer cancel() |
||||
var req AuthRequest |
||||
if err = c.getKey(ctx, keyID(authRequestPrefix, id), &req); err != nil { |
||||
return |
||||
} |
||||
return toStorageAuthRequest(req), nil |
||||
} |
||||
|
||||
func (c *conn) UpdateAuthRequest(id string, updater func(a storage.AuthRequest) (storage.AuthRequest, error)) error { |
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout) |
||||
defer cancel() |
||||
return c.txnUpdate(ctx, keyID(authRequestPrefix, id), func(currentValue []byte) ([]byte, error) { |
||||
var current AuthRequest |
||||
if len(currentValue) > 0 { |
||||
if err := json.Unmarshal(currentValue, ¤t); err != nil { |
||||
return nil, err |
||||
} |
||||
} |
||||
updated, err := updater(toStorageAuthRequest(current)) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
return json.Marshal(fromStorageAuthRequest(updated)) |
||||
}) |
||||
} |
||||
|
||||
func (c *conn) DeleteAuthRequest(id string) error { |
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout) |
||||
defer cancel() |
||||
return c.deleteKey(ctx, keyID(authRequestPrefix, id)) |
||||
} |
||||
|
||||
func (c *conn) CreateAuthCode(a storage.AuthCode) error { |
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout) |
||||
defer cancel() |
||||
return c.txnCreate(ctx, keyID(authCodePrefix, a.ID), fromStorageAuthCode(a)) |
||||
} |
||||
|
||||
func (c *conn) GetAuthCode(id string) (a storage.AuthCode, err error) { |
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout) |
||||
defer cancel() |
||||
err = c.getKey(ctx, keyID(authCodePrefix, id), &a) |
||||
return a, err |
||||
} |
||||
|
||||
func (c *conn) DeleteAuthCode(id string) error { |
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout) |
||||
defer cancel() |
||||
return c.deleteKey(ctx, keyID(authCodePrefix, id)) |
||||
} |
||||
|
||||
func (c *conn) CreateRefresh(r storage.RefreshToken) error { |
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout) |
||||
defer cancel() |
||||
return c.txnCreate(ctx, keyID(refreshTokenPrefix, r.ID), fromStorageRefreshToken(r)) |
||||
} |
||||
|
||||
func (c *conn) GetRefresh(id string) (r storage.RefreshToken, err error) { |
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout) |
||||
defer cancel() |
||||
var token RefreshToken |
||||
if err = c.getKey(ctx, keyID(refreshTokenPrefix, id), &token); err != nil { |
||||
return |
||||
} |
||||
return toStorageRefreshToken(token), nil |
||||
} |
||||
|
||||
func (c *conn) UpdateRefreshToken(id string, updater func(old storage.RefreshToken) (storage.RefreshToken, error)) error { |
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout) |
||||
defer cancel() |
||||
return c.txnUpdate(ctx, keyID(refreshTokenPrefix, id), func(currentValue []byte) ([]byte, error) { |
||||
var current RefreshToken |
||||
if len(currentValue) > 0 { |
||||
if err := json.Unmarshal([]byte(currentValue), ¤t); err != nil { |
||||
return nil, err |
||||
} |
||||
} |
||||
updated, err := updater(toStorageRefreshToken(current)) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
return json.Marshal(fromStorageRefreshToken(updated)) |
||||
}) |
||||
} |
||||
|
||||
func (c *conn) DeleteRefresh(id string) error { |
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout) |
||||
defer cancel() |
||||
return c.deleteKey(ctx, keyID(refreshTokenPrefix, id)) |
||||
} |
||||
|
||||
func (c *conn) ListRefreshTokens() (tokens []storage.RefreshToken, err error) { |
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout) |
||||
defer cancel() |
||||
res, err := c.db.Get(ctx, refreshTokenPrefix, clientv3.WithPrefix()) |
||||
if err != nil { |
||||
return tokens, err |
||||
} |
||||
for _, v := range res.Kvs { |
||||
var token RefreshToken |
||||
if err = json.Unmarshal(v.Value, &token); err != nil { |
||||
return tokens, err |
||||
} |
||||
tokens = append(tokens, toStorageRefreshToken(token)) |
||||
} |
||||
return tokens, nil |
||||
} |
||||
|
||||
func (c *conn) CreateClient(cli storage.Client) error { |
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout) |
||||
defer cancel() |
||||
return c.txnCreate(ctx, keyID(clientPrefix, cli.ID), cli) |
||||
} |
||||
|
||||
func (c *conn) GetClient(id string) (cli storage.Client, err error) { |
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout) |
||||
defer cancel() |
||||
err = c.getKey(ctx, keyID(clientPrefix, id), &cli) |
||||
return cli, err |
||||
} |
||||
|
||||
func (c *conn) UpdateClient(id string, updater func(old storage.Client) (storage.Client, error)) error { |
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout) |
||||
defer cancel() |
||||
return c.txnUpdate(ctx, keyID(clientPrefix, id), func(currentValue []byte) ([]byte, error) { |
||||
var current storage.Client |
||||
if len(currentValue) > 0 { |
||||
if err := json.Unmarshal(currentValue, ¤t); err != nil { |
||||
return nil, err |
||||
} |
||||
} |
||||
updated, err := updater(current) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
return json.Marshal(updated) |
||||
}) |
||||
} |
||||
|
||||
func (c *conn) DeleteClient(id string) error { |
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout) |
||||
defer cancel() |
||||
return c.deleteKey(ctx, keyID(clientPrefix, id)) |
||||
} |
||||
|
||||
func (c *conn) ListClients() (clients []storage.Client, err error) { |
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout) |
||||
defer cancel() |
||||
res, err := c.db.Get(ctx, clientPrefix, clientv3.WithPrefix()) |
||||
if err != nil { |
||||
return clients, err |
||||
} |
||||
for _, v := range res.Kvs { |
||||
var cli storage.Client |
||||
if err = json.Unmarshal(v.Value, &cli); err != nil { |
||||
return clients, err |
||||
} |
||||
clients = append(clients, cli) |
||||
} |
||||
return clients, nil |
||||
} |
||||
|
||||
func (c *conn) CreatePassword(p storage.Password) error { |
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout) |
||||
defer cancel() |
||||
return c.txnCreate(ctx, passwordPrefix+strings.ToLower(p.Email), p) |
||||
} |
||||
|
||||
func (c *conn) GetPassword(email string) (p storage.Password, err error) { |
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout) |
||||
defer cancel() |
||||
err = c.getKey(ctx, keyEmail(passwordPrefix, email), &p) |
||||
return p, err |
||||
} |
||||
|
||||
func (c *conn) UpdatePassword(email string, updater func(p storage.Password) (storage.Password, error)) error { |
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout) |
||||
defer cancel() |
||||
return c.txnUpdate(ctx, keyEmail(passwordPrefix, email), func(currentValue []byte) ([]byte, error) { |
||||
var current storage.Password |
||||
if len(currentValue) > 0 { |
||||
if err := json.Unmarshal(currentValue, ¤t); err != nil { |
||||
return nil, err |
||||
} |
||||
} |
||||
updated, err := updater(current) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
return json.Marshal(updated) |
||||
}) |
||||
} |
||||
|
||||
func (c *conn) DeletePassword(email string) error { |
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout) |
||||
defer cancel() |
||||
return c.deleteKey(ctx, keyEmail(passwordPrefix, email)) |
||||
} |
||||
|
||||
func (c *conn) ListPasswords() (passwords []storage.Password, err error) { |
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout) |
||||
defer cancel() |
||||
res, err := c.db.Get(ctx, passwordPrefix, clientv3.WithPrefix()) |
||||
if err != nil { |
||||
return passwords, err |
||||
} |
||||
for _, v := range res.Kvs { |
||||
var p storage.Password |
||||
if err = json.Unmarshal(v.Value, &p); err != nil { |
||||
return passwords, err |
||||
} |
||||
passwords = append(passwords, p) |
||||
} |
||||
return passwords, nil |
||||
} |
||||
|
||||
func (c *conn) CreateOfflineSessions(s storage.OfflineSessions) error { |
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout) |
||||
defer cancel() |
||||
return c.txnCreate(ctx, keySession(offlineSessionPrefix, s.UserID, s.ConnID), fromStorageOfflineSessions(s)) |
||||
} |
||||
|
||||
func (c *conn) UpdateOfflineSessions(userID string, connID string, updater func(s storage.OfflineSessions) (storage.OfflineSessions, error)) error { |
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout) |
||||
defer cancel() |
||||
return c.txnUpdate(ctx, keySession(offlineSessionPrefix, userID, connID), func(currentValue []byte) ([]byte, error) { |
||||
var current OfflineSessions |
||||
if len(currentValue) > 0 { |
||||
if err := json.Unmarshal(currentValue, ¤t); err != nil { |
||||
return nil, err |
||||
} |
||||
} |
||||
updated, err := updater(toStorageOfflineSessions(current)) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
return json.Marshal(fromStorageOfflineSessions(updated)) |
||||
}) |
||||
} |
||||
|
||||
func (c *conn) GetOfflineSessions(userID string, connID string) (s storage.OfflineSessions, err error) { |
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout) |
||||
defer cancel() |
||||
var os OfflineSessions |
||||
if err = c.getKey(ctx, keySession(offlineSessionPrefix, userID, connID), &os); err != nil { |
||||
return |
||||
} |
||||
return toStorageOfflineSessions(os), nil |
||||
} |
||||
|
||||
func (c *conn) DeleteOfflineSessions(userID string, connID string) error { |
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout) |
||||
defer cancel() |
||||
return c.deleteKey(ctx, keySession(offlineSessionPrefix, userID, connID)) |
||||
} |
||||
|
||||
func (c *conn) CreateConnector(connector storage.Connector) error { |
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout) |
||||
defer cancel() |
||||
return c.txnCreate(ctx, keyID(connectorPrefix, connector.ID), connector) |
||||
} |
||||
|
||||
func (c *conn) GetConnector(id string) (conn storage.Connector, err error) { |
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout) |
||||
defer cancel() |
||||
err = c.getKey(ctx, keyID(connectorPrefix, id), &conn) |
||||
return conn, err |
||||
} |
||||
|
||||
func (c *conn) UpdateConnector(id string, updater func(s storage.Connector) (storage.Connector, error)) error { |
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout) |
||||
defer cancel() |
||||
return c.txnUpdate(ctx, keyID(connectorPrefix, id), func(currentValue []byte) ([]byte, error) { |
||||
var current storage.Connector |
||||
if len(currentValue) > 0 { |
||||
if err := json.Unmarshal(currentValue, ¤t); err != nil { |
||||
return nil, err |
||||
} |
||||
} |
||||
updated, err := updater(current) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
return json.Marshal(updated) |
||||
}) |
||||
} |
||||
|
||||
func (c *conn) DeleteConnector(id string) error { |
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout) |
||||
defer cancel() |
||||
return c.deleteKey(ctx, keyID(connectorPrefix, id)) |
||||
} |
||||
|
||||
func (c *conn) ListConnectors() (connectors []storage.Connector, err error) { |
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout) |
||||
defer cancel() |
||||
res, err := c.db.Get(ctx, connectorPrefix, clientv3.WithPrefix()) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
for _, v := range res.Kvs { |
||||
var c storage.Connector |
||||
if err = json.Unmarshal(v.Value, &c); err != nil { |
||||
return nil, err |
||||
} |
||||
connectors = append(connectors, c) |
||||
} |
||||
return connectors, nil |
||||
} |
||||
|
||||
func (c *conn) GetKeys() (keys storage.Keys, err error) { |
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout) |
||||
defer cancel() |
||||
res, err := c.db.Get(ctx, keysName) |
||||
if err != nil { |
||||
return keys, err |
||||
} |
||||
if res.Count > 0 && len(res.Kvs) > 0 { |
||||
err = json.Unmarshal(res.Kvs[0].Value, &keys) |
||||
} |
||||
return keys, err |
||||
} |
||||
|
||||
func (c *conn) UpdateKeys(updater func(old storage.Keys) (storage.Keys, error)) error { |
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout) |
||||
defer cancel() |
||||
return c.txnUpdate(ctx, keysName, func(currentValue []byte) ([]byte, error) { |
||||
var current storage.Keys |
||||
if len(currentValue) > 0 { |
||||
if err := json.Unmarshal(currentValue, ¤t); err != nil { |
||||
return nil, err |
||||
} |
||||
} |
||||
updated, err := updater(current) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
return json.Marshal(updated) |
||||
}) |
||||
} |
||||
|
||||
func (c *conn) deleteKey(ctx context.Context, key string) error { |
||||
res, err := c.db.Delete(ctx, key) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
if res.Deleted == 0 { |
||||
return storage.ErrNotFound |
||||
} |
||||
return nil |
||||
} |
||||
|
||||
func (c *conn) getKey(ctx context.Context, key string, value interface{}) error { |
||||
r, err := c.db.Get(ctx, key) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
if r.Count == 0 { |
||||
return storage.ErrNotFound |
||||
} |
||||
return json.Unmarshal(r.Kvs[0].Value, value) |
||||
} |
||||
|
||||
func (c *conn) listAuthRequests(ctx context.Context) (reqs []AuthRequest, err error) { |
||||
res, err := c.db.Get(ctx, authRequestPrefix, clientv3.WithPrefix()) |
||||
if err != nil { |
||||
return reqs, err |
||||
} |
||||
for _, v := range res.Kvs { |
||||
var r AuthRequest |
||||
if err = json.Unmarshal(v.Value, &r); err != nil { |
||||
return reqs, err |
||||
} |
||||
reqs = append(reqs, r) |
||||
} |
||||
return reqs, nil |
||||
} |
||||
|
||||
func (c *conn) listAuthCodes(ctx context.Context) (codes []AuthCode, err error) { |
||||
res, err := c.db.Get(ctx, authCodePrefix, clientv3.WithPrefix()) |
||||
if err != nil { |
||||
return codes, err |
||||
} |
||||
for _, v := range res.Kvs { |
||||
var c AuthCode |
||||
if err = json.Unmarshal(v.Value, &c); err != nil { |
||||
return codes, err |
||||
} |
||||
codes = append(codes, c) |
||||
} |
||||
return codes, nil |
||||
} |
||||
|
||||
func (c *conn) txnCreate(ctx context.Context, key string, value interface{}) error { |
||||
b, err := json.Marshal(value) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
txn := c.db.Txn(ctx) |
||||
res, err := txn. |
||||
If(clientv3.Compare(clientv3.CreateRevision(key), "=", 0)). |
||||
Then(clientv3.OpPut(key, string(b))). |
||||
Commit() |
||||
if err != nil { |
||||
return err |
||||
} |
||||
if !res.Succeeded { |
||||
return storage.ErrAlreadyExists |
||||
} |
||||
return nil |
||||
} |
||||
|
||||
func (c *conn) txnUpdate(ctx context.Context, key string, update func(current []byte) ([]byte, error)) error { |
||||
getResp, err := c.db.Get(ctx, key) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
var currentValue []byte |
||||
var modRev int64 |
||||
if len(getResp.Kvs) > 0 { |
||||
currentValue = getResp.Kvs[0].Value |
||||
modRev = getResp.Kvs[0].ModRevision |
||||
} |
||||
|
||||
updatedValue, err := update(currentValue) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
|
||||
txn := c.db.Txn(ctx) |
||||
updateResp, err := txn. |
||||
If(clientv3.Compare(clientv3.ModRevision(key), "=", modRev)). |
||||
Then(clientv3.OpPut(key, string(updatedValue))). |
||||
Commit() |
||||
if err != nil { |
||||
return err |
||||
} |
||||
if !updateResp.Succeeded { |
||||
return fmt.Errorf("failed to update key=%q: concurrent conflicting update happened", key) |
||||
} |
||||
return nil |
||||
} |
||||
|
||||
func keyID(prefix, id string) string { return prefix + id } |
||||
func keyEmail(prefix, email string) string { return prefix + strings.ToLower(email) } |
||||
func keySession(prefix, userID, connID string) string { |
||||
return prefix + strings.ToLower(userID+"|"+connID) |
||||
} |
||||
@ -0,0 +1,94 @@
|
||||
package etcd |
||||
|
||||
import ( |
||||
"context" |
||||
"fmt" |
||||
"os" |
||||
"runtime" |
||||
"strings" |
||||
"testing" |
||||
"time" |
||||
|
||||
"github.com/coreos/dex/storage" |
||||
"github.com/coreos/dex/storage/conformance" |
||||
"github.com/coreos/etcd/clientv3" |
||||
"github.com/sirupsen/logrus" |
||||
) |
||||
|
||||
func withTimeout(t time.Duration, f func()) { |
||||
c := make(chan struct{}) |
||||
defer close(c) |
||||
|
||||
go func() { |
||||
select { |
||||
case <-c: |
||||
case <-time.After(t): |
||||
// Dump a stack trace of the program. Useful for debugging deadlocks.
|
||||
buf := make([]byte, 2<<20) |
||||
fmt.Fprintf(os.Stderr, "%s\n", buf[:runtime.Stack(buf, true)]) |
||||
panic("test took too long") |
||||
} |
||||
}() |
||||
|
||||
f() |
||||
} |
||||
|
||||
func cleanDB(c *conn) error { |
||||
ctx := context.TODO() |
||||
for _, prefix := range []string{ |
||||
clientPrefix, |
||||
authCodePrefix, |
||||
refreshTokenPrefix, |
||||
authRequestPrefix, |
||||
passwordPrefix, |
||||
offlineSessionPrefix, |
||||
connectorPrefix, |
||||
} { |
||||
_, err := c.db.Delete(ctx, prefix, clientv3.WithPrefix()) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
} |
||||
return nil |
||||
} |
||||
|
||||
var logger = &logrus.Logger{ |
||||
Out: os.Stderr, |
||||
Formatter: &logrus.TextFormatter{DisableColors: true}, |
||||
Level: logrus.DebugLevel, |
||||
} |
||||
|
||||
func TestEtcd(t *testing.T) { |
||||
testEtcdEnv := "DEX_ETCD_ENDPOINTS" |
||||
endpointsStr := os.Getenv(testEtcdEnv) |
||||
if endpointsStr == "" { |
||||
t.Skipf("test environment variable %q not set, skipping", testEtcdEnv) |
||||
return |
||||
} |
||||
endpoints := strings.Split(endpointsStr, ",") |
||||
|
||||
newStorage := func() storage.Storage { |
||||
s := &Etcd{ |
||||
Endpoints: endpoints, |
||||
} |
||||
conn, err := s.open(logger) |
||||
if err != nil { |
||||
fmt.Fprintln(os.Stdout, err) |
||||
t.Fatal(err) |
||||
} |
||||
|
||||
if err := cleanDB(conn); err != nil { |
||||
fmt.Fprintln(os.Stdout, err) |
||||
t.Fatal(err) |
||||
} |
||||
return conn |
||||
} |
||||
|
||||
withTimeout(time.Second*10, func() { |
||||
conformance.RunTests(t, newStorage) |
||||
}) |
||||
|
||||
withTimeout(time.Minute*1, func() { |
||||
conformance.RunTransactionTests(t, newStorage) |
||||
}) |
||||
} |
||||
@ -0,0 +1,109 @@
|
||||
#!/bin/bash |
||||
|
||||
if [ "$EUID" -ne 0 ] |
||||
then echo "Please run as root" |
||||
exit |
||||
fi |
||||
|
||||
function usage { |
||||
cat << EOF >> /dev/stderr |
||||
Usage: sudo ./standup.sh [create|destroy] [etcd] |
||||
|
||||
This is a script for standing up test databases. It uses systemd to daemonize |
||||
rkt containers running on a local loopback IP. |
||||
|
||||
The general workflow is to create a daemonized container, use the output to set |
||||
the test environment variables, run the tests, then destroy the container. |
||||
|
||||
sudo ./standup.sh create etcd |
||||
# Copy environment variables and run tests. |
||||
go test -v -i # always install test dependencies |
||||
go test -v |
||||
sudo ./standup.sh destroy etcd |
||||
|
||||
EOF |
||||
exit 2 |
||||
} |
||||
|
||||
function main { |
||||
if [ "$#" -ne 2 ]; then |
||||
usage |
||||
exit 2 |
||||
fi |
||||
|
||||
case "$1" in |
||||
"create") |
||||
case "$2" in |
||||
"etcd") |
||||
create_etcd;; |
||||
*) |
||||
usage |
||||
exit 2 |
||||
;; |
||||
esac |
||||
;; |
||||
"destroy") |
||||
case "$2" in |
||||
"etcd") |
||||
destroy_etcd;; |
||||
*) |
||||
usage |
||||
exit 2 |
||||
;; |
||||
esac |
||||
;; |
||||
*) |
||||
usage |
||||
exit 2 |
||||
;; |
||||
esac |
||||
} |
||||
|
||||
function wait_for_file { |
||||
while [ ! -f $1 ]; do |
||||
sleep 1 |
||||
done |
||||
} |
||||
|
||||
function wait_for_container { |
||||
while [ -z "$( rkt list --full | grep $1 | grep running )" ]; do |
||||
sleep 1 |
||||
done |
||||
} |
||||
|
||||
function create_etcd { |
||||
UUID_FILE=/tmp/dex-etcd-uuid |
||||
if [ -f $UUID_FILE ]; then |
||||
echo "etcd database already exists, try ./standup.sh destroy etcd" |
||||
exit 2 |
||||
fi |
||||
|
||||
echo "Starting etcd . To view progress run:" |
||||
echo "" |
||||
echo " journalctl -fu dex-etcd" |
||||
echo "" |
||||
UNIFIED_CGROUP_HIERARCHY=no \ |
||||
systemd-run --unit=dex-etcd \ |
||||
rkt run --uuid-file-save=$UUID_FILE --insecure-options=image \ |
||||
--net=host \ |
||||
docker://quay.io/coreos/etcd:v3.2.9 |
||||
|
||||
wait_for_file $UUID_FILE |
||||
|
||||
UUID=$( cat $UUID_FILE ) |
||||
wait_for_container $UUID |
||||
echo "To run tests export the following environment variables:" |
||||
echo "" |
||||
echo " export DEX_ETCD_ENDPOINTS=http://localhost:2379" |
||||
echo "" |
||||
} |
||||
|
||||
function destroy_etcd { |
||||
UUID_FILE=/tmp/dex-etcd-uuid |
||||
systemctl stop dex-etcd |
||||
rkt rm --uuid-file=$UUID_FILE |
||||
rm $UUID_FILE |
||||
} |
||||
|
||||
|
||||
main $@ |
||||
@ -0,0 +1,229 @@
|
||||
package etcd |
||||
|
||||
import ( |
||||
"time" |
||||
|
||||
"github.com/coreos/dex/storage" |
||||
jose "gopkg.in/square/go-jose.v2" |
||||
) |
||||
|
||||
// AuthCode is a mirrored struct from storage with JSON struct tags
|
||||
type AuthCode struct { |
||||
ID string `json:"ID"` |
||||
ClientID string `json:"clientID"` |
||||
RedirectURI string `json:"redirectURI"` |
||||
Nonce string `json:"nonce,omitempty"` |
||||
Scopes []string `json:"scopes,omitempty"` |
||||
|
||||
ConnectorID string `json:"connectorID,omitempty"` |
||||
ConnectorData []byte `json:"connectorData,omitempty"` |
||||
Claims Claims `json:"claims,omitempty"` |
||||
|
||||
Expiry time.Time `json:"expiry"` |
||||
} |
||||
|
||||
func fromStorageAuthCode(a storage.AuthCode) AuthCode { |
||||
return AuthCode{ |
||||
ID: a.ID, |
||||
ClientID: a.ClientID, |
||||
RedirectURI: a.RedirectURI, |
||||
ConnectorID: a.ConnectorID, |
||||
ConnectorData: a.ConnectorData, |
||||
Nonce: a.Nonce, |
||||
Scopes: a.Scopes, |
||||
Claims: fromStorageClaims(a.Claims), |
||||
Expiry: a.Expiry, |
||||
} |
||||
} |
||||
|
||||
// AuthRequest is a mirrored struct from storage with JSON struct tags
|
||||
type AuthRequest struct { |
||||
ID string `json:"id"` |
||||
ClientID string `json:"client_id"` |
||||
|
||||
ResponseTypes []string `json:"response_types"` |
||||
Scopes []string `json:"scopes"` |
||||
RedirectURI string `json:"redirect_uri"` |
||||
Nonce string `json:"nonce"` |
||||
State string `json:"state"` |
||||
|
||||
ForceApprovalPrompt bool `json:"force_approval_prompt"` |
||||
|
||||
Expiry time.Time `json:"expiry"` |
||||
|
||||
LoggedIn bool `json:"logged_in"` |
||||
|
||||
Claims Claims `json:"claims"` |
||||
|
||||
ConnectorID string `json:"connector_id"` |
||||
ConnectorData []byte `json:"connector_data"` |
||||
} |
||||
|
||||
func fromStorageAuthRequest(a storage.AuthRequest) AuthRequest { |
||||
return AuthRequest{ |
||||
ID: a.ID, |
||||
ClientID: a.ClientID, |
||||
ResponseTypes: a.ResponseTypes, |
||||
Scopes: a.Scopes, |
||||
RedirectURI: a.RedirectURI, |
||||
Nonce: a.Nonce, |
||||
State: a.State, |
||||
ForceApprovalPrompt: a.ForceApprovalPrompt, |
||||
Expiry: a.Expiry, |
||||
LoggedIn: a.LoggedIn, |
||||
Claims: fromStorageClaims(a.Claims), |
||||
ConnectorID: a.ConnectorID, |
||||
ConnectorData: a.ConnectorData, |
||||
} |
||||
} |
||||
|
||||
func toStorageAuthRequest(a AuthRequest) storage.AuthRequest { |
||||
return storage.AuthRequest{ |
||||
ID: a.ID, |
||||
ClientID: a.ClientID, |
||||
ResponseTypes: a.ResponseTypes, |
||||
Scopes: a.Scopes, |
||||
RedirectURI: a.RedirectURI, |
||||
Nonce: a.Nonce, |
||||
State: a.State, |
||||
ForceApprovalPrompt: a.ForceApprovalPrompt, |
||||
LoggedIn: a.LoggedIn, |
||||
ConnectorID: a.ConnectorID, |
||||
ConnectorData: a.ConnectorData, |
||||
Expiry: a.Expiry, |
||||
Claims: toStorageClaims(a.Claims), |
||||
} |
||||
} |
||||
|
||||
// RefreshToken is a mirrored struct from storage with JSON struct tags
|
||||
type RefreshToken struct { |
||||
ID string `json:"id"` |
||||
|
||||
Token string `json:"token"` |
||||
|
||||
CreatedAt time.Time `json:"created_at"` |
||||
LastUsed time.Time `json:"last_used"` |
||||
|
||||
ClientID string `json:"client_id"` |
||||
|
||||
ConnectorID string `json:"connector_id"` |
||||
ConnectorData []byte `json:"connector_data"` |
||||
Claims Claims `json:"claims"` |
||||
|
||||
Scopes []string `json:"scopes"` |
||||
|
||||
Nonce string `json:"nonce"` |
||||
} |
||||
|
||||
func toStorageRefreshToken(r RefreshToken) storage.RefreshToken { |
||||
return storage.RefreshToken{ |
||||
ID: r.ID, |
||||
Token: r.Token, |
||||
CreatedAt: r.CreatedAt, |
||||
LastUsed: r.LastUsed, |
||||
ClientID: r.ClientID, |
||||
ConnectorID: r.ConnectorID, |
||||
ConnectorData: r.ConnectorData, |
||||
Scopes: r.Scopes, |
||||
Nonce: r.Nonce, |
||||
Claims: toStorageClaims(r.Claims), |
||||
} |
||||
} |
||||
|
||||
func fromStorageRefreshToken(r storage.RefreshToken) RefreshToken { |
||||
return RefreshToken{ |
||||
ID: r.ID, |
||||
Token: r.Token, |
||||
CreatedAt: r.CreatedAt, |
||||
LastUsed: r.LastUsed, |
||||
ClientID: r.ClientID, |
||||
ConnectorID: r.ConnectorID, |
||||
ConnectorData: r.ConnectorData, |
||||
Scopes: r.Scopes, |
||||
Nonce: r.Nonce, |
||||
Claims: fromStorageClaims(r.Claims), |
||||
} |
||||
} |
||||
|
||||
// Claims is a mirrored struct from storage with JSON struct tags.
|
||||
type Claims struct { |
||||
UserID string `json:"userID"` |
||||
Username string `json:"username"` |
||||
Email string `json:"email"` |
||||
EmailVerified bool `json:"emailVerified"` |
||||
Groups []string `json:"groups,omitempty"` |
||||
} |
||||
|
||||
func fromStorageClaims(i storage.Claims) Claims { |
||||
return Claims{ |
||||
UserID: i.UserID, |
||||
Username: i.Username, |
||||
Email: i.Email, |
||||
EmailVerified: i.EmailVerified, |
||||
Groups: i.Groups, |
||||
} |
||||
} |
||||
|
||||
func toStorageClaims(i Claims) storage.Claims { |
||||
return storage.Claims{ |
||||
UserID: i.UserID, |
||||
Username: i.Username, |
||||
Email: i.Email, |
||||
EmailVerified: i.EmailVerified, |
||||
Groups: i.Groups, |
||||
} |
||||
} |
||||
|
||||
// Keys is a mirrored struct from storage with JSON struct tags
|
||||
type Keys struct { |
||||
SigningKey *jose.JSONWebKey `json:"signing_key,omitempty"` |
||||
SigningKeyPub *jose.JSONWebKey `json:"signing_key_pub,omitempty"` |
||||
VerificationKeys []storage.VerificationKey `json:"verification_keys"` |
||||
NextRotation time.Time `json:"next_rotation"` |
||||
} |
||||
|
||||
func fromStorageKeys(keys storage.Keys) Keys { |
||||
return Keys{ |
||||
SigningKey: keys.SigningKey, |
||||
SigningKeyPub: keys.SigningKeyPub, |
||||
VerificationKeys: keys.VerificationKeys, |
||||
NextRotation: keys.NextRotation, |
||||
} |
||||
} |
||||
|
||||
func toStorageKeys(keys Keys) storage.Keys { |
||||
return storage.Keys{ |
||||
SigningKey: keys.SigningKey, |
||||
SigningKeyPub: keys.SigningKeyPub, |
||||
VerificationKeys: keys.VerificationKeys, |
||||
NextRotation: keys.NextRotation, |
||||
} |
||||
} |
||||
|
||||
// OfflineSessions is a mirrored struct from storage with JSON struct tags
|
||||
type OfflineSessions struct { |
||||
UserID string `json:"user_id,omitempty"` |
||||
ConnID string `json:"conn_id,omitempty"` |
||||
Refresh map[string]*storage.RefreshTokenRef `json:"refresh,omitempty"` |
||||
} |
||||
|
||||
func fromStorageOfflineSessions(o storage.OfflineSessions) OfflineSessions { |
||||
return OfflineSessions{ |
||||
UserID: o.UserID, |
||||
ConnID: o.ConnID, |
||||
Refresh: o.Refresh, |
||||
} |
||||
} |
||||
|
||||
func toStorageOfflineSessions(o OfflineSessions) storage.OfflineSessions { |
||||
s := storage.OfflineSessions{ |
||||
UserID: o.UserID, |
||||
ConnID: o.ConnID, |
||||
Refresh: o.Refresh, |
||||
} |
||||
if s.Refresh == nil { |
||||
// Server code assumes this will be non-nil.
|
||||
s.Refresh = make(map[string]*storage.RefreshTokenRef) |
||||
} |
||||
return s |
||||
} |
||||
Loading…
Reference in new issue