|
|
|
|
@ -2,9 +2,12 @@ package conformance
|
|
|
|
|
|
|
|
|
|
import ( |
|
|
|
|
"context" |
|
|
|
|
"strconv" |
|
|
|
|
"sync" |
|
|
|
|
"testing" |
|
|
|
|
"time" |
|
|
|
|
|
|
|
|
|
"github.com/stretchr/testify/require" |
|
|
|
|
"golang.org/x/crypto/bcrypt" |
|
|
|
|
|
|
|
|
|
"github.com/dexidp/dex/storage" |
|
|
|
|
@ -26,6 +29,16 @@ func RunTransactionTests(t *testing.T, newStorage func(t *testing.T) storage.Sto
|
|
|
|
|
}) |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
// RunConcurrencyTests runs tests that verify storage implementations handle
|
|
|
|
|
// high-contention parallel updates correctly. Unlike RunTransactionTests,
|
|
|
|
|
// these tests use real goroutine-based parallelism rather than nested calls,
|
|
|
|
|
// and are safe to run on all storage backends (including those with non-reentrant locks).
|
|
|
|
|
func RunConcurrencyTests(t *testing.T, newStorage func(t *testing.T) storage.Storage) { |
|
|
|
|
runTests(t, newStorage, []subTest{ |
|
|
|
|
{"RefreshTokenParallelUpdate", testRefreshTokenParallelUpdate}, |
|
|
|
|
}) |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
func testClientConcurrentUpdate(t *testing.T, s storage.Storage) { |
|
|
|
|
ctx := t.Context() |
|
|
|
|
c := storage.Client{ |
|
|
|
|
@ -180,3 +193,111 @@ func testKeysConcurrentUpdate(t *testing.T, s storage.Storage) {
|
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
// testRefreshTokenParallelUpdate tests that many parallel updates to the same
|
|
|
|
|
// refresh token are serialized correctly by the storage and no updates are lost.
|
|
|
|
|
//
|
|
|
|
|
// Each goroutine atomically increments a counter stored in the Token field.
|
|
|
|
|
// After all goroutines finish, the counter must equal the number of successful updates.
|
|
|
|
|
// A mismatch indicates lost updates due to broken atomicity.
|
|
|
|
|
func testRefreshTokenParallelUpdate(t *testing.T, s storage.Storage) { |
|
|
|
|
ctx := t.Context() |
|
|
|
|
|
|
|
|
|
id := storage.NewID() |
|
|
|
|
refresh := storage.RefreshToken{ |
|
|
|
|
ID: id, |
|
|
|
|
Token: "0", |
|
|
|
|
Nonce: "foo", |
|
|
|
|
ClientID: "client_id", |
|
|
|
|
ConnectorID: "connector_id", |
|
|
|
|
Scopes: []string{"openid"}, |
|
|
|
|
CreatedAt: time.Now().UTC().Round(time.Millisecond), |
|
|
|
|
LastUsed: time.Now().UTC().Round(time.Millisecond), |
|
|
|
|
Claims: storage.Claims{ |
|
|
|
|
UserID: "1", |
|
|
|
|
Username: "jane", |
|
|
|
|
Email: "jane@example.com", |
|
|
|
|
}, |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
require.NoError(t, s.CreateRefresh(ctx, refresh)) |
|
|
|
|
|
|
|
|
|
const numWorkers = 100 |
|
|
|
|
|
|
|
|
|
type updateResult struct { |
|
|
|
|
err error |
|
|
|
|
newToken string // token value written by this worker's updater
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
var wg sync.WaitGroup |
|
|
|
|
results := make([]updateResult, numWorkers) |
|
|
|
|
|
|
|
|
|
for i := range numWorkers { |
|
|
|
|
wg.Add(1) |
|
|
|
|
go func() { |
|
|
|
|
defer wg.Done() |
|
|
|
|
results[i].err = s.UpdateRefreshToken(ctx, id, func(old storage.RefreshToken) (storage.RefreshToken, error) { |
|
|
|
|
counter, _ := strconv.Atoi(old.Token) |
|
|
|
|
old.Token = strconv.Itoa(counter + 1) |
|
|
|
|
results[i].newToken = old.Token |
|
|
|
|
return old, nil |
|
|
|
|
}) |
|
|
|
|
}() |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
wg.Wait() |
|
|
|
|
|
|
|
|
|
errCounts := map[string]int{} |
|
|
|
|
var successes int |
|
|
|
|
writtenTokens := map[string]int{} |
|
|
|
|
for _, r := range results { |
|
|
|
|
if r.err == nil { |
|
|
|
|
successes++ |
|
|
|
|
writtenTokens[r.newToken]++ |
|
|
|
|
} else { |
|
|
|
|
errCounts[r.err.Error()]++ |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
for msg, count := range errCounts { |
|
|
|
|
t.Logf("error (x%d): %s", count, msg) |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
stored, err := s.GetRefresh(ctx, id) |
|
|
|
|
require.NoError(t, err) |
|
|
|
|
|
|
|
|
|
counter, err := strconv.Atoi(stored.Token) |
|
|
|
|
require.NoError(t, err) |
|
|
|
|
|
|
|
|
|
t.Logf("parallel refresh token updates: %d/%d succeeded, final counter: %d", successes, numWorkers, counter) |
|
|
|
|
|
|
|
|
|
if successes < numWorkers { |
|
|
|
|
t.Errorf("not all updates succeeded: %d/%d (some failed under contention)", successes, numWorkers) |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
if counter != successes { |
|
|
|
|
t.Errorf("lost updates detected: %d successful updates but counter is %d", successes, counter) |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
// Each successful updater must have seen a unique counter value.
|
|
|
|
|
// Duplicates would mean two updaters read the same state — a sign of broken atomicity.
|
|
|
|
|
for token, count := range writtenTokens { |
|
|
|
|
if count > 1 { |
|
|
|
|
t.Errorf("token %q was written by %d updaters — concurrent updaters saw the same state", token, count) |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
// Successful updaters must have produced a contiguous sequence 1..N.
|
|
|
|
|
// A gap would mean an updater saw stale state even though the write succeeded.
|
|
|
|
|
for i := 1; i <= successes; i++ { |
|
|
|
|
if writtenTokens[strconv.Itoa(i)] != 1 { |
|
|
|
|
t.Errorf("expected token %q to be written exactly once, got %d", strconv.Itoa(i), writtenTokens[strconv.Itoa(i)]) |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
// The token stored in the database must match the highest value written.
|
|
|
|
|
// This confirms that the last successful update is the one persisted.
|
|
|
|
|
if stored.Token != strconv.Itoa(successes) { |
|
|
|
|
t.Errorf("stored token %q does not match expected final value %q", stored.Token, strconv.Itoa(successes)) |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
|