// Package conformance provides conformance tests for storage implementations. package conformance import ( "context" "reflect" "sort" "testing" "time" jose "github.com/go-jose/go-jose/v4" "github.com/kylelemons/godebug/pretty" "github.com/stretchr/testify/require" "golang.org/x/crypto/bcrypt" "github.com/dexidp/dex/storage" ) // ensure that values being tested on never expire. var neverExpire = time.Now().UTC().Add(time.Hour * 24 * 365 * 100) type subTest struct { name string run func(t *testing.T, s storage.Storage) } func runTests(t *testing.T, newStorage func(t *testing.T) storage.Storage, tests []subTest) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { s := newStorage(t) test.run(t, s) s.Close() }) } } // RunTests runs a set of conformance tests against a storage. newStorage should // return an initialized but empty storage. The storage will be closed at the // end of each test run. func RunTests(t *testing.T, newStorage func(t *testing.T) storage.Storage) { runTests(t, newStorage, []subTest{ {"AuthCodeCRUD", testAuthCodeCRUD}, {"AuthRequestCRUD", testAuthRequestCRUD}, {"ClientCRUD", testClientCRUD}, {"RefreshTokenCRUD", testRefreshTokenCRUD}, {"PasswordCRUD", testPasswordCRUD}, {"KeysCRUD", testKeysCRUD}, {"OfflineSessionCRUD", testOfflineSessionCRUD}, {"ConnectorCRUD", testConnectorCRUD}, {"GarbageCollection", testGC}, {"TimezoneSupport", testTimezones}, {"DeviceRequestCRUD", testDeviceRequestCRUD}, {"DeviceTokenCRUD", testDeviceTokenCRUD}, {"UserIdentityCRUD", testUserIdentityCRUD}, {"AuthSessionCRUD", testAuthSessionCRUD}, }) } func mustLoadJWK(b string) *jose.JSONWebKey { var jwt jose.JSONWebKey if err := jwt.UnmarshalJSON([]byte(b)); err != nil { panic(err) } return &jwt } func mustBeErrNotFound(t *testing.T, kind string, err error) { switch { case err == nil: t.Errorf("deleting nonexistent %s should return an error", kind) case err != storage.ErrNotFound: t.Errorf("deleting %s expected storage.ErrNotFound, got %v", kind, err) } } func mustBeErrAlreadyExists(t *testing.T, kind string, err error) { switch { case err == nil: t.Errorf("attempting to create an existing %s should return an error", kind) case err != storage.ErrAlreadyExists: t.Errorf("creating an existing %s expected storage.ErrAlreadyExists, got %v", kind, err) } } func testAuthRequestCRUD(t *testing.T, s storage.Storage) { ctx := t.Context() codeChallenge := storage.PKCE{ CodeChallenge: "code_challenge_test", CodeChallengeMethod: "plain", } a1 := storage.AuthRequest{ ID: storage.NewID(), ClientID: "client1", ResponseTypes: []string{"code"}, Scopes: []string{"openid", "email"}, RedirectURI: "https://localhost:80/callback", Nonce: "foo", State: "bar", ForceApprovalPrompt: true, LoggedIn: true, Expiry: neverExpire, ConnectorID: "ldap", ConnectorData: []byte(`{"some":"data"}`), Claims: storage.Claims{ UserID: "1", Username: "jane", Email: "jane.doe@example.com", EmailVerified: true, Groups: []string{"a", "b"}, }, PKCE: codeChallenge, HMACKey: []byte("hmac_key"), } identity := storage.Claims{Email: "foobar"} if err := s.CreateAuthRequest(ctx, a1); err != nil { t.Fatalf("failed creating auth request: %v", err) } // Attempt to create same AuthRequest twice. err := s.CreateAuthRequest(ctx, a1) mustBeErrAlreadyExists(t, "auth request", err) a2 := storage.AuthRequest{ ID: storage.NewID(), ClientID: "client2", ResponseTypes: []string{"code"}, Scopes: []string{"openid", "email"}, RedirectURI: "https://localhost:80/callback", Nonce: "bar", State: "foo", ForceApprovalPrompt: true, LoggedIn: true, Expiry: neverExpire, ConnectorID: "ldap", ConnectorData: []byte(`{"some":"data"}`), Claims: storage.Claims{ UserID: "2", Username: "john", Email: "john.doe@example.com", EmailVerified: true, Groups: []string{"a"}, }, HMACKey: []byte("hmac_key"), } if err := s.CreateAuthRequest(ctx, a2); err != nil { t.Fatalf("failed creating auth request: %v", err) } if err := s.UpdateAuthRequest(ctx, a1.ID, func(old storage.AuthRequest) (storage.AuthRequest, error) { old.Claims = identity old.ConnectorID = "connID" return old, nil }); err != nil { t.Fatalf("failed to update auth request: %v", err) } got, err := s.GetAuthRequest(ctx, a1.ID) if err != nil { t.Fatalf("failed to get auth req: %v", err) } if !reflect.DeepEqual(got.Claims, identity) { t.Fatalf("update failed, wanted identity=%#v got %#v", identity, got.Claims) } if !reflect.DeepEqual(got.PKCE, codeChallenge) { t.Fatalf("storage does not support PKCE, wanted challenge=%#v got %#v", codeChallenge, got.PKCE) } if err := s.DeleteAuthRequest(ctx, a1.ID); err != nil { t.Fatalf("failed to delete auth request: %v", err) } if err := s.DeleteAuthRequest(ctx, a2.ID); err != nil { t.Fatalf("failed to delete auth request: %v", err) } _, err = s.GetAuthRequest(ctx, a1.ID) mustBeErrNotFound(t, "auth request", err) } func testAuthCodeCRUD(t *testing.T, s storage.Storage) { ctx := t.Context() a1 := storage.AuthCode{ ID: storage.NewID(), ClientID: "client1", RedirectURI: "https://localhost:80/callback", Nonce: "foobar", Scopes: []string{"openid", "email"}, Expiry: neverExpire, ConnectorID: "ldap", ConnectorData: []byte(`{"some":"data"}`), PKCE: storage.PKCE{ CodeChallenge: "12345", CodeChallengeMethod: "Whatever", }, Claims: storage.Claims{ UserID: "1", Username: "jane", Email: "jane.doe@example.com", EmailVerified: true, Groups: []string{"a", "b"}, }, } if err := s.CreateAuthCode(ctx, a1); err != nil { t.Fatalf("failed creating auth code: %v", err) } a2 := storage.AuthCode{ ID: storage.NewID(), ClientID: "client2", RedirectURI: "https://localhost:80/callback", Nonce: "foobar", Scopes: []string{"openid", "email"}, Expiry: neverExpire, ConnectorID: "ldap", ConnectorData: []byte(`{"some":"data"}`), Claims: storage.Claims{ UserID: "2", Username: "john", Email: "john.doe@example.com", EmailVerified: true, Groups: []string{"a"}, }, } // Attempt to create same AuthCode twice. err := s.CreateAuthCode(ctx, a1) mustBeErrAlreadyExists(t, "auth code", err) if err := s.CreateAuthCode(ctx, a2); err != nil { t.Fatalf("failed creating auth code: %v", err) } got, err := s.GetAuthCode(ctx, a1.ID) if err != nil { t.Fatalf("failed to get auth code: %v", err) } if a1.Expiry.Unix() != got.Expiry.Unix() { t.Errorf("auth code expiry did not match want=%s vs got=%s", a1.Expiry, got.Expiry) } got.Expiry = a1.Expiry // time fields do not compare well if diff := pretty.Compare(a1, got); diff != "" { t.Errorf("auth code retrieved from storage did not match: %s", diff) } if err := s.DeleteAuthCode(ctx, a1.ID); err != nil { t.Fatalf("delete auth code: %v", err) } if err := s.DeleteAuthCode(ctx, a2.ID); err != nil { t.Fatalf("delete auth code: %v", err) } _, err = s.GetAuthCode(ctx, a1.ID) mustBeErrNotFound(t, "auth code", err) } func testClientCRUD(t *testing.T, s storage.Storage) { ctx := t.Context() id1 := storage.NewID() c1 := storage.Client{ ID: id1, Secret: "foobar", RedirectURIs: []string{"foo://bar.com/", "https://auth.example.com"}, Name: "dex client", LogoURL: "https://goo.gl/JIyzIC", AllowedConnectors: []string{"github", "google"}, } err := s.DeleteClient(ctx, id1) mustBeErrNotFound(t, "client", err) if err := s.CreateClient(ctx, c1); err != nil { t.Fatalf("create client: %v", err) } // Attempt to create same Client twice. err = s.CreateClient(ctx, c1) mustBeErrAlreadyExists(t, "client", err) id2 := storage.NewID() c2 := storage.Client{ ID: id2, Secret: "barfoo", RedirectURIs: []string{"foo://bar.com/", "https://auth.example.com"}, Name: "dex client", LogoURL: "https://goo.gl/JIyzIC", } if err := s.CreateClient(ctx, c2); err != nil { t.Fatalf("create client: %v", err) } getAndCompare := func(_ string, want storage.Client) { gc, err := s.GetClient(ctx, id1) if err != nil { t.Errorf("get client: %v", err) return } if diff := pretty.Compare(want, gc); diff != "" { t.Errorf("client retrieved from storage did not match: %s", diff) } } getAndCompare(id1, c1) newSecret := "barfoo" err = s.UpdateClient(ctx, id1, func(old storage.Client) (storage.Client, error) { old.Secret = newSecret return old, nil }) if err != nil { t.Errorf("update client: %v", err) } c1.Secret = newSecret getAndCompare(id1, c1) if err := s.DeleteClient(ctx, id1); err != nil { t.Fatalf("delete client: %v", err) } if err := s.DeleteClient(ctx, id2); err != nil { t.Fatalf("delete client: %v", err) } _, err = s.GetClient(ctx, id1) mustBeErrNotFound(t, "client", err) } func testRefreshTokenCRUD(t *testing.T, s storage.Storage) { ctx := t.Context() id := storage.NewID() refresh := storage.RefreshToken{ ID: id, Token: "bar", ObsoleteToken: "", Nonce: "foo", ClientID: "client_id", ConnectorID: "client_secret", Scopes: []string{"openid", "email", "profile"}, CreatedAt: time.Now().UTC().Round(time.Millisecond), LastUsed: time.Now().UTC().Round(time.Millisecond), Claims: storage.Claims{ UserID: "1", Username: "jane", Email: "jane.doe@example.com", EmailVerified: true, Groups: []string{"a", "b"}, }, ConnectorData: []byte(`{"some":"data"}`), } if err := s.CreateRefresh(ctx, refresh); err != nil { t.Fatalf("create refresh token: %v", err) } // Attempt to create same Refresh Token twice. err := s.CreateRefresh(ctx, refresh) mustBeErrAlreadyExists(t, "refresh token", err) getAndCompare := func(id string, want storage.RefreshToken) { gr, err := s.GetRefresh(ctx, id) if err != nil { t.Errorf("get refresh: %v", err) return } if diff := pretty.Compare(gr.CreatedAt.UnixNano(), gr.CreatedAt.UnixNano()); diff != "" { t.Errorf("refresh token created timestamp retrieved from storage did not match: %s", diff) } if diff := pretty.Compare(gr.LastUsed.UnixNano(), gr.LastUsed.UnixNano()); diff != "" { t.Errorf("refresh token last used timestamp retrieved from storage did not match: %s", diff) } gr.CreatedAt = time.Time{} gr.LastUsed = time.Time{} want.CreatedAt = time.Time{} want.LastUsed = time.Time{} if diff := pretty.Compare(want, gr); diff != "" { t.Errorf("refresh token retrieved from storage did not match: %s", diff) } } getAndCompare(id, refresh) id2 := storage.NewID() refresh2 := storage.RefreshToken{ ID: id2, Token: "bar_2", ObsoleteToken: refresh.Token, Nonce: "foo_2", ClientID: "client_id_2", ConnectorID: "client_secret", Scopes: []string{"openid", "email", "profile"}, CreatedAt: time.Now().UTC().Round(time.Millisecond), LastUsed: time.Now().UTC().Round(time.Millisecond), Claims: storage.Claims{ UserID: "2", Username: "john", Email: "john.doe@example.com", EmailVerified: true, Groups: []string{"a", "b"}, }, ConnectorData: []byte(`{"some":"data"}`), } if err := s.CreateRefresh(ctx, refresh2); err != nil { t.Fatalf("create second refresh token: %v", err) } getAndCompare(id2, refresh2) updatedAt := time.Now().UTC().Round(time.Millisecond) updater := func(r storage.RefreshToken) (storage.RefreshToken, error) { r.Token = "spam" r.LastUsed = updatedAt return r, nil } if err := s.UpdateRefreshToken(ctx, id, updater); err != nil { t.Errorf("failed to update refresh token: %v", err) } refresh.Token = "spam" refresh.LastUsed = updatedAt getAndCompare(id, refresh) // Ensure that updating the first token doesn't impact the second. Issue #847. getAndCompare(id2, refresh2) if err := s.DeleteRefresh(ctx, id); err != nil { t.Fatalf("failed to delete refresh request: %v", err) } if err := s.DeleteRefresh(ctx, id2); err != nil { t.Fatalf("failed to delete refresh request: %v", err) } _, err = s.GetRefresh(ctx, id) mustBeErrNotFound(t, "refresh token", err) } type byEmail []storage.Password func (n byEmail) Len() int { return len(n) } func (n byEmail) Less(i, j int) bool { return n[i].Email < n[j].Email } func (n byEmail) Swap(i, j int) { n[i], n[j] = n[j], n[i] } func boolPtr(v bool) *bool { return &v } func testPasswordCRUD(t *testing.T, s storage.Storage) { ctx := t.Context() // Use bcrypt.MinCost to keep the tests short. passwordHash1, err := bcrypt.GenerateFromPassword([]byte("secret"), bcrypt.MinCost) if err != nil { t.Fatal(err) } password1 := storage.Password{ Email: "jane@example.com", Hash: passwordHash1, Username: "jane", Name: "Jane Doe", PreferredUsername: "jane-public", EmailVerified: boolPtr(true), UserID: "foobar", Groups: []string{"team-a", "team-a/admins"}, } if err := s.CreatePassword(ctx, password1); err != nil { t.Fatalf("create password token: %v", err) } // Attempt to create same Password twice. err = s.CreatePassword(ctx, password1) mustBeErrAlreadyExists(t, "password", err) passwordHash2, err := bcrypt.GenerateFromPassword([]byte("password"), bcrypt.MinCost) if err != nil { t.Fatal(err) } password2 := storage.Password{ Email: "john@example.com", Hash: passwordHash2, Username: "john", Name: "John Smith", PreferredUsername: "john-public", EmailVerified: boolPtr(false), UserID: "barfoo", Groups: []string{"team-b"}, } if err := s.CreatePassword(ctx, password2); err != nil { t.Fatalf("create password token: %v", err) } getAndCompare := func(id string, want storage.Password) { gr, err := s.GetPassword(ctx, id) if err != nil { t.Errorf("get password %q: %v", id, err) return } if diff := pretty.Compare(want, gr); diff != "" { t.Errorf("password retrieved from storage did not match: %s", diff) } } getAndCompare("jane@example.com", password1) getAndCompare("JANE@example.com", password1) // Emails should be case insensitive if err := s.UpdatePassword(ctx, password1.Email, func(old storage.Password) (storage.Password, error) { old.Username = "jane doe" return old, nil }); err != nil { t.Fatalf("failed to update auth request: %v", err) } password1.Username = "jane doe" getAndCompare("jane@example.com", password1) var passwordList []storage.Password passwordList = append(passwordList, password1, password2) listAndCompare := func(want []storage.Password) { passwords, err := s.ListPasswords(ctx) if err != nil { t.Errorf("list password: %v", err) return } sort.Sort(byEmail(want)) sort.Sort(byEmail(passwords)) if diff := pretty.Compare(want, passwords); diff != "" { t.Errorf("password list retrieved from storage did not match: %s", diff) } } listAndCompare(passwordList) if err := s.DeletePassword(ctx, password1.Email); err != nil { t.Fatalf("failed to delete password: %v", err) } if err := s.DeletePassword(ctx, password2.Email); err != nil { t.Fatalf("failed to delete password: %v", err) } _, err = s.GetPassword(ctx, password1.Email) mustBeErrNotFound(t, "password", err) } func testOfflineSessionCRUD(t *testing.T, s storage.Storage) { ctx := t.Context() userID1 := storage.NewID() session1 := storage.OfflineSessions{ UserID: userID1, ConnID: "Conn1", Refresh: make(map[string]*storage.RefreshTokenRef), ConnectorData: []byte(`{"some":"data"}`), } // Creating an OfflineSession with an empty Refresh list to ensure that // an empty map is translated as expected by the storage. if err := s.CreateOfflineSessions(ctx, session1); err != nil { t.Fatalf("create offline session with UserID = %s: %v", session1.UserID, err) } // Attempt to create same OfflineSession twice. err := s.CreateOfflineSessions(ctx, session1) mustBeErrAlreadyExists(t, "offline session", err) userID2 := storage.NewID() session2 := storage.OfflineSessions{ UserID: userID2, ConnID: "Conn2", Refresh: make(map[string]*storage.RefreshTokenRef), ConnectorData: []byte(`{"some":"data"}`), } if err := s.CreateOfflineSessions(ctx, session2); err != nil { t.Fatalf("create offline session with UserID = %s: %v", session2.UserID, err) } getAndCompare := func(userID string, connID string, want storage.OfflineSessions) { gr, err := s.GetOfflineSessions(ctx, userID, connID) if err != nil { t.Errorf("get offline session: %v", err) return } if diff := pretty.Compare(want, gr); diff != "" { t.Errorf("offline session retrieved from storage did not match: %s", diff) } } getAndCompare(userID1, "Conn1", session1) id := storage.NewID() tokenRef := storage.RefreshTokenRef{ ID: id, ClientID: "client_id", CreatedAt: time.Now().UTC().Round(time.Millisecond), LastUsed: time.Now().UTC().Round(time.Millisecond), } session1.Refresh[tokenRef.ClientID] = &tokenRef if err := s.UpdateOfflineSessions(ctx, session1.UserID, session1.ConnID, func(old storage.OfflineSessions) (storage.OfflineSessions, error) { old.Refresh[tokenRef.ClientID] = &tokenRef return old, nil }); err != nil { t.Fatalf("failed to update offline session: %v", err) } getAndCompare(userID1, "Conn1", session1) if err := s.DeleteOfflineSessions(ctx, session1.UserID, session1.ConnID); err != nil { t.Fatalf("failed to delete offline session: %v", err) } if err := s.DeleteOfflineSessions(ctx, session2.UserID, session2.ConnID); err != nil { t.Fatalf("failed to delete offline session: %v", err) } _, err = s.GetOfflineSessions(ctx, session1.UserID, session1.ConnID) mustBeErrNotFound(t, "offline session", err) } func testConnectorCRUD(t *testing.T, s storage.Storage) { ctx := t.Context() id1 := storage.NewID() config1 := []byte(`{"issuer": "https://accounts.google.com"}`) c1 := storage.Connector{ ID: id1, Type: "Default", Name: "Default", Config: config1, GrantTypes: []string{"authorization_code", "refresh_token"}, } if err := s.CreateConnector(ctx, c1); err != nil { t.Fatalf("create connector with ID = %s: %v", c1.ID, err) } // Attempt to create same Connector twice. err := s.CreateConnector(ctx, c1) mustBeErrAlreadyExists(t, "connector", err) id2 := storage.NewID() config2 := []byte(`{"redirectURI": "http://127.0.0.1:5556/dex/callback"}`) c2 := storage.Connector{ ID: id2, Type: "Mock", Name: "Mock", Config: config2, } if err := s.CreateConnector(ctx, c2); err != nil { t.Fatalf("create connector with ID = %s: %v", c2.ID, err) } getAndCompare := func(id string, want storage.Connector) { gr, err := s.GetConnector(ctx, id) if err != nil { t.Errorf("get connector: %v", err) return } // ignore resource version comparison gr.ResourceVersion = "" if diff := pretty.Compare(want, gr); diff != "" { t.Errorf("connector retrieved from storage did not match: %s", diff) } } getAndCompare(id1, c1) if err := s.UpdateConnector(ctx, c1.ID, func(old storage.Connector) (storage.Connector, error) { old.Type = "oidc" old.GrantTypes = []string{"urn:ietf:params:oauth:grant-type:token-exchange"} return old, nil }); err != nil { t.Fatalf("failed to update Connector: %v", err) } c1.Type = "oidc" c1.GrantTypes = []string{"urn:ietf:params:oauth:grant-type:token-exchange"} getAndCompare(id1, c1) connectorList := []storage.Connector{c1, c2} listAndCompare := func(want []storage.Connector) { connectors, err := s.ListConnectors(ctx) if err != nil { t.Errorf("list connectors: %v", err) return } // ignore resource version comparison for i := range connectors { connectors[i].ResourceVersion = "" } sort.Slice(connectors, func(i, j int) bool { return connectors[i].Name < connectors[j].Name }) if diff := pretty.Compare(want, connectors); diff != "" { t.Errorf("connector list retrieved from storage did not match: %s", diff) } } listAndCompare(connectorList) if err := s.DeleteConnector(ctx, c1.ID); err != nil { t.Fatalf("failed to delete connector: %v", err) } if err := s.DeleteConnector(ctx, c2.ID); err != nil { t.Fatalf("failed to delete connector: %v", err) } _, err = s.GetConnector(ctx, c1.ID) mustBeErrNotFound(t, "connector", err) } func testKeysCRUD(t *testing.T, s storage.Storage) { ctx := context.TODO() updateAndCompare := func(k storage.Keys) { err := s.UpdateKeys(ctx, func(oldKeys storage.Keys) (storage.Keys, error) { return k, nil }) if err != nil { t.Errorf("failed to update keys: %v", err) return } if got, err := s.GetKeys(ctx); err != nil { t.Errorf("failed to get keys: %v", err) } else { got.NextRotation = got.NextRotation.UTC() if diff := pretty.Compare(k, got); diff != "" { t.Errorf("got keys did not equal expected: %s", diff) } } } // Postgres isn't as accurate with nano seconds as we'd like n := time.Now().UTC().Round(time.Second) keys1 := storage.Keys{ SigningKey: jsonWebKeys[0].Private, SigningKeyPub: jsonWebKeys[0].Public, NextRotation: n, } keys2 := storage.Keys{ SigningKey: jsonWebKeys[2].Private, SigningKeyPub: jsonWebKeys[2].Public, NextRotation: n.Add(time.Hour), VerificationKeys: []storage.VerificationKey{ { PublicKey: jsonWebKeys[0].Public, Expiry: n.Add(time.Hour), }, { PublicKey: jsonWebKeys[1].Public, Expiry: n.Add(time.Hour * 2), }, }, } updateAndCompare(keys1) updateAndCompare(keys2) } func testGC(t *testing.T, s storage.Storage) { ctx := t.Context() est, err := time.LoadLocation("America/New_York") if err != nil { t.Fatal(err) } pst, err := time.LoadLocation("America/Los_Angeles") if err != nil { t.Fatal(err) } expiry := time.Now().In(est) c := storage.AuthCode{ ID: storage.NewID(), ClientID: "foobar", RedirectURI: "https://localhost:80/callback", Nonce: "foobar", Scopes: []string{"openid", "email"}, Expiry: expiry, ConnectorID: "ldap", ConnectorData: []byte(`{"some":"data"}`), Claims: storage.Claims{ UserID: "1", Username: "jane", Email: "jane.doe@example.com", EmailVerified: true, Groups: []string{"a", "b"}, }, } if err := s.CreateAuthCode(ctx, c); err != nil { t.Fatalf("failed creating auth code: %v", err) } for _, tz := range []*time.Location{time.UTC, est, pst} { result, err := s.GarbageCollect(ctx, expiry.Add(-time.Hour).In(tz)) if err != nil { t.Errorf("garbage collection failed: %v", err) } else if result.AuthCodes != 0 || result.AuthRequests != 0 { t.Errorf("expected no garbage collection results, got %#v", result) } if _, err := s.GetAuthCode(ctx, c.ID); err != nil { t.Errorf("expected to be able to get auth code after GC: %v", err) } } if r, err := s.GarbageCollect(ctx, expiry.Add(time.Hour)); err != nil { t.Errorf("garbage collection failed: %v", err) } else if r.AuthCodes != 1 { t.Errorf("expected to garbage collect 1 objects, got %d", r.AuthCodes) } if _, err := s.GetAuthCode(ctx, c.ID); err == nil { t.Errorf("expected auth code to be GC'd") } else if err != storage.ErrNotFound { t.Errorf("expected storage.ErrNotFound, got %v", err) } a := storage.AuthRequest{ ID: storage.NewID(), ClientID: "foobar", ResponseTypes: []string{"code"}, Scopes: []string{"openid", "email"}, RedirectURI: "https://localhost:80/callback", Nonce: "foo", State: "bar", ForceApprovalPrompt: true, LoggedIn: true, Expiry: expiry, ConnectorID: "ldap", ConnectorData: []byte(`{"some":"data"}`), Claims: storage.Claims{ UserID: "1", Username: "jane", Email: "jane.doe@example.com", EmailVerified: true, Groups: []string{"a", "b"}, }, HMACKey: []byte("hmac_key"), } if err := s.CreateAuthRequest(ctx, a); err != nil { t.Fatalf("failed creating auth request: %v", err) } for _, tz := range []*time.Location{time.UTC, est, pst} { result, err := s.GarbageCollect(ctx, expiry.Add(-time.Hour).In(tz)) if err != nil { t.Errorf("garbage collection failed: %v", err) } else if result.AuthCodes != 0 || result.AuthRequests != 0 { t.Errorf("expected no garbage collection results, got %#v", result) } if _, err := s.GetAuthRequest(ctx, a.ID); err != nil { t.Errorf("expected to be able to get auth request after GC: %v", err) } } if r, err := s.GarbageCollect(ctx, expiry.Add(time.Hour)); err != nil { t.Errorf("garbage collection failed: %v", err) } else if r.AuthRequests != 1 { t.Errorf("expected to garbage collect 1 objects, got %d", r.AuthRequests) } if _, err := s.GetAuthRequest(ctx, a.ID); err == nil { t.Errorf("expected auth request to be GC'd") } else if err != storage.ErrNotFound { t.Errorf("expected storage.ErrNotFound, got %v", err) } d := storage.DeviceRequest{ UserCode: storage.NewUserCode(), DeviceCode: storage.NewID(), ClientID: "client1", ClientSecret: "secret1", Scopes: []string{"openid", "email"}, Expiry: expiry, } if err := s.CreateDeviceRequest(ctx, d); err != nil { t.Fatalf("failed creating device request: %v", err) } for _, tz := range []*time.Location{time.UTC, est, pst} { result, err := s.GarbageCollect(ctx, expiry.Add(-time.Hour).In(tz)) if err != nil { t.Errorf("garbage collection failed: %v", err) } else if result.DeviceRequests != 0 { t.Errorf("expected no device garbage collection results, got %#v", result) } if _, err := s.GetDeviceRequest(ctx, d.UserCode); err != nil { t.Errorf("expected to be able to get auth request after GC: %v", err) } } if r, err := s.GarbageCollect(ctx, expiry.Add(time.Hour)); err != nil { t.Errorf("garbage collection failed: %v", err) } else if r.DeviceRequests != 1 { t.Errorf("expected to garbage collect 1 device request, got %d", r.DeviceRequests) } if _, err := s.GetDeviceRequest(ctx, d.UserCode); err == nil { t.Errorf("expected device request to be GC'd") } else if err != storage.ErrNotFound { t.Errorf("expected storage.ErrNotFound, got %v", err) } dt := storage.DeviceToken{ DeviceCode: storage.NewID(), Status: "pending", Token: "foo", Expiry: expiry, LastRequestTime: time.Now(), PollIntervalSeconds: 0, PKCE: storage.PKCE{ CodeChallenge: "challenge", CodeChallengeMethod: "S256", }, } if err := s.CreateDeviceToken(ctx, dt); err != nil { t.Fatalf("failed creating device token: %v", err) } for _, tz := range []*time.Location{time.UTC, est, pst} { result, err := s.GarbageCollect(ctx, expiry.Add(-time.Hour).In(tz)) if err != nil { t.Errorf("garbage collection failed: %v", err) } else if result.DeviceTokens != 0 { t.Errorf("expected no device token garbage collection results, got %#v", result) } if _, err := s.GetDeviceToken(ctx, dt.DeviceCode); err != nil { t.Errorf("expected to be able to get device token after GC: %v", err) } } if r, err := s.GarbageCollect(ctx, expiry.Add(time.Hour)); err != nil { t.Errorf("garbage collection failed: %v", err) } else if r.DeviceTokens != 1 { t.Errorf("expected to garbage collect 1 device token, got %d", r.DeviceTokens) } if _, err := s.GetDeviceToken(ctx, dt.DeviceCode); err == nil { t.Errorf("expected device token to be GC'd") } else if err != storage.ErrNotFound { t.Errorf("expected storage.ErrNotFound, got %v", err) } } // testTimezones tests that backends either fully support timezones or // do the correct standardization. func testTimezones(t *testing.T, s storage.Storage) { ctx := t.Context() est, err := time.LoadLocation("America/New_York") if err != nil { t.Fatal(err) } // Create an expiry with timezone info. Only expect backends to be // accurate to the millisecond expiry := time.Now().In(est).Round(time.Millisecond) c := storage.AuthCode{ ID: storage.NewID(), ClientID: "foobar", RedirectURI: "https://localhost:80/callback", Nonce: "foobar", Scopes: []string{"openid", "email"}, Expiry: expiry, ConnectorID: "ldap", ConnectorData: []byte(`{"some":"data"}`), Claims: storage.Claims{ UserID: "1", Username: "jane", Email: "jane.doe@example.com", EmailVerified: true, Groups: []string{"a", "b"}, }, } if err := s.CreateAuthCode(ctx, c); err != nil { t.Fatalf("failed creating auth code: %v", err) } got, err := s.GetAuthCode(ctx, c.ID) if err != nil { t.Fatalf("failed to get auth code: %v", err) } // Ensure that if the resulting time is converted to the same // timezone, it's the same value. We DO NOT expect timezones // to be preserved. gotTime := got.Expiry.In(est) wantTime := expiry if !gotTime.Equal(wantTime) { t.Fatalf("expected expiry %v got %v", wantTime, gotTime) } } func testDeviceRequestCRUD(t *testing.T, s storage.Storage) { ctx := t.Context() d1 := storage.DeviceRequest{ UserCode: storage.NewUserCode(), DeviceCode: storage.NewID(), ClientID: "client1", ClientSecret: "secret1", Scopes: []string{"openid", "email"}, Expiry: neverExpire.Round(time.Second), } if err := s.CreateDeviceRequest(ctx, d1); err != nil { t.Fatalf("failed creating device request: %v", err) } // Attempt to create same DeviceRequest twice. err := s.CreateDeviceRequest(ctx, d1) mustBeErrAlreadyExists(t, "device request", err) got, err := s.GetDeviceRequest(ctx, d1.UserCode) if err != nil { t.Fatalf("failed to get device request: %v", err) } require.Equal(t, d1, got) // No manual deletes for device requests, will be handled by garbage collection routines // see testGC } func testDeviceTokenCRUD(t *testing.T, s storage.Storage) { ctx := t.Context() codeChallenge := storage.PKCE{ CodeChallenge: "code_challenge_test", CodeChallengeMethod: "plain", } // Create a Token d1 := storage.DeviceToken{ DeviceCode: storage.NewID(), Status: "pending", Token: storage.NewID(), Expiry: neverExpire, LastRequestTime: time.Now(), PollIntervalSeconds: 0, PKCE: codeChallenge, } if err := s.CreateDeviceToken(ctx, d1); err != nil { t.Fatalf("failed creating device token: %v", err) } // Attempt to create same Device Token twice. err := s.CreateDeviceToken(ctx, d1) mustBeErrAlreadyExists(t, "device token", err) // Update the device token, simulate a redemption if err := s.UpdateDeviceToken(ctx, d1.DeviceCode, func(old storage.DeviceToken) (storage.DeviceToken, error) { old.Token = "token data" old.Status = "complete" return old, nil }); err != nil { t.Fatalf("failed to update device token: %v", err) } // Retrieve the device token got, err := s.GetDeviceToken(ctx, d1.DeviceCode) if err != nil { t.Fatalf("failed to get device token: %v", err) } // Validate expected result set if got.Status != "complete" { t.Fatalf("update failed, wanted token status=%v got %v", "complete", got.Status) } if got.Token != "token data" { t.Fatalf("update failed, wanted token %v got %v", "token data", got.Token) } if !reflect.DeepEqual(got.PKCE, codeChallenge) { t.Fatalf("storage does not support PKCE, wanted challenge=%#v got %#v", codeChallenge, got.PKCE) } } func testUserIdentityCRUD(t *testing.T, s storage.Storage) { ctx := t.Context() now := time.Now().UTC().Round(time.Millisecond) u1 := storage.UserIdentity{ UserID: "user1", ConnectorID: "conn1", Claims: storage.Claims{ UserID: "user1", Username: "jane", Email: "jane@example.com", EmailVerified: true, Groups: []string{"a", "b"}, }, Consents: make(map[string][]string), CreatedAt: now, LastLogin: now, BlockedUntil: time.Unix(0, 0).UTC(), } // Create with empty Consents map. if err := s.CreateUserIdentity(ctx, u1); err != nil { t.Fatalf("create user identity: %v", err) } // Duplicate create should return ErrAlreadyExists. err := s.CreateUserIdentity(ctx, u1) mustBeErrAlreadyExists(t, "user identity", err) // Get and compare. got, err := s.GetUserIdentity(ctx, u1.UserID, u1.ConnectorID) if err != nil { t.Fatalf("get user identity: %v", err) } got.CreatedAt = got.CreatedAt.UTC().Round(time.Millisecond) got.LastLogin = got.LastLogin.UTC().Round(time.Millisecond) got.BlockedUntil = got.BlockedUntil.UTC().Round(time.Millisecond) u1.BlockedUntil = u1.BlockedUntil.UTC().Round(time.Millisecond) if diff := pretty.Compare(u1, got); diff != "" { t.Errorf("user identity retrieved from storage did not match: %s", diff) } // Update: add consent entry. if err := s.UpdateUserIdentity(ctx, u1.UserID, u1.ConnectorID, func(old storage.UserIdentity) (storage.UserIdentity, error) { old.Consents["client1"] = []string{"openid", "email"} return old, nil }); err != nil { t.Fatalf("update user identity: %v", err) } // Get and verify updated consents. got, err = s.GetUserIdentity(ctx, u1.UserID, u1.ConnectorID) if err != nil { t.Fatalf("get user identity after update: %v", err) } wantConsents := map[string][]string{"client1": {"openid", "email"}} if diff := pretty.Compare(wantConsents, got.Consents); diff != "" { t.Errorf("user identity consents did not match after update: %s", diff) } // List and verify. identities, err := s.ListUserIdentities(ctx) if err != nil { t.Fatalf("list user identities: %v", err) } if len(identities) != 1 { t.Fatalf("expected 1 user identity, got %d", len(identities)) } // Delete. if err := s.DeleteUserIdentity(ctx, u1.UserID, u1.ConnectorID); err != nil { t.Fatalf("delete user identity: %v", err) } // Get deleted should return ErrNotFound. _, err = s.GetUserIdentity(ctx, u1.UserID, u1.ConnectorID) mustBeErrNotFound(t, "user identity", err) } func testAuthSessionCRUD(t *testing.T, s storage.Storage) { ctx := t.Context() now := time.Now().UTC().Round(time.Millisecond) session := storage.AuthSession{ ID: storage.NewID(), ClientStates: map[string]*storage.ClientAuthState{ "client1": { UserID: "user1", ConnectorID: "conn1", Active: true, ExpiresAt: now.Add(24 * time.Hour), LastActivity: now, LastTokenIssuedAt: now, }, }, CreatedAt: now, LastActivity: now, IPAddress: "192.168.1.1", UserAgent: "TestBrowser/1.0", } // Create. if err := s.CreateAuthSession(ctx, session); err != nil { t.Fatalf("create auth session: %v", err) } // Duplicate create should return ErrAlreadyExists. err := s.CreateAuthSession(ctx, session) mustBeErrAlreadyExists(t, "auth session", err) // Get and compare. got, err := s.GetAuthSession(ctx, session.ID) if err != nil { t.Fatalf("get auth session: %v", err) } got.CreatedAt = got.CreatedAt.UTC().Round(time.Millisecond) got.LastActivity = got.LastActivity.UTC().Round(time.Millisecond) for _, cs := range got.ClientStates { cs.ExpiresAt = cs.ExpiresAt.UTC().Round(time.Millisecond) cs.LastActivity = cs.LastActivity.UTC().Round(time.Millisecond) cs.LastTokenIssuedAt = cs.LastTokenIssuedAt.UTC().Round(time.Millisecond) } if diff := pretty.Compare(session, got); diff != "" { t.Errorf("auth session retrieved from storage did not match: %s", diff) } // Update: add a new client state. newNow := now.Add(time.Minute) if err := s.UpdateAuthSession(ctx, session.ID, func(old storage.AuthSession) (storage.AuthSession, error) { old.ClientStates["client2"] = &storage.ClientAuthState{ UserID: "user2", ConnectorID: "conn2", Active: true, ExpiresAt: newNow.Add(24 * time.Hour), LastActivity: newNow, } old.LastActivity = newNow return old, nil }); err != nil { t.Fatalf("update auth session: %v", err) } // Get and verify update. got, err = s.GetAuthSession(ctx, session.ID) if err != nil { t.Fatalf("get auth session after update: %v", err) } if len(got.ClientStates) != 2 { t.Fatalf("expected 2 client states, got %d", len(got.ClientStates)) } if got.ClientStates["client2"] == nil { t.Fatal("expected client2 state to exist") } if got.ClientStates["client2"].UserID != "user2" { t.Errorf("expected client2 user_id to be user2, got %s", got.ClientStates["client2"].UserID) } // List and verify. sessions, err := s.ListAuthSessions(ctx) if err != nil { t.Fatalf("list auth sessions: %v", err) } if len(sessions) != 1 { t.Fatalf("expected 1 auth session, got %d", len(sessions)) } // Delete. if err := s.DeleteAuthSession(ctx, session.ID); err != nil { t.Fatalf("delete auth session: %v", err) } // Get deleted should return ErrNotFound. _, err = s.GetAuthSession(ctx, session.ID) mustBeErrNotFound(t, "auth session", err) }