Browse Source

test: add concurrency tests for storage implementations (#4631)

Signed-off-by: maksim.nabokikh <max.nabokih@gmail.com>
pull/3818/merge
Maksim Nabokikh 7 days ago committed by GitHub
parent
commit
3d97c59032
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 121
      storage/conformance/transactions.go
  2. 10
      storage/ent/mysql_test.go
  3. 5
      storage/ent/postgres_test.go
  4. 1
      storage/ent/sqlite_test.go
  5. 7
      storage/etcd/etcd_test.go
  6. 41
      storage/kubernetes/lock.go
  7. 1
      storage/kubernetes/storage_test.go
  8. 1
      storage/memory/memory_test.go
  9. 15
      storage/sql/config_test.go
  10. 2
      storage/sql/sqlite_test.go

121
storage/conformance/transactions.go

@ -2,9 +2,12 @@ package conformance
import ( import (
"context" "context"
"strconv"
"sync"
"testing" "testing"
"time" "time"
"github.com/stretchr/testify/require"
"golang.org/x/crypto/bcrypt" "golang.org/x/crypto/bcrypt"
"github.com/dexidp/dex/storage" "github.com/dexidp/dex/storage"
@ -26,6 +29,16 @@ func RunTransactionTests(t *testing.T, newStorage func(t *testing.T) storage.Sto
}) })
} }
// RunConcurrencyTests runs tests that verify storage implementations handle
// high-contention parallel updates correctly. Unlike RunTransactionTests,
// these tests use real goroutine-based parallelism rather than nested calls,
// and are safe to run on all storage backends (including those with non-reentrant locks).
func RunConcurrencyTests(t *testing.T, newStorage func(t *testing.T) storage.Storage) {
runTests(t, newStorage, []subTest{
{"RefreshTokenParallelUpdate", testRefreshTokenParallelUpdate},
})
}
func testClientConcurrentUpdate(t *testing.T, s storage.Storage) { func testClientConcurrentUpdate(t *testing.T, s storage.Storage) {
ctx := t.Context() ctx := t.Context()
c := storage.Client{ c := storage.Client{
@ -180,3 +193,111 @@ func testKeysConcurrentUpdate(t *testing.T, s storage.Storage) {
} }
} }
} }
// testRefreshTokenParallelUpdate tests that many parallel updates to the same
// refresh token are serialized correctly by the storage and no updates are lost.
//
// Each goroutine atomically increments a counter stored in the Token field.
// After all goroutines finish, the counter must equal the number of successful updates.
// A mismatch indicates lost updates due to broken atomicity.
func testRefreshTokenParallelUpdate(t *testing.T, s storage.Storage) {
ctx := t.Context()
id := storage.NewID()
refresh := storage.RefreshToken{
ID: id,
Token: "0",
Nonce: "foo",
ClientID: "client_id",
ConnectorID: "connector_id",
Scopes: []string{"openid"},
CreatedAt: time.Now().UTC().Round(time.Millisecond),
LastUsed: time.Now().UTC().Round(time.Millisecond),
Claims: storage.Claims{
UserID: "1",
Username: "jane",
Email: "jane@example.com",
},
}
require.NoError(t, s.CreateRefresh(ctx, refresh))
const numWorkers = 100
type updateResult struct {
err error
newToken string // token value written by this worker's updater
}
var wg sync.WaitGroup
results := make([]updateResult, numWorkers)
for i := range numWorkers {
wg.Add(1)
go func() {
defer wg.Done()
results[i].err = s.UpdateRefreshToken(ctx, id, func(old storage.RefreshToken) (storage.RefreshToken, error) {
counter, _ := strconv.Atoi(old.Token)
old.Token = strconv.Itoa(counter + 1)
results[i].newToken = old.Token
return old, nil
})
}()
}
wg.Wait()
errCounts := map[string]int{}
var successes int
writtenTokens := map[string]int{}
for _, r := range results {
if r.err == nil {
successes++
writtenTokens[r.newToken]++
} else {
errCounts[r.err.Error()]++
}
}
for msg, count := range errCounts {
t.Logf("error (x%d): %s", count, msg)
}
stored, err := s.GetRefresh(ctx, id)
require.NoError(t, err)
counter, err := strconv.Atoi(stored.Token)
require.NoError(t, err)
t.Logf("parallel refresh token updates: %d/%d succeeded, final counter: %d", successes, numWorkers, counter)
if successes < numWorkers {
t.Errorf("not all updates succeeded: %d/%d (some failed under contention)", successes, numWorkers)
}
if counter != successes {
t.Errorf("lost updates detected: %d successful updates but counter is %d", successes, counter)
}
// Each successful updater must have seen a unique counter value.
// Duplicates would mean two updaters read the same state — a sign of broken atomicity.
for token, count := range writtenTokens {
if count > 1 {
t.Errorf("token %q was written by %d updaters — concurrent updaters saw the same state", token, count)
}
}
// Successful updaters must have produced a contiguous sequence 1..N.
// A gap would mean an updater saw stale state even though the write succeeded.
for i := 1; i <= successes; i++ {
if writtenTokens[strconv.Itoa(i)] != 1 {
t.Errorf("expected token %q to be written exactly once, got %d", strconv.Itoa(i), writtenTokens[strconv.Itoa(i)])
}
}
// The token stored in the database must match the highest value written.
// This confirms that the last successful update is the one persisted.
if stored.Token != strconv.Itoa(successes) {
t.Errorf("stored token %q does not match expected final value %q", stored.Token, strconv.Itoa(successes))
}
}

10
storage/ent/mysql_test.go

@ -105,6 +105,11 @@ func TestMySQL(t *testing.T) {
} }
conformance.RunTests(t, newStorage) conformance.RunTests(t, newStorage)
conformance.RunTransactionTests(t, newStorage) conformance.RunTransactionTests(t, newStorage)
// TODO(nabokihms): ent MySQL does not retry on deadlocks (Error 1213, SQLSTATE 40001:
// Deadlock found when trying to get lock; try restarting transaction).
// Under high contention most updates fail.
// conformance.RunConcurrencyTests(t, newStorage)
} }
func TestMySQL8(t *testing.T) { func TestMySQL8(t *testing.T) {
@ -126,6 +131,11 @@ func TestMySQL8(t *testing.T) {
} }
conformance.RunTests(t, newStorage) conformance.RunTests(t, newStorage)
conformance.RunTransactionTests(t, newStorage) conformance.RunTransactionTests(t, newStorage)
// TODO(nabokihms): ent MySQL 8 does not retry on deadlocks (Error 1213, SQLSTATE 40001:
// Deadlock found when trying to get lock; try restarting transaction).
// Under high contention most updates fail.
// conformance.RunConcurrencyTests(t, newStorage)
} }
func TestMySQLDSN(t *testing.T) { func TestMySQLDSN(t *testing.T) {

5
storage/ent/postgres_test.go

@ -65,6 +65,11 @@ func TestPostgres(t *testing.T) {
} }
conformance.RunTests(t, newStorage) conformance.RunTests(t, newStorage)
conformance.RunTransactionTests(t, newStorage) conformance.RunTransactionTests(t, newStorage)
// TODO(nabokihms): ent Postgres uses SERIALIZABLE transaction isolation for UpdateRefreshToken,
// but does not retry on serialization failures (pq: could not serialize access due to
// concurrent update, SQLSTATE 40001). Under high contention most updates fail immediately.
// conformance.RunConcurrencyTests(t, newStorage)
} }
func TestPostgresDSN(t *testing.T) { func TestPostgresDSN(t *testing.T) {

1
storage/ent/sqlite_test.go

@ -21,4 +21,5 @@ func newSQLiteStorage(t *testing.T) storage.Storage {
func TestSQLite3(t *testing.T) { func TestSQLite3(t *testing.T) {
conformance.RunTests(t, newSQLiteStorage) conformance.RunTests(t, newSQLiteStorage)
conformance.RunConcurrencyTests(t, newSQLiteStorage)
} }

7
storage/etcd/etcd_test.go

@ -89,4 +89,11 @@ func TestEtcd(t *testing.T) {
withTimeout(time.Minute*1, func() { withTimeout(time.Minute*1, func() {
conformance.RunTransactionTests(t, newStorage) conformance.RunTransactionTests(t, newStorage)
}) })
// TODO(nabokihms): etcd uses compare-and-swap (txnUpdate) for UpdateRefreshToken,
// but does not retry on CAS conflicts ("concurrent conflicting update happened").
// Under high contention virtually all updates fail — only the first writer succeeds.
// withTimeout(time.Minute*1, func() {
// conformance.RunConcurrencyTests(t, newStorage)
// })
} }

41
storage/kubernetes/lock.go

@ -22,8 +22,18 @@ var (
// - Some of OIDC providers could use the refresh token rotation feature which requires calling refresh only once. // - Some of OIDC providers could use the refresh token rotation feature which requires calling refresh only once.
// - Providers can limit the rate of requests to the token endpoint, which will lead to the error // - Providers can limit the rate of requests to the token endpoint, which will lead to the error
// in case of many concurrent requests. // in case of many concurrent requests.
//
// The lock uses a Kubernetes annotation on the refresh token resource as a mutex.
// Only one goroutine can hold the lock at a time; others poll until the annotation
// is removed (unlocked) or expires (broken). The Kubernetes resourceVersion on put
// acts as compare-and-swap: if two goroutines race to set the annotation, only one
// succeeds and the other gets a 409 Conflict.
type refreshTokenLock struct { type refreshTokenLock struct {
cli *client cli *client
// waitingState tracks whether this lock instance has lost a compare-and-swap race
// and is now polling for the lock to be released. Used by Unlock to skip the
// annotation removal — only the goroutine that successfully wrote the annotation
// should remove it.
waitingState bool waitingState bool
} }
@ -31,6 +41,8 @@ func newRefreshTokenLock(cli *client) *refreshTokenLock {
return &refreshTokenLock{cli: cli} return &refreshTokenLock{cli: cli}
} }
// Lock polls until the lock annotation can be set on the refresh token resource.
// Returns nil when the lock is acquired, or an error on timeout (60 attempts × 100ms).
func (l *refreshTokenLock) Lock(id string) error { func (l *refreshTokenLock) Lock(id string) error {
for i := 0; i <= 60; i++ { for i := 0; i <= 60; i++ {
ok, err := l.setLockAnnotation(id) ok, err := l.setLockAnnotation(id)
@ -45,9 +57,12 @@ func (l *refreshTokenLock) Lock(id string) error {
return fmt.Errorf("timeout waiting for refresh token %s lock", id) return fmt.Errorf("timeout waiting for refresh token %s lock", id)
} }
// Unlock removes the lock annotation from the refresh token resource.
// Only the holder of the lock (waitingState == false) performs the removal.
func (l *refreshTokenLock) Unlock(id string) { func (l *refreshTokenLock) Unlock(id string) {
if l.waitingState { if l.waitingState {
// Do not need to unlock for waiting goroutines, because the have not set it. // This goroutine never successfully wrote the annotation, so there's
// nothing to remove. Another goroutine holds (or held) the lock.
return return
} }
@ -64,6 +79,13 @@ func (l *refreshTokenLock) Unlock(id string) {
} }
} }
// setLockAnnotation attempts to acquire the lock by writing an annotation with
// an expiration timestamp. Returns (true, nil) when the caller should keep waiting,
// (false, nil) when the lock is acquired, or (false, err) on a non-retriable error.
//
// The locking protocol relies on Kubernetes optimistic concurrency: every put
// includes the resource's current resourceVersion, so concurrent writes to the
// same object result in a 409 Conflict for all but one writer.
func (l *refreshTokenLock) setLockAnnotation(id string) (bool, error) { func (l *refreshTokenLock) setLockAnnotation(id string) (bool, error) {
r, err := l.cli.getRefreshToken(id) r, err := l.cli.getRefreshToken(id)
if err != nil { if err != nil {
@ -77,13 +99,14 @@ func (l *refreshTokenLock) setLockAnnotation(id string) (bool, error) {
val, ok := r.Annotations[lockAnnotation] val, ok := r.Annotations[lockAnnotation]
if !ok { if !ok {
if l.waitingState { // No annotation means the lock is free. Every goroutine — whether it's
return false, nil // a first-time caller or was previously waiting — must compete by writing
} // the annotation. The put uses the current resourceVersion, so only one
// writer succeeds; the rest get a 409 Conflict and go back to polling.
r.Annotations = lockData r.Annotations = lockData
err := l.cli.put(resourceRefreshToken, r.ObjectMeta.Name, r) err := l.cli.put(resourceRefreshToken, r.ObjectMeta.Name, r)
if err == nil { if err == nil {
l.waitingState = false
return false, nil return false, nil
} }
@ -100,24 +123,24 @@ func (l *refreshTokenLock) setLockAnnotation(id string) (bool, error) {
} }
if !currentTime.After(until) { if !currentTime.After(until) {
// waiting for the lock to be released // Lock is held by another goroutine and has not expired yet — keep polling.
l.waitingState = true l.waitingState = true
return true, nil return true, nil
} }
// Lock time is out, lets break the lock and take the advantage // Lock has expired (holder crashed or is too slow). Attempt to break it by
// overwriting the annotation with a new expiration. Again, only one writer
// can win the compare-and-swap race.
r.Annotations = lockData r.Annotations = lockData
err = l.cli.put(resourceRefreshToken, r.ObjectMeta.Name, r) err = l.cli.put(resourceRefreshToken, r.ObjectMeta.Name, r)
if err == nil { if err == nil {
// break lock annotation
return false, nil return false, nil
} }
l.cli.logger.Debug("break lock annotation", "error", err) l.cli.logger.Debug("break lock annotation", "error", err)
if isKubernetesAPIConflictError(err) { if isKubernetesAPIConflictError(err) {
l.waitingState = true l.waitingState = true
// after breaking error waiting for the lock to be released
return true, nil return true, nil
} }
return false, err return false, err

1
storage/kubernetes/storage_test.go

@ -85,6 +85,7 @@ func (s *StorageTestSuite) TestStorage() {
} }
conformance.RunTests(s.T(), newStorage) conformance.RunTests(s.T(), newStorage)
conformance.RunConcurrencyTests(s.T(), newStorage)
conformance.RunTransactionTests(s.T(), newStorage) conformance.RunTransactionTests(s.T(), newStorage)
} }

1
storage/memory/memory_test.go

@ -15,4 +15,5 @@ func TestStorage(t *testing.T) {
return New(logger) return New(logger)
} }
conformance.RunTests(t, newStorage) conformance.RunTests(t, newStorage)
conformance.RunConcurrencyTests(t, newStorage)
} }

15
storage/sql/config_test.go

@ -50,7 +50,7 @@ type opener interface {
open(logger *slog.Logger) (*conn, error) open(logger *slog.Logger) (*conn, error)
} }
func testDB(t *testing.T, o opener, withTransactions bool) { func testDB(t *testing.T, o opener, withTransactions, withConcurrentTests bool) {
// t.Fatal has a bad habit of not actually printing the error // t.Fatal has a bad habit of not actually printing the error
fatal := func(i any) { fatal := func(i any) {
fmt.Fprintln(os.Stdout, i) fmt.Fprintln(os.Stdout, i)
@ -71,11 +71,18 @@ func testDB(t *testing.T, o opener, withTransactions bool) {
withTimeout(time.Minute*1, func() { withTimeout(time.Minute*1, func() {
conformance.RunTests(t, newStorage) conformance.RunTests(t, newStorage)
}) })
if withTransactions { if withTransactions {
withTimeout(time.Minute*1, func() { withTimeout(time.Minute*1, func() {
conformance.RunTransactionTests(t, newStorage) conformance.RunTransactionTests(t, newStorage)
}) })
} }
if withConcurrentTests {
withTimeout(time.Minute*1, func() {
conformance.RunConcurrencyTests(t, newStorage)
})
}
} }
func getenv(key, defaultVal string) string { func getenv(key, defaultVal string) string {
@ -236,7 +243,7 @@ func TestPostgres(t *testing.T) {
Mode: pgSSLDisable, // Postgres container doesn't support SSL. Mode: pgSSLDisable, // Postgres container doesn't support SSL.
}, },
} }
testDB(t, p, true) testDB(t, p, true, false)
} }
const testMySQLEnv = "DEX_MYSQL_HOST" const testMySQLEnv = "DEX_MYSQL_HOST"
@ -273,7 +280,7 @@ func TestMySQL(t *testing.T) {
"innodb_lock_wait_timeout": "3", "innodb_lock_wait_timeout": "3",
}, },
} }
testDB(t, s, true) testDB(t, s, true, false)
} }
const testMySQL8Env = "DEX_MYSQL8_HOST" const testMySQL8Env = "DEX_MYSQL8_HOST"
@ -310,5 +317,5 @@ func TestMySQL8(t *testing.T) {
"innodb_lock_wait_timeout": "3", "innodb_lock_wait_timeout": "3",
}, },
} }
testDB(t, s, true) testDB(t, s, true, false)
} }

2
storage/sql/sqlite_test.go

@ -8,5 +8,5 @@ import (
) )
func TestSQLite3(t *testing.T) { func TestSQLite3(t *testing.T) {
testDB(t, &SQLite3{":memory:"}, false) testDB(t, &SQLite3{":memory:"}, false, true)
} }

Loading…
Cancel
Save