Browse Source

test: use new Go features in tests

Signed-off-by: Mark Sagi-Kazar <mark.sagikazar@gmail.com>
pull/4277/head
Mark Sagi-Kazar 7 months ago
parent
commit
e230d9426d
No known key found for this signature in database
GPG Key ID: 31AB0439F4C5C90E
  1. 117
      server/api_test.go
  2. 27
      server/deviceflowhandlers_test.go
  3. 52
      server/handlers_test.go
  4. 23
      server/introspectionhandler_test.go
  5. 11
      server/oauth2_test.go
  6. 8
      server/refreshhandlers_test.go
  7. 79
      server/server_test.go
  8. 28
      storage/conformance/conformance.go
  9. 8
      storage/conformance/transactions.go
  10. 8
      storage/ent/mysql_test.go
  11. 8
      storage/ent/postgres_test.go
  12. 4
      storage/ent/sqlite_test.go
  13. 5
      storage/etcd/etcd_test.go
  14. 4
      storage/kubernetes/storage_test.go
  15. 4
      storage/memory/memory_test.go
  16. 7
      storage/sql/config_test.go
  17. 3
      storage/sql/postgres_test.go

117
server/api_test.go

@ -1,10 +1,9 @@
package server
import (
"context"
"log/slog"
"net"
"os"
"slices"
"strings"
"testing"
"time"
@ -29,8 +28,12 @@ type apiClient struct {
Close func()
}
func newLogger(t *testing.T) *slog.Logger {
return slog.New(slog.NewTextHandler(t.Output(), &slog.HandlerOptions{Level: slog.LevelDebug}))
}
// newAPI constructs a gRCP client connected to a backing server.
func newAPI(s storage.Storage, logger *slog.Logger, t *testing.T) *apiClient {
func newAPI(t *testing.T, s storage.Storage, logger *slog.Logger) *apiClient {
l, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatal(err)
@ -59,13 +62,14 @@ func newAPI(s storage.Storage, logger *slog.Logger, t *testing.T) *apiClient {
// Attempts to create, update and delete a test Password
func TestPassword(t *testing.T) {
logger := slog.New(slog.DiscardHandler)
logger := newLogger(t)
s := memory.New(logger)
client := newAPI(s, logger, t)
client := newAPI(t, s, logger)
defer client.Close()
ctx := context.Background()
ctx := t.Context()
email := "test@example.com"
p := api.Password{
Email: email,
@ -168,10 +172,10 @@ func TestPassword(t *testing.T) {
// Ensures checkCost returns expected values
func TestCheckCost(t *testing.T) {
logger := slog.New(slog.DiscardHandler)
logger := newLogger(t)
s := memory.New(logger)
client := newAPI(s, logger, t)
client := newAPI(t, s, logger)
defer client.Close()
tests := []struct {
@ -221,13 +225,13 @@ func TestCheckCost(t *testing.T) {
// Attempts to list and revoke an existing refresh token.
func TestRefreshToken(t *testing.T) {
logger := slog.New(slog.DiscardHandler)
logger := newLogger(t)
s := memory.New(logger)
client := newAPI(s, logger, t)
client := newAPI(t, s, logger)
defer client.Close()
ctx := context.Background()
ctx := t.Context()
// Creating a storage with an existing refresh token and offline session for the user.
id := storage.NewID()
@ -330,12 +334,13 @@ func TestRefreshToken(t *testing.T) {
}
func TestUpdateClient(t *testing.T) {
logger := slog.New(slog.DiscardHandler)
logger := newLogger(t)
s := memory.New(logger)
client := newAPI(s, logger, t)
client := newAPI(t, s, logger)
defer client.Close()
ctx := context.Background()
ctx := t.Context()
createClient := func(t *testing.T, clientId string) {
resp, err := client.CreateClient(ctx, &api.CreateClientReq{
@ -463,13 +468,13 @@ func TestUpdateClient(t *testing.T) {
t.Errorf("expected stored client with LogoURL: %s, found %s", tc.req.LogoUrl, client.LogoURL)
}
for _, redirectURI := range tc.req.RedirectUris {
found := find(redirectURI, client.RedirectURIs)
found := slices.Contains(client.RedirectURIs, redirectURI)
if !found {
t.Errorf("expected redirect URI: %s", redirectURI)
}
}
for _, peer := range tc.req.TrustedPeers {
found := find(peer, client.TrustedPeers)
found := slices.Contains(client.TrustedPeers, peer)
if !found {
t.Errorf("expected trusted peer: %s", peer)
}
@ -483,26 +488,17 @@ func TestUpdateClient(t *testing.T) {
}
}
func find(item string, items []string) bool {
for _, i := range items {
if item == i {
return true
}
}
return false
}
func TestCreateConnector(t *testing.T) {
os.Setenv("DEX_API_CONNECTORS_CRUD", "true")
defer os.Unsetenv("DEX_API_CONNECTORS_CRUD")
logger := slog.New(slog.DiscardHandler)
t.Setenv("DEX_API_CONNECTORS_CRUD", "true")
logger := newLogger(t)
s := memory.New(logger)
client := newAPI(s, logger, t)
client := newAPI(t, s, logger)
defer client.Close()
ctx := context.Background()
ctx := t.Context()
connectorID := "connector123"
connectorName := "TestConnector"
connectorType := "TestType"
@ -543,16 +539,16 @@ func TestCreateConnector(t *testing.T) {
}
func TestUpdateConnector(t *testing.T) {
os.Setenv("DEX_API_CONNECTORS_CRUD", "true")
defer os.Unsetenv("DEX_API_CONNECTORS_CRUD")
logger := slog.New(slog.DiscardHandler)
t.Setenv("DEX_API_CONNECTORS_CRUD", "true")
logger := newLogger(t)
s := memory.New(logger)
client := newAPI(s, logger, t)
client := newAPI(t, s, logger)
defer client.Close()
ctx := context.Background()
ctx := t.Context()
connectorID := "connector123"
newConnectorName := "UpdatedConnector"
newConnectorType := "UpdatedType"
@ -611,16 +607,16 @@ func TestUpdateConnector(t *testing.T) {
}
func TestDeleteConnector(t *testing.T) {
os.Setenv("DEX_API_CONNECTORS_CRUD", "true")
defer os.Unsetenv("DEX_API_CONNECTORS_CRUD")
logger := slog.New(slog.DiscardHandler)
t.Setenv("DEX_API_CONNECTORS_CRUD", "true")
logger := newLogger(t)
s := memory.New(logger)
client := newAPI(s, logger, t)
client := newAPI(t, s, logger)
defer client.Close()
ctx := context.Background()
ctx := t.Context()
connectorID := "connector123"
// Create a connector for testing
@ -655,16 +651,15 @@ func TestDeleteConnector(t *testing.T) {
}
func TestListConnectors(t *testing.T) {
os.Setenv("DEX_API_CONNECTORS_CRUD", "true")
defer os.Unsetenv("DEX_API_CONNECTORS_CRUD")
logger := slog.New(slog.DiscardHandler)
t.Setenv("DEX_API_CONNECTORS_CRUD", "true")
logger := newLogger(t)
s := memory.New(logger)
client := newAPI(s, logger, t)
client := newAPI(t, s, logger)
defer client.Close()
ctx := context.Background()
ctx := t.Context()
// Create connectors for testing
createReq1 := api.CreateConnectorReq{
@ -698,13 +693,13 @@ func TestListConnectors(t *testing.T) {
}
func TestMissingConnectorsCRUDFeatureFlag(t *testing.T) {
logger := slog.New(slog.DiscardHandler)
logger := newLogger(t)
s := memory.New(logger)
client := newAPI(s, logger, t)
client := newAPI(t, s, logger)
defer client.Close()
ctx := context.Background()
ctx := t.Context()
// Create connectors for testing
createReq1 := api.CreateConnectorReq{
@ -735,13 +730,13 @@ func TestMissingConnectorsCRUDFeatureFlag(t *testing.T) {
}
func TestListClients(t *testing.T) {
logger := slog.New(slog.DiscardHandler)
logger := newLogger(t)
s := memory.New(logger)
client := newAPI(s, logger, t)
client := newAPI(t, s, logger)
defer client.Close()
ctx := context.Background()
ctx := t.Context()
// List Clients
listResp, err := client.ListClients(ctx, &api.ListClientReq{})

27
server/deviceflowhandlers_test.go

@ -2,7 +2,6 @@ package server
import (
"bytes"
"context"
"encoding/json"
"io"
"net/http"
@ -20,10 +19,8 @@ func TestDeviceVerificationURI(t *testing.T) {
t0 := time.Now()
now := func() time.Time { return t0 }
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// Setup a dex server.
httpServer, s := newTestServer(ctx, t, func(c *Config) {
httpServer, s := newTestServer(t, func(c *Config) {
c.Issuer += "/non-root-path"
c.Now = now
})
@ -101,11 +98,8 @@ func TestHandleDeviceCode(t *testing.T) {
}
for _, tc := range tests {
t.Run(tc.testName, func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// Setup a dex server.
httpServer, s := newTestServer(ctx, t, func(c *Config) {
httpServer, s := newTestServer(t, func(c *Config) {
c.Issuer += "/non-root-path"
c.Now = now
})
@ -364,11 +358,10 @@ func TestDeviceCallback(t *testing.T) {
}
for _, tc := range tests {
t.Run(tc.testName, func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
ctx := t.Context()
// Setup a dex server.
httpServer, s := newTestServer(ctx, t, func(c *Config) {
httpServer, s := newTestServer(t, func(c *Config) {
// c.Issuer = c.Issuer + "/non-root-path"
c.Now = now
})
@ -658,11 +651,10 @@ func TestDeviceTokenResponse(t *testing.T) {
}
for _, tc := range tests {
t.Run(tc.testName, func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
ctx := t.Context()
// Setup a dex server.
httpServer, s := newTestServer(ctx, t, func(c *Config) {
httpServer, s := newTestServer(t, func(c *Config) {
c.Issuer += "/non-root-path"
c.Now = now
})
@ -715,7 +707,7 @@ func TestDeviceTokenResponse(t *testing.T) {
}
func expectJSONErrorResponse(testCase string, body []byte, expectedError string, t *testing.T) {
jsonMap := make(map[string]interface{})
jsonMap := make(map[string]any)
err := json.Unmarshal(body, &jsonMap)
if err != nil {
t.Errorf("Unexpected error unmarshalling response: %v", err)
@ -792,11 +784,10 @@ func TestVerifyCodeResponse(t *testing.T) {
}
for _, tc := range tests {
t.Run(tc.testName, func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
ctx := t.Context()
// Setup a dex server.
httpServer, s := newTestServer(ctx, t, func(c *Config) {
httpServer, s := newTestServer(t, func(c *Config) {
c.Issuer += "/non-root-path"
c.Now = now
})

52
server/handlers_test.go

@ -24,10 +24,7 @@ import (
)
func TestHandleHealth(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
httpServer, server := newTestServer(ctx, t, nil)
httpServer, server := newTestServer(t, nil)
defer httpServer.Close()
rr := httptest.NewRecorder()
@ -38,10 +35,7 @@ func TestHandleHealth(t *testing.T) {
}
func TestHandleDiscovery(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
httpServer, server := newTestServer(ctx, t, nil)
httpServer, server := newTestServer(t, nil)
defer httpServer.Close()
rr := httptest.NewRecorder()
@ -108,10 +102,7 @@ func TestHandleDiscovery(t *testing.T) {
}
func TestHandleHealthFailure(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
httpServer, server := newTestServer(ctx, t, func(c *Config) {
httpServer, server := newTestServer(t, func(c *Config) {
c.HealthChecker = gosundheit.New()
c.HealthChecker.RegisterCheck(
@ -143,10 +134,7 @@ func (*emptyStorage) GetAuthRequest(context.Context, string) (storage.AuthReques
}
func TestHandleInvalidOAuth2Callbacks(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
httpServer, server := newTestServer(ctx, t, func(c *Config) {
httpServer, server := newTestServer(t, func(c *Config) {
c.Storage = &emptyStorage{c.Storage}
})
defer httpServer.Close()
@ -171,10 +159,7 @@ func TestHandleInvalidOAuth2Callbacks(t *testing.T) {
}
func TestHandleInvalidSAMLCallbacks(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
httpServer, server := newTestServer(ctx, t, func(c *Config) {
httpServer, server := newTestServer(t, func(c *Config) {
c.Storage = &emptyStorage{c.Storage}
})
defer httpServer.Close()
@ -251,10 +236,9 @@ func TestHandleAuthCode(t *testing.T) {
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
ctx := t.Context()
httpServer, s := newTestServer(ctx, t, func(c *Config) { c.Issuer += "/non-root-path" })
httpServer, s := newTestServer(t, func(c *Config) { c.Issuer += "/non-root-path" })
defer httpServer.Close()
p, err := oidc.NewProvider(ctx, httpServer.URL)
@ -303,7 +287,7 @@ func TestHandleAuthCode(t *testing.T) {
}
func mockConnectorDataTestStorage(t *testing.T, s storage.Storage) {
ctx := context.Background()
ctx := t.Context()
c := storage.Client{
ID: "test",
Secret: "barfoo",
@ -339,8 +323,7 @@ func mockConnectorDataTestStorage(t *testing.T, s storage.Storage) {
}
func TestHandlePassword(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
ctx := t.Context()
tests := []struct {
name string
@ -361,7 +344,7 @@ func TestHandlePassword(t *testing.T) {
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
// Setup a dex server.
httpServer, s := newTestServer(ctx, t, func(c *Config) {
httpServer, s := newTestServer(t, func(c *Config) {
c.PasswordConnector = "test"
c.Now = time.Now
})
@ -420,8 +403,7 @@ func TestHandlePassword(t *testing.T) {
}
func TestHandlePasswordLoginWithSkipApproval(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
ctx := t.Context()
connID := "mockPw"
authReqID := "test"
@ -525,7 +507,7 @@ func TestHandlePasswordLoginWithSkipApproval(t *testing.T) {
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
httpServer, s := newTestServer(ctx, t, func(c *Config) {
httpServer, s := newTestServer(t, func(c *Config) {
c.SkipApprovalScreen = tc.skipApproval
c.Now = time.Now
})
@ -574,8 +556,7 @@ func TestHandlePasswordLoginWithSkipApproval(t *testing.T) {
}
func TestHandleConnectorCallbackWithSkipApproval(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
ctx := t.Context()
connID := "mock"
authReqID := "test"
@ -679,7 +660,7 @@ func TestHandleConnectorCallbackWithSkipApproval(t *testing.T) {
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
httpServer, s := newTestServer(ctx, t, func(c *Config) {
httpServer, s := newTestServer(t, func(c *Config) {
c.SkipApprovalScreen = tc.skipApproval
c.Now = time.Now
})
@ -780,9 +761,8 @@ func TestHandleTokenExchange(t *testing.T) {
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
httpServer, s := newTestServer(ctx, t, func(c *Config) {
ctx := t.Context()
httpServer, s := newTestServer(t, func(c *Config) {
c.Storage.CreateClient(ctx, storage.Client{
ID: "client_1",
Secret: "secret_1",

23
server/introspectionhandler_test.go

@ -2,7 +2,6 @@ package server
import (
"bytes"
"context"
"encoding/json"
"io"
"net/http"
@ -29,7 +28,7 @@ func toJSON(a interface{}) string {
}
func mockTestStorage(t *testing.T, s storage.Storage) {
ctx := context.Background()
ctx := t.Context()
c := storage.Client{
ID: "test",
Secret: "barfoo",
@ -139,11 +138,8 @@ func TestGetTokenFromRequestSuccess(t *testing.T) {
t0 := time.Now()
now := func() time.Time { return t0 }
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// Setup a dex server.
httpServer, s := newTestServer(ctx, t, func(c *Config) {
httpServer, s := newTestServer(t, func(c *Config) {
c.Issuer += "/non-root-path"
c.Now = now
})
@ -201,11 +197,9 @@ func TestGetTokenFromRequestFailure(t *testing.T) {
t0 := time.Now()
now := func() time.Time { return t0 }
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// Setup a dex server.
httpServer, s := newTestServer(ctx, t, func(c *Config) {
httpServer, s := newTestServer(t, func(c *Config) {
c.Issuer += "/non-root-path"
c.Now = now
})
@ -238,19 +232,20 @@ func TestGetTokenFromRequestFailure(t *testing.T) {
func TestHandleIntrospect(t *testing.T) {
t0 := time.Now()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
ctx := t.Context()
// Setup a dex server.
now := func() time.Time { return t0 }
logger := newLogger(t)
refreshTokenPolicy, err := NewRefreshTokenPolicy(logger, false, "", "24h", "")
if err != nil {
t.Fatalf("failed to prepare rotation policy: %v", err)
}
refreshTokenPolicy.now = now
httpServer, s := newTestServer(ctx, t, func(c *Config) {
httpServer, s := newTestServer(t, func(c *Config) {
c.Issuer += "/non-root-path"
c.RefreshTokenPolicy = refreshTokenPolicy
c.Now = now
@ -361,11 +356,9 @@ func TestIntrospectErrHelper(t *testing.T) {
t0 := time.Now()
now := func() time.Time { return t0 }
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// Setup a dex server.
httpServer, s := newTestServer(ctx, t, func(c *Config) {
httpServer, s := newTestServer(t, func(c *Config) {
c.Issuer += "/non-root-path"
c.Now = now
})

11
server/oauth2_test.go

@ -1,7 +1,6 @@
package server
import (
"context"
"crypto/rand"
"crypto/rsa"
"net/http"
@ -323,10 +322,7 @@ func TestParseAuthorizationRequest(t *testing.T) {
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
httpServer, server := newTestServerMultipleConnectors(ctx, t, func(c *Config) {
httpServer, server := newTestServerMultipleConnectors(t, func(c *Config) {
c.SupportedResponseTypes = tc.supportedResponseTypes
c.Storage = storage.WithStaticClients(c.Storage, tc.clients)
})
@ -598,8 +594,9 @@ func TestValidRedirectURI(t *testing.T) {
}
func TestStorageKeySet(t *testing.T) {
logger := newLogger(t)
s := memory.New(logger)
if err := s.UpdateKeys(context.TODO(), func(keys storage.Keys) (storage.Keys, error) {
if err := s.UpdateKeys(t.Context(), func(keys storage.Keys) (storage.Keys, error) {
keys.SigningKey = &jose.JSONWebKey{
Key: testKey,
KeyID: "testkey",
@ -673,7 +670,7 @@ func TestStorageKeySet(t *testing.T) {
keySet := &storageKeySet{s}
_, err = keySet.VerifySignature(context.Background(), jwt)
_, err = keySet.VerifySignature(t.Context(), jwt)
if (err != nil && !tc.wantErr) || (err == nil && tc.wantErr) {
t.Fatalf("wantErr = %v, but got err = %v", tc.wantErr, err)
}

8
server/refreshhandlers_test.go

@ -2,7 +2,6 @@ package server
import (
"bytes"
"context"
"encoding/json"
"net/http"
"net/http/httptest"
@ -18,7 +17,7 @@ import (
)
func mockRefreshTokenTestStorage(t *testing.T, s storage.Storage, useObsolete bool) {
ctx := context.Background()
ctx := t.Context()
c := storage.Client{
ID: "test",
Secret: "barfoo",
@ -153,11 +152,8 @@ func TestRefreshTokenExpirationScenarios(t *testing.T) {
for _, tc := range tests {
t.Run(tc.name, func(*testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// Setup a dex server.
httpServer, s := newTestServer(ctx, t, func(c *Config) {
httpServer, s := newTestServer(t, func(c *Config) {
c.RefreshTokenPolicy = tc.policy
c.Now = func() time.Time { return t0 }
})

79
server/server_test.go

@ -9,7 +9,6 @@ import (
"errors"
"fmt"
"io"
"log/slog"
"net/http"
"net/http/httptest"
"net/http/httputil"
@ -76,14 +75,15 @@ FDWV28nTP9sqbtsmU8Tem2jzMvZ7C/Q0AuDoKELFUpux8shm8wfIhyaPnXUGZoAZ
Np4vUwMSYV5mopESLWOg3loBxKyLGFtgGKVCjGiQvy6zISQ4fQo=
-----END RSA PRIVATE KEY-----`)
var logger = slog.New(slog.DiscardHandler)
func newTestServer(ctx context.Context, t *testing.T, updateConfig func(c *Config)) (*httptest.Server, *Server) {
func newTestServer(t *testing.T, updateConfig func(c *Config)) (*httptest.Server, *Server) {
var server *Server
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
server.ServeHTTP(w, r)
}))
logger := newLogger(t)
ctx := t.Context()
config := Config{
Issuer: s.URL,
Storage: memory.New(logger),
@ -135,12 +135,15 @@ func newTestServer(ctx context.Context, t *testing.T, updateConfig func(c *Confi
return s, server
}
func newTestServerMultipleConnectors(ctx context.Context, t *testing.T, updateConfig func(c *Config)) (*httptest.Server, *Server) {
func newTestServerMultipleConnectors(t *testing.T, updateConfig func(c *Config)) (*httptest.Server, *Server) {
var server *Server
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
server.ServeHTTP(w, r)
}))
logger := newLogger(t)
ctx := t.Context()
config := Config{
Issuer: s.URL,
Storage: memory.New(logger),
@ -183,21 +186,16 @@ func newTestServerMultipleConnectors(ctx context.Context, t *testing.T, updateCo
}
func TestNewTestServer(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
newTestServer(ctx, t, nil)
newTestServer(t, nil)
}
func TestDiscovery(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
httpServer, _ := newTestServer(ctx, t, func(c *Config) {
httpServer, _ := newTestServer(t, func(c *Config) {
c.Issuer += "/non-root-path"
})
defer httpServer.Close()
p, err := oidc.NewProvider(ctx, httpServer.URL)
p, err := oidc.NewProvider(t.Context(), httpServer.URL)
if err != nil {
t.Fatalf("failed to get provider: %v", err)
}
@ -734,11 +732,10 @@ func TestOAuth2CodeFlow(t *testing.T) {
tests := makeOAuth2Tests(clientID, clientSecret, now)
for _, tc := range tests.tests {
t.Run(tc.name, func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
ctx := t.Context()
// Setup a dex server.
httpServer, s := newTestServer(ctx, t, func(c *Config) {
httpServer, s := newTestServer(t, func(c *Config) {
c.Issuer += "/non-root-path"
c.Now = now
c.IDTokensValidFor = idTokensValidFor
@ -890,10 +887,9 @@ func TestOAuth2CodeFlow(t *testing.T) {
}
func TestOAuth2ImplicitFlow(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
ctx := t.Context()
httpServer, s := newTestServer(ctx, t, func(c *Config) {
httpServer, s := newTestServer(t, func(c *Config) {
// Enable support for the implicit flow.
c.SupportedResponseTypes = []string{"code", "token", "id_token"}
})
@ -1026,10 +1022,9 @@ func TestOAuth2ImplicitFlow(t *testing.T) {
}
func TestCrossClientScopes(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
ctx := t.Context()
httpServer, s := newTestServer(ctx, t, func(c *Config) {
httpServer, s := newTestServer(t, func(c *Config) {
c.Issuer += "/non-root-path"
})
defer httpServer.Close()
@ -1149,10 +1144,9 @@ func TestCrossClientScopes(t *testing.T) {
}
func TestCrossClientScopesWithAzpInAudienceByDefault(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
ctx := t.Context()
httpServer, s := newTestServer(ctx, t, func(c *Config) {
httpServer, s := newTestServer(t, func(c *Config) {
c.Issuer += "/non-root-path"
})
defer httpServer.Close()
@ -1271,7 +1265,9 @@ func TestCrossClientScopesWithAzpInAudienceByDefault(t *testing.T) {
}
func TestPasswordDB(t *testing.T) {
ctx := context.Background()
ctx := t.Context()
logger := newLogger(t)
s := memory.New(logger)
conn := newPasswordDB(s)
@ -1323,7 +1319,7 @@ func TestPasswordDB(t *testing.T) {
}
for _, tc := range tests {
ident, valid, err := conn.Login(context.Background(), connector.Scopes{}, tc.username, tc.password)
ident, valid, err := conn.Login(t.Context(), connector.Scopes{}, tc.username, tc.password)
if err != nil {
if !tc.wantErr {
t.Errorf("%s: %v", tc.name, err)
@ -1355,6 +1351,7 @@ func TestPasswordDB(t *testing.T) {
}
func TestPasswordDBUsernamePrompt(t *testing.T) {
logger := newLogger(t)
s := memory.New(logger)
conn := newPasswordDB(s)
@ -1377,7 +1374,8 @@ func (s storageWithKeysTrigger) GetKeys(ctx context.Context) (storage.Keys, erro
func TestKeyCacher(t *testing.T) {
tNow := time.Now()
now := func() time.Time { return tNow }
ctx := context.TODO()
ctx := t.Context()
logger := newLogger(t)
s := memory.New(logger)
tests := []struct {
@ -1428,7 +1426,7 @@ func TestKeyCacher(t *testing.T) {
for i, tc := range tests {
gotCall = false
tc.before()
s.GetKeys(context.TODO())
s.GetKeys(t.Context())
if gotCall != tc.wantCallToStorage {
t.Errorf("case %d: expected call to storage=%t got call to storage=%t", i, tc.wantCallToStorage, gotCall)
}
@ -1470,10 +1468,10 @@ type oauth2Client struct {
func TestRefreshTokenFlow(t *testing.T) {
state := "state"
now := time.Now
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
httpServer, s := newTestServer(ctx, t, func(c *Config) {
ctx := t.Context()
httpServer, s := newTestServer(t, func(c *Config) {
c.Now = now
})
defer httpServer.Close()
@ -1604,11 +1602,10 @@ func TestOAuth2DeviceFlow(t *testing.T) {
for _, testCase := range testCases {
for _, tc := range testCase.oauth2Tests.tests {
t.Run(tc.name, func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
ctx := t.Context()
// Setup a dex server.
httpServer, s := newTestServer(ctx, t, func(c *Config) {
httpServer, s := newTestServer(t, func(c *Config) {
c.Issuer += "/non-root-path"
c.Now = now
c.IDTokensValidFor = idTokensValidFor
@ -1789,17 +1786,16 @@ func TestServerSupportedGrants(t *testing.T) {
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
_, srv := newTestServer(context.TODO(), t, tc.config)
_, srv := newTestServer(t, tc.config)
require.Equal(t, tc.resGrants, srv.supportedGrantTypes)
})
}
}
func TestHeaders(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
ctx := t.Context()
httpServer, _ := newTestServer(ctx, t, func(c *Config) {
httpServer, _ := newTestServer(t, func(c *Config) {
c.Headers = map[string][]string{
"Strict-Transport-Security": {"max-age=31536000; includeSubDomains"},
}
@ -1818,8 +1814,7 @@ func TestHeaders(t *testing.T) {
}
func TestConnectorFailureHandling(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
ctx := t.Context()
tests := []struct {
name string
@ -1959,6 +1954,8 @@ func TestConnectorFailureHandling(t *testing.T) {
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
logger := newLogger(t)
config := Config{
Issuer: "http://localhost",
Storage: memory.New(logger),

28
storage/conformance/conformance.go

@ -24,10 +24,10 @@ type subTest struct {
run func(t *testing.T, s storage.Storage)
}
func runTests(t *testing.T, newStorage func() storage.Storage, tests []subTest) {
func runTests(t *testing.T, newStorage func(t *testing.T) storage.Storage, tests []subTest) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
s := newStorage()
s := newStorage(t)
test.run(t, s)
s.Close()
})
@ -37,7 +37,7 @@ func runTests(t *testing.T, newStorage func() storage.Storage, tests []subTest)
// RunTests runs a set of conformance tests against a storage. newStorage should
// return an initialized but empty storage. The storage will be closed at the
// end of each test run.
func RunTests(t *testing.T, newStorage func() storage.Storage) {
func RunTests(t *testing.T, newStorage func(t *testing.T) storage.Storage) {
runTests(t, newStorage, []subTest{
{"AuthCodeCRUD", testAuthCodeCRUD},
{"AuthRequestCRUD", testAuthRequestCRUD},
@ -81,7 +81,7 @@ func mustBeErrAlreadyExists(t *testing.T, kind string, err error) {
}
func testAuthRequestCRUD(t *testing.T, s storage.Storage) {
ctx := context.Background()
ctx := t.Context()
codeChallenge := storage.PKCE{
CodeChallenge: "code_challenge_test",
CodeChallengeMethod: "plain",
@ -181,7 +181,7 @@ func testAuthRequestCRUD(t *testing.T, s storage.Storage) {
}
func testAuthCodeCRUD(t *testing.T, s storage.Storage) {
ctx := context.Background()
ctx := t.Context()
a1 := storage.AuthCode{
ID: storage.NewID(),
ClientID: "client1",
@ -259,7 +259,7 @@ func testAuthCodeCRUD(t *testing.T, s storage.Storage) {
}
func testClientCRUD(t *testing.T, s storage.Storage) {
ctx := context.Background()
ctx := t.Context()
id1 := storage.NewID()
c1 := storage.Client{
ID: id1,
@ -329,7 +329,7 @@ func testClientCRUD(t *testing.T, s storage.Storage) {
}
func testRefreshTokenCRUD(t *testing.T, s storage.Storage) {
ctx := context.Background()
ctx := t.Context()
id := storage.NewID()
refresh := storage.RefreshToken{
ID: id,
@ -448,7 +448,7 @@ func (n byEmail) Less(i, j int) bool { return n[i].Email < n[j].Email }
func (n byEmail) Swap(i, j int) { n[i], n[j] = n[j], n[i] }
func testPasswordCRUD(t *testing.T, s storage.Storage) {
ctx := context.Background()
ctx := t.Context()
// Use bcrypt.MinCost to keep the tests short.
passwordHash1, err := bcrypt.GenerateFromPassword([]byte("secret"), bcrypt.MinCost)
if err != nil {
@ -539,7 +539,7 @@ func testPasswordCRUD(t *testing.T, s storage.Storage) {
}
func testOfflineSessionCRUD(t *testing.T, s storage.Storage) {
ctx := context.Background()
ctx := t.Context()
userID1 := storage.NewID()
session1 := storage.OfflineSessions{
UserID: userID1,
@ -614,7 +614,7 @@ func testOfflineSessionCRUD(t *testing.T, s storage.Storage) {
}
func testConnectorCRUD(t *testing.T, s storage.Storage) {
ctx := context.Background()
ctx := t.Context()
id1 := storage.NewID()
config1 := []byte(`{"issuer": "https://accounts.google.com"}`)
c1 := storage.Connector{
@ -754,7 +754,7 @@ func testKeysCRUD(t *testing.T, s storage.Storage) {
}
func testGC(t *testing.T, s storage.Storage) {
ctx := context.Background()
ctx := t.Context()
est, err := time.LoadLocation("America/New_York")
if err != nil {
t.Fatal(err)
@ -942,7 +942,7 @@ func testGC(t *testing.T, s storage.Storage) {
// testTimezones tests that backends either fully support timezones or
// do the correct standardization.
func testTimezones(t *testing.T, s storage.Storage) {
ctx := context.Background()
ctx := t.Context()
est, err := time.LoadLocation("America/New_York")
if err != nil {
t.Fatal(err)
@ -987,7 +987,7 @@ func testTimezones(t *testing.T, s storage.Storage) {
}
func testDeviceRequestCRUD(t *testing.T, s storage.Storage) {
ctx := context.Background()
ctx := t.Context()
d1 := storage.DeviceRequest{
UserCode: storage.NewUserCode(),
DeviceCode: storage.NewID(),
@ -1017,7 +1017,7 @@ func testDeviceRequestCRUD(t *testing.T, s storage.Storage) {
}
func testDeviceTokenCRUD(t *testing.T, s storage.Storage) {
ctx := context.Background()
ctx := t.Context()
codeChallenge := storage.PKCE{
CodeChallenge: "code_challenge_test",
CodeChallengeMethod: "plain",

8
storage/conformance/transactions.go

@ -17,7 +17,7 @@ import (
// This call is separate from RunTests because some storage perform extremely
// poorly under deadlocks, such as SQLite3, while others may be working towards
// conformance.
func RunTransactionTests(t *testing.T, newStorage func() storage.Storage) {
func RunTransactionTests(t *testing.T, newStorage func(t *testing.T) storage.Storage) {
runTests(t, newStorage, []subTest{
{"AuthRequestConcurrentUpdate", testAuthRequestConcurrentUpdate},
{"ClientConcurrentUpdate", testClientConcurrentUpdate},
@ -27,7 +27,7 @@ func RunTransactionTests(t *testing.T, newStorage func() storage.Storage) {
}
func testClientConcurrentUpdate(t *testing.T, s storage.Storage) {
ctx := context.Background()
ctx := t.Context()
c := storage.Client{
ID: storage.NewID(),
Secret: "foobar",
@ -57,7 +57,7 @@ func testClientConcurrentUpdate(t *testing.T, s storage.Storage) {
}
func testAuthRequestConcurrentUpdate(t *testing.T, s storage.Storage) {
ctx := context.Background()
ctx := t.Context()
a := storage.AuthRequest{
ID: storage.NewID(),
ClientID: "foobar",
@ -102,7 +102,7 @@ func testAuthRequestConcurrentUpdate(t *testing.T, s storage.Storage) {
}
func testPasswordConcurrentUpdate(t *testing.T, s storage.Storage) {
ctx := context.Background()
ctx := t.Context()
// Use bcrypt.MinCost to keep the tests short.
passwordHash, err := bcrypt.GenerateFromPassword([]byte("secret"), bcrypt.MinCost)
if err != nil {

8
storage/ent/mysql_test.go

@ -40,8 +40,8 @@ func mysqlTestConfig(host string, port uint64) *MySQL {
}
}
func newMySQLStorage(host string, port uint64) storage.Storage {
logger := slog.New(slog.DiscardHandler)
func newMySQLStorage(t *testing.T, host string, port uint64) storage.Storage {
logger := slog.New(slog.NewTextHandler(t.Output(), &slog.HandlerOptions{Level: slog.LevelDebug}))
cfg := mysqlTestConfig(host, port)
s, err := cfg.Open(logger)
@ -65,8 +65,8 @@ func TestMySQL(t *testing.T) {
require.NoError(t, err, "invalid mysql port %q: %s", rawPort, err)
}
newStorage := func() storage.Storage {
return newMySQLStorage(host, port)
newStorage := func(t *testing.T) storage.Storage {
return newMySQLStorage(t, host, port)
}
conformance.RunTests(t, newStorage)
conformance.RunTransactionTests(t, newStorage)

8
storage/ent/postgres_test.go

@ -35,8 +35,8 @@ func postgresTestConfig(host string, port uint64) *Postgres {
}
}
func newPostgresStorage(host string, port uint64) storage.Storage {
logger := slog.New(slog.DiscardHandler)
func newPostgresStorage(t *testing.T, host string, port uint64) storage.Storage {
logger := slog.New(slog.NewTextHandler(t.Output(), &slog.HandlerOptions{Level: slog.LevelDebug}))
cfg := postgresTestConfig(host, port)
s, err := cfg.Open(logger)
@ -60,8 +60,8 @@ func TestPostgres(t *testing.T) {
require.NoError(t, err, "invalid postgres port %q: %s", rawPort, err)
}
newStorage := func() storage.Storage {
return newPostgresStorage(host, port)
newStorage := func(t *testing.T) storage.Storage {
return newPostgresStorage(t, host, port)
}
conformance.RunTests(t, newStorage)
conformance.RunTransactionTests(t, newStorage)

4
storage/ent/sqlite_test.go

@ -8,8 +8,8 @@ import (
"github.com/dexidp/dex/storage/conformance"
)
func newSQLiteStorage() storage.Storage {
logger := slog.New(slog.DiscardHandler)
func newSQLiteStorage(t *testing.T) storage.Storage {
logger := slog.New(slog.NewTextHandler(t.Output(), &slog.HandlerOptions{Level: slog.LevelDebug}))
cfg := SQLite3{File: ":memory:"}
s, err := cfg.Open(logger)

5
storage/etcd/etcd_test.go

@ -55,8 +55,6 @@ func cleanDB(c *conn) error {
return nil
}
var logger = slog.New(slog.DiscardHandler)
func TestEtcd(t *testing.T) {
testEtcdEnv := "DEX_ETCD_ENDPOINTS"
endpointsStr := os.Getenv(testEtcdEnv)
@ -66,10 +64,11 @@ func TestEtcd(t *testing.T) {
}
endpoints := strings.Split(endpointsStr, ",")
newStorage := func() storage.Storage {
newStorage := func(t *testing.T) storage.Storage {
s := &Etcd{
Endpoints: endpoints,
}
logger := slog.New(slog.NewTextHandler(t.Output(), &slog.HandlerOptions{Level: slog.LevelDebug}))
conn, err := s.open(logger)
if err != nil {
fmt.Fprintln(os.Stdout, err)

4
storage/kubernetes/storage_test.go

@ -57,7 +57,7 @@ func (s *StorageTestSuite) SetupTest() {
KubeConfigFile: kubeconfigPath,
}
logger := slog.New(slog.DiscardHandler)
logger := slog.New(slog.NewTextHandler(s.T().Output(), &slog.HandlerOptions{Level: slog.LevelDebug}))
kubeClient, err := config.open(logger, true)
s.Require().NoError(err)
@ -66,7 +66,7 @@ func (s *StorageTestSuite) SetupTest() {
}
func (s *StorageTestSuite) TestStorage() {
newStorage := func() storage.Storage {
newStorage := func(t *testing.T) storage.Storage {
for _, resource := range []string{
resourceAuthCode,
resourceAuthRequest,

4
storage/memory/memory_test.go

@ -9,9 +9,9 @@ import (
)
func TestStorage(t *testing.T) {
logger := slog.New(slog.DiscardHandler)
newStorage := func(t *testing.T) storage.Storage {
logger := slog.New(slog.NewTextHandler(t.Output(), &slog.HandlerOptions{Level: slog.LevelDebug}))
newStorage := func() storage.Storage {
return New(logger)
}
conformance.RunTests(t, newStorage)

7
storage/sql/config_test.go

@ -46,20 +46,19 @@ func cleanDB(c *conn) error {
return nil
}
var logger = slog.New(slog.DiscardHandler)
type opener interface {
open(logger *slog.Logger) (*conn, error)
}
func testDB(t *testing.T, o opener, withTransactions bool) {
// t.Fatal has a bad habit of not actually printing the error
fatal := func(i interface{}) {
fatal := func(i any) {
fmt.Fprintln(os.Stdout, i)
t.Fatal(i)
}
newStorage := func() storage.Storage {
newStorage := func(t *testing.T) storage.Storage {
logger := slog.New(slog.NewTextHandler(t.Output(), &slog.HandlerOptions{Level: slog.LevelDebug}))
conn, err := o.open(logger)
if err != nil {
fatal(err)

3
storage/sql/postgres_test.go

@ -4,6 +4,7 @@
package sql
import (
"log/slog"
"os"
"strconv"
"testing"
@ -40,6 +41,7 @@ func TestPostgresTunables(t *testing.T) {
t.Run("with nothing set, uses defaults", func(t *testing.T) {
cfg := *baseCfg
logger := slog.New(slog.NewTextHandler(t.Output(), &slog.HandlerOptions{Level: slog.LevelDebug}))
c, err := cfg.open(logger)
if err != nil {
t.Fatalf("error opening connector: %s", err.Error())
@ -53,6 +55,7 @@ func TestPostgresTunables(t *testing.T) {
t.Run("with something set, uses that", func(t *testing.T) {
cfg := *baseCfg
cfg.MaxOpenConns = 101
logger := slog.New(slog.NewTextHandler(t.Output(), &slog.HandlerOptions{Level: slog.LevelDebug}))
c, err := cfg.open(logger)
if err != nil {
t.Fatalf("error opening connector: %s", err.Error())

Loading…
Cancel
Save