mirror of https://github.com/dexidp/dex.git
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
389 lines
9.4 KiB
389 lines
9.4 KiB
package repo |
|
|
|
import ( |
|
"encoding/base64" |
|
"fmt" |
|
"net/url" |
|
"sort" |
|
"testing" |
|
"time" |
|
|
|
"github.com/coreos/go-oidc/oidc" |
|
"github.com/kylelemons/godebug/pretty" |
|
|
|
"github.com/coreos/dex/client" |
|
"github.com/coreos/dex/db" |
|
"github.com/coreos/dex/refresh" |
|
"github.com/coreos/dex/user" |
|
) |
|
|
|
var ( |
|
testRefreshClientID = "client1" |
|
testRefreshClientID2 = "client2" |
|
|
|
testRefreshConnectorID = "IDPC-1" |
|
|
|
testRefreshClients = []client.LoadableClient{ |
|
{ |
|
Client: client.Client{ |
|
Credentials: oidc.ClientCredentials{ |
|
ID: testRefreshClientID, |
|
Secret: base64.URLEncoding.EncodeToString([]byte("secret-2")), |
|
}, |
|
Metadata: oidc.ClientMetadata{ |
|
RedirectURIs: []url.URL{ |
|
url.URL{Scheme: "https", Host: "client1.example.com", Path: "/callback"}, |
|
}, |
|
}, |
|
}, |
|
}, |
|
{ |
|
Client: client.Client{ |
|
Credentials: oidc.ClientCredentials{ |
|
ID: testRefreshClientID2, |
|
Secret: base64.URLEncoding.EncodeToString([]byte("secret-2")), |
|
}, |
|
Metadata: oidc.ClientMetadata{ |
|
RedirectURIs: []url.URL{ |
|
url.URL{Scheme: "https", Host: "client2.example.com", Path: "/callback"}, |
|
}, |
|
}, |
|
}, |
|
}, |
|
} |
|
|
|
testRefreshUserID = "user1" |
|
testRefreshUsers = []user.UserWithRemoteIdentities{ |
|
{ |
|
User: user.User{ |
|
ID: testRefreshUserID, |
|
Email: "Email-1@example.com", |
|
CreatedAt: time.Now().Truncate(time.Second), |
|
}, |
|
RemoteIdentities: []user.RemoteIdentity{ |
|
{ |
|
ConnectorID: testRefreshConnectorID, |
|
ID: "RID-1", |
|
}, |
|
}, |
|
}, |
|
} |
|
) |
|
|
|
func newRefreshRepo(t *testing.T, users []user.UserWithRemoteIdentities, clients []client.LoadableClient) refresh.RefreshTokenRepo { |
|
dbMap := connect(t) |
|
if _, err := db.NewUserRepoFromUsers(dbMap, users); err != nil { |
|
t.Fatalf("Unable to add users: %v", err) |
|
} |
|
|
|
if _, err := db.NewClientRepoFromClients(dbMap, clients); err != nil { |
|
t.Fatalf("Unable to add clients: %v", err) |
|
} |
|
|
|
return db.NewRefreshTokenRepo(dbMap) |
|
} |
|
|
|
func TestRefreshTokenRepoCreateVerify(t *testing.T) { |
|
tests := []struct { |
|
createScopes []string |
|
verifyClientID string |
|
wantVerifyErr bool |
|
}{ |
|
{ |
|
createScopes: []string{"openid", "profile"}, |
|
verifyClientID: testRefreshClientID, |
|
}, |
|
{ |
|
createScopes: []string{}, |
|
verifyClientID: testRefreshClientID, |
|
}, |
|
{ |
|
createScopes: []string{"openid", "profile"}, |
|
verifyClientID: "not-a-client", |
|
wantVerifyErr: true, |
|
}, |
|
} |
|
|
|
for i, tt := range tests { |
|
repo := newRefreshRepo(t, testRefreshUsers, testRefreshClients) |
|
tok, err := repo.Create(testRefreshUserID, testRefreshClientID, testRefreshConnectorID, tt.createScopes) |
|
if err != nil { |
|
t.Fatalf("case %d: failed to create refresh token: %v", i, err) |
|
} |
|
|
|
tokUserID, gotConnectorID, gotScopes, err := repo.Verify(tt.verifyClientID, tok) |
|
if tt.wantVerifyErr { |
|
if err == nil { |
|
t.Errorf("case %d: want non-nil error.", i) |
|
} |
|
continue |
|
} |
|
|
|
if diff := pretty.Compare(tt.createScopes, gotScopes); diff != "" { |
|
t.Errorf("case %d: Compare(want, got): %v", i, diff) |
|
} |
|
|
|
if err != nil { |
|
t.Errorf("case %d: Could not verify token: %v", i, err) |
|
} else if tokUserID != testRefreshUserID { |
|
t.Errorf("case %d: Verified token returned wrong user id, want=%s, got=%s", i, |
|
testRefreshUserID, tokUserID) |
|
} |
|
|
|
if gotConnectorID != testRefreshConnectorID { |
|
t.Errorf("case %d: wanted connector_id=%q got=%q", i, testRefreshConnectorID, gotConnectorID) |
|
} |
|
} |
|
} |
|
|
|
// buildRefreshToken combines the token ID and token payload to create a new token. |
|
// used in the tests to created a refresh token. |
|
func buildRefreshToken(tokenID int64, tokenPayload []byte) string { |
|
return fmt.Sprintf("%d%s%s", tokenID, refresh.TokenDelimer, base64.URLEncoding.EncodeToString(tokenPayload)) |
|
} |
|
|
|
func TestRefreshRepoVerifyInvalidTokens(t *testing.T) { |
|
r := db.NewRefreshTokenRepo(connect(t)) |
|
|
|
token, err := r.Create("user-foo", "client-foo", testRefreshConnectorID, oidc.DefaultScope) |
|
if err != nil { |
|
t.Fatalf("Unexpected error: %v", err) |
|
} |
|
|
|
badTokenPayload, err := refresh.DefaultRefreshTokenGenerator() |
|
if err != nil { |
|
t.Fatalf("Unexpected error: %v", err) |
|
} |
|
tokenWithBadID := "404" + token[1:] |
|
tokenWithBadPayload := buildRefreshToken(1, badTokenPayload) |
|
|
|
tests := []struct { |
|
token string |
|
creds oidc.ClientCredentials |
|
err error |
|
expected string |
|
}{ |
|
{ |
|
"invalid-token-format", |
|
oidc.ClientCredentials{ID: "client-foo", Secret: "secret-foo"}, |
|
refresh.ErrorInvalidToken, |
|
"", |
|
}, |
|
{ |
|
"b/invalid-base64-encoded-format", |
|
oidc.ClientCredentials{ID: "client-foo", Secret: "secret-foo"}, |
|
refresh.ErrorInvalidToken, |
|
"", |
|
}, |
|
{ |
|
"1/invalid-base64-encoded-format", |
|
oidc.ClientCredentials{ID: "client-foo", Secret: "secret-foo"}, |
|
refresh.ErrorInvalidToken, |
|
"", |
|
}, |
|
{ |
|
token + "corrupted-token-payload", |
|
oidc.ClientCredentials{ID: "client-foo", Secret: "secret-foo"}, |
|
refresh.ErrorInvalidToken, |
|
"", |
|
}, |
|
{ |
|
// The token's ID content is invalid. |
|
tokenWithBadID, |
|
oidc.ClientCredentials{ID: "client-foo", Secret: "secret-foo"}, |
|
refresh.ErrorInvalidToken, |
|
"", |
|
}, |
|
{ |
|
// The token's payload content is invalid. |
|
tokenWithBadPayload, |
|
oidc.ClientCredentials{ID: "client-foo", Secret: "secret-foo"}, |
|
refresh.ErrorInvalidToken, |
|
"", |
|
}, |
|
{ |
|
token, |
|
oidc.ClientCredentials{ID: "invalid-client", Secret: "secret-foo"}, |
|
refresh.ErrorInvalidClientID, |
|
"", |
|
}, |
|
{ |
|
token, |
|
oidc.ClientCredentials{ID: "client-foo", Secret: "secret-foo"}, |
|
nil, |
|
"user-foo", |
|
}, |
|
} |
|
|
|
for i, tt := range tests { |
|
result, _, _, err := r.Verify(tt.creds.ID, tt.token) |
|
if err != tt.err { |
|
t.Errorf("Case #%d: expected: %v, got: %v", i, tt.err, err) |
|
} |
|
if result != tt.expected { |
|
t.Errorf("Case #%d: expected: %v, got: %v", i, tt.expected, result) |
|
} |
|
} |
|
} |
|
|
|
func TestRefreshTokenRepoClientsWithRefreshTokens(t *testing.T) { |
|
tests := []struct { |
|
clientIDs []string |
|
}{ |
|
{clientIDs: []string{"client1", "client2"}}, |
|
{clientIDs: []string{"client1"}}, |
|
{clientIDs: []string{}}, |
|
} |
|
|
|
for i, tt := range tests { |
|
repo := newRefreshRepo(t, testRefreshUsers, testRefreshClients) |
|
|
|
for _, clientID := range tt.clientIDs { |
|
_, err := repo.Create(testRefreshUserID, clientID, testRefreshConnectorID, []string{"openid"}) |
|
if err != nil { |
|
t.Fatalf("case %d: client_id: %s couldn't create refresh token: %v", i, clientID, err) |
|
} |
|
} |
|
|
|
clients, err := repo.ClientsWithRefreshTokens(testRefreshUserID) |
|
if err != nil { |
|
t.Fatalf("case %d: unexpected error fetching clients %q", i, err) |
|
} |
|
var clientIDs []string |
|
for _, client := range clients { |
|
clientIDs = append(clientIDs, client.Credentials.ID) |
|
} |
|
sort.Strings(clientIDs) |
|
|
|
if diff := pretty.Compare(clientIDs, tt.clientIDs); diff != "" { |
|
t.Errorf("case %d: Compare(want, got): %v", i, diff) |
|
} |
|
} |
|
} |
|
|
|
func TestRefreshTokenRepoRevokeForClient(t *testing.T) { |
|
tests := []struct { |
|
createIDs []string |
|
revokeID string |
|
}{ |
|
{ |
|
createIDs: []string{"client1", "client2"}, |
|
revokeID: "client1", |
|
}, |
|
{ |
|
createIDs: []string{"client2"}, |
|
revokeID: "client1", |
|
}, |
|
{ |
|
createIDs: []string{"client1"}, |
|
revokeID: "client1", |
|
}, |
|
{ |
|
createIDs: []string{}, |
|
revokeID: "oops", |
|
}, |
|
} |
|
|
|
for i, tt := range tests { |
|
repo := newRefreshRepo(t, testRefreshUsers, testRefreshClients) |
|
|
|
for _, clientID := range tt.createIDs { |
|
_, err := repo.Create(testRefreshUserID, clientID, testRefreshConnectorID, []string{"openid"}) |
|
if err != nil { |
|
t.Fatalf("case %d: client_id: %s couldn't create refresh token: %v", i, clientID, err) |
|
} |
|
|
|
if err := repo.RevokeTokensForClient(testRefreshUserID, tt.revokeID); err != nil { |
|
t.Fatalf("case %d: couldn't revoke refresh token(s): %v", i, err) |
|
} |
|
} |
|
|
|
var wantIDs []string |
|
for _, id := range tt.createIDs { |
|
if id != tt.revokeID { |
|
wantIDs = append(wantIDs, id) |
|
} |
|
} |
|
|
|
clients, err := repo.ClientsWithRefreshTokens(testRefreshUserID) |
|
if err != nil { |
|
t.Fatalf("case %d: unexpected error fetching clients %q", i, err) |
|
} |
|
|
|
var gotIDs []string |
|
for _, client := range clients { |
|
gotIDs = append(gotIDs, client.Credentials.ID) |
|
} |
|
sort.Strings(gotIDs) |
|
|
|
if diff := pretty.Compare(wantIDs, gotIDs); diff != "" { |
|
t.Errorf("case %d: Compare(wantIDs, gotIDs): %v", i, diff) |
|
} |
|
} |
|
} |
|
|
|
func TestRefreshRepoRevoke(t *testing.T) { |
|
r := db.NewRefreshTokenRepo(connect(t)) |
|
|
|
token, err := r.Create("user-foo", "client-foo", testRefreshConnectorID, oidc.DefaultScope) |
|
if err != nil { |
|
t.Fatalf("Unexpected error: %v", err) |
|
} |
|
|
|
badTokenPayload, err := refresh.DefaultRefreshTokenGenerator() |
|
if err != nil { |
|
t.Fatalf("Unexpected error: %v", err) |
|
} |
|
tokenWithBadID := "404" + token[1:] |
|
tokenWithBadPayload := buildRefreshToken(1, badTokenPayload) |
|
|
|
tests := []struct { |
|
token string |
|
userID string |
|
err error |
|
}{ |
|
{ |
|
"invalid-token-format", |
|
"user-foo", |
|
refresh.ErrorInvalidToken, |
|
}, |
|
{ |
|
"1/invalid-base64-encoded-format", |
|
"user-foo", |
|
refresh.ErrorInvalidToken, |
|
}, |
|
{ |
|
token + "corrupted-token-payload", |
|
"user-foo", |
|
refresh.ErrorInvalidToken, |
|
}, |
|
{ |
|
// The token's ID is invalid. |
|
tokenWithBadID, |
|
"user-foo", |
|
refresh.ErrorInvalidToken, |
|
}, |
|
{ |
|
// The token's payload is invalid. |
|
tokenWithBadPayload, |
|
"user-foo", |
|
refresh.ErrorInvalidToken, |
|
}, |
|
{ |
|
token, |
|
"invalid-user", |
|
refresh.ErrorInvalidUserID, |
|
}, |
|
{ |
|
token, |
|
"user-foo", |
|
nil, |
|
}, |
|
} |
|
|
|
for i, tt := range tests { |
|
if err := r.Revoke(tt.userID, tt.token); err != tt.err { |
|
t.Errorf("Case #%d: expected: %v, got: %v", i, tt.err, err) |
|
} |
|
} |
|
}
|
|
|