From 3d97c59032d4feeea731eab8728caa2b13d7a674 Mon Sep 17 00:00:00 2001 From: Maksim Nabokikh Date: Tue, 10 Mar 2026 15:55:10 +0100 Subject: [PATCH] test: add concurrency tests for storage implementations (#4631) Signed-off-by: maksim.nabokikh --- storage/conformance/transactions.go | 121 ++++++++++++++++++++++++++++ storage/ent/mysql_test.go | 10 +++ storage/ent/postgres_test.go | 5 ++ storage/ent/sqlite_test.go | 1 + storage/etcd/etcd_test.go | 7 ++ storage/kubernetes/lock.go | 43 +++++++--- storage/kubernetes/storage_test.go | 1 + storage/memory/memory_test.go | 1 + storage/sql/config_test.go | 15 +++- storage/sql/sqlite_test.go | 2 +- 10 files changed, 191 insertions(+), 15 deletions(-) diff --git a/storage/conformance/transactions.go b/storage/conformance/transactions.go index 1383c8e7..5889a024 100644 --- a/storage/conformance/transactions.go +++ b/storage/conformance/transactions.go @@ -2,9 +2,12 @@ package conformance import ( "context" + "strconv" + "sync" "testing" "time" + "github.com/stretchr/testify/require" "golang.org/x/crypto/bcrypt" "github.com/dexidp/dex/storage" @@ -26,6 +29,16 @@ func RunTransactionTests(t *testing.T, newStorage func(t *testing.T) storage.Sto }) } +// RunConcurrencyTests runs tests that verify storage implementations handle +// high-contention parallel updates correctly. Unlike RunTransactionTests, +// these tests use real goroutine-based parallelism rather than nested calls, +// and are safe to run on all storage backends (including those with non-reentrant locks). +func RunConcurrencyTests(t *testing.T, newStorage func(t *testing.T) storage.Storage) { + runTests(t, newStorage, []subTest{ + {"RefreshTokenParallelUpdate", testRefreshTokenParallelUpdate}, + }) +} + func testClientConcurrentUpdate(t *testing.T, s storage.Storage) { ctx := t.Context() c := storage.Client{ @@ -180,3 +193,111 @@ func testKeysConcurrentUpdate(t *testing.T, s storage.Storage) { } } } + +// testRefreshTokenParallelUpdate tests that many parallel updates to the same +// refresh token are serialized correctly by the storage and no updates are lost. +// +// Each goroutine atomically increments a counter stored in the Token field. +// After all goroutines finish, the counter must equal the number of successful updates. +// A mismatch indicates lost updates due to broken atomicity. +func testRefreshTokenParallelUpdate(t *testing.T, s storage.Storage) { + ctx := t.Context() + + id := storage.NewID() + refresh := storage.RefreshToken{ + ID: id, + Token: "0", + Nonce: "foo", + ClientID: "client_id", + ConnectorID: "connector_id", + Scopes: []string{"openid"}, + CreatedAt: time.Now().UTC().Round(time.Millisecond), + LastUsed: time.Now().UTC().Round(time.Millisecond), + Claims: storage.Claims{ + UserID: "1", + Username: "jane", + Email: "jane@example.com", + }, + } + + require.NoError(t, s.CreateRefresh(ctx, refresh)) + + const numWorkers = 100 + + type updateResult struct { + err error + newToken string // token value written by this worker's updater + } + + var wg sync.WaitGroup + results := make([]updateResult, numWorkers) + + for i := range numWorkers { + wg.Add(1) + go func() { + defer wg.Done() + results[i].err = s.UpdateRefreshToken(ctx, id, func(old storage.RefreshToken) (storage.RefreshToken, error) { + counter, _ := strconv.Atoi(old.Token) + old.Token = strconv.Itoa(counter + 1) + results[i].newToken = old.Token + return old, nil + }) + }() + } + + wg.Wait() + + errCounts := map[string]int{} + var successes int + writtenTokens := map[string]int{} + for _, r := range results { + if r.err == nil { + successes++ + writtenTokens[r.newToken]++ + } else { + errCounts[r.err.Error()]++ + } + } + + for msg, count := range errCounts { + t.Logf("error (x%d): %s", count, msg) + } + + stored, err := s.GetRefresh(ctx, id) + require.NoError(t, err) + + counter, err := strconv.Atoi(stored.Token) + require.NoError(t, err) + + t.Logf("parallel refresh token updates: %d/%d succeeded, final counter: %d", successes, numWorkers, counter) + + if successes < numWorkers { + t.Errorf("not all updates succeeded: %d/%d (some failed under contention)", successes, numWorkers) + } + + if counter != successes { + t.Errorf("lost updates detected: %d successful updates but counter is %d", successes, counter) + } + + // Each successful updater must have seen a unique counter value. + // Duplicates would mean two updaters read the same state — a sign of broken atomicity. + for token, count := range writtenTokens { + if count > 1 { + t.Errorf("token %q was written by %d updaters — concurrent updaters saw the same state", token, count) + } + } + + // Successful updaters must have produced a contiguous sequence 1..N. + // A gap would mean an updater saw stale state even though the write succeeded. + for i := 1; i <= successes; i++ { + if writtenTokens[strconv.Itoa(i)] != 1 { + t.Errorf("expected token %q to be written exactly once, got %d", strconv.Itoa(i), writtenTokens[strconv.Itoa(i)]) + } + } + + // The token stored in the database must match the highest value written. + // This confirms that the last successful update is the one persisted. + if stored.Token != strconv.Itoa(successes) { + t.Errorf("stored token %q does not match expected final value %q", stored.Token, strconv.Itoa(successes)) + } +} diff --git a/storage/ent/mysql_test.go b/storage/ent/mysql_test.go index b602e4a5..d7a06ffa 100644 --- a/storage/ent/mysql_test.go +++ b/storage/ent/mysql_test.go @@ -105,6 +105,11 @@ func TestMySQL(t *testing.T) { } conformance.RunTests(t, newStorage) conformance.RunTransactionTests(t, newStorage) + + // TODO(nabokihms): ent MySQL does not retry on deadlocks (Error 1213, SQLSTATE 40001: + // Deadlock found when trying to get lock; try restarting transaction). + // Under high contention most updates fail. + // conformance.RunConcurrencyTests(t, newStorage) } func TestMySQL8(t *testing.T) { @@ -126,6 +131,11 @@ func TestMySQL8(t *testing.T) { } conformance.RunTests(t, newStorage) conformance.RunTransactionTests(t, newStorage) + + // TODO(nabokihms): ent MySQL 8 does not retry on deadlocks (Error 1213, SQLSTATE 40001: + // Deadlock found when trying to get lock; try restarting transaction). + // Under high contention most updates fail. + // conformance.RunConcurrencyTests(t, newStorage) } func TestMySQLDSN(t *testing.T) { diff --git a/storage/ent/postgres_test.go b/storage/ent/postgres_test.go index bbbde38e..b53ed382 100644 --- a/storage/ent/postgres_test.go +++ b/storage/ent/postgres_test.go @@ -65,6 +65,11 @@ func TestPostgres(t *testing.T) { } conformance.RunTests(t, newStorage) conformance.RunTransactionTests(t, newStorage) + + // TODO(nabokihms): ent Postgres uses SERIALIZABLE transaction isolation for UpdateRefreshToken, + // but does not retry on serialization failures (pq: could not serialize access due to + // concurrent update, SQLSTATE 40001). Under high contention most updates fail immediately. + // conformance.RunConcurrencyTests(t, newStorage) } func TestPostgresDSN(t *testing.T) { diff --git a/storage/ent/sqlite_test.go b/storage/ent/sqlite_test.go index 55c1b5c5..54638c19 100644 --- a/storage/ent/sqlite_test.go +++ b/storage/ent/sqlite_test.go @@ -21,4 +21,5 @@ func newSQLiteStorage(t *testing.T) storage.Storage { func TestSQLite3(t *testing.T) { conformance.RunTests(t, newSQLiteStorage) + conformance.RunConcurrencyTests(t, newSQLiteStorage) } diff --git a/storage/etcd/etcd_test.go b/storage/etcd/etcd_test.go index 6783c25b..55501b49 100644 --- a/storage/etcd/etcd_test.go +++ b/storage/etcd/etcd_test.go @@ -89,4 +89,11 @@ func TestEtcd(t *testing.T) { withTimeout(time.Minute*1, func() { conformance.RunTransactionTests(t, newStorage) }) + + // TODO(nabokihms): etcd uses compare-and-swap (txnUpdate) for UpdateRefreshToken, + // but does not retry on CAS conflicts ("concurrent conflicting update happened"). + // Under high contention virtually all updates fail — only the first writer succeeds. + // withTimeout(time.Minute*1, func() { + // conformance.RunConcurrencyTests(t, newStorage) + // }) } diff --git a/storage/kubernetes/lock.go b/storage/kubernetes/lock.go index c67380dc..e3bdb444 100644 --- a/storage/kubernetes/lock.go +++ b/storage/kubernetes/lock.go @@ -22,8 +22,18 @@ var ( // - Some of OIDC providers could use the refresh token rotation feature which requires calling refresh only once. // - Providers can limit the rate of requests to the token endpoint, which will lead to the error // in case of many concurrent requests. +// +// The lock uses a Kubernetes annotation on the refresh token resource as a mutex. +// Only one goroutine can hold the lock at a time; others poll until the annotation +// is removed (unlocked) or expires (broken). The Kubernetes resourceVersion on put +// acts as compare-and-swap: if two goroutines race to set the annotation, only one +// succeeds and the other gets a 409 Conflict. type refreshTokenLock struct { - cli *client + cli *client + // waitingState tracks whether this lock instance has lost a compare-and-swap race + // and is now polling for the lock to be released. Used by Unlock to skip the + // annotation removal — only the goroutine that successfully wrote the annotation + // should remove it. waitingState bool } @@ -31,6 +41,8 @@ func newRefreshTokenLock(cli *client) *refreshTokenLock { return &refreshTokenLock{cli: cli} } +// Lock polls until the lock annotation can be set on the refresh token resource. +// Returns nil when the lock is acquired, or an error on timeout (60 attempts × 100ms). func (l *refreshTokenLock) Lock(id string) error { for i := 0; i <= 60; i++ { ok, err := l.setLockAnnotation(id) @@ -45,9 +57,12 @@ func (l *refreshTokenLock) Lock(id string) error { return fmt.Errorf("timeout waiting for refresh token %s lock", id) } +// Unlock removes the lock annotation from the refresh token resource. +// Only the holder of the lock (waitingState == false) performs the removal. func (l *refreshTokenLock) Unlock(id string) { if l.waitingState { - // Do not need to unlock for waiting goroutines, because the have not set it. + // This goroutine never successfully wrote the annotation, so there's + // nothing to remove. Another goroutine holds (or held) the lock. return } @@ -64,6 +79,13 @@ func (l *refreshTokenLock) Unlock(id string) { } } +// setLockAnnotation attempts to acquire the lock by writing an annotation with +// an expiration timestamp. Returns (true, nil) when the caller should keep waiting, +// (false, nil) when the lock is acquired, or (false, err) on a non-retriable error. +// +// The locking protocol relies on Kubernetes optimistic concurrency: every put +// includes the resource's current resourceVersion, so concurrent writes to the +// same object result in a 409 Conflict for all but one writer. func (l *refreshTokenLock) setLockAnnotation(id string) (bool, error) { r, err := l.cli.getRefreshToken(id) if err != nil { @@ -77,13 +99,14 @@ func (l *refreshTokenLock) setLockAnnotation(id string) (bool, error) { val, ok := r.Annotations[lockAnnotation] if !ok { - if l.waitingState { - return false, nil - } - + // No annotation means the lock is free. Every goroutine — whether it's + // a first-time caller or was previously waiting — must compete by writing + // the annotation. The put uses the current resourceVersion, so only one + // writer succeeds; the rest get a 409 Conflict and go back to polling. r.Annotations = lockData err := l.cli.put(resourceRefreshToken, r.ObjectMeta.Name, r) if err == nil { + l.waitingState = false return false, nil } @@ -100,24 +123,24 @@ func (l *refreshTokenLock) setLockAnnotation(id string) (bool, error) { } if !currentTime.After(until) { - // waiting for the lock to be released + // Lock is held by another goroutine and has not expired yet — keep polling. l.waitingState = true return true, nil } - // Lock time is out, lets break the lock and take the advantage + // Lock has expired (holder crashed or is too slow). Attempt to break it by + // overwriting the annotation with a new expiration. Again, only one writer + // can win the compare-and-swap race. r.Annotations = lockData err = l.cli.put(resourceRefreshToken, r.ObjectMeta.Name, r) if err == nil { - // break lock annotation return false, nil } l.cli.logger.Debug("break lock annotation", "error", err) if isKubernetesAPIConflictError(err) { l.waitingState = true - // after breaking error waiting for the lock to be released return true, nil } return false, err diff --git a/storage/kubernetes/storage_test.go b/storage/kubernetes/storage_test.go index 98ef25fa..906e5ce5 100644 --- a/storage/kubernetes/storage_test.go +++ b/storage/kubernetes/storage_test.go @@ -85,6 +85,7 @@ func (s *StorageTestSuite) TestStorage() { } conformance.RunTests(s.T(), newStorage) + conformance.RunConcurrencyTests(s.T(), newStorage) conformance.RunTransactionTests(s.T(), newStorage) } diff --git a/storage/memory/memory_test.go b/storage/memory/memory_test.go index e6e8232f..acd0be1b 100644 --- a/storage/memory/memory_test.go +++ b/storage/memory/memory_test.go @@ -15,4 +15,5 @@ func TestStorage(t *testing.T) { return New(logger) } conformance.RunTests(t, newStorage) + conformance.RunConcurrencyTests(t, newStorage) } diff --git a/storage/sql/config_test.go b/storage/sql/config_test.go index 606c95a8..c098919b 100644 --- a/storage/sql/config_test.go +++ b/storage/sql/config_test.go @@ -50,7 +50,7 @@ type opener interface { open(logger *slog.Logger) (*conn, error) } -func testDB(t *testing.T, o opener, withTransactions bool) { +func testDB(t *testing.T, o opener, withTransactions, withConcurrentTests bool) { // t.Fatal has a bad habit of not actually printing the error fatal := func(i any) { fmt.Fprintln(os.Stdout, i) @@ -71,11 +71,18 @@ func testDB(t *testing.T, o opener, withTransactions bool) { withTimeout(time.Minute*1, func() { conformance.RunTests(t, newStorage) }) + if withTransactions { withTimeout(time.Minute*1, func() { conformance.RunTransactionTests(t, newStorage) }) } + + if withConcurrentTests { + withTimeout(time.Minute*1, func() { + conformance.RunConcurrencyTests(t, newStorage) + }) + } } func getenv(key, defaultVal string) string { @@ -236,7 +243,7 @@ func TestPostgres(t *testing.T) { Mode: pgSSLDisable, // Postgres container doesn't support SSL. }, } - testDB(t, p, true) + testDB(t, p, true, false) } const testMySQLEnv = "DEX_MYSQL_HOST" @@ -273,7 +280,7 @@ func TestMySQL(t *testing.T) { "innodb_lock_wait_timeout": "3", }, } - testDB(t, s, true) + testDB(t, s, true, false) } const testMySQL8Env = "DEX_MYSQL8_HOST" @@ -310,5 +317,5 @@ func TestMySQL8(t *testing.T) { "innodb_lock_wait_timeout": "3", }, } - testDB(t, s, true) + testDB(t, s, true, false) } diff --git a/storage/sql/sqlite_test.go b/storage/sql/sqlite_test.go index 89d06aee..e21cbcc2 100644 --- a/storage/sql/sqlite_test.go +++ b/storage/sql/sqlite_test.go @@ -8,5 +8,5 @@ import ( ) func TestSQLite3(t *testing.T) { - testDB(t, &SQLite3{":memory:"}, false) + testDB(t, &SQLite3{":memory:"}, false, true) }