OpenID Connect (OIDC) identity and OAuth 2.0 provider with pluggable connectors
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

1268 lines
37 KiB

// Package conformance provides conformance tests for storage implementations.
package conformance
import (
"context"
"reflect"
"sort"
"testing"
"time"
jose "github.com/go-jose/go-jose/v4"
"github.com/kylelemons/godebug/pretty"
"github.com/stretchr/testify/require"
"golang.org/x/crypto/bcrypt"
"github.com/dexidp/dex/storage"
)
// ensure that values being tested on never expire.
var neverExpire = time.Now().UTC().Add(time.Hour * 24 * 365 * 100)
type subTest struct {
name string
run func(t *testing.T, s storage.Storage)
}
func runTests(t *testing.T, newStorage func(t *testing.T) storage.Storage, tests []subTest) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
s := newStorage(t)
test.run(t, s)
s.Close()
})
}
}
// RunTests runs a set of conformance tests against a storage. newStorage should
// return an initialized but empty storage. The storage will be closed at the
// end of each test run.
func RunTests(t *testing.T, newStorage func(t *testing.T) storage.Storage) {
runTests(t, newStorage, []subTest{
{"AuthCodeCRUD", testAuthCodeCRUD},
{"AuthRequestCRUD", testAuthRequestCRUD},
{"ClientCRUD", testClientCRUD},
{"RefreshTokenCRUD", testRefreshTokenCRUD},
{"PasswordCRUD", testPasswordCRUD},
{"KeysCRUD", testKeysCRUD},
{"OfflineSessionCRUD", testOfflineSessionCRUD},
{"ConnectorCRUD", testConnectorCRUD},
{"GarbageCollection", testGC},
{"TimezoneSupport", testTimezones},
{"DeviceRequestCRUD", testDeviceRequestCRUD},
{"DeviceTokenCRUD", testDeviceTokenCRUD},
{"UserIdentityCRUD", testUserIdentityCRUD},
{"AuthSessionCRUD", testAuthSessionCRUD},
})
}
func mustLoadJWK(b string) *jose.JSONWebKey {
var jwt jose.JSONWebKey
if err := jwt.UnmarshalJSON([]byte(b)); err != nil {
panic(err)
}
return &jwt
}
func mustBeErrNotFound(t *testing.T, kind string, err error) {
switch {
case err == nil:
t.Errorf("deleting nonexistent %s should return an error", kind)
case err != storage.ErrNotFound:
t.Errorf("deleting %s expected storage.ErrNotFound, got %v", kind, err)
}
}
func mustBeErrAlreadyExists(t *testing.T, kind string, err error) {
switch {
case err == nil:
t.Errorf("attempting to create an existing %s should return an error", kind)
case err != storage.ErrAlreadyExists:
t.Errorf("creating an existing %s expected storage.ErrAlreadyExists, got %v", kind, err)
}
}
func testAuthRequestCRUD(t *testing.T, s storage.Storage) {
ctx := t.Context()
codeChallenge := storage.PKCE{
CodeChallenge: "code_challenge_test",
CodeChallengeMethod: "plain",
}
a1 := storage.AuthRequest{
ID: storage.NewID(),
ClientID: "client1",
ResponseTypes: []string{"code"},
Scopes: []string{"openid", "email"},
RedirectURI: "https://localhost:80/callback",
Nonce: "foo",
State: "bar",
ForceApprovalPrompt: true,
LoggedIn: true,
Expiry: neverExpire,
ConnectorID: "ldap",
ConnectorData: []byte(`{"some":"data"}`),
Claims: storage.Claims{
UserID: "1",
Username: "jane",
Email: "jane.doe@example.com",
EmailVerified: true,
Groups: []string{"a", "b"},
},
PKCE: codeChallenge,
HMACKey: []byte("hmac_key"),
}
identity := storage.Claims{Email: "foobar"}
if err := s.CreateAuthRequest(ctx, a1); err != nil {
t.Fatalf("failed creating auth request: %v", err)
}
// Attempt to create same AuthRequest twice.
err := s.CreateAuthRequest(ctx, a1)
mustBeErrAlreadyExists(t, "auth request", err)
a2 := storage.AuthRequest{
ID: storage.NewID(),
ClientID: "client2",
ResponseTypes: []string{"code"},
Scopes: []string{"openid", "email"},
RedirectURI: "https://localhost:80/callback",
Nonce: "bar",
State: "foo",
ForceApprovalPrompt: true,
LoggedIn: true,
Expiry: neverExpire,
ConnectorID: "ldap",
ConnectorData: []byte(`{"some":"data"}`),
Claims: storage.Claims{
UserID: "2",
Username: "john",
Email: "john.doe@example.com",
EmailVerified: true,
Groups: []string{"a"},
},
HMACKey: []byte("hmac_key"),
}
if err := s.CreateAuthRequest(ctx, a2); err != nil {
t.Fatalf("failed creating auth request: %v", err)
}
if err := s.UpdateAuthRequest(ctx, a1.ID, func(old storage.AuthRequest) (storage.AuthRequest, error) {
old.Claims = identity
old.ConnectorID = "connID"
return old, nil
}); err != nil {
t.Fatalf("failed to update auth request: %v", err)
}
got, err := s.GetAuthRequest(ctx, a1.ID)
if err != nil {
t.Fatalf("failed to get auth req: %v", err)
}
if !reflect.DeepEqual(got.Claims, identity) {
t.Fatalf("update failed, wanted identity=%#v got %#v", identity, got.Claims)
}
if !reflect.DeepEqual(got.PKCE, codeChallenge) {
t.Fatalf("storage does not support PKCE, wanted challenge=%#v got %#v", codeChallenge, got.PKCE)
}
if err := s.DeleteAuthRequest(ctx, a1.ID); err != nil {
t.Fatalf("failed to delete auth request: %v", err)
}
if err := s.DeleteAuthRequest(ctx, a2.ID); err != nil {
t.Fatalf("failed to delete auth request: %v", err)
}
_, err = s.GetAuthRequest(ctx, a1.ID)
mustBeErrNotFound(t, "auth request", err)
}
func testAuthCodeCRUD(t *testing.T, s storage.Storage) {
ctx := t.Context()
a1 := storage.AuthCode{
ID: storage.NewID(),
ClientID: "client1",
RedirectURI: "https://localhost:80/callback",
Nonce: "foobar",
Scopes: []string{"openid", "email"},
Expiry: neverExpire,
ConnectorID: "ldap",
ConnectorData: []byte(`{"some":"data"}`),
PKCE: storage.PKCE{
CodeChallenge: "12345",
CodeChallengeMethod: "Whatever",
},
Claims: storage.Claims{
UserID: "1",
Username: "jane",
Email: "jane.doe@example.com",
EmailVerified: true,
Groups: []string{"a", "b"},
},
}
if err := s.CreateAuthCode(ctx, a1); err != nil {
t.Fatalf("failed creating auth code: %v", err)
}
a2 := storage.AuthCode{
ID: storage.NewID(),
ClientID: "client2",
RedirectURI: "https://localhost:80/callback",
Nonce: "foobar",
Scopes: []string{"openid", "email"},
Expiry: neverExpire,
ConnectorID: "ldap",
ConnectorData: []byte(`{"some":"data"}`),
Claims: storage.Claims{
UserID: "2",
Username: "john",
Email: "john.doe@example.com",
EmailVerified: true,
Groups: []string{"a"},
},
}
// Attempt to create same AuthCode twice.
err := s.CreateAuthCode(ctx, a1)
mustBeErrAlreadyExists(t, "auth code", err)
if err := s.CreateAuthCode(ctx, a2); err != nil {
t.Fatalf("failed creating auth code: %v", err)
}
got, err := s.GetAuthCode(ctx, a1.ID)
if err != nil {
t.Fatalf("failed to get auth code: %v", err)
}
if a1.Expiry.Unix() != got.Expiry.Unix() {
t.Errorf("auth code expiry did not match want=%s vs got=%s", a1.Expiry, got.Expiry)
}
got.Expiry = a1.Expiry // time fields do not compare well
if diff := pretty.Compare(a1, got); diff != "" {
t.Errorf("auth code retrieved from storage did not match: %s", diff)
}
if err := s.DeleteAuthCode(ctx, a1.ID); err != nil {
t.Fatalf("delete auth code: %v", err)
}
if err := s.DeleteAuthCode(ctx, a2.ID); err != nil {
t.Fatalf("delete auth code: %v", err)
}
_, err = s.GetAuthCode(ctx, a1.ID)
mustBeErrNotFound(t, "auth code", err)
}
func testClientCRUD(t *testing.T, s storage.Storage) {
ctx := t.Context()
id1 := storage.NewID()
c1 := storage.Client{
ID: id1,
Secret: "foobar",
RedirectURIs: []string{"foo://bar.com/", "https://auth.example.com"},
Name: "dex client",
LogoURL: "https://goo.gl/JIyzIC",
AllowedConnectors: []string{"github", "google"},
}
err := s.DeleteClient(ctx, id1)
mustBeErrNotFound(t, "client", err)
if err := s.CreateClient(ctx, c1); err != nil {
t.Fatalf("create client: %v", err)
}
// Attempt to create same Client twice.
err = s.CreateClient(ctx, c1)
mustBeErrAlreadyExists(t, "client", err)
id2 := storage.NewID()
c2 := storage.Client{
ID: id2,
Secret: "barfoo",
RedirectURIs: []string{"foo://bar.com/", "https://auth.example.com"},
Name: "dex client",
LogoURL: "https://goo.gl/JIyzIC",
}
if err := s.CreateClient(ctx, c2); err != nil {
t.Fatalf("create client: %v", err)
}
getAndCompare := func(_ string, want storage.Client) {
gc, err := s.GetClient(ctx, id1)
if err != nil {
t.Errorf("get client: %v", err)
return
}
if diff := pretty.Compare(want, gc); diff != "" {
t.Errorf("client retrieved from storage did not match: %s", diff)
}
}
getAndCompare(id1, c1)
newSecret := "barfoo"
err = s.UpdateClient(ctx, id1, func(old storage.Client) (storage.Client, error) {
old.Secret = newSecret
return old, nil
})
if err != nil {
t.Errorf("update client: %v", err)
}
c1.Secret = newSecret
getAndCompare(id1, c1)
if err := s.DeleteClient(ctx, id1); err != nil {
t.Fatalf("delete client: %v", err)
}
if err := s.DeleteClient(ctx, id2); err != nil {
t.Fatalf("delete client: %v", err)
}
_, err = s.GetClient(ctx, id1)
mustBeErrNotFound(t, "client", err)
}
func testRefreshTokenCRUD(t *testing.T, s storage.Storage) {
ctx := t.Context()
id := storage.NewID()
refresh := storage.RefreshToken{
ID: id,
Token: "bar",
ObsoleteToken: "",
Nonce: "foo",
ClientID: "client_id",
ConnectorID: "client_secret",
Scopes: []string{"openid", "email", "profile"},
CreatedAt: time.Now().UTC().Round(time.Millisecond),
LastUsed: time.Now().UTC().Round(time.Millisecond),
Claims: storage.Claims{
UserID: "1",
Username: "jane",
Email: "jane.doe@example.com",
EmailVerified: true,
Groups: []string{"a", "b"},
},
ConnectorData: []byte(`{"some":"data"}`),
}
if err := s.CreateRefresh(ctx, refresh); err != nil {
t.Fatalf("create refresh token: %v", err)
}
// Attempt to create same Refresh Token twice.
err := s.CreateRefresh(ctx, refresh)
mustBeErrAlreadyExists(t, "refresh token", err)
getAndCompare := func(id string, want storage.RefreshToken) {
gr, err := s.GetRefresh(ctx, id)
if err != nil {
t.Errorf("get refresh: %v", err)
return
}
if diff := pretty.Compare(gr.CreatedAt.UnixNano(), gr.CreatedAt.UnixNano()); diff != "" {
t.Errorf("refresh token created timestamp retrieved from storage did not match: %s", diff)
}
if diff := pretty.Compare(gr.LastUsed.UnixNano(), gr.LastUsed.UnixNano()); diff != "" {
t.Errorf("refresh token last used timestamp retrieved from storage did not match: %s", diff)
}
gr.CreatedAt = time.Time{}
gr.LastUsed = time.Time{}
want.CreatedAt = time.Time{}
want.LastUsed = time.Time{}
if diff := pretty.Compare(want, gr); diff != "" {
t.Errorf("refresh token retrieved from storage did not match: %s", diff)
}
}
getAndCompare(id, refresh)
id2 := storage.NewID()
refresh2 := storage.RefreshToken{
ID: id2,
Token: "bar_2",
ObsoleteToken: refresh.Token,
Nonce: "foo_2",
ClientID: "client_id_2",
ConnectorID: "client_secret",
Scopes: []string{"openid", "email", "profile"},
CreatedAt: time.Now().UTC().Round(time.Millisecond),
LastUsed: time.Now().UTC().Round(time.Millisecond),
Claims: storage.Claims{
UserID: "2",
Username: "john",
Email: "john.doe@example.com",
EmailVerified: true,
Groups: []string{"a", "b"},
},
ConnectorData: []byte(`{"some":"data"}`),
}
if err := s.CreateRefresh(ctx, refresh2); err != nil {
t.Fatalf("create second refresh token: %v", err)
}
getAndCompare(id2, refresh2)
updatedAt := time.Now().UTC().Round(time.Millisecond)
updater := func(r storage.RefreshToken) (storage.RefreshToken, error) {
r.Token = "spam"
r.LastUsed = updatedAt
return r, nil
}
if err := s.UpdateRefreshToken(ctx, id, updater); err != nil {
t.Errorf("failed to update refresh token: %v", err)
}
refresh.Token = "spam"
refresh.LastUsed = updatedAt
getAndCompare(id, refresh)
// Ensure that updating the first token doesn't impact the second. Issue #847.
getAndCompare(id2, refresh2)
if err := s.DeleteRefresh(ctx, id); err != nil {
t.Fatalf("failed to delete refresh request: %v", err)
}
if err := s.DeleteRefresh(ctx, id2); err != nil {
t.Fatalf("failed to delete refresh request: %v", err)
}
_, err = s.GetRefresh(ctx, id)
mustBeErrNotFound(t, "refresh token", err)
}
type byEmail []storage.Password
func (n byEmail) Len() int { return len(n) }
func (n byEmail) Less(i, j int) bool { return n[i].Email < n[j].Email }
func (n byEmail) Swap(i, j int) { n[i], n[j] = n[j], n[i] }
func boolPtr(v bool) *bool {
return &v
}
func testPasswordCRUD(t *testing.T, s storage.Storage) {
ctx := t.Context()
// Use bcrypt.MinCost to keep the tests short.
passwordHash1, err := bcrypt.GenerateFromPassword([]byte("secret"), bcrypt.MinCost)
if err != nil {
t.Fatal(err)
}
password1 := storage.Password{
Email: "jane@example.com",
Hash: passwordHash1,
Username: "jane",
Name: "Jane Doe",
PreferredUsername: "jane-public",
EmailVerified: boolPtr(true),
UserID: "foobar",
Groups: []string{"team-a", "team-a/admins"},
}
if err := s.CreatePassword(ctx, password1); err != nil {
t.Fatalf("create password token: %v", err)
}
// Attempt to create same Password twice.
err = s.CreatePassword(ctx, password1)
mustBeErrAlreadyExists(t, "password", err)
passwordHash2, err := bcrypt.GenerateFromPassword([]byte("password"), bcrypt.MinCost)
if err != nil {
t.Fatal(err)
}
password2 := storage.Password{
Email: "john@example.com",
Hash: passwordHash2,
Username: "john",
Name: "John Smith",
PreferredUsername: "john-public",
EmailVerified: boolPtr(false),
UserID: "barfoo",
Groups: []string{"team-b"},
}
if err := s.CreatePassword(ctx, password2); err != nil {
t.Fatalf("create password token: %v", err)
}
getAndCompare := func(id string, want storage.Password) {
gr, err := s.GetPassword(ctx, id)
if err != nil {
t.Errorf("get password %q: %v", id, err)
return
}
if diff := pretty.Compare(want, gr); diff != "" {
t.Errorf("password retrieved from storage did not match: %s", diff)
}
}
getAndCompare("jane@example.com", password1)
getAndCompare("JANE@example.com", password1) // Emails should be case insensitive
if err := s.UpdatePassword(ctx, password1.Email, func(old storage.Password) (storage.Password, error) {
old.Username = "jane doe"
return old, nil
}); err != nil {
t.Fatalf("failed to update auth request: %v", err)
}
password1.Username = "jane doe"
getAndCompare("jane@example.com", password1)
var passwordList []storage.Password
passwordList = append(passwordList, password1, password2)
listAndCompare := func(want []storage.Password) {
passwords, err := s.ListPasswords(ctx)
if err != nil {
t.Errorf("list password: %v", err)
return
}
sort.Sort(byEmail(want))
sort.Sort(byEmail(passwords))
if diff := pretty.Compare(want, passwords); diff != "" {
t.Errorf("password list retrieved from storage did not match: %s", diff)
}
}
listAndCompare(passwordList)
if err := s.DeletePassword(ctx, password1.Email); err != nil {
t.Fatalf("failed to delete password: %v", err)
}
if err := s.DeletePassword(ctx, password2.Email); err != nil {
t.Fatalf("failed to delete password: %v", err)
}
_, err = s.GetPassword(ctx, password1.Email)
mustBeErrNotFound(t, "password", err)
}
func testOfflineSessionCRUD(t *testing.T, s storage.Storage) {
ctx := t.Context()
userID1 := storage.NewID()
session1 := storage.OfflineSessions{
UserID: userID1,
ConnID: "Conn1",
Refresh: make(map[string]*storage.RefreshTokenRef),
ConnectorData: []byte(`{"some":"data"}`),
}
// Creating an OfflineSession with an empty Refresh list to ensure that
// an empty map is translated as expected by the storage.
if err := s.CreateOfflineSessions(ctx, session1); err != nil {
t.Fatalf("create offline session with UserID = %s: %v", session1.UserID, err)
}
// Attempt to create same OfflineSession twice.
err := s.CreateOfflineSessions(ctx, session1)
mustBeErrAlreadyExists(t, "offline session", err)
userID2 := storage.NewID()
session2 := storage.OfflineSessions{
UserID: userID2,
ConnID: "Conn2",
Refresh: make(map[string]*storage.RefreshTokenRef),
ConnectorData: []byte(`{"some":"data"}`),
}
if err := s.CreateOfflineSessions(ctx, session2); err != nil {
t.Fatalf("create offline session with UserID = %s: %v", session2.UserID, err)
}
getAndCompare := func(userID string, connID string, want storage.OfflineSessions) {
gr, err := s.GetOfflineSessions(ctx, userID, connID)
if err != nil {
t.Errorf("get offline session: %v", err)
return
}
if diff := pretty.Compare(want, gr); diff != "" {
t.Errorf("offline session retrieved from storage did not match: %s", diff)
}
}
getAndCompare(userID1, "Conn1", session1)
id := storage.NewID()
tokenRef := storage.RefreshTokenRef{
ID: id,
ClientID: "client_id",
CreatedAt: time.Now().UTC().Round(time.Millisecond),
LastUsed: time.Now().UTC().Round(time.Millisecond),
}
session1.Refresh[tokenRef.ClientID] = &tokenRef
if err := s.UpdateOfflineSessions(ctx, session1.UserID, session1.ConnID, func(old storage.OfflineSessions) (storage.OfflineSessions, error) {
old.Refresh[tokenRef.ClientID] = &tokenRef
return old, nil
}); err != nil {
t.Fatalf("failed to update offline session: %v", err)
}
getAndCompare(userID1, "Conn1", session1)
if err := s.DeleteOfflineSessions(ctx, session1.UserID, session1.ConnID); err != nil {
t.Fatalf("failed to delete offline session: %v", err)
}
if err := s.DeleteOfflineSessions(ctx, session2.UserID, session2.ConnID); err != nil {
t.Fatalf("failed to delete offline session: %v", err)
}
_, err = s.GetOfflineSessions(ctx, session1.UserID, session1.ConnID)
mustBeErrNotFound(t, "offline session", err)
}
func testConnectorCRUD(t *testing.T, s storage.Storage) {
ctx := t.Context()
id1 := storage.NewID()
config1 := []byte(`{"issuer": "https://accounts.google.com"}`)
c1 := storage.Connector{
ID: id1,
Type: "Default",
Name: "Default",
Config: config1,
GrantTypes: []string{"authorization_code", "refresh_token"},
}
if err := s.CreateConnector(ctx, c1); err != nil {
t.Fatalf("create connector with ID = %s: %v", c1.ID, err)
}
// Attempt to create same Connector twice.
err := s.CreateConnector(ctx, c1)
mustBeErrAlreadyExists(t, "connector", err)
id2 := storage.NewID()
config2 := []byte(`{"redirectURI": "http://127.0.0.1:5556/dex/callback"}`)
c2 := storage.Connector{
ID: id2,
Type: "Mock",
Name: "Mock",
Config: config2,
}
if err := s.CreateConnector(ctx, c2); err != nil {
t.Fatalf("create connector with ID = %s: %v", c2.ID, err)
}
getAndCompare := func(id string, want storage.Connector) {
gr, err := s.GetConnector(ctx, id)
if err != nil {
t.Errorf("get connector: %v", err)
return
}
// ignore resource version comparison
gr.ResourceVersion = ""
if diff := pretty.Compare(want, gr); diff != "" {
t.Errorf("connector retrieved from storage did not match: %s", diff)
}
}
getAndCompare(id1, c1)
if err := s.UpdateConnector(ctx, c1.ID, func(old storage.Connector) (storage.Connector, error) {
old.Type = "oidc"
old.GrantTypes = []string{"urn:ietf:params:oauth:grant-type:token-exchange"}
return old, nil
}); err != nil {
t.Fatalf("failed to update Connector: %v", err)
}
c1.Type = "oidc"
c1.GrantTypes = []string{"urn:ietf:params:oauth:grant-type:token-exchange"}
getAndCompare(id1, c1)
connectorList := []storage.Connector{c1, c2}
listAndCompare := func(want []storage.Connector) {
connectors, err := s.ListConnectors(ctx)
if err != nil {
t.Errorf("list connectors: %v", err)
return
}
// ignore resource version comparison
for i := range connectors {
connectors[i].ResourceVersion = ""
}
sort.Slice(connectors, func(i, j int) bool {
return connectors[i].Name < connectors[j].Name
})
if diff := pretty.Compare(want, connectors); diff != "" {
t.Errorf("connector list retrieved from storage did not match: %s", diff)
}
}
listAndCompare(connectorList)
if err := s.DeleteConnector(ctx, c1.ID); err != nil {
t.Fatalf("failed to delete connector: %v", err)
}
if err := s.DeleteConnector(ctx, c2.ID); err != nil {
t.Fatalf("failed to delete connector: %v", err)
}
_, err = s.GetConnector(ctx, c1.ID)
mustBeErrNotFound(t, "connector", err)
}
func testKeysCRUD(t *testing.T, s storage.Storage) {
ctx := context.TODO()
updateAndCompare := func(k storage.Keys) {
err := s.UpdateKeys(ctx, func(oldKeys storage.Keys) (storage.Keys, error) {
return k, nil
})
if err != nil {
t.Errorf("failed to update keys: %v", err)
return
}
if got, err := s.GetKeys(ctx); err != nil {
t.Errorf("failed to get keys: %v", err)
} else {
got.NextRotation = got.NextRotation.UTC()
if diff := pretty.Compare(k, got); diff != "" {
t.Errorf("got keys did not equal expected: %s", diff)
}
}
}
// Postgres isn't as accurate with nano seconds as we'd like
n := time.Now().UTC().Round(time.Second)
keys1 := storage.Keys{
SigningKey: jsonWebKeys[0].Private,
SigningKeyPub: jsonWebKeys[0].Public,
NextRotation: n,
}
keys2 := storage.Keys{
SigningKey: jsonWebKeys[2].Private,
SigningKeyPub: jsonWebKeys[2].Public,
NextRotation: n.Add(time.Hour),
VerificationKeys: []storage.VerificationKey{
{
PublicKey: jsonWebKeys[0].Public,
Expiry: n.Add(time.Hour),
},
{
PublicKey: jsonWebKeys[1].Public,
Expiry: n.Add(time.Hour * 2),
},
},
}
updateAndCompare(keys1)
updateAndCompare(keys2)
}
func testGC(t *testing.T, s storage.Storage) {
ctx := t.Context()
est, err := time.LoadLocation("America/New_York")
if err != nil {
t.Fatal(err)
}
pst, err := time.LoadLocation("America/Los_Angeles")
if err != nil {
t.Fatal(err)
}
expiry := time.Now().In(est)
c := storage.AuthCode{
ID: storage.NewID(),
ClientID: "foobar",
RedirectURI: "https://localhost:80/callback",
Nonce: "foobar",
Scopes: []string{"openid", "email"},
Expiry: expiry,
ConnectorID: "ldap",
ConnectorData: []byte(`{"some":"data"}`),
Claims: storage.Claims{
UserID: "1",
Username: "jane",
Email: "jane.doe@example.com",
EmailVerified: true,
Groups: []string{"a", "b"},
},
}
if err := s.CreateAuthCode(ctx, c); err != nil {
t.Fatalf("failed creating auth code: %v", err)
}
for _, tz := range []*time.Location{time.UTC, est, pst} {
result, err := s.GarbageCollect(ctx, expiry.Add(-time.Hour).In(tz))
if err != nil {
t.Errorf("garbage collection failed: %v", err)
} else if result.AuthCodes != 0 || result.AuthRequests != 0 {
t.Errorf("expected no garbage collection results, got %#v", result)
}
if _, err := s.GetAuthCode(ctx, c.ID); err != nil {
t.Errorf("expected to be able to get auth code after GC: %v", err)
}
}
if r, err := s.GarbageCollect(ctx, expiry.Add(time.Hour)); err != nil {
t.Errorf("garbage collection failed: %v", err)
} else if r.AuthCodes != 1 {
t.Errorf("expected to garbage collect 1 objects, got %d", r.AuthCodes)
}
if _, err := s.GetAuthCode(ctx, c.ID); err == nil {
t.Errorf("expected auth code to be GC'd")
} else if err != storage.ErrNotFound {
t.Errorf("expected storage.ErrNotFound, got %v", err)
}
a := storage.AuthRequest{
ID: storage.NewID(),
ClientID: "foobar",
ResponseTypes: []string{"code"},
Scopes: []string{"openid", "email"},
RedirectURI: "https://localhost:80/callback",
Nonce: "foo",
State: "bar",
ForceApprovalPrompt: true,
LoggedIn: true,
Expiry: expiry,
ConnectorID: "ldap",
ConnectorData: []byte(`{"some":"data"}`),
Claims: storage.Claims{
UserID: "1",
Username: "jane",
Email: "jane.doe@example.com",
EmailVerified: true,
Groups: []string{"a", "b"},
},
HMACKey: []byte("hmac_key"),
}
if err := s.CreateAuthRequest(ctx, a); err != nil {
t.Fatalf("failed creating auth request: %v", err)
}
for _, tz := range []*time.Location{time.UTC, est, pst} {
result, err := s.GarbageCollect(ctx, expiry.Add(-time.Hour).In(tz))
if err != nil {
t.Errorf("garbage collection failed: %v", err)
} else if result.AuthCodes != 0 || result.AuthRequests != 0 {
t.Errorf("expected no garbage collection results, got %#v", result)
}
if _, err := s.GetAuthRequest(ctx, a.ID); err != nil {
t.Errorf("expected to be able to get auth request after GC: %v", err)
}
}
if r, err := s.GarbageCollect(ctx, expiry.Add(time.Hour)); err != nil {
t.Errorf("garbage collection failed: %v", err)
} else if r.AuthRequests != 1 {
t.Errorf("expected to garbage collect 1 objects, got %d", r.AuthRequests)
}
if _, err := s.GetAuthRequest(ctx, a.ID); err == nil {
t.Errorf("expected auth request to be GC'd")
} else if err != storage.ErrNotFound {
t.Errorf("expected storage.ErrNotFound, got %v", err)
}
d := storage.DeviceRequest{
UserCode: storage.NewUserCode(),
DeviceCode: storage.NewID(),
ClientID: "client1",
ClientSecret: "secret1",
Scopes: []string{"openid", "email"},
Expiry: expiry,
}
if err := s.CreateDeviceRequest(ctx, d); err != nil {
t.Fatalf("failed creating device request: %v", err)
}
for _, tz := range []*time.Location{time.UTC, est, pst} {
result, err := s.GarbageCollect(ctx, expiry.Add(-time.Hour).In(tz))
if err != nil {
t.Errorf("garbage collection failed: %v", err)
} else if result.DeviceRequests != 0 {
t.Errorf("expected no device garbage collection results, got %#v", result)
}
if _, err := s.GetDeviceRequest(ctx, d.UserCode); err != nil {
t.Errorf("expected to be able to get auth request after GC: %v", err)
}
}
if r, err := s.GarbageCollect(ctx, expiry.Add(time.Hour)); err != nil {
t.Errorf("garbage collection failed: %v", err)
} else if r.DeviceRequests != 1 {
t.Errorf("expected to garbage collect 1 device request, got %d", r.DeviceRequests)
}
if _, err := s.GetDeviceRequest(ctx, d.UserCode); err == nil {
t.Errorf("expected device request to be GC'd")
} else if err != storage.ErrNotFound {
t.Errorf("expected storage.ErrNotFound, got %v", err)
}
dt := storage.DeviceToken{
DeviceCode: storage.NewID(),
Status: "pending",
Token: "foo",
Expiry: expiry,
LastRequestTime: time.Now(),
PollIntervalSeconds: 0,
PKCE: storage.PKCE{
CodeChallenge: "challenge",
CodeChallengeMethod: "S256",
},
}
if err := s.CreateDeviceToken(ctx, dt); err != nil {
t.Fatalf("failed creating device token: %v", err)
}
for _, tz := range []*time.Location{time.UTC, est, pst} {
result, err := s.GarbageCollect(ctx, expiry.Add(-time.Hour).In(tz))
if err != nil {
t.Errorf("garbage collection failed: %v", err)
} else if result.DeviceTokens != 0 {
t.Errorf("expected no device token garbage collection results, got %#v", result)
}
if _, err := s.GetDeviceToken(ctx, dt.DeviceCode); err != nil {
t.Errorf("expected to be able to get device token after GC: %v", err)
}
}
if r, err := s.GarbageCollect(ctx, expiry.Add(time.Hour)); err != nil {
t.Errorf("garbage collection failed: %v", err)
} else if r.DeviceTokens != 1 {
t.Errorf("expected to garbage collect 1 device token, got %d", r.DeviceTokens)
}
if _, err := s.GetDeviceToken(ctx, dt.DeviceCode); err == nil {
t.Errorf("expected device token to be GC'd")
} else if err != storage.ErrNotFound {
t.Errorf("expected storage.ErrNotFound, got %v", err)
}
}
// testTimezones tests that backends either fully support timezones or
// do the correct standardization.
func testTimezones(t *testing.T, s storage.Storage) {
ctx := t.Context()
est, err := time.LoadLocation("America/New_York")
if err != nil {
t.Fatal(err)
}
// Create an expiry with timezone info. Only expect backends to be
// accurate to the millisecond
expiry := time.Now().In(est).Round(time.Millisecond)
c := storage.AuthCode{
ID: storage.NewID(),
ClientID: "foobar",
RedirectURI: "https://localhost:80/callback",
Nonce: "foobar",
Scopes: []string{"openid", "email"},
Expiry: expiry,
ConnectorID: "ldap",
ConnectorData: []byte(`{"some":"data"}`),
Claims: storage.Claims{
UserID: "1",
Username: "jane",
Email: "jane.doe@example.com",
EmailVerified: true,
Groups: []string{"a", "b"},
},
}
if err := s.CreateAuthCode(ctx, c); err != nil {
t.Fatalf("failed creating auth code: %v", err)
}
got, err := s.GetAuthCode(ctx, c.ID)
if err != nil {
t.Fatalf("failed to get auth code: %v", err)
}
// Ensure that if the resulting time is converted to the same
// timezone, it's the same value. We DO NOT expect timezones
// to be preserved.
gotTime := got.Expiry.In(est)
wantTime := expiry
if !gotTime.Equal(wantTime) {
t.Fatalf("expected expiry %v got %v", wantTime, gotTime)
}
}
func testDeviceRequestCRUD(t *testing.T, s storage.Storage) {
ctx := t.Context()
d1 := storage.DeviceRequest{
UserCode: storage.NewUserCode(),
DeviceCode: storage.NewID(),
ClientID: "client1",
ClientSecret: "secret1",
Scopes: []string{"openid", "email"},
Expiry: neverExpire.Round(time.Second),
}
if err := s.CreateDeviceRequest(ctx, d1); err != nil {
t.Fatalf("failed creating device request: %v", err)
}
// Attempt to create same DeviceRequest twice.
err := s.CreateDeviceRequest(ctx, d1)
mustBeErrAlreadyExists(t, "device request", err)
got, err := s.GetDeviceRequest(ctx, d1.UserCode)
if err != nil {
t.Fatalf("failed to get device request: %v", err)
}
require.Equal(t, d1, got)
// No manual deletes for device requests, will be handled by garbage collection routines
// see testGC
}
func testDeviceTokenCRUD(t *testing.T, s storage.Storage) {
ctx := t.Context()
codeChallenge := storage.PKCE{
CodeChallenge: "code_challenge_test",
CodeChallengeMethod: "plain",
}
// Create a Token
d1 := storage.DeviceToken{
DeviceCode: storage.NewID(),
Status: "pending",
Token: storage.NewID(),
Expiry: neverExpire,
LastRequestTime: time.Now(),
PollIntervalSeconds: 0,
PKCE: codeChallenge,
}
if err := s.CreateDeviceToken(ctx, d1); err != nil {
t.Fatalf("failed creating device token: %v", err)
}
// Attempt to create same Device Token twice.
err := s.CreateDeviceToken(ctx, d1)
mustBeErrAlreadyExists(t, "device token", err)
// Update the device token, simulate a redemption
if err := s.UpdateDeviceToken(ctx, d1.DeviceCode, func(old storage.DeviceToken) (storage.DeviceToken, error) {
old.Token = "token data"
old.Status = "complete"
return old, nil
}); err != nil {
t.Fatalf("failed to update device token: %v", err)
}
// Retrieve the device token
got, err := s.GetDeviceToken(ctx, d1.DeviceCode)
if err != nil {
t.Fatalf("failed to get device token: %v", err)
}
// Validate expected result set
if got.Status != "complete" {
t.Fatalf("update failed, wanted token status=%v got %v", "complete", got.Status)
}
if got.Token != "token data" {
t.Fatalf("update failed, wanted token %v got %v", "token data", got.Token)
}
if !reflect.DeepEqual(got.PKCE, codeChallenge) {
t.Fatalf("storage does not support PKCE, wanted challenge=%#v got %#v", codeChallenge, got.PKCE)
}
}
func testUserIdentityCRUD(t *testing.T, s storage.Storage) {
ctx := t.Context()
now := time.Now().UTC().Round(time.Millisecond)
u1 := storage.UserIdentity{
UserID: "user1",
ConnectorID: "conn1",
Claims: storage.Claims{
UserID: "user1",
Username: "jane",
Email: "jane@example.com",
EmailVerified: true,
Groups: []string{"a", "b"},
},
Consents: make(map[string][]string),
CreatedAt: now,
LastLogin: now,
BlockedUntil: time.Unix(0, 0).UTC(),
}
// Create with empty Consents map.
if err := s.CreateUserIdentity(ctx, u1); err != nil {
t.Fatalf("create user identity: %v", err)
}
// Duplicate create should return ErrAlreadyExists.
err := s.CreateUserIdentity(ctx, u1)
mustBeErrAlreadyExists(t, "user identity", err)
// Get and compare.
got, err := s.GetUserIdentity(ctx, u1.UserID, u1.ConnectorID)
if err != nil {
t.Fatalf("get user identity: %v", err)
}
got.CreatedAt = got.CreatedAt.UTC().Round(time.Millisecond)
got.LastLogin = got.LastLogin.UTC().Round(time.Millisecond)
got.BlockedUntil = got.BlockedUntil.UTC().Round(time.Millisecond)
u1.BlockedUntil = u1.BlockedUntil.UTC().Round(time.Millisecond)
if diff := pretty.Compare(u1, got); diff != "" {
t.Errorf("user identity retrieved from storage did not match: %s", diff)
}
// Update: add consent entry.
if err := s.UpdateUserIdentity(ctx, u1.UserID, u1.ConnectorID, func(old storage.UserIdentity) (storage.UserIdentity, error) {
old.Consents["client1"] = []string{"openid", "email"}
return old, nil
}); err != nil {
t.Fatalf("update user identity: %v", err)
}
// Get and verify updated consents.
got, err = s.GetUserIdentity(ctx, u1.UserID, u1.ConnectorID)
if err != nil {
t.Fatalf("get user identity after update: %v", err)
}
wantConsents := map[string][]string{"client1": {"openid", "email"}}
if diff := pretty.Compare(wantConsents, got.Consents); diff != "" {
t.Errorf("user identity consents did not match after update: %s", diff)
}
// List and verify.
identities, err := s.ListUserIdentities(ctx)
if err != nil {
t.Fatalf("list user identities: %v", err)
}
if len(identities) != 1 {
t.Fatalf("expected 1 user identity, got %d", len(identities))
}
// Delete.
if err := s.DeleteUserIdentity(ctx, u1.UserID, u1.ConnectorID); err != nil {
t.Fatalf("delete user identity: %v", err)
}
// Get deleted should return ErrNotFound.
_, err = s.GetUserIdentity(ctx, u1.UserID, u1.ConnectorID)
mustBeErrNotFound(t, "user identity", err)
}
func testAuthSessionCRUD(t *testing.T, s storage.Storage) {
ctx := t.Context()
now := time.Now().UTC().Round(time.Millisecond)
session := storage.AuthSession{
ID: storage.NewID(),
ClientStates: map[string]*storage.ClientAuthState{
"client1": {
UserID: "user1",
ConnectorID: "conn1",
Active: true,
ExpiresAt: now.Add(24 * time.Hour),
LastActivity: now,
LastTokenIssuedAt: now,
},
},
CreatedAt: now,
LastActivity: now,
IPAddress: "192.168.1.1",
UserAgent: "TestBrowser/1.0",
}
// Create.
if err := s.CreateAuthSession(ctx, session); err != nil {
t.Fatalf("create auth session: %v", err)
}
// Duplicate create should return ErrAlreadyExists.
err := s.CreateAuthSession(ctx, session)
mustBeErrAlreadyExists(t, "auth session", err)
// Get and compare.
got, err := s.GetAuthSession(ctx, session.ID)
if err != nil {
t.Fatalf("get auth session: %v", err)
}
got.CreatedAt = got.CreatedAt.UTC().Round(time.Millisecond)
got.LastActivity = got.LastActivity.UTC().Round(time.Millisecond)
for _, cs := range got.ClientStates {
cs.ExpiresAt = cs.ExpiresAt.UTC().Round(time.Millisecond)
cs.LastActivity = cs.LastActivity.UTC().Round(time.Millisecond)
cs.LastTokenIssuedAt = cs.LastTokenIssuedAt.UTC().Round(time.Millisecond)
}
if diff := pretty.Compare(session, got); diff != "" {
t.Errorf("auth session retrieved from storage did not match: %s", diff)
}
// Update: add a new client state.
newNow := now.Add(time.Minute)
if err := s.UpdateAuthSession(ctx, session.ID, func(old storage.AuthSession) (storage.AuthSession, error) {
old.ClientStates["client2"] = &storage.ClientAuthState{
UserID: "user2",
ConnectorID: "conn2",
Active: true,
ExpiresAt: newNow.Add(24 * time.Hour),
LastActivity: newNow,
}
old.LastActivity = newNow
return old, nil
}); err != nil {
t.Fatalf("update auth session: %v", err)
}
// Get and verify update.
got, err = s.GetAuthSession(ctx, session.ID)
if err != nil {
t.Fatalf("get auth session after update: %v", err)
}
if len(got.ClientStates) != 2 {
t.Fatalf("expected 2 client states, got %d", len(got.ClientStates))
}
if got.ClientStates["client2"] == nil {
t.Fatal("expected client2 state to exist")
}
if got.ClientStates["client2"].UserID != "user2" {
t.Errorf("expected client2 user_id to be user2, got %s", got.ClientStates["client2"].UserID)
}
// List and verify.
sessions, err := s.ListAuthSessions(ctx)
if err != nil {
t.Fatalf("list auth sessions: %v", err)
}
if len(sessions) != 1 {
t.Fatalf("expected 1 auth session, got %d", len(sessions))
}
// Delete.
if err := s.DeleteAuthSession(ctx, session.ID); err != nil {
t.Fatalf("delete auth session: %v", err)
}
// Get deleted should return ErrNotFound.
_, err = s.GetAuthSession(ctx, session.ID)
mustBeErrNotFound(t, "auth session", err)
}