Browse Source

test: use new Go features in tests

Signed-off-by: Mark Sagi-Kazar <mark.sagikazar@gmail.com>
pull/4277/head
Mark Sagi-Kazar 7 months ago
parent
commit
e230d9426d
No known key found for this signature in database
GPG Key ID: 31AB0439F4C5C90E
  1. 117
      server/api_test.go
  2. 27
      server/deviceflowhandlers_test.go
  3. 52
      server/handlers_test.go
  4. 23
      server/introspectionhandler_test.go
  5. 11
      server/oauth2_test.go
  6. 8
      server/refreshhandlers_test.go
  7. 79
      server/server_test.go
  8. 28
      storage/conformance/conformance.go
  9. 8
      storage/conformance/transactions.go
  10. 8
      storage/ent/mysql_test.go
  11. 8
      storage/ent/postgres_test.go
  12. 4
      storage/ent/sqlite_test.go
  13. 5
      storage/etcd/etcd_test.go
  14. 4
      storage/kubernetes/storage_test.go
  15. 4
      storage/memory/memory_test.go
  16. 7
      storage/sql/config_test.go
  17. 3
      storage/sql/postgres_test.go

117
server/api_test.go

@ -1,10 +1,9 @@
package server package server
import ( import (
"context"
"log/slog" "log/slog"
"net" "net"
"os" "slices"
"strings" "strings"
"testing" "testing"
"time" "time"
@ -29,8 +28,12 @@ type apiClient struct {
Close func() Close func()
} }
func newLogger(t *testing.T) *slog.Logger {
return slog.New(slog.NewTextHandler(t.Output(), &slog.HandlerOptions{Level: slog.LevelDebug}))
}
// newAPI constructs a gRCP client connected to a backing server. // newAPI constructs a gRCP client connected to a backing server.
func newAPI(s storage.Storage, logger *slog.Logger, t *testing.T) *apiClient { func newAPI(t *testing.T, s storage.Storage, logger *slog.Logger) *apiClient {
l, err := net.Listen("tcp", "127.0.0.1:0") l, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@ -59,13 +62,14 @@ func newAPI(s storage.Storage, logger *slog.Logger, t *testing.T) *apiClient {
// Attempts to create, update and delete a test Password // Attempts to create, update and delete a test Password
func TestPassword(t *testing.T) { func TestPassword(t *testing.T) {
logger := slog.New(slog.DiscardHandler) logger := newLogger(t)
s := memory.New(logger) s := memory.New(logger)
client := newAPI(s, logger, t)
client := newAPI(t, s, logger)
defer client.Close() defer client.Close()
ctx := context.Background() ctx := t.Context()
email := "test@example.com" email := "test@example.com"
p := api.Password{ p := api.Password{
Email: email, Email: email,
@ -168,10 +172,10 @@ func TestPassword(t *testing.T) {
// Ensures checkCost returns expected values // Ensures checkCost returns expected values
func TestCheckCost(t *testing.T) { func TestCheckCost(t *testing.T) {
logger := slog.New(slog.DiscardHandler) logger := newLogger(t)
s := memory.New(logger) s := memory.New(logger)
client := newAPI(s, logger, t)
client := newAPI(t, s, logger)
defer client.Close() defer client.Close()
tests := []struct { tests := []struct {
@ -221,13 +225,13 @@ func TestCheckCost(t *testing.T) {
// Attempts to list and revoke an existing refresh token. // Attempts to list and revoke an existing refresh token.
func TestRefreshToken(t *testing.T) { func TestRefreshToken(t *testing.T) {
logger := slog.New(slog.DiscardHandler) logger := newLogger(t)
s := memory.New(logger) s := memory.New(logger)
client := newAPI(s, logger, t)
client := newAPI(t, s, logger)
defer client.Close() defer client.Close()
ctx := context.Background() ctx := t.Context()
// Creating a storage with an existing refresh token and offline session for the user. // Creating a storage with an existing refresh token and offline session for the user.
id := storage.NewID() id := storage.NewID()
@ -330,12 +334,13 @@ func TestRefreshToken(t *testing.T) {
} }
func TestUpdateClient(t *testing.T) { func TestUpdateClient(t *testing.T) {
logger := slog.New(slog.DiscardHandler) logger := newLogger(t)
s := memory.New(logger) s := memory.New(logger)
client := newAPI(s, logger, t)
client := newAPI(t, s, logger)
defer client.Close() defer client.Close()
ctx := context.Background()
ctx := t.Context()
createClient := func(t *testing.T, clientId string) { createClient := func(t *testing.T, clientId string) {
resp, err := client.CreateClient(ctx, &api.CreateClientReq{ resp, err := client.CreateClient(ctx, &api.CreateClientReq{
@ -463,13 +468,13 @@ func TestUpdateClient(t *testing.T) {
t.Errorf("expected stored client with LogoURL: %s, found %s", tc.req.LogoUrl, client.LogoURL) t.Errorf("expected stored client with LogoURL: %s, found %s", tc.req.LogoUrl, client.LogoURL)
} }
for _, redirectURI := range tc.req.RedirectUris { for _, redirectURI := range tc.req.RedirectUris {
found := find(redirectURI, client.RedirectURIs) found := slices.Contains(client.RedirectURIs, redirectURI)
if !found { if !found {
t.Errorf("expected redirect URI: %s", redirectURI) t.Errorf("expected redirect URI: %s", redirectURI)
} }
} }
for _, peer := range tc.req.TrustedPeers { for _, peer := range tc.req.TrustedPeers {
found := find(peer, client.TrustedPeers) found := slices.Contains(client.TrustedPeers, peer)
if !found { if !found {
t.Errorf("expected trusted peer: %s", peer) t.Errorf("expected trusted peer: %s", peer)
} }
@ -483,26 +488,17 @@ func TestUpdateClient(t *testing.T) {
} }
} }
func find(item string, items []string) bool {
for _, i := range items {
if item == i {
return true
}
}
return false
}
func TestCreateConnector(t *testing.T) { func TestCreateConnector(t *testing.T) {
os.Setenv("DEX_API_CONNECTORS_CRUD", "true") t.Setenv("DEX_API_CONNECTORS_CRUD", "true")
defer os.Unsetenv("DEX_API_CONNECTORS_CRUD")
logger := slog.New(slog.DiscardHandler)
logger := newLogger(t)
s := memory.New(logger) s := memory.New(logger)
client := newAPI(s, logger, t)
client := newAPI(t, s, logger)
defer client.Close() defer client.Close()
ctx := context.Background() ctx := t.Context()
connectorID := "connector123" connectorID := "connector123"
connectorName := "TestConnector" connectorName := "TestConnector"
connectorType := "TestType" connectorType := "TestType"
@ -543,16 +539,16 @@ func TestCreateConnector(t *testing.T) {
} }
func TestUpdateConnector(t *testing.T) { func TestUpdateConnector(t *testing.T) {
os.Setenv("DEX_API_CONNECTORS_CRUD", "true") t.Setenv("DEX_API_CONNECTORS_CRUD", "true")
defer os.Unsetenv("DEX_API_CONNECTORS_CRUD")
logger := slog.New(slog.DiscardHandler)
logger := newLogger(t)
s := memory.New(logger) s := memory.New(logger)
client := newAPI(s, logger, t)
client := newAPI(t, s, logger)
defer client.Close() defer client.Close()
ctx := context.Background() ctx := t.Context()
connectorID := "connector123" connectorID := "connector123"
newConnectorName := "UpdatedConnector" newConnectorName := "UpdatedConnector"
newConnectorType := "UpdatedType" newConnectorType := "UpdatedType"
@ -611,16 +607,16 @@ func TestUpdateConnector(t *testing.T) {
} }
func TestDeleteConnector(t *testing.T) { func TestDeleteConnector(t *testing.T) {
os.Setenv("DEX_API_CONNECTORS_CRUD", "true") t.Setenv("DEX_API_CONNECTORS_CRUD", "true")
defer os.Unsetenv("DEX_API_CONNECTORS_CRUD")
logger := slog.New(slog.DiscardHandler)
logger := newLogger(t)
s := memory.New(logger) s := memory.New(logger)
client := newAPI(s, logger, t)
client := newAPI(t, s, logger)
defer client.Close() defer client.Close()
ctx := context.Background() ctx := t.Context()
connectorID := "connector123" connectorID := "connector123"
// Create a connector for testing // Create a connector for testing
@ -655,16 +651,15 @@ func TestDeleteConnector(t *testing.T) {
} }
func TestListConnectors(t *testing.T) { func TestListConnectors(t *testing.T) {
os.Setenv("DEX_API_CONNECTORS_CRUD", "true") t.Setenv("DEX_API_CONNECTORS_CRUD", "true")
defer os.Unsetenv("DEX_API_CONNECTORS_CRUD")
logger := slog.New(slog.DiscardHandler)
logger := newLogger(t)
s := memory.New(logger) s := memory.New(logger)
client := newAPI(s, logger, t)
client := newAPI(t, s, logger)
defer client.Close() defer client.Close()
ctx := context.Background() ctx := t.Context()
// Create connectors for testing // Create connectors for testing
createReq1 := api.CreateConnectorReq{ createReq1 := api.CreateConnectorReq{
@ -698,13 +693,13 @@ func TestListConnectors(t *testing.T) {
} }
func TestMissingConnectorsCRUDFeatureFlag(t *testing.T) { func TestMissingConnectorsCRUDFeatureFlag(t *testing.T) {
logger := slog.New(slog.DiscardHandler) logger := newLogger(t)
s := memory.New(logger) s := memory.New(logger)
client := newAPI(s, logger, t)
client := newAPI(t, s, logger)
defer client.Close() defer client.Close()
ctx := context.Background() ctx := t.Context()
// Create connectors for testing // Create connectors for testing
createReq1 := api.CreateConnectorReq{ createReq1 := api.CreateConnectorReq{
@ -735,13 +730,13 @@ func TestMissingConnectorsCRUDFeatureFlag(t *testing.T) {
} }
func TestListClients(t *testing.T) { func TestListClients(t *testing.T) {
logger := slog.New(slog.DiscardHandler) logger := newLogger(t)
s := memory.New(logger) s := memory.New(logger)
client := newAPI(s, logger, t)
client := newAPI(t, s, logger)
defer client.Close() defer client.Close()
ctx := context.Background() ctx := t.Context()
// List Clients // List Clients
listResp, err := client.ListClients(ctx, &api.ListClientReq{}) listResp, err := client.ListClients(ctx, &api.ListClientReq{})

27
server/deviceflowhandlers_test.go

@ -2,7 +2,6 @@ package server
import ( import (
"bytes" "bytes"
"context"
"encoding/json" "encoding/json"
"io" "io"
"net/http" "net/http"
@ -20,10 +19,8 @@ func TestDeviceVerificationURI(t *testing.T) {
t0 := time.Now() t0 := time.Now()
now := func() time.Time { return t0 } now := func() time.Time { return t0 }
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// Setup a dex server. // Setup a dex server.
httpServer, s := newTestServer(ctx, t, func(c *Config) { httpServer, s := newTestServer(t, func(c *Config) {
c.Issuer += "/non-root-path" c.Issuer += "/non-root-path"
c.Now = now c.Now = now
}) })
@ -101,11 +98,8 @@ func TestHandleDeviceCode(t *testing.T) {
} }
for _, tc := range tests { for _, tc := range tests {
t.Run(tc.testName, func(t *testing.T) { t.Run(tc.testName, func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// Setup a dex server. // Setup a dex server.
httpServer, s := newTestServer(ctx, t, func(c *Config) { httpServer, s := newTestServer(t, func(c *Config) {
c.Issuer += "/non-root-path" c.Issuer += "/non-root-path"
c.Now = now c.Now = now
}) })
@ -364,11 +358,10 @@ func TestDeviceCallback(t *testing.T) {
} }
for _, tc := range tests { for _, tc := range tests {
t.Run(tc.testName, func(t *testing.T) { t.Run(tc.testName, func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background()) ctx := t.Context()
defer cancel()
// Setup a dex server. // Setup a dex server.
httpServer, s := newTestServer(ctx, t, func(c *Config) { httpServer, s := newTestServer(t, func(c *Config) {
// c.Issuer = c.Issuer + "/non-root-path" // c.Issuer = c.Issuer + "/non-root-path"
c.Now = now c.Now = now
}) })
@ -658,11 +651,10 @@ func TestDeviceTokenResponse(t *testing.T) {
} }
for _, tc := range tests { for _, tc := range tests {
t.Run(tc.testName, func(t *testing.T) { t.Run(tc.testName, func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background()) ctx := t.Context()
defer cancel()
// Setup a dex server. // Setup a dex server.
httpServer, s := newTestServer(ctx, t, func(c *Config) { httpServer, s := newTestServer(t, func(c *Config) {
c.Issuer += "/non-root-path" c.Issuer += "/non-root-path"
c.Now = now c.Now = now
}) })
@ -715,7 +707,7 @@ func TestDeviceTokenResponse(t *testing.T) {
} }
func expectJSONErrorResponse(testCase string, body []byte, expectedError string, t *testing.T) { func expectJSONErrorResponse(testCase string, body []byte, expectedError string, t *testing.T) {
jsonMap := make(map[string]interface{}) jsonMap := make(map[string]any)
err := json.Unmarshal(body, &jsonMap) err := json.Unmarshal(body, &jsonMap)
if err != nil { if err != nil {
t.Errorf("Unexpected error unmarshalling response: %v", err) t.Errorf("Unexpected error unmarshalling response: %v", err)
@ -792,11 +784,10 @@ func TestVerifyCodeResponse(t *testing.T) {
} }
for _, tc := range tests { for _, tc := range tests {
t.Run(tc.testName, func(t *testing.T) { t.Run(tc.testName, func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background()) ctx := t.Context()
defer cancel()
// Setup a dex server. // Setup a dex server.
httpServer, s := newTestServer(ctx, t, func(c *Config) { httpServer, s := newTestServer(t, func(c *Config) {
c.Issuer += "/non-root-path" c.Issuer += "/non-root-path"
c.Now = now c.Now = now
}) })

52
server/handlers_test.go

@ -24,10 +24,7 @@ import (
) )
func TestHandleHealth(t *testing.T) { func TestHandleHealth(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background()) httpServer, server := newTestServer(t, nil)
defer cancel()
httpServer, server := newTestServer(ctx, t, nil)
defer httpServer.Close() defer httpServer.Close()
rr := httptest.NewRecorder() rr := httptest.NewRecorder()
@ -38,10 +35,7 @@ func TestHandleHealth(t *testing.T) {
} }
func TestHandleDiscovery(t *testing.T) { func TestHandleDiscovery(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background()) httpServer, server := newTestServer(t, nil)
defer cancel()
httpServer, server := newTestServer(ctx, t, nil)
defer httpServer.Close() defer httpServer.Close()
rr := httptest.NewRecorder() rr := httptest.NewRecorder()
@ -108,10 +102,7 @@ func TestHandleDiscovery(t *testing.T) {
} }
func TestHandleHealthFailure(t *testing.T) { func TestHandleHealthFailure(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background()) httpServer, server := newTestServer(t, func(c *Config) {
defer cancel()
httpServer, server := newTestServer(ctx, t, func(c *Config) {
c.HealthChecker = gosundheit.New() c.HealthChecker = gosundheit.New()
c.HealthChecker.RegisterCheck( c.HealthChecker.RegisterCheck(
@ -143,10 +134,7 @@ func (*emptyStorage) GetAuthRequest(context.Context, string) (storage.AuthReques
} }
func TestHandleInvalidOAuth2Callbacks(t *testing.T) { func TestHandleInvalidOAuth2Callbacks(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background()) httpServer, server := newTestServer(t, func(c *Config) {
defer cancel()
httpServer, server := newTestServer(ctx, t, func(c *Config) {
c.Storage = &emptyStorage{c.Storage} c.Storage = &emptyStorage{c.Storage}
}) })
defer httpServer.Close() defer httpServer.Close()
@ -171,10 +159,7 @@ func TestHandleInvalidOAuth2Callbacks(t *testing.T) {
} }
func TestHandleInvalidSAMLCallbacks(t *testing.T) { func TestHandleInvalidSAMLCallbacks(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background()) httpServer, server := newTestServer(t, func(c *Config) {
defer cancel()
httpServer, server := newTestServer(ctx, t, func(c *Config) {
c.Storage = &emptyStorage{c.Storage} c.Storage = &emptyStorage{c.Storage}
}) })
defer httpServer.Close() defer httpServer.Close()
@ -251,10 +236,9 @@ func TestHandleAuthCode(t *testing.T) {
for _, tc := range tests { for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background()) ctx := t.Context()
defer cancel()
httpServer, s := newTestServer(ctx, t, func(c *Config) { c.Issuer += "/non-root-path" }) httpServer, s := newTestServer(t, func(c *Config) { c.Issuer += "/non-root-path" })
defer httpServer.Close() defer httpServer.Close()
p, err := oidc.NewProvider(ctx, httpServer.URL) p, err := oidc.NewProvider(ctx, httpServer.URL)
@ -303,7 +287,7 @@ func TestHandleAuthCode(t *testing.T) {
} }
func mockConnectorDataTestStorage(t *testing.T, s storage.Storage) { func mockConnectorDataTestStorage(t *testing.T, s storage.Storage) {
ctx := context.Background() ctx := t.Context()
c := storage.Client{ c := storage.Client{
ID: "test", ID: "test",
Secret: "barfoo", Secret: "barfoo",
@ -339,8 +323,7 @@ func mockConnectorDataTestStorage(t *testing.T, s storage.Storage) {
} }
func TestHandlePassword(t *testing.T) { func TestHandlePassword(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background()) ctx := t.Context()
defer cancel()
tests := []struct { tests := []struct {
name string name string
@ -361,7 +344,7 @@ func TestHandlePassword(t *testing.T) {
for _, tc := range tests { for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
// Setup a dex server. // Setup a dex server.
httpServer, s := newTestServer(ctx, t, func(c *Config) { httpServer, s := newTestServer(t, func(c *Config) {
c.PasswordConnector = "test" c.PasswordConnector = "test"
c.Now = time.Now c.Now = time.Now
}) })
@ -420,8 +403,7 @@ func TestHandlePassword(t *testing.T) {
} }
func TestHandlePasswordLoginWithSkipApproval(t *testing.T) { func TestHandlePasswordLoginWithSkipApproval(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background()) ctx := t.Context()
defer cancel()
connID := "mockPw" connID := "mockPw"
authReqID := "test" authReqID := "test"
@ -525,7 +507,7 @@ func TestHandlePasswordLoginWithSkipApproval(t *testing.T) {
for _, tc := range tests { for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
httpServer, s := newTestServer(ctx, t, func(c *Config) { httpServer, s := newTestServer(t, func(c *Config) {
c.SkipApprovalScreen = tc.skipApproval c.SkipApprovalScreen = tc.skipApproval
c.Now = time.Now c.Now = time.Now
}) })
@ -574,8 +556,7 @@ func TestHandlePasswordLoginWithSkipApproval(t *testing.T) {
} }
func TestHandleConnectorCallbackWithSkipApproval(t *testing.T) { func TestHandleConnectorCallbackWithSkipApproval(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background()) ctx := t.Context()
defer cancel()
connID := "mock" connID := "mock"
authReqID := "test" authReqID := "test"
@ -679,7 +660,7 @@ func TestHandleConnectorCallbackWithSkipApproval(t *testing.T) {
for _, tc := range tests { for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
httpServer, s := newTestServer(ctx, t, func(c *Config) { httpServer, s := newTestServer(t, func(c *Config) {
c.SkipApprovalScreen = tc.skipApproval c.SkipApprovalScreen = tc.skipApproval
c.Now = time.Now c.Now = time.Now
}) })
@ -780,9 +761,8 @@ func TestHandleTokenExchange(t *testing.T) {
} }
for _, tc := range tests { for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background()) ctx := t.Context()
defer cancel() httpServer, s := newTestServer(t, func(c *Config) {
httpServer, s := newTestServer(ctx, t, func(c *Config) {
c.Storage.CreateClient(ctx, storage.Client{ c.Storage.CreateClient(ctx, storage.Client{
ID: "client_1", ID: "client_1",
Secret: "secret_1", Secret: "secret_1",

23
server/introspectionhandler_test.go

@ -2,7 +2,6 @@ package server
import ( import (
"bytes" "bytes"
"context"
"encoding/json" "encoding/json"
"io" "io"
"net/http" "net/http"
@ -29,7 +28,7 @@ func toJSON(a interface{}) string {
} }
func mockTestStorage(t *testing.T, s storage.Storage) { func mockTestStorage(t *testing.T, s storage.Storage) {
ctx := context.Background() ctx := t.Context()
c := storage.Client{ c := storage.Client{
ID: "test", ID: "test",
Secret: "barfoo", Secret: "barfoo",
@ -139,11 +138,8 @@ func TestGetTokenFromRequestSuccess(t *testing.T) {
t0 := time.Now() t0 := time.Now()
now := func() time.Time { return t0 } now := func() time.Time { return t0 }
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// Setup a dex server. // Setup a dex server.
httpServer, s := newTestServer(ctx, t, func(c *Config) { httpServer, s := newTestServer(t, func(c *Config) {
c.Issuer += "/non-root-path" c.Issuer += "/non-root-path"
c.Now = now c.Now = now
}) })
@ -201,11 +197,9 @@ func TestGetTokenFromRequestFailure(t *testing.T) {
t0 := time.Now() t0 := time.Now()
now := func() time.Time { return t0 } now := func() time.Time { return t0 }
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// Setup a dex server. // Setup a dex server.
httpServer, s := newTestServer(ctx, t, func(c *Config) { httpServer, s := newTestServer(t, func(c *Config) {
c.Issuer += "/non-root-path" c.Issuer += "/non-root-path"
c.Now = now c.Now = now
}) })
@ -238,19 +232,20 @@ func TestGetTokenFromRequestFailure(t *testing.T) {
func TestHandleIntrospect(t *testing.T) { func TestHandleIntrospect(t *testing.T) {
t0 := time.Now() t0 := time.Now()
ctx, cancel := context.WithCancel(context.Background()) ctx := t.Context()
defer cancel()
// Setup a dex server. // Setup a dex server.
now := func() time.Time { return t0 } now := func() time.Time { return t0 }
logger := newLogger(t)
refreshTokenPolicy, err := NewRefreshTokenPolicy(logger, false, "", "24h", "") refreshTokenPolicy, err := NewRefreshTokenPolicy(logger, false, "", "24h", "")
if err != nil { if err != nil {
t.Fatalf("failed to prepare rotation policy: %v", err) t.Fatalf("failed to prepare rotation policy: %v", err)
} }
refreshTokenPolicy.now = now refreshTokenPolicy.now = now
httpServer, s := newTestServer(ctx, t, func(c *Config) { httpServer, s := newTestServer(t, func(c *Config) {
c.Issuer += "/non-root-path" c.Issuer += "/non-root-path"
c.RefreshTokenPolicy = refreshTokenPolicy c.RefreshTokenPolicy = refreshTokenPolicy
c.Now = now c.Now = now
@ -361,11 +356,9 @@ func TestIntrospectErrHelper(t *testing.T) {
t0 := time.Now() t0 := time.Now()
now := func() time.Time { return t0 } now := func() time.Time { return t0 }
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// Setup a dex server. // Setup a dex server.
httpServer, s := newTestServer(ctx, t, func(c *Config) { httpServer, s := newTestServer(t, func(c *Config) {
c.Issuer += "/non-root-path" c.Issuer += "/non-root-path"
c.Now = now c.Now = now
}) })

11
server/oauth2_test.go

@ -1,7 +1,6 @@
package server package server
import ( import (
"context"
"crypto/rand" "crypto/rand"
"crypto/rsa" "crypto/rsa"
"net/http" "net/http"
@ -323,10 +322,7 @@ func TestParseAuthorizationRequest(t *testing.T) {
for _, tc := range tests { for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background()) httpServer, server := newTestServerMultipleConnectors(t, func(c *Config) {
defer cancel()
httpServer, server := newTestServerMultipleConnectors(ctx, t, func(c *Config) {
c.SupportedResponseTypes = tc.supportedResponseTypes c.SupportedResponseTypes = tc.supportedResponseTypes
c.Storage = storage.WithStaticClients(c.Storage, tc.clients) c.Storage = storage.WithStaticClients(c.Storage, tc.clients)
}) })
@ -598,8 +594,9 @@ func TestValidRedirectURI(t *testing.T) {
} }
func TestStorageKeySet(t *testing.T) { func TestStorageKeySet(t *testing.T) {
logger := newLogger(t)
s := memory.New(logger) s := memory.New(logger)
if err := s.UpdateKeys(context.TODO(), func(keys storage.Keys) (storage.Keys, error) { if err := s.UpdateKeys(t.Context(), func(keys storage.Keys) (storage.Keys, error) {
keys.SigningKey = &jose.JSONWebKey{ keys.SigningKey = &jose.JSONWebKey{
Key: testKey, Key: testKey,
KeyID: "testkey", KeyID: "testkey",
@ -673,7 +670,7 @@ func TestStorageKeySet(t *testing.T) {
keySet := &storageKeySet{s} keySet := &storageKeySet{s}
_, err = keySet.VerifySignature(context.Background(), jwt) _, err = keySet.VerifySignature(t.Context(), jwt)
if (err != nil && !tc.wantErr) || (err == nil && tc.wantErr) { if (err != nil && !tc.wantErr) || (err == nil && tc.wantErr) {
t.Fatalf("wantErr = %v, but got err = %v", tc.wantErr, err) t.Fatalf("wantErr = %v, but got err = %v", tc.wantErr, err)
} }

8
server/refreshhandlers_test.go

@ -2,7 +2,6 @@ package server
import ( import (
"bytes" "bytes"
"context"
"encoding/json" "encoding/json"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
@ -18,7 +17,7 @@ import (
) )
func mockRefreshTokenTestStorage(t *testing.T, s storage.Storage, useObsolete bool) { func mockRefreshTokenTestStorage(t *testing.T, s storage.Storage, useObsolete bool) {
ctx := context.Background() ctx := t.Context()
c := storage.Client{ c := storage.Client{
ID: "test", ID: "test",
Secret: "barfoo", Secret: "barfoo",
@ -153,11 +152,8 @@ func TestRefreshTokenExpirationScenarios(t *testing.T) {
for _, tc := range tests { for _, tc := range tests {
t.Run(tc.name, func(*testing.T) { t.Run(tc.name, func(*testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// Setup a dex server. // Setup a dex server.
httpServer, s := newTestServer(ctx, t, func(c *Config) { httpServer, s := newTestServer(t, func(c *Config) {
c.RefreshTokenPolicy = tc.policy c.RefreshTokenPolicy = tc.policy
c.Now = func() time.Time { return t0 } c.Now = func() time.Time { return t0 }
}) })

79
server/server_test.go

@ -9,7 +9,6 @@ import (
"errors" "errors"
"fmt" "fmt"
"io" "io"
"log/slog"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"net/http/httputil" "net/http/httputil"
@ -76,14 +75,15 @@ FDWV28nTP9sqbtsmU8Tem2jzMvZ7C/Q0AuDoKELFUpux8shm8wfIhyaPnXUGZoAZ
Np4vUwMSYV5mopESLWOg3loBxKyLGFtgGKVCjGiQvy6zISQ4fQo= Np4vUwMSYV5mopESLWOg3loBxKyLGFtgGKVCjGiQvy6zISQ4fQo=
-----END RSA PRIVATE KEY-----`) -----END RSA PRIVATE KEY-----`)
var logger = slog.New(slog.DiscardHandler) func newTestServer(t *testing.T, updateConfig func(c *Config)) (*httptest.Server, *Server) {
func newTestServer(ctx context.Context, t *testing.T, updateConfig func(c *Config)) (*httptest.Server, *Server) {
var server *Server var server *Server
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
server.ServeHTTP(w, r) server.ServeHTTP(w, r)
})) }))
logger := newLogger(t)
ctx := t.Context()
config := Config{ config := Config{
Issuer: s.URL, Issuer: s.URL,
Storage: memory.New(logger), Storage: memory.New(logger),
@ -135,12 +135,15 @@ func newTestServer(ctx context.Context, t *testing.T, updateConfig func(c *Confi
return s, server return s, server
} }
func newTestServerMultipleConnectors(ctx context.Context, t *testing.T, updateConfig func(c *Config)) (*httptest.Server, *Server) { func newTestServerMultipleConnectors(t *testing.T, updateConfig func(c *Config)) (*httptest.Server, *Server) {
var server *Server var server *Server
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
server.ServeHTTP(w, r) server.ServeHTTP(w, r)
})) }))
logger := newLogger(t)
ctx := t.Context()
config := Config{ config := Config{
Issuer: s.URL, Issuer: s.URL,
Storage: memory.New(logger), Storage: memory.New(logger),
@ -183,21 +186,16 @@ func newTestServerMultipleConnectors(ctx context.Context, t *testing.T, updateCo
} }
func TestNewTestServer(t *testing.T) { func TestNewTestServer(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background()) newTestServer(t, nil)
defer cancel()
newTestServer(ctx, t, nil)
} }
func TestDiscovery(t *testing.T) { func TestDiscovery(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background()) httpServer, _ := newTestServer(t, func(c *Config) {
defer cancel()
httpServer, _ := newTestServer(ctx, t, func(c *Config) {
c.Issuer += "/non-root-path" c.Issuer += "/non-root-path"
}) })
defer httpServer.Close() defer httpServer.Close()
p, err := oidc.NewProvider(ctx, httpServer.URL) p, err := oidc.NewProvider(t.Context(), httpServer.URL)
if err != nil { if err != nil {
t.Fatalf("failed to get provider: %v", err) t.Fatalf("failed to get provider: %v", err)
} }
@ -734,11 +732,10 @@ func TestOAuth2CodeFlow(t *testing.T) {
tests := makeOAuth2Tests(clientID, clientSecret, now) tests := makeOAuth2Tests(clientID, clientSecret, now)
for _, tc := range tests.tests { for _, tc := range tests.tests {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background()) ctx := t.Context()
defer cancel()
// Setup a dex server. // Setup a dex server.
httpServer, s := newTestServer(ctx, t, func(c *Config) { httpServer, s := newTestServer(t, func(c *Config) {
c.Issuer += "/non-root-path" c.Issuer += "/non-root-path"
c.Now = now c.Now = now
c.IDTokensValidFor = idTokensValidFor c.IDTokensValidFor = idTokensValidFor
@ -890,10 +887,9 @@ func TestOAuth2CodeFlow(t *testing.T) {
} }
func TestOAuth2ImplicitFlow(t *testing.T) { func TestOAuth2ImplicitFlow(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background()) ctx := t.Context()
defer cancel()
httpServer, s := newTestServer(ctx, t, func(c *Config) { httpServer, s := newTestServer(t, func(c *Config) {
// Enable support for the implicit flow. // Enable support for the implicit flow.
c.SupportedResponseTypes = []string{"code", "token", "id_token"} c.SupportedResponseTypes = []string{"code", "token", "id_token"}
}) })
@ -1026,10 +1022,9 @@ func TestOAuth2ImplicitFlow(t *testing.T) {
} }
func TestCrossClientScopes(t *testing.T) { func TestCrossClientScopes(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background()) ctx := t.Context()
defer cancel()
httpServer, s := newTestServer(ctx, t, func(c *Config) { httpServer, s := newTestServer(t, func(c *Config) {
c.Issuer += "/non-root-path" c.Issuer += "/non-root-path"
}) })
defer httpServer.Close() defer httpServer.Close()
@ -1149,10 +1144,9 @@ func TestCrossClientScopes(t *testing.T) {
} }
func TestCrossClientScopesWithAzpInAudienceByDefault(t *testing.T) { func TestCrossClientScopesWithAzpInAudienceByDefault(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background()) ctx := t.Context()
defer cancel()
httpServer, s := newTestServer(ctx, t, func(c *Config) { httpServer, s := newTestServer(t, func(c *Config) {
c.Issuer += "/non-root-path" c.Issuer += "/non-root-path"
}) })
defer httpServer.Close() defer httpServer.Close()
@ -1271,7 +1265,9 @@ func TestCrossClientScopesWithAzpInAudienceByDefault(t *testing.T) {
} }
func TestPasswordDB(t *testing.T) { func TestPasswordDB(t *testing.T) {
ctx := context.Background() ctx := t.Context()
logger := newLogger(t)
s := memory.New(logger) s := memory.New(logger)
conn := newPasswordDB(s) conn := newPasswordDB(s)
@ -1323,7 +1319,7 @@ func TestPasswordDB(t *testing.T) {
} }
for _, tc := range tests { for _, tc := range tests {
ident, valid, err := conn.Login(context.Background(), connector.Scopes{}, tc.username, tc.password) ident, valid, err := conn.Login(t.Context(), connector.Scopes{}, tc.username, tc.password)
if err != nil { if err != nil {
if !tc.wantErr { if !tc.wantErr {
t.Errorf("%s: %v", tc.name, err) t.Errorf("%s: %v", tc.name, err)
@ -1355,6 +1351,7 @@ func TestPasswordDB(t *testing.T) {
} }
func TestPasswordDBUsernamePrompt(t *testing.T) { func TestPasswordDBUsernamePrompt(t *testing.T) {
logger := newLogger(t)
s := memory.New(logger) s := memory.New(logger)
conn := newPasswordDB(s) conn := newPasswordDB(s)
@ -1377,7 +1374,8 @@ func (s storageWithKeysTrigger) GetKeys(ctx context.Context) (storage.Keys, erro
func TestKeyCacher(t *testing.T) { func TestKeyCacher(t *testing.T) {
tNow := time.Now() tNow := time.Now()
now := func() time.Time { return tNow } now := func() time.Time { return tNow }
ctx := context.TODO() ctx := t.Context()
logger := newLogger(t)
s := memory.New(logger) s := memory.New(logger)
tests := []struct { tests := []struct {
@ -1428,7 +1426,7 @@ func TestKeyCacher(t *testing.T) {
for i, tc := range tests { for i, tc := range tests {
gotCall = false gotCall = false
tc.before() tc.before()
s.GetKeys(context.TODO()) s.GetKeys(t.Context())
if gotCall != tc.wantCallToStorage { if gotCall != tc.wantCallToStorage {
t.Errorf("case %d: expected call to storage=%t got call to storage=%t", i, tc.wantCallToStorage, gotCall) t.Errorf("case %d: expected call to storage=%t got call to storage=%t", i, tc.wantCallToStorage, gotCall)
} }
@ -1470,10 +1468,10 @@ type oauth2Client struct {
func TestRefreshTokenFlow(t *testing.T) { func TestRefreshTokenFlow(t *testing.T) {
state := "state" state := "state"
now := time.Now now := time.Now
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
httpServer, s := newTestServer(ctx, t, func(c *Config) { ctx := t.Context()
httpServer, s := newTestServer(t, func(c *Config) {
c.Now = now c.Now = now
}) })
defer httpServer.Close() defer httpServer.Close()
@ -1604,11 +1602,10 @@ func TestOAuth2DeviceFlow(t *testing.T) {
for _, testCase := range testCases { for _, testCase := range testCases {
for _, tc := range testCase.oauth2Tests.tests { for _, tc := range testCase.oauth2Tests.tests {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background()) ctx := t.Context()
defer cancel()
// Setup a dex server. // Setup a dex server.
httpServer, s := newTestServer(ctx, t, func(c *Config) { httpServer, s := newTestServer(t, func(c *Config) {
c.Issuer += "/non-root-path" c.Issuer += "/non-root-path"
c.Now = now c.Now = now
c.IDTokensValidFor = idTokensValidFor c.IDTokensValidFor = idTokensValidFor
@ -1789,17 +1786,16 @@ func TestServerSupportedGrants(t *testing.T) {
for _, tc := range tests { for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
_, srv := newTestServer(context.TODO(), t, tc.config) _, srv := newTestServer(t, tc.config)
require.Equal(t, tc.resGrants, srv.supportedGrantTypes) require.Equal(t, tc.resGrants, srv.supportedGrantTypes)
}) })
} }
} }
func TestHeaders(t *testing.T) { func TestHeaders(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background()) ctx := t.Context()
defer cancel()
httpServer, _ := newTestServer(ctx, t, func(c *Config) { httpServer, _ := newTestServer(t, func(c *Config) {
c.Headers = map[string][]string{ c.Headers = map[string][]string{
"Strict-Transport-Security": {"max-age=31536000; includeSubDomains"}, "Strict-Transport-Security": {"max-age=31536000; includeSubDomains"},
} }
@ -1818,8 +1814,7 @@ func TestHeaders(t *testing.T) {
} }
func TestConnectorFailureHandling(t *testing.T) { func TestConnectorFailureHandling(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background()) ctx := t.Context()
defer cancel()
tests := []struct { tests := []struct {
name string name string
@ -1959,6 +1954,8 @@ func TestConnectorFailureHandling(t *testing.T) {
for _, tc := range tests { for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
logger := newLogger(t)
config := Config{ config := Config{
Issuer: "http://localhost", Issuer: "http://localhost",
Storage: memory.New(logger), Storage: memory.New(logger),

28
storage/conformance/conformance.go

@ -24,10 +24,10 @@ type subTest struct {
run func(t *testing.T, s storage.Storage) run func(t *testing.T, s storage.Storage)
} }
func runTests(t *testing.T, newStorage func() storage.Storage, tests []subTest) { func runTests(t *testing.T, newStorage func(t *testing.T) storage.Storage, tests []subTest) {
for _, test := range tests { for _, test := range tests {
t.Run(test.name, func(t *testing.T) { t.Run(test.name, func(t *testing.T) {
s := newStorage() s := newStorage(t)
test.run(t, s) test.run(t, s)
s.Close() s.Close()
}) })
@ -37,7 +37,7 @@ func runTests(t *testing.T, newStorage func() storage.Storage, tests []subTest)
// RunTests runs a set of conformance tests against a storage. newStorage should // RunTests runs a set of conformance tests against a storage. newStorage should
// return an initialized but empty storage. The storage will be closed at the // return an initialized but empty storage. The storage will be closed at the
// end of each test run. // end of each test run.
func RunTests(t *testing.T, newStorage func() storage.Storage) { func RunTests(t *testing.T, newStorage func(t *testing.T) storage.Storage) {
runTests(t, newStorage, []subTest{ runTests(t, newStorage, []subTest{
{"AuthCodeCRUD", testAuthCodeCRUD}, {"AuthCodeCRUD", testAuthCodeCRUD},
{"AuthRequestCRUD", testAuthRequestCRUD}, {"AuthRequestCRUD", testAuthRequestCRUD},
@ -81,7 +81,7 @@ func mustBeErrAlreadyExists(t *testing.T, kind string, err error) {
} }
func testAuthRequestCRUD(t *testing.T, s storage.Storage) { func testAuthRequestCRUD(t *testing.T, s storage.Storage) {
ctx := context.Background() ctx := t.Context()
codeChallenge := storage.PKCE{ codeChallenge := storage.PKCE{
CodeChallenge: "code_challenge_test", CodeChallenge: "code_challenge_test",
CodeChallengeMethod: "plain", CodeChallengeMethod: "plain",
@ -181,7 +181,7 @@ func testAuthRequestCRUD(t *testing.T, s storage.Storage) {
} }
func testAuthCodeCRUD(t *testing.T, s storage.Storage) { func testAuthCodeCRUD(t *testing.T, s storage.Storage) {
ctx := context.Background() ctx := t.Context()
a1 := storage.AuthCode{ a1 := storage.AuthCode{
ID: storage.NewID(), ID: storage.NewID(),
ClientID: "client1", ClientID: "client1",
@ -259,7 +259,7 @@ func testAuthCodeCRUD(t *testing.T, s storage.Storage) {
} }
func testClientCRUD(t *testing.T, s storage.Storage) { func testClientCRUD(t *testing.T, s storage.Storage) {
ctx := context.Background() ctx := t.Context()
id1 := storage.NewID() id1 := storage.NewID()
c1 := storage.Client{ c1 := storage.Client{
ID: id1, ID: id1,
@ -329,7 +329,7 @@ func testClientCRUD(t *testing.T, s storage.Storage) {
} }
func testRefreshTokenCRUD(t *testing.T, s storage.Storage) { func testRefreshTokenCRUD(t *testing.T, s storage.Storage) {
ctx := context.Background() ctx := t.Context()
id := storage.NewID() id := storage.NewID()
refresh := storage.RefreshToken{ refresh := storage.RefreshToken{
ID: id, ID: id,
@ -448,7 +448,7 @@ func (n byEmail) Less(i, j int) bool { return n[i].Email < n[j].Email }
func (n byEmail) Swap(i, j int) { n[i], n[j] = n[j], n[i] } func (n byEmail) Swap(i, j int) { n[i], n[j] = n[j], n[i] }
func testPasswordCRUD(t *testing.T, s storage.Storage) { func testPasswordCRUD(t *testing.T, s storage.Storage) {
ctx := context.Background() ctx := t.Context()
// Use bcrypt.MinCost to keep the tests short. // Use bcrypt.MinCost to keep the tests short.
passwordHash1, err := bcrypt.GenerateFromPassword([]byte("secret"), bcrypt.MinCost) passwordHash1, err := bcrypt.GenerateFromPassword([]byte("secret"), bcrypt.MinCost)
if err != nil { if err != nil {
@ -539,7 +539,7 @@ func testPasswordCRUD(t *testing.T, s storage.Storage) {
} }
func testOfflineSessionCRUD(t *testing.T, s storage.Storage) { func testOfflineSessionCRUD(t *testing.T, s storage.Storage) {
ctx := context.Background() ctx := t.Context()
userID1 := storage.NewID() userID1 := storage.NewID()
session1 := storage.OfflineSessions{ session1 := storage.OfflineSessions{
UserID: userID1, UserID: userID1,
@ -614,7 +614,7 @@ func testOfflineSessionCRUD(t *testing.T, s storage.Storage) {
} }
func testConnectorCRUD(t *testing.T, s storage.Storage) { func testConnectorCRUD(t *testing.T, s storage.Storage) {
ctx := context.Background() ctx := t.Context()
id1 := storage.NewID() id1 := storage.NewID()
config1 := []byte(`{"issuer": "https://accounts.google.com"}`) config1 := []byte(`{"issuer": "https://accounts.google.com"}`)
c1 := storage.Connector{ c1 := storage.Connector{
@ -754,7 +754,7 @@ func testKeysCRUD(t *testing.T, s storage.Storage) {
} }
func testGC(t *testing.T, s storage.Storage) { func testGC(t *testing.T, s storage.Storage) {
ctx := context.Background() ctx := t.Context()
est, err := time.LoadLocation("America/New_York") est, err := time.LoadLocation("America/New_York")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@ -942,7 +942,7 @@ func testGC(t *testing.T, s storage.Storage) {
// testTimezones tests that backends either fully support timezones or // testTimezones tests that backends either fully support timezones or
// do the correct standardization. // do the correct standardization.
func testTimezones(t *testing.T, s storage.Storage) { func testTimezones(t *testing.T, s storage.Storage) {
ctx := context.Background() ctx := t.Context()
est, err := time.LoadLocation("America/New_York") est, err := time.LoadLocation("America/New_York")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@ -987,7 +987,7 @@ func testTimezones(t *testing.T, s storage.Storage) {
} }
func testDeviceRequestCRUD(t *testing.T, s storage.Storage) { func testDeviceRequestCRUD(t *testing.T, s storage.Storage) {
ctx := context.Background() ctx := t.Context()
d1 := storage.DeviceRequest{ d1 := storage.DeviceRequest{
UserCode: storage.NewUserCode(), UserCode: storage.NewUserCode(),
DeviceCode: storage.NewID(), DeviceCode: storage.NewID(),
@ -1017,7 +1017,7 @@ func testDeviceRequestCRUD(t *testing.T, s storage.Storage) {
} }
func testDeviceTokenCRUD(t *testing.T, s storage.Storage) { func testDeviceTokenCRUD(t *testing.T, s storage.Storage) {
ctx := context.Background() ctx := t.Context()
codeChallenge := storage.PKCE{ codeChallenge := storage.PKCE{
CodeChallenge: "code_challenge_test", CodeChallenge: "code_challenge_test",
CodeChallengeMethod: "plain", CodeChallengeMethod: "plain",

8
storage/conformance/transactions.go

@ -17,7 +17,7 @@ import (
// This call is separate from RunTests because some storage perform extremely // This call is separate from RunTests because some storage perform extremely
// poorly under deadlocks, such as SQLite3, while others may be working towards // poorly under deadlocks, such as SQLite3, while others may be working towards
// conformance. // conformance.
func RunTransactionTests(t *testing.T, newStorage func() storage.Storage) { func RunTransactionTests(t *testing.T, newStorage func(t *testing.T) storage.Storage) {
runTests(t, newStorage, []subTest{ runTests(t, newStorage, []subTest{
{"AuthRequestConcurrentUpdate", testAuthRequestConcurrentUpdate}, {"AuthRequestConcurrentUpdate", testAuthRequestConcurrentUpdate},
{"ClientConcurrentUpdate", testClientConcurrentUpdate}, {"ClientConcurrentUpdate", testClientConcurrentUpdate},
@ -27,7 +27,7 @@ func RunTransactionTests(t *testing.T, newStorage func() storage.Storage) {
} }
func testClientConcurrentUpdate(t *testing.T, s storage.Storage) { func testClientConcurrentUpdate(t *testing.T, s storage.Storage) {
ctx := context.Background() ctx := t.Context()
c := storage.Client{ c := storage.Client{
ID: storage.NewID(), ID: storage.NewID(),
Secret: "foobar", Secret: "foobar",
@ -57,7 +57,7 @@ func testClientConcurrentUpdate(t *testing.T, s storage.Storage) {
} }
func testAuthRequestConcurrentUpdate(t *testing.T, s storage.Storage) { func testAuthRequestConcurrentUpdate(t *testing.T, s storage.Storage) {
ctx := context.Background() ctx := t.Context()
a := storage.AuthRequest{ a := storage.AuthRequest{
ID: storage.NewID(), ID: storage.NewID(),
ClientID: "foobar", ClientID: "foobar",
@ -102,7 +102,7 @@ func testAuthRequestConcurrentUpdate(t *testing.T, s storage.Storage) {
} }
func testPasswordConcurrentUpdate(t *testing.T, s storage.Storage) { func testPasswordConcurrentUpdate(t *testing.T, s storage.Storage) {
ctx := context.Background() ctx := t.Context()
// Use bcrypt.MinCost to keep the tests short. // Use bcrypt.MinCost to keep the tests short.
passwordHash, err := bcrypt.GenerateFromPassword([]byte("secret"), bcrypt.MinCost) passwordHash, err := bcrypt.GenerateFromPassword([]byte("secret"), bcrypt.MinCost)
if err != nil { if err != nil {

8
storage/ent/mysql_test.go

@ -40,8 +40,8 @@ func mysqlTestConfig(host string, port uint64) *MySQL {
} }
} }
func newMySQLStorage(host string, port uint64) storage.Storage { func newMySQLStorage(t *testing.T, host string, port uint64) storage.Storage {
logger := slog.New(slog.DiscardHandler) logger := slog.New(slog.NewTextHandler(t.Output(), &slog.HandlerOptions{Level: slog.LevelDebug}))
cfg := mysqlTestConfig(host, port) cfg := mysqlTestConfig(host, port)
s, err := cfg.Open(logger) s, err := cfg.Open(logger)
@ -65,8 +65,8 @@ func TestMySQL(t *testing.T) {
require.NoError(t, err, "invalid mysql port %q: %s", rawPort, err) require.NoError(t, err, "invalid mysql port %q: %s", rawPort, err)
} }
newStorage := func() storage.Storage { newStorage := func(t *testing.T) storage.Storage {
return newMySQLStorage(host, port) return newMySQLStorage(t, host, port)
} }
conformance.RunTests(t, newStorage) conformance.RunTests(t, newStorage)
conformance.RunTransactionTests(t, newStorage) conformance.RunTransactionTests(t, newStorage)

8
storage/ent/postgres_test.go

@ -35,8 +35,8 @@ func postgresTestConfig(host string, port uint64) *Postgres {
} }
} }
func newPostgresStorage(host string, port uint64) storage.Storage { func newPostgresStorage(t *testing.T, host string, port uint64) storage.Storage {
logger := slog.New(slog.DiscardHandler) logger := slog.New(slog.NewTextHandler(t.Output(), &slog.HandlerOptions{Level: slog.LevelDebug}))
cfg := postgresTestConfig(host, port) cfg := postgresTestConfig(host, port)
s, err := cfg.Open(logger) s, err := cfg.Open(logger)
@ -60,8 +60,8 @@ func TestPostgres(t *testing.T) {
require.NoError(t, err, "invalid postgres port %q: %s", rawPort, err) require.NoError(t, err, "invalid postgres port %q: %s", rawPort, err)
} }
newStorage := func() storage.Storage { newStorage := func(t *testing.T) storage.Storage {
return newPostgresStorage(host, port) return newPostgresStorage(t, host, port)
} }
conformance.RunTests(t, newStorage) conformance.RunTests(t, newStorage)
conformance.RunTransactionTests(t, newStorage) conformance.RunTransactionTests(t, newStorage)

4
storage/ent/sqlite_test.go

@ -8,8 +8,8 @@ import (
"github.com/dexidp/dex/storage/conformance" "github.com/dexidp/dex/storage/conformance"
) )
func newSQLiteStorage() storage.Storage { func newSQLiteStorage(t *testing.T) storage.Storage {
logger := slog.New(slog.DiscardHandler) logger := slog.New(slog.NewTextHandler(t.Output(), &slog.HandlerOptions{Level: slog.LevelDebug}))
cfg := SQLite3{File: ":memory:"} cfg := SQLite3{File: ":memory:"}
s, err := cfg.Open(logger) s, err := cfg.Open(logger)

5
storage/etcd/etcd_test.go

@ -55,8 +55,6 @@ func cleanDB(c *conn) error {
return nil return nil
} }
var logger = slog.New(slog.DiscardHandler)
func TestEtcd(t *testing.T) { func TestEtcd(t *testing.T) {
testEtcdEnv := "DEX_ETCD_ENDPOINTS" testEtcdEnv := "DEX_ETCD_ENDPOINTS"
endpointsStr := os.Getenv(testEtcdEnv) endpointsStr := os.Getenv(testEtcdEnv)
@ -66,10 +64,11 @@ func TestEtcd(t *testing.T) {
} }
endpoints := strings.Split(endpointsStr, ",") endpoints := strings.Split(endpointsStr, ",")
newStorage := func() storage.Storage { newStorage := func(t *testing.T) storage.Storage {
s := &Etcd{ s := &Etcd{
Endpoints: endpoints, Endpoints: endpoints,
} }
logger := slog.New(slog.NewTextHandler(t.Output(), &slog.HandlerOptions{Level: slog.LevelDebug}))
conn, err := s.open(logger) conn, err := s.open(logger)
if err != nil { if err != nil {
fmt.Fprintln(os.Stdout, err) fmt.Fprintln(os.Stdout, err)

4
storage/kubernetes/storage_test.go

@ -57,7 +57,7 @@ func (s *StorageTestSuite) SetupTest() {
KubeConfigFile: kubeconfigPath, KubeConfigFile: kubeconfigPath,
} }
logger := slog.New(slog.DiscardHandler) logger := slog.New(slog.NewTextHandler(s.T().Output(), &slog.HandlerOptions{Level: slog.LevelDebug}))
kubeClient, err := config.open(logger, true) kubeClient, err := config.open(logger, true)
s.Require().NoError(err) s.Require().NoError(err)
@ -66,7 +66,7 @@ func (s *StorageTestSuite) SetupTest() {
} }
func (s *StorageTestSuite) TestStorage() { func (s *StorageTestSuite) TestStorage() {
newStorage := func() storage.Storage { newStorage := func(t *testing.T) storage.Storage {
for _, resource := range []string{ for _, resource := range []string{
resourceAuthCode, resourceAuthCode,
resourceAuthRequest, resourceAuthRequest,

4
storage/memory/memory_test.go

@ -9,9 +9,9 @@ import (
) )
func TestStorage(t *testing.T) { func TestStorage(t *testing.T) {
logger := slog.New(slog.DiscardHandler) newStorage := func(t *testing.T) storage.Storage {
logger := slog.New(slog.NewTextHandler(t.Output(), &slog.HandlerOptions{Level: slog.LevelDebug}))
newStorage := func() storage.Storage {
return New(logger) return New(logger)
} }
conformance.RunTests(t, newStorage) conformance.RunTests(t, newStorage)

7
storage/sql/config_test.go

@ -46,20 +46,19 @@ func cleanDB(c *conn) error {
return nil return nil
} }
var logger = slog.New(slog.DiscardHandler)
type opener interface { type opener interface {
open(logger *slog.Logger) (*conn, error) open(logger *slog.Logger) (*conn, error)
} }
func testDB(t *testing.T, o opener, withTransactions bool) { func testDB(t *testing.T, o opener, withTransactions bool) {
// t.Fatal has a bad habit of not actually printing the error // t.Fatal has a bad habit of not actually printing the error
fatal := func(i interface{}) { fatal := func(i any) {
fmt.Fprintln(os.Stdout, i) fmt.Fprintln(os.Stdout, i)
t.Fatal(i) t.Fatal(i)
} }
newStorage := func() storage.Storage { newStorage := func(t *testing.T) storage.Storage {
logger := slog.New(slog.NewTextHandler(t.Output(), &slog.HandlerOptions{Level: slog.LevelDebug}))
conn, err := o.open(logger) conn, err := o.open(logger)
if err != nil { if err != nil {
fatal(err) fatal(err)

3
storage/sql/postgres_test.go

@ -4,6 +4,7 @@
package sql package sql
import ( import (
"log/slog"
"os" "os"
"strconv" "strconv"
"testing" "testing"
@ -40,6 +41,7 @@ func TestPostgresTunables(t *testing.T) {
t.Run("with nothing set, uses defaults", func(t *testing.T) { t.Run("with nothing set, uses defaults", func(t *testing.T) {
cfg := *baseCfg cfg := *baseCfg
logger := slog.New(slog.NewTextHandler(t.Output(), &slog.HandlerOptions{Level: slog.LevelDebug}))
c, err := cfg.open(logger) c, err := cfg.open(logger)
if err != nil { if err != nil {
t.Fatalf("error opening connector: %s", err.Error()) t.Fatalf("error opening connector: %s", err.Error())
@ -53,6 +55,7 @@ func TestPostgresTunables(t *testing.T) {
t.Run("with something set, uses that", func(t *testing.T) { t.Run("with something set, uses that", func(t *testing.T) {
cfg := *baseCfg cfg := *baseCfg
cfg.MaxOpenConns = 101 cfg.MaxOpenConns = 101
logger := slog.New(slog.NewTextHandler(t.Output(), &slog.HandlerOptions{Level: slog.LevelDebug}))
c, err := cfg.open(logger) c, err := cfg.open(logger)
if err != nil { if err != nil {
t.Fatalf("error opening connector: %s", err.Error()) t.Fatalf("error opening connector: %s", err.Error())

Loading…
Cancel
Save