diff --git a/server/api_test.go b/server/api_test.go index b61e15d3..5ddbcc4a 100644 --- a/server/api_test.go +++ b/server/api_test.go @@ -1,10 +1,9 @@ package server import ( - "context" "log/slog" "net" - "os" + "slices" "strings" "testing" "time" @@ -29,8 +28,12 @@ type apiClient struct { Close func() } +func newLogger(t *testing.T) *slog.Logger { + return slog.New(slog.NewTextHandler(t.Output(), &slog.HandlerOptions{Level: slog.LevelDebug})) +} + // newAPI constructs a gRCP client connected to a backing server. -func newAPI(s storage.Storage, logger *slog.Logger, t *testing.T) *apiClient { +func newAPI(t *testing.T, s storage.Storage, logger *slog.Logger) *apiClient { l, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatal(err) @@ -59,13 +62,14 @@ func newAPI(s storage.Storage, logger *slog.Logger, t *testing.T) *apiClient { // Attempts to create, update and delete a test Password func TestPassword(t *testing.T) { - logger := slog.New(slog.DiscardHandler) - + logger := newLogger(t) s := memory.New(logger) - client := newAPI(s, logger, t) + + client := newAPI(t, s, logger) defer client.Close() - ctx := context.Background() + ctx := t.Context() + email := "test@example.com" p := api.Password{ Email: email, @@ -168,10 +172,10 @@ func TestPassword(t *testing.T) { // Ensures checkCost returns expected values func TestCheckCost(t *testing.T) { - logger := slog.New(slog.DiscardHandler) - + logger := newLogger(t) s := memory.New(logger) - client := newAPI(s, logger, t) + + client := newAPI(t, s, logger) defer client.Close() tests := []struct { @@ -221,13 +225,13 @@ func TestCheckCost(t *testing.T) { // Attempts to list and revoke an existing refresh token. func TestRefreshToken(t *testing.T) { - logger := slog.New(slog.DiscardHandler) - + logger := newLogger(t) s := memory.New(logger) - client := newAPI(s, logger, t) + + client := newAPI(t, s, logger) defer client.Close() - ctx := context.Background() + ctx := t.Context() // Creating a storage with an existing refresh token and offline session for the user. id := storage.NewID() @@ -330,12 +334,13 @@ func TestRefreshToken(t *testing.T) { } func TestUpdateClient(t *testing.T) { - logger := slog.New(slog.DiscardHandler) - + logger := newLogger(t) s := memory.New(logger) - client := newAPI(s, logger, t) + + client := newAPI(t, s, logger) defer client.Close() - ctx := context.Background() + + ctx := t.Context() createClient := func(t *testing.T, clientId string) { resp, err := client.CreateClient(ctx, &api.CreateClientReq{ @@ -463,13 +468,13 @@ func TestUpdateClient(t *testing.T) { t.Errorf("expected stored client with LogoURL: %s, found %s", tc.req.LogoUrl, client.LogoURL) } for _, redirectURI := range tc.req.RedirectUris { - found := find(redirectURI, client.RedirectURIs) + found := slices.Contains(client.RedirectURIs, redirectURI) if !found { t.Errorf("expected redirect URI: %s", redirectURI) } } for _, peer := range tc.req.TrustedPeers { - found := find(peer, client.TrustedPeers) + found := slices.Contains(client.TrustedPeers, peer) if !found { t.Errorf("expected trusted peer: %s", peer) } @@ -483,26 +488,17 @@ func TestUpdateClient(t *testing.T) { } } -func find(item string, items []string) bool { - for _, i := range items { - if item == i { - return true - } - } - return false -} - func TestCreateConnector(t *testing.T) { - os.Setenv("DEX_API_CONNECTORS_CRUD", "true") - defer os.Unsetenv("DEX_API_CONNECTORS_CRUD") - - logger := slog.New(slog.DiscardHandler) + t.Setenv("DEX_API_CONNECTORS_CRUD", "true") + logger := newLogger(t) s := memory.New(logger) - client := newAPI(s, logger, t) + + client := newAPI(t, s, logger) defer client.Close() - ctx := context.Background() + ctx := t.Context() + connectorID := "connector123" connectorName := "TestConnector" connectorType := "TestType" @@ -543,16 +539,16 @@ func TestCreateConnector(t *testing.T) { } func TestUpdateConnector(t *testing.T) { - os.Setenv("DEX_API_CONNECTORS_CRUD", "true") - defer os.Unsetenv("DEX_API_CONNECTORS_CRUD") - - logger := slog.New(slog.DiscardHandler) + t.Setenv("DEX_API_CONNECTORS_CRUD", "true") + logger := newLogger(t) s := memory.New(logger) - client := newAPI(s, logger, t) + + client := newAPI(t, s, logger) defer client.Close() - ctx := context.Background() + ctx := t.Context() + connectorID := "connector123" newConnectorName := "UpdatedConnector" newConnectorType := "UpdatedType" @@ -611,16 +607,16 @@ func TestUpdateConnector(t *testing.T) { } func TestDeleteConnector(t *testing.T) { - os.Setenv("DEX_API_CONNECTORS_CRUD", "true") - defer os.Unsetenv("DEX_API_CONNECTORS_CRUD") - - logger := slog.New(slog.DiscardHandler) + t.Setenv("DEX_API_CONNECTORS_CRUD", "true") + logger := newLogger(t) s := memory.New(logger) - client := newAPI(s, logger, t) + + client := newAPI(t, s, logger) defer client.Close() - ctx := context.Background() + ctx := t.Context() + connectorID := "connector123" // Create a connector for testing @@ -655,16 +651,15 @@ func TestDeleteConnector(t *testing.T) { } func TestListConnectors(t *testing.T) { - os.Setenv("DEX_API_CONNECTORS_CRUD", "true") - defer os.Unsetenv("DEX_API_CONNECTORS_CRUD") - - logger := slog.New(slog.DiscardHandler) + t.Setenv("DEX_API_CONNECTORS_CRUD", "true") + logger := newLogger(t) s := memory.New(logger) - client := newAPI(s, logger, t) + + client := newAPI(t, s, logger) defer client.Close() - ctx := context.Background() + ctx := t.Context() // Create connectors for testing createReq1 := api.CreateConnectorReq{ @@ -698,13 +693,13 @@ func TestListConnectors(t *testing.T) { } func TestMissingConnectorsCRUDFeatureFlag(t *testing.T) { - logger := slog.New(slog.DiscardHandler) - + logger := newLogger(t) s := memory.New(logger) - client := newAPI(s, logger, t) + + client := newAPI(t, s, logger) defer client.Close() - ctx := context.Background() + ctx := t.Context() // Create connectors for testing createReq1 := api.CreateConnectorReq{ @@ -735,13 +730,13 @@ func TestMissingConnectorsCRUDFeatureFlag(t *testing.T) { } func TestListClients(t *testing.T) { - logger := slog.New(slog.DiscardHandler) - + logger := newLogger(t) s := memory.New(logger) - client := newAPI(s, logger, t) + + client := newAPI(t, s, logger) defer client.Close() - ctx := context.Background() + ctx := t.Context() // List Clients listResp, err := client.ListClients(ctx, &api.ListClientReq{}) diff --git a/server/deviceflowhandlers_test.go b/server/deviceflowhandlers_test.go index 2b4fbbfa..3f3ea81e 100644 --- a/server/deviceflowhandlers_test.go +++ b/server/deviceflowhandlers_test.go @@ -2,7 +2,6 @@ package server import ( "bytes" - "context" "encoding/json" "io" "net/http" @@ -20,10 +19,8 @@ func TestDeviceVerificationURI(t *testing.T) { t0 := time.Now() now := func() time.Time { return t0 } - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() // Setup a dex server. - httpServer, s := newTestServer(ctx, t, func(c *Config) { + httpServer, s := newTestServer(t, func(c *Config) { c.Issuer += "/non-root-path" c.Now = now }) @@ -101,11 +98,8 @@ func TestHandleDeviceCode(t *testing.T) { } for _, tc := range tests { t.Run(tc.testName, func(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - // Setup a dex server. - httpServer, s := newTestServer(ctx, t, func(c *Config) { + httpServer, s := newTestServer(t, func(c *Config) { c.Issuer += "/non-root-path" c.Now = now }) @@ -364,11 +358,10 @@ func TestDeviceCallback(t *testing.T) { } for _, tc := range tests { t.Run(tc.testName, func(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() + ctx := t.Context() // Setup a dex server. - httpServer, s := newTestServer(ctx, t, func(c *Config) { + httpServer, s := newTestServer(t, func(c *Config) { // c.Issuer = c.Issuer + "/non-root-path" c.Now = now }) @@ -658,11 +651,10 @@ func TestDeviceTokenResponse(t *testing.T) { } for _, tc := range tests { t.Run(tc.testName, func(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() + ctx := t.Context() // Setup a dex server. - httpServer, s := newTestServer(ctx, t, func(c *Config) { + httpServer, s := newTestServer(t, func(c *Config) { c.Issuer += "/non-root-path" c.Now = now }) @@ -715,7 +707,7 @@ func TestDeviceTokenResponse(t *testing.T) { } func expectJSONErrorResponse(testCase string, body []byte, expectedError string, t *testing.T) { - jsonMap := make(map[string]interface{}) + jsonMap := make(map[string]any) err := json.Unmarshal(body, &jsonMap) if err != nil { t.Errorf("Unexpected error unmarshalling response: %v", err) @@ -792,11 +784,10 @@ func TestVerifyCodeResponse(t *testing.T) { } for _, tc := range tests { t.Run(tc.testName, func(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() + ctx := t.Context() // Setup a dex server. - httpServer, s := newTestServer(ctx, t, func(c *Config) { + httpServer, s := newTestServer(t, func(c *Config) { c.Issuer += "/non-root-path" c.Now = now }) diff --git a/server/handlers_test.go b/server/handlers_test.go index 1aa4bfa5..114712ba 100644 --- a/server/handlers_test.go +++ b/server/handlers_test.go @@ -24,10 +24,7 @@ import ( ) func TestHandleHealth(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - httpServer, server := newTestServer(ctx, t, nil) + httpServer, server := newTestServer(t, nil) defer httpServer.Close() rr := httptest.NewRecorder() @@ -38,10 +35,7 @@ func TestHandleHealth(t *testing.T) { } func TestHandleDiscovery(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - httpServer, server := newTestServer(ctx, t, nil) + httpServer, server := newTestServer(t, nil) defer httpServer.Close() rr := httptest.NewRecorder() @@ -108,10 +102,7 @@ func TestHandleDiscovery(t *testing.T) { } func TestHandleHealthFailure(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - httpServer, server := newTestServer(ctx, t, func(c *Config) { + httpServer, server := newTestServer(t, func(c *Config) { c.HealthChecker = gosundheit.New() c.HealthChecker.RegisterCheck( @@ -143,10 +134,7 @@ func (*emptyStorage) GetAuthRequest(context.Context, string) (storage.AuthReques } func TestHandleInvalidOAuth2Callbacks(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - httpServer, server := newTestServer(ctx, t, func(c *Config) { + httpServer, server := newTestServer(t, func(c *Config) { c.Storage = &emptyStorage{c.Storage} }) defer httpServer.Close() @@ -171,10 +159,7 @@ func TestHandleInvalidOAuth2Callbacks(t *testing.T) { } func TestHandleInvalidSAMLCallbacks(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - httpServer, server := newTestServer(ctx, t, func(c *Config) { + httpServer, server := newTestServer(t, func(c *Config) { c.Storage = &emptyStorage{c.Storage} }) defer httpServer.Close() @@ -251,10 +236,9 @@ func TestHandleAuthCode(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() + ctx := t.Context() - httpServer, s := newTestServer(ctx, t, func(c *Config) { c.Issuer += "/non-root-path" }) + httpServer, s := newTestServer(t, func(c *Config) { c.Issuer += "/non-root-path" }) defer httpServer.Close() p, err := oidc.NewProvider(ctx, httpServer.URL) @@ -303,7 +287,7 @@ func TestHandleAuthCode(t *testing.T) { } func mockConnectorDataTestStorage(t *testing.T, s storage.Storage) { - ctx := context.Background() + ctx := t.Context() c := storage.Client{ ID: "test", Secret: "barfoo", @@ -339,8 +323,7 @@ func mockConnectorDataTestStorage(t *testing.T, s storage.Storage) { } func TestHandlePassword(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() + ctx := t.Context() tests := []struct { name string @@ -361,7 +344,7 @@ func TestHandlePassword(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { // Setup a dex server. - httpServer, s := newTestServer(ctx, t, func(c *Config) { + httpServer, s := newTestServer(t, func(c *Config) { c.PasswordConnector = "test" c.Now = time.Now }) @@ -420,8 +403,7 @@ func TestHandlePassword(t *testing.T) { } func TestHandlePasswordLoginWithSkipApproval(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() + ctx := t.Context() connID := "mockPw" authReqID := "test" @@ -525,7 +507,7 @@ func TestHandlePasswordLoginWithSkipApproval(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { - httpServer, s := newTestServer(ctx, t, func(c *Config) { + httpServer, s := newTestServer(t, func(c *Config) { c.SkipApprovalScreen = tc.skipApproval c.Now = time.Now }) @@ -574,8 +556,7 @@ func TestHandlePasswordLoginWithSkipApproval(t *testing.T) { } func TestHandleConnectorCallbackWithSkipApproval(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() + ctx := t.Context() connID := "mock" authReqID := "test" @@ -679,7 +660,7 @@ func TestHandleConnectorCallbackWithSkipApproval(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { - httpServer, s := newTestServer(ctx, t, func(c *Config) { + httpServer, s := newTestServer(t, func(c *Config) { c.SkipApprovalScreen = tc.skipApproval c.Now = time.Now }) @@ -780,9 +761,8 @@ func TestHandleTokenExchange(t *testing.T) { } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - httpServer, s := newTestServer(ctx, t, func(c *Config) { + ctx := t.Context() + httpServer, s := newTestServer(t, func(c *Config) { c.Storage.CreateClient(ctx, storage.Client{ ID: "client_1", Secret: "secret_1", diff --git a/server/introspectionhandler_test.go b/server/introspectionhandler_test.go index 695bbad8..6f18d056 100644 --- a/server/introspectionhandler_test.go +++ b/server/introspectionhandler_test.go @@ -2,7 +2,6 @@ package server import ( "bytes" - "context" "encoding/json" "io" "net/http" @@ -29,7 +28,7 @@ func toJSON(a interface{}) string { } func mockTestStorage(t *testing.T, s storage.Storage) { - ctx := context.Background() + ctx := t.Context() c := storage.Client{ ID: "test", Secret: "barfoo", @@ -139,11 +138,8 @@ func TestGetTokenFromRequestSuccess(t *testing.T) { t0 := time.Now() now := func() time.Time { return t0 } - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - // Setup a dex server. - httpServer, s := newTestServer(ctx, t, func(c *Config) { + httpServer, s := newTestServer(t, func(c *Config) { c.Issuer += "/non-root-path" c.Now = now }) @@ -201,11 +197,9 @@ func TestGetTokenFromRequestFailure(t *testing.T) { t0 := time.Now() now := func() time.Time { return t0 } - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() // Setup a dex server. - httpServer, s := newTestServer(ctx, t, func(c *Config) { + httpServer, s := newTestServer(t, func(c *Config) { c.Issuer += "/non-root-path" c.Now = now }) @@ -238,19 +232,20 @@ func TestGetTokenFromRequestFailure(t *testing.T) { func TestHandleIntrospect(t *testing.T) { t0 := time.Now() - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() + ctx := t.Context() // Setup a dex server. now := func() time.Time { return t0 } + logger := newLogger(t) + refreshTokenPolicy, err := NewRefreshTokenPolicy(logger, false, "", "24h", "") if err != nil { t.Fatalf("failed to prepare rotation policy: %v", err) } refreshTokenPolicy.now = now - httpServer, s := newTestServer(ctx, t, func(c *Config) { + httpServer, s := newTestServer(t, func(c *Config) { c.Issuer += "/non-root-path" c.RefreshTokenPolicy = refreshTokenPolicy c.Now = now @@ -361,11 +356,9 @@ func TestIntrospectErrHelper(t *testing.T) { t0 := time.Now() now := func() time.Time { return t0 } - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() // Setup a dex server. - httpServer, s := newTestServer(ctx, t, func(c *Config) { + httpServer, s := newTestServer(t, func(c *Config) { c.Issuer += "/non-root-path" c.Now = now }) diff --git a/server/oauth2_test.go b/server/oauth2_test.go index 70e4095c..3dff30d6 100644 --- a/server/oauth2_test.go +++ b/server/oauth2_test.go @@ -1,7 +1,6 @@ package server import ( - "context" "crypto/rand" "crypto/rsa" "net/http" @@ -323,10 +322,7 @@ func TestParseAuthorizationRequest(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - httpServer, server := newTestServerMultipleConnectors(ctx, t, func(c *Config) { + httpServer, server := newTestServerMultipleConnectors(t, func(c *Config) { c.SupportedResponseTypes = tc.supportedResponseTypes c.Storage = storage.WithStaticClients(c.Storage, tc.clients) }) @@ -598,8 +594,9 @@ func TestValidRedirectURI(t *testing.T) { } func TestStorageKeySet(t *testing.T) { + logger := newLogger(t) s := memory.New(logger) - if err := s.UpdateKeys(context.TODO(), func(keys storage.Keys) (storage.Keys, error) { + if err := s.UpdateKeys(t.Context(), func(keys storage.Keys) (storage.Keys, error) { keys.SigningKey = &jose.JSONWebKey{ Key: testKey, KeyID: "testkey", @@ -673,7 +670,7 @@ func TestStorageKeySet(t *testing.T) { keySet := &storageKeySet{s} - _, err = keySet.VerifySignature(context.Background(), jwt) + _, err = keySet.VerifySignature(t.Context(), jwt) if (err != nil && !tc.wantErr) || (err == nil && tc.wantErr) { t.Fatalf("wantErr = %v, but got err = %v", tc.wantErr, err) } diff --git a/server/refreshhandlers_test.go b/server/refreshhandlers_test.go index 6b0925c2..f937769c 100644 --- a/server/refreshhandlers_test.go +++ b/server/refreshhandlers_test.go @@ -2,7 +2,6 @@ package server import ( "bytes" - "context" "encoding/json" "net/http" "net/http/httptest" @@ -18,7 +17,7 @@ import ( ) func mockRefreshTokenTestStorage(t *testing.T, s storage.Storage, useObsolete bool) { - ctx := context.Background() + ctx := t.Context() c := storage.Client{ ID: "test", Secret: "barfoo", @@ -153,11 +152,8 @@ func TestRefreshTokenExpirationScenarios(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(*testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - // Setup a dex server. - httpServer, s := newTestServer(ctx, t, func(c *Config) { + httpServer, s := newTestServer(t, func(c *Config) { c.RefreshTokenPolicy = tc.policy c.Now = func() time.Time { return t0 } }) diff --git a/server/server_test.go b/server/server_test.go index c414eb88..a922aa75 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -9,7 +9,6 @@ import ( "errors" "fmt" "io" - "log/slog" "net/http" "net/http/httptest" "net/http/httputil" @@ -76,14 +75,15 @@ FDWV28nTP9sqbtsmU8Tem2jzMvZ7C/Q0AuDoKELFUpux8shm8wfIhyaPnXUGZoAZ Np4vUwMSYV5mopESLWOg3loBxKyLGFtgGKVCjGiQvy6zISQ4fQo= -----END RSA PRIVATE KEY-----`) -var logger = slog.New(slog.DiscardHandler) - -func newTestServer(ctx context.Context, t *testing.T, updateConfig func(c *Config)) (*httptest.Server, *Server) { +func newTestServer(t *testing.T, updateConfig func(c *Config)) (*httptest.Server, *Server) { var server *Server s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { server.ServeHTTP(w, r) })) + logger := newLogger(t) + ctx := t.Context() + config := Config{ Issuer: s.URL, Storage: memory.New(logger), @@ -135,12 +135,15 @@ func newTestServer(ctx context.Context, t *testing.T, updateConfig func(c *Confi return s, server } -func newTestServerMultipleConnectors(ctx context.Context, t *testing.T, updateConfig func(c *Config)) (*httptest.Server, *Server) { +func newTestServerMultipleConnectors(t *testing.T, updateConfig func(c *Config)) (*httptest.Server, *Server) { var server *Server s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { server.ServeHTTP(w, r) })) + logger := newLogger(t) + ctx := t.Context() + config := Config{ Issuer: s.URL, Storage: memory.New(logger), @@ -183,21 +186,16 @@ func newTestServerMultipleConnectors(ctx context.Context, t *testing.T, updateCo } func TestNewTestServer(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - newTestServer(ctx, t, nil) + newTestServer(t, nil) } func TestDiscovery(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - httpServer, _ := newTestServer(ctx, t, func(c *Config) { + httpServer, _ := newTestServer(t, func(c *Config) { c.Issuer += "/non-root-path" }) defer httpServer.Close() - p, err := oidc.NewProvider(ctx, httpServer.URL) + p, err := oidc.NewProvider(t.Context(), httpServer.URL) if err != nil { t.Fatalf("failed to get provider: %v", err) } @@ -734,11 +732,10 @@ func TestOAuth2CodeFlow(t *testing.T) { tests := makeOAuth2Tests(clientID, clientSecret, now) for _, tc := range tests.tests { t.Run(tc.name, func(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() + ctx := t.Context() // Setup a dex server. - httpServer, s := newTestServer(ctx, t, func(c *Config) { + httpServer, s := newTestServer(t, func(c *Config) { c.Issuer += "/non-root-path" c.Now = now c.IDTokensValidFor = idTokensValidFor @@ -890,10 +887,9 @@ func TestOAuth2CodeFlow(t *testing.T) { } func TestOAuth2ImplicitFlow(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() + ctx := t.Context() - httpServer, s := newTestServer(ctx, t, func(c *Config) { + httpServer, s := newTestServer(t, func(c *Config) { // Enable support for the implicit flow. c.SupportedResponseTypes = []string{"code", "token", "id_token"} }) @@ -1026,10 +1022,9 @@ func TestOAuth2ImplicitFlow(t *testing.T) { } func TestCrossClientScopes(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() + ctx := t.Context() - httpServer, s := newTestServer(ctx, t, func(c *Config) { + httpServer, s := newTestServer(t, func(c *Config) { c.Issuer += "/non-root-path" }) defer httpServer.Close() @@ -1149,10 +1144,9 @@ func TestCrossClientScopes(t *testing.T) { } func TestCrossClientScopesWithAzpInAudienceByDefault(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() + ctx := t.Context() - httpServer, s := newTestServer(ctx, t, func(c *Config) { + httpServer, s := newTestServer(t, func(c *Config) { c.Issuer += "/non-root-path" }) defer httpServer.Close() @@ -1271,7 +1265,9 @@ func TestCrossClientScopesWithAzpInAudienceByDefault(t *testing.T) { } func TestPasswordDB(t *testing.T) { - ctx := context.Background() + ctx := t.Context() + + logger := newLogger(t) s := memory.New(logger) conn := newPasswordDB(s) @@ -1323,7 +1319,7 @@ func TestPasswordDB(t *testing.T) { } for _, tc := range tests { - ident, valid, err := conn.Login(context.Background(), connector.Scopes{}, tc.username, tc.password) + ident, valid, err := conn.Login(t.Context(), connector.Scopes{}, tc.username, tc.password) if err != nil { if !tc.wantErr { t.Errorf("%s: %v", tc.name, err) @@ -1355,6 +1351,7 @@ func TestPasswordDB(t *testing.T) { } func TestPasswordDBUsernamePrompt(t *testing.T) { + logger := newLogger(t) s := memory.New(logger) conn := newPasswordDB(s) @@ -1377,7 +1374,8 @@ func (s storageWithKeysTrigger) GetKeys(ctx context.Context) (storage.Keys, erro func TestKeyCacher(t *testing.T) { tNow := time.Now() now := func() time.Time { return tNow } - ctx := context.TODO() + ctx := t.Context() + logger := newLogger(t) s := memory.New(logger) tests := []struct { @@ -1428,7 +1426,7 @@ func TestKeyCacher(t *testing.T) { for i, tc := range tests { gotCall = false tc.before() - s.GetKeys(context.TODO()) + s.GetKeys(t.Context()) if gotCall != tc.wantCallToStorage { t.Errorf("case %d: expected call to storage=%t got call to storage=%t", i, tc.wantCallToStorage, gotCall) } @@ -1470,10 +1468,10 @@ type oauth2Client struct { func TestRefreshTokenFlow(t *testing.T) { state := "state" now := time.Now - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - httpServer, s := newTestServer(ctx, t, func(c *Config) { + ctx := t.Context() + + httpServer, s := newTestServer(t, func(c *Config) { c.Now = now }) defer httpServer.Close() @@ -1604,11 +1602,10 @@ func TestOAuth2DeviceFlow(t *testing.T) { for _, testCase := range testCases { for _, tc := range testCase.oauth2Tests.tests { t.Run(tc.name, func(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() + ctx := t.Context() // Setup a dex server. - httpServer, s := newTestServer(ctx, t, func(c *Config) { + httpServer, s := newTestServer(t, func(c *Config) { c.Issuer += "/non-root-path" c.Now = now c.IDTokensValidFor = idTokensValidFor @@ -1789,17 +1786,16 @@ func TestServerSupportedGrants(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { - _, srv := newTestServer(context.TODO(), t, tc.config) + _, srv := newTestServer(t, tc.config) require.Equal(t, tc.resGrants, srv.supportedGrantTypes) }) } } func TestHeaders(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() + ctx := t.Context() - httpServer, _ := newTestServer(ctx, t, func(c *Config) { + httpServer, _ := newTestServer(t, func(c *Config) { c.Headers = map[string][]string{ "Strict-Transport-Security": {"max-age=31536000; includeSubDomains"}, } @@ -1818,8 +1814,7 @@ func TestHeaders(t *testing.T) { } func TestConnectorFailureHandling(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() + ctx := t.Context() tests := []struct { name string @@ -1959,6 +1954,8 @@ func TestConnectorFailureHandling(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { + logger := newLogger(t) + config := Config{ Issuer: "http://localhost", Storage: memory.New(logger), diff --git a/storage/conformance/conformance.go b/storage/conformance/conformance.go index 58ae3d95..f9d21961 100644 --- a/storage/conformance/conformance.go +++ b/storage/conformance/conformance.go @@ -24,10 +24,10 @@ type subTest struct { run func(t *testing.T, s storage.Storage) } -func runTests(t *testing.T, newStorage func() storage.Storage, tests []subTest) { +func runTests(t *testing.T, newStorage func(t *testing.T) storage.Storage, tests []subTest) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - s := newStorage() + s := newStorage(t) test.run(t, s) s.Close() }) @@ -37,7 +37,7 @@ func runTests(t *testing.T, newStorage func() storage.Storage, tests []subTest) // RunTests runs a set of conformance tests against a storage. newStorage should // return an initialized but empty storage. The storage will be closed at the // end of each test run. -func RunTests(t *testing.T, newStorage func() storage.Storage) { +func RunTests(t *testing.T, newStorage func(t *testing.T) storage.Storage) { runTests(t, newStorage, []subTest{ {"AuthCodeCRUD", testAuthCodeCRUD}, {"AuthRequestCRUD", testAuthRequestCRUD}, @@ -81,7 +81,7 @@ func mustBeErrAlreadyExists(t *testing.T, kind string, err error) { } func testAuthRequestCRUD(t *testing.T, s storage.Storage) { - ctx := context.Background() + ctx := t.Context() codeChallenge := storage.PKCE{ CodeChallenge: "code_challenge_test", CodeChallengeMethod: "plain", @@ -181,7 +181,7 @@ func testAuthRequestCRUD(t *testing.T, s storage.Storage) { } func testAuthCodeCRUD(t *testing.T, s storage.Storage) { - ctx := context.Background() + ctx := t.Context() a1 := storage.AuthCode{ ID: storage.NewID(), ClientID: "client1", @@ -259,7 +259,7 @@ func testAuthCodeCRUD(t *testing.T, s storage.Storage) { } func testClientCRUD(t *testing.T, s storage.Storage) { - ctx := context.Background() + ctx := t.Context() id1 := storage.NewID() c1 := storage.Client{ ID: id1, @@ -329,7 +329,7 @@ func testClientCRUD(t *testing.T, s storage.Storage) { } func testRefreshTokenCRUD(t *testing.T, s storage.Storage) { - ctx := context.Background() + ctx := t.Context() id := storage.NewID() refresh := storage.RefreshToken{ ID: id, @@ -448,7 +448,7 @@ func (n byEmail) Less(i, j int) bool { return n[i].Email < n[j].Email } func (n byEmail) Swap(i, j int) { n[i], n[j] = n[j], n[i] } func testPasswordCRUD(t *testing.T, s storage.Storage) { - ctx := context.Background() + ctx := t.Context() // Use bcrypt.MinCost to keep the tests short. passwordHash1, err := bcrypt.GenerateFromPassword([]byte("secret"), bcrypt.MinCost) if err != nil { @@ -539,7 +539,7 @@ func testPasswordCRUD(t *testing.T, s storage.Storage) { } func testOfflineSessionCRUD(t *testing.T, s storage.Storage) { - ctx := context.Background() + ctx := t.Context() userID1 := storage.NewID() session1 := storage.OfflineSessions{ UserID: userID1, @@ -614,7 +614,7 @@ func testOfflineSessionCRUD(t *testing.T, s storage.Storage) { } func testConnectorCRUD(t *testing.T, s storage.Storage) { - ctx := context.Background() + ctx := t.Context() id1 := storage.NewID() config1 := []byte(`{"issuer": "https://accounts.google.com"}`) c1 := storage.Connector{ @@ -754,7 +754,7 @@ func testKeysCRUD(t *testing.T, s storage.Storage) { } func testGC(t *testing.T, s storage.Storage) { - ctx := context.Background() + ctx := t.Context() est, err := time.LoadLocation("America/New_York") if err != nil { t.Fatal(err) @@ -942,7 +942,7 @@ func testGC(t *testing.T, s storage.Storage) { // testTimezones tests that backends either fully support timezones or // do the correct standardization. func testTimezones(t *testing.T, s storage.Storage) { - ctx := context.Background() + ctx := t.Context() est, err := time.LoadLocation("America/New_York") if err != nil { t.Fatal(err) @@ -987,7 +987,7 @@ func testTimezones(t *testing.T, s storage.Storage) { } func testDeviceRequestCRUD(t *testing.T, s storage.Storage) { - ctx := context.Background() + ctx := t.Context() d1 := storage.DeviceRequest{ UserCode: storage.NewUserCode(), DeviceCode: storage.NewID(), @@ -1017,7 +1017,7 @@ func testDeviceRequestCRUD(t *testing.T, s storage.Storage) { } func testDeviceTokenCRUD(t *testing.T, s storage.Storage) { - ctx := context.Background() + ctx := t.Context() codeChallenge := storage.PKCE{ CodeChallenge: "code_challenge_test", CodeChallengeMethod: "plain", diff --git a/storage/conformance/transactions.go b/storage/conformance/transactions.go index 60365c9a..a67a6d7d 100644 --- a/storage/conformance/transactions.go +++ b/storage/conformance/transactions.go @@ -17,7 +17,7 @@ import ( // This call is separate from RunTests because some storage perform extremely // poorly under deadlocks, such as SQLite3, while others may be working towards // conformance. -func RunTransactionTests(t *testing.T, newStorage func() storage.Storage) { +func RunTransactionTests(t *testing.T, newStorage func(t *testing.T) storage.Storage) { runTests(t, newStorage, []subTest{ {"AuthRequestConcurrentUpdate", testAuthRequestConcurrentUpdate}, {"ClientConcurrentUpdate", testClientConcurrentUpdate}, @@ -27,7 +27,7 @@ func RunTransactionTests(t *testing.T, newStorage func() storage.Storage) { } func testClientConcurrentUpdate(t *testing.T, s storage.Storage) { - ctx := context.Background() + ctx := t.Context() c := storage.Client{ ID: storage.NewID(), Secret: "foobar", @@ -57,7 +57,7 @@ func testClientConcurrentUpdate(t *testing.T, s storage.Storage) { } func testAuthRequestConcurrentUpdate(t *testing.T, s storage.Storage) { - ctx := context.Background() + ctx := t.Context() a := storage.AuthRequest{ ID: storage.NewID(), ClientID: "foobar", @@ -102,7 +102,7 @@ func testAuthRequestConcurrentUpdate(t *testing.T, s storage.Storage) { } func testPasswordConcurrentUpdate(t *testing.T, s storage.Storage) { - ctx := context.Background() + ctx := t.Context() // Use bcrypt.MinCost to keep the tests short. passwordHash, err := bcrypt.GenerateFromPassword([]byte("secret"), bcrypt.MinCost) if err != nil { diff --git a/storage/ent/mysql_test.go b/storage/ent/mysql_test.go index cc3260f8..a1baffb3 100644 --- a/storage/ent/mysql_test.go +++ b/storage/ent/mysql_test.go @@ -40,8 +40,8 @@ func mysqlTestConfig(host string, port uint64) *MySQL { } } -func newMySQLStorage(host string, port uint64) storage.Storage { - logger := slog.New(slog.DiscardHandler) +func newMySQLStorage(t *testing.T, host string, port uint64) storage.Storage { + logger := slog.New(slog.NewTextHandler(t.Output(), &slog.HandlerOptions{Level: slog.LevelDebug})) cfg := mysqlTestConfig(host, port) s, err := cfg.Open(logger) @@ -65,8 +65,8 @@ func TestMySQL(t *testing.T) { require.NoError(t, err, "invalid mysql port %q: %s", rawPort, err) } - newStorage := func() storage.Storage { - return newMySQLStorage(host, port) + newStorage := func(t *testing.T) storage.Storage { + return newMySQLStorage(t, host, port) } conformance.RunTests(t, newStorage) conformance.RunTransactionTests(t, newStorage) diff --git a/storage/ent/postgres_test.go b/storage/ent/postgres_test.go index fb4f959f..bbbde38e 100644 --- a/storage/ent/postgres_test.go +++ b/storage/ent/postgres_test.go @@ -35,8 +35,8 @@ func postgresTestConfig(host string, port uint64) *Postgres { } } -func newPostgresStorage(host string, port uint64) storage.Storage { - logger := slog.New(slog.DiscardHandler) +func newPostgresStorage(t *testing.T, host string, port uint64) storage.Storage { + logger := slog.New(slog.NewTextHandler(t.Output(), &slog.HandlerOptions{Level: slog.LevelDebug})) cfg := postgresTestConfig(host, port) s, err := cfg.Open(logger) @@ -60,8 +60,8 @@ func TestPostgres(t *testing.T) { require.NoError(t, err, "invalid postgres port %q: %s", rawPort, err) } - newStorage := func() storage.Storage { - return newPostgresStorage(host, port) + newStorage := func(t *testing.T) storage.Storage { + return newPostgresStorage(t, host, port) } conformance.RunTests(t, newStorage) conformance.RunTransactionTests(t, newStorage) diff --git a/storage/ent/sqlite_test.go b/storage/ent/sqlite_test.go index b72b7ff0..55c1b5c5 100644 --- a/storage/ent/sqlite_test.go +++ b/storage/ent/sqlite_test.go @@ -8,8 +8,8 @@ import ( "github.com/dexidp/dex/storage/conformance" ) -func newSQLiteStorage() storage.Storage { - logger := slog.New(slog.DiscardHandler) +func newSQLiteStorage(t *testing.T) storage.Storage { + logger := slog.New(slog.NewTextHandler(t.Output(), &slog.HandlerOptions{Level: slog.LevelDebug})) cfg := SQLite3{File: ":memory:"} s, err := cfg.Open(logger) diff --git a/storage/etcd/etcd_test.go b/storage/etcd/etcd_test.go index 6e500c1c..6783c25b 100644 --- a/storage/etcd/etcd_test.go +++ b/storage/etcd/etcd_test.go @@ -55,8 +55,6 @@ func cleanDB(c *conn) error { return nil } -var logger = slog.New(slog.DiscardHandler) - func TestEtcd(t *testing.T) { testEtcdEnv := "DEX_ETCD_ENDPOINTS" endpointsStr := os.Getenv(testEtcdEnv) @@ -66,10 +64,11 @@ func TestEtcd(t *testing.T) { } endpoints := strings.Split(endpointsStr, ",") - newStorage := func() storage.Storage { + newStorage := func(t *testing.T) storage.Storage { s := &Etcd{ Endpoints: endpoints, } + logger := slog.New(slog.NewTextHandler(t.Output(), &slog.HandlerOptions{Level: slog.LevelDebug})) conn, err := s.open(logger) if err != nil { fmt.Fprintln(os.Stdout, err) diff --git a/storage/kubernetes/storage_test.go b/storage/kubernetes/storage_test.go index aa1360d4..98ef25fa 100644 --- a/storage/kubernetes/storage_test.go +++ b/storage/kubernetes/storage_test.go @@ -57,7 +57,7 @@ func (s *StorageTestSuite) SetupTest() { KubeConfigFile: kubeconfigPath, } - logger := slog.New(slog.DiscardHandler) + logger := slog.New(slog.NewTextHandler(s.T().Output(), &slog.HandlerOptions{Level: slog.LevelDebug})) kubeClient, err := config.open(logger, true) s.Require().NoError(err) @@ -66,7 +66,7 @@ func (s *StorageTestSuite) SetupTest() { } func (s *StorageTestSuite) TestStorage() { - newStorage := func() storage.Storage { + newStorage := func(t *testing.T) storage.Storage { for _, resource := range []string{ resourceAuthCode, resourceAuthRequest, diff --git a/storage/memory/memory_test.go b/storage/memory/memory_test.go index cf090810..e6e8232f 100644 --- a/storage/memory/memory_test.go +++ b/storage/memory/memory_test.go @@ -9,9 +9,9 @@ import ( ) func TestStorage(t *testing.T) { - logger := slog.New(slog.DiscardHandler) + newStorage := func(t *testing.T) storage.Storage { + logger := slog.New(slog.NewTextHandler(t.Output(), &slog.HandlerOptions{Level: slog.LevelDebug})) - newStorage := func() storage.Storage { return New(logger) } conformance.RunTests(t, newStorage) diff --git a/storage/sql/config_test.go b/storage/sql/config_test.go index a9560643..93a593ea 100644 --- a/storage/sql/config_test.go +++ b/storage/sql/config_test.go @@ -46,20 +46,19 @@ func cleanDB(c *conn) error { return nil } -var logger = slog.New(slog.DiscardHandler) - type opener interface { open(logger *slog.Logger) (*conn, error) } func testDB(t *testing.T, o opener, withTransactions bool) { // t.Fatal has a bad habit of not actually printing the error - fatal := func(i interface{}) { + fatal := func(i any) { fmt.Fprintln(os.Stdout, i) t.Fatal(i) } - newStorage := func() storage.Storage { + newStorage := func(t *testing.T) storage.Storage { + logger := slog.New(slog.NewTextHandler(t.Output(), &slog.HandlerOptions{Level: slog.LevelDebug})) conn, err := o.open(logger) if err != nil { fatal(err) diff --git a/storage/sql/postgres_test.go b/storage/sql/postgres_test.go index 3e5f8a8f..085c068a 100644 --- a/storage/sql/postgres_test.go +++ b/storage/sql/postgres_test.go @@ -4,6 +4,7 @@ package sql import ( + "log/slog" "os" "strconv" "testing" @@ -40,6 +41,7 @@ func TestPostgresTunables(t *testing.T) { t.Run("with nothing set, uses defaults", func(t *testing.T) { cfg := *baseCfg + logger := slog.New(slog.NewTextHandler(t.Output(), &slog.HandlerOptions{Level: slog.LevelDebug})) c, err := cfg.open(logger) if err != nil { t.Fatalf("error opening connector: %s", err.Error()) @@ -53,6 +55,7 @@ func TestPostgresTunables(t *testing.T) { t.Run("with something set, uses that", func(t *testing.T) { cfg := *baseCfg cfg.MaxOpenConns = 101 + logger := slog.New(slog.NewTextHandler(t.Output(), &slog.HandlerOptions{Level: slog.LevelDebug})) c, err := cfg.open(logger) if err != nil { t.Fatalf("error opening connector: %s", err.Error())