@ -2,13 +2,13 @@ package repo
import (
"encoding/base64"
"fmt"
"net/url"
"o s"
"sort "
"testing"
"time"
"github.com/coreos/go-oidc/oidc"
"github.com/go-gorp/gorp"
"github.com/kylelemons/godebug/pretty"
"github.com/coreos/dex/client"
@ -17,40 +17,43 @@ import (
"github.com/coreos/dex/user"
)
func newRefreshRepo ( t * testing . T , users [ ] user . UserWithRemoteIdentities , clients [ ] client . Client ) refresh . RefreshTokenRepo {
var dbMap * gorp . DbMap
if dsn := os . Getenv ( "DEX_TEST_DSN" ) ; dsn == "" {
dbMap = db . NewMemDB ( )
} else {
dbMap = connect ( t )
}
if _ , err := db . NewUserRepoFromUsers ( dbMap , users ) ; err != nil {
t . Fatalf ( "Unable to add users: %v" , err )
}
return db . NewRefreshTokenRepo ( dbMap )
}
func TestRefreshTokenRepo ( t * testing . T ) {
clientID := "client1"
userID := "user1"
clients := [ ] client . Client {
var (
testRefreshClientID = "client1"
testRefreshClientID2 = "client2"
testRefreshClients = [ ] client . LoadableClient {
{
Credentials : oidc . ClientCredentials {
ID : clientID ,
Secret : base64 . URLEncoding . EncodeToString ( [ ] byte ( "secret-2" ) ) ,
Client : client . Client {
Credentials : oidc . ClientCredentials {
ID : testRefreshClientID ,
Secret : base64 . URLEncoding . EncodeToString ( [ ] byte ( "secret-2" ) ) ,
} ,
Metadata : oidc . ClientMetadata {
RedirectURIs : [ ] url . URL {
url . URL { Scheme : "https" , Host : "client1.example.com" , Path : "/callback" } ,
} ,
} ,
} ,
Metadata : oidc . ClientMetadata {
RedirectURIs : [ ] url . URL {
url . URL { Scheme : "https" , Host : "client1.example.com" , Path : "/callback" } ,
} ,
{
Client : client . Client {
Credentials : oidc . ClientCredentials {
ID : testRefreshClientID2 ,
Secret : base64 . URLEncoding . EncodeToString ( [ ] byte ( "secret-2" ) ) ,
} ,
Metadata : oidc . ClientMetadata {
RedirectURIs : [ ] url . URL {
url . URL { Scheme : "https" , Host : "client2.example.com" , Path : "/callback" } ,
} ,
} ,
} ,
} ,
}
users := [ ] user . UserWithRemoteIdentities {
testRefreshUserID = "user1"
testRefreshUsers = [ ] user . UserWithRemoteIdentities {
{
User : user . User {
ID : userID ,
ID : testRefreshU serID,
Email : "Email-1@example.com" ,
CreatedAt : time . Now ( ) . Truncate ( time . Second ) ,
} ,
@ -62,31 +65,318 @@ func TestRefreshTokenRepo(t *testing.T) {
} ,
} ,
}
)
func newRefreshRepo ( t * testing . T , users [ ] user . UserWithRemoteIdentities , clients [ ] client . LoadableClient ) refresh . RefreshTokenRepo {
dbMap := connect ( t )
if _ , err := db . NewUserRepoFromUsers ( dbMap , users ) ; err != nil {
t . Fatalf ( "Unable to add users: %v" , err )
}
if _ , err := db . NewClientRepoFromClients ( dbMap , clients ) ; err != nil {
t . Fatalf ( "Unable to add clients: %v" , err )
}
return db . NewRefreshTokenRepo ( dbMap )
}
func TestRefreshTokenRepoCreateVerify ( t * testing . T ) {
tests := [ ] struct {
createScopes [ ] string
verifyClientID string
wantVerifyErr bool
} {
{
createScopes : [ ] string { "openid" , "profile" } ,
verifyClientID : testRefreshClientID ,
} ,
{
createScopes : [ ] string { } ,
verifyClientID : testRefreshClientID ,
} ,
{
createScopes : [ ] string { "openid" , "profile" } ,
verifyClientID : "not-a-client" ,
wantVerifyErr : true ,
} ,
}
repo := newRefreshRepo ( t , users , clients )
tok , err := repo . Create ( userID , clientID )
for i , tt := range tests {
repo := newRefreshRepo ( t , testRefreshUsers , testRefreshClients )
tok , err := repo . Create ( testRefreshUserID , testRefreshClientID , tt . createScopes )
if err != nil {
t . Fatalf ( "case %d: failed to create refresh token: %v" , i , err )
}
tokUserID , gotScopes , err := repo . Verify ( tt . verifyClientID , tok )
if tt . wantVerifyErr {
if err == nil {
t . Errorf ( "case %d: want non-nil error." , i )
}
continue
}
if diff := pretty . Compare ( tt . createScopes , gotScopes ) ; diff != "" {
t . Errorf ( "case %d: Compare(want, got): %v" , i , diff )
}
if err != nil {
t . Errorf ( "case %d: Could not verify token: %v" , i , err )
} else if tokUserID != testRefreshUserID {
t . Errorf ( "case %d: Verified token returned wrong user id, want=%s, got=%s" , i ,
testRefreshUserID , tokUserID )
}
}
}
// buildRefreshToken combines the token ID and token payload to create a new token.
// used in the tests to created a refresh token.
func buildRefreshToken ( tokenID int64 , tokenPayload [ ] byte ) string {
return fmt . Sprintf ( "%d%s%s" , tokenID , refresh . TokenDelimer , base64 . URLEncoding . EncodeToString ( tokenPayload ) )
}
func TestRefreshRepoVerifyInvalidTokens ( t * testing . T ) {
r := db . NewRefreshTokenRepo ( connect ( t ) )
token , err := r . Create ( "user-foo" , "client-foo" , oidc . DefaultScope )
if err != nil {
t . Fatalf ( "failed to create refresh token: %v" , err )
t . Fatalf ( "Unexpected error : %v" , err )
}
if tokUserID , err := repo . Verify ( clientID , tok ) ; err != nil {
t . Errorf ( "Could not verify token: %v" , err )
} else if tokUserID != userID {
t . Errorf ( "Verified token returned wrong user id, want=%s, got=%s" , userID , tokUserID )
badTokenPayload , err := refresh . DefaultRefreshTokenGenerator ( )
if err != nil {
t . Fatalf ( "Unexpected error: %v ", err )
}
tokenWithBadID := "404" + token [ 1 : ]
tokenWithBadPayload := buildRefreshToken ( 1 , badTokenPayload )
if userClients , err := repo . ClientsWithRefreshTokens ( userID ) ; err != nil {
t . Errorf ( "Failed to get the list of clients the user was logged into: %v" , err )
} else {
if diff := pretty . Compare ( userClients , clients ) ; diff == "" {
t . Errorf ( "Clients user logged into: want did not equal got %s" , diff )
tests := [ ] struct {
token string
creds oidc . ClientCredentials
err error
expected string
} {
{
"invalid-token-format" ,
oidc . ClientCredentials { ID : "client-foo" , Secret : "secret-foo" } ,
refresh . ErrorInvalidToken ,
"" ,
} ,
{
"b/invalid-base64-encoded-format" ,
oidc . ClientCredentials { ID : "client-foo" , Secret : "secret-foo" } ,
refresh . ErrorInvalidToken ,
"" ,
} ,
{
"1/invalid-base64-encoded-format" ,
oidc . ClientCredentials { ID : "client-foo" , Secret : "secret-foo" } ,
refresh . ErrorInvalidToken ,
"" ,
} ,
{
token + "corrupted-token-payload" ,
oidc . ClientCredentials { ID : "client-foo" , Secret : "secret-foo" } ,
refresh . ErrorInvalidToken ,
"" ,
} ,
{
// The token's ID content is invalid.
tokenWithBadID ,
oidc . ClientCredentials { ID : "client-foo" , Secret : "secret-foo" } ,
refresh . ErrorInvalidToken ,
"" ,
} ,
{
// The token's payload content is invalid.
tokenWithBadPayload ,
oidc . ClientCredentials { ID : "client-foo" , Secret : "secret-foo" } ,
refresh . ErrorInvalidToken ,
"" ,
} ,
{
token ,
oidc . ClientCredentials { ID : "invalid-client" , Secret : "secret-foo" } ,
refresh . ErrorInvalidClientID ,
"" ,
} ,
{
token ,
oidc . ClientCredentials { ID : "client-foo" , Secret : "secret-foo" } ,
nil ,
"user-foo" ,
} ,
}
for i , tt := range tests {
result , _ , err := r . Verify ( tt . creds . ID , tt . token )
if err != tt . err {
t . Errorf ( "Case #%d: expected: %v, got: %v" , i , tt . err , err )
}
if result != tt . expected {
t . Errorf ( "Case #%d: expected: %v, got: %v" , i , tt . expected , result )
}
}
}
func TestRefreshTokenRepoClientsWithRefreshTokens ( t * testing . T ) {
tests := [ ] struct {
clientIDs [ ] string
} {
{ clientIDs : [ ] string { "client1" , "client2" } } ,
{ clientIDs : [ ] string { "client1" } } ,
{ clientIDs : [ ] string { } } ,
}
for i , tt := range tests {
repo := newRefreshRepo ( t , testRefreshUsers , testRefreshClients )
for _ , clientID := range tt . clientIDs {
_ , err := repo . Create ( testRefreshUserID , clientID , [ ] string { "openid" } )
if err != nil {
t . Fatalf ( "case %d: client_id: %s couldn't create refresh token: %v" , i , clientID , err )
}
}
clients , err := repo . ClientsWithRefreshTokens ( testRefreshUserID )
if err != nil {
t . Fatalf ( "case %d: unexpected error fetching clients %q" , i , err )
}
var clientIDs [ ] string
for _ , client := range clients {
clientIDs = append ( clientIDs , client . Credentials . ID )
}
sort . Strings ( clientIDs )
if diff := pretty . Compare ( clientIDs , tt . clientIDs ) ; diff != "" {
t . Errorf ( "case %d: Compare(want, got): %v" , i , diff )
}
}
}
func TestRefreshTokenRepoRevokeForClient ( t * testing . T ) {
tests := [ ] struct {
createIDs [ ] string
revokeID string
} {
{
createIDs : [ ] string { "client1" , "client2" } ,
revokeID : "client1" ,
} ,
{
createIDs : [ ] string { "client2" } ,
revokeID : "client1" ,
} ,
{
createIDs : [ ] string { "client1" } ,
revokeID : "client1" ,
} ,
{
createIDs : [ ] string { } ,
revokeID : "oops" ,
} ,
}
for i , tt := range tests {
repo := newRefreshRepo ( t , testRefreshUsers , testRefreshClients )
for _ , clientID := range tt . createIDs {
_ , err := repo . Create ( testRefreshUserID , clientID , [ ] string { "openid" } )
if err != nil {
t . Fatalf ( "case %d: client_id: %s couldn't create refresh token: %v" , i , clientID , err )
}
if err := repo . RevokeTokensForClient ( testRefreshUserID , tt . revokeID ) ; err != nil {
t . Fatalf ( "case %d: couldn't revoke refresh token(s): %v" , i , err )
}
}
var wantIDs [ ] string
for _ , id := range tt . createIDs {
if id != tt . revokeID {
wantIDs = append ( wantIDs , id )
}
}
clients , err := repo . ClientsWithRefreshTokens ( testRefreshUserID )
if err != nil {
t . Fatalf ( "case %d: unexpected error fetching clients %q" , i , err )
}
var gotIDs [ ] string
for _ , client := range clients {
gotIDs = append ( gotIDs , client . Credentials . ID )
}
sort . Strings ( gotIDs )
if diff := pretty . Compare ( wantIDs , gotIDs ) ; diff != "" {
t . Errorf ( "case %d: Compare(wantIDs, gotIDs): %v" , i , diff )
}
}
}
func TestRefreshRepoRevoke ( t * testing . T ) {
r := db . NewRefreshTokenRepo ( connect ( t ) )
token , err := r . Create ( "user-foo" , "client-foo" , oidc . DefaultScope )
if err != nil {
t . Fatalf ( "Unexpected error: %v" , err )
}
if err := repo . RevokeTokensForClient ( userID , clientID ) ; err != nil {
t . Errorf ( "Failed to revoke refresh token: %v" , err )
badTokenPayload , err := refresh . DefaultRefreshTokenGenerator ( )
if err != nil {
t . Fatalf ( "Unexpected error: %v" , err )
}
tokenWithBadID := "404" + token [ 1 : ]
tokenWithBadPayload := buildRefreshToken ( 1 , badTokenPayload )
if _ , err := repo . Verify ( clientID , tok ) ; err == nil {
t . Errorf ( "Token which should have been revoked was verified" )
tests := [ ] struct {
token string
userID string
err error
} {
{
"invalid-token-format" ,
"user-foo" ,
refresh . ErrorInvalidToken ,
} ,
{
"1/invalid-base64-encoded-format" ,
"user-foo" ,
refresh . ErrorInvalidToken ,
} ,
{
token + "corrupted-token-payload" ,
"user-foo" ,
refresh . ErrorInvalidToken ,
} ,
{
// The token's ID is invalid.
tokenWithBadID ,
"user-foo" ,
refresh . ErrorInvalidToken ,
} ,
{
// The token's payload is invalid.
tokenWithBadPayload ,
"user-foo" ,
refresh . ErrorInvalidToken ,
} ,
{
token ,
"invalid-user" ,
refresh . ErrorInvalidUserID ,
} ,
{
token ,
"user-foo" ,
nil ,
} ,
}
for i , tt := range tests {
if err := r . Revoke ( tt . userID , tt . token ) ; err != tt . err {
t . Errorf ( "Case #%d: expected: %v, got: %v" , i , tt . err , err )
}
}
}