mirror of https://github.com/dexidp/dex.git
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
493 lines
11 KiB
493 lines
11 KiB
package manager |
|
|
|
import ( |
|
"net/url" |
|
"testing" |
|
"time" |
|
|
|
"github.com/coreos/go-oidc/jose" |
|
"github.com/jonboulle/clockwork" |
|
"github.com/kylelemons/godebug/pretty" |
|
|
|
"github.com/coreos/dex/connector" |
|
"github.com/coreos/dex/db" |
|
"github.com/coreos/dex/user" |
|
) |
|
|
|
type testFixtures struct { |
|
ur user.UserRepo |
|
pwr user.PasswordInfoRepo |
|
ccr connector.ConnectorConfigRepo |
|
mgr *UserManager |
|
clock clockwork.Clock |
|
} |
|
|
|
func makeTestFixtures() *testFixtures { |
|
f := &testFixtures{} |
|
f.clock = clockwork.NewFakeClock() |
|
|
|
dbMap := db.NewMemDB() |
|
f.ur = func() user.UserRepo { |
|
repo, err := db.NewUserRepoFromUsers(dbMap, []user.UserWithRemoteIdentities{ |
|
{ |
|
User: user.User{ |
|
ID: "ID-1", |
|
Email: "Email-1@example.com", |
|
}, |
|
RemoteIdentities: []user.RemoteIdentity{ |
|
{ |
|
ConnectorID: "local", |
|
ID: "1", |
|
}, |
|
}, |
|
}, { |
|
User: user.User{ |
|
ID: "ID-2", |
|
Email: "Email-2@example.com", |
|
EmailVerified: true, |
|
}, |
|
RemoteIdentities: []user.RemoteIdentity{ |
|
{ |
|
ConnectorID: "local", |
|
ID: "2", |
|
}, |
|
}, |
|
}, |
|
}) |
|
if err != nil { |
|
panic("Failed to create user repo: " + err.Error()) |
|
} |
|
return repo |
|
}() |
|
|
|
f.pwr = func() user.PasswordInfoRepo { |
|
repo, err := db.NewPasswordInfoRepoFromPasswordInfos(dbMap, []user.PasswordInfo{ |
|
{ |
|
UserID: "ID-1", |
|
Password: []byte("password-1"), |
|
}, |
|
{ |
|
UserID: "ID-2", |
|
Password: []byte("password-2"), |
|
}, |
|
}) |
|
if err != nil { |
|
panic("Failed to create user repo: " + err.Error()) |
|
} |
|
return repo |
|
}() |
|
|
|
f.ccr = func() connector.ConnectorConfigRepo { |
|
repo := db.NewConnectorConfigRepo(dbMap) |
|
c := []connector.ConnectorConfig{ |
|
&connector.LocalConnectorConfig{ID: "local"}, |
|
} |
|
if err := repo.Set(c); err != nil { |
|
panic(err) |
|
} |
|
return repo |
|
}() |
|
|
|
f.mgr = NewUserManager(f.ur, f.pwr, f.ccr, db.TransactionFactory(dbMap), ManagerOptions{}) |
|
f.mgr.Clock = f.clock |
|
return f |
|
} |
|
|
|
func TestRegisterWithRemoteIdentity(t *testing.T) { |
|
tests := []struct { |
|
email string |
|
emailVerified bool |
|
rid user.RemoteIdentity |
|
err error |
|
}{ |
|
{ |
|
email: "email@example.com", |
|
emailVerified: false, |
|
rid: user.RemoteIdentity{ |
|
ConnectorID: "local", |
|
ID: "1234", |
|
}, |
|
err: nil, |
|
}, |
|
{ |
|
emailVerified: false, |
|
rid: user.RemoteIdentity{ |
|
ConnectorID: "local", |
|
ID: "1234", |
|
}, |
|
err: user.ErrorInvalidEmail, |
|
}, |
|
{ |
|
email: "email@example.com", |
|
emailVerified: false, |
|
rid: user.RemoteIdentity{ |
|
ConnectorID: "local", |
|
ID: "1", |
|
}, |
|
err: user.ErrorDuplicateRemoteIdentity, |
|
}, |
|
{ |
|
email: "anotheremail@example.com", |
|
emailVerified: false, |
|
rid: user.RemoteIdentity{ |
|
ConnectorID: "idonotexist", |
|
ID: "1", |
|
}, |
|
err: connector.ErrorNotFound, |
|
}, |
|
} |
|
|
|
for i, tt := range tests { |
|
f := makeTestFixtures() |
|
userID, err := f.mgr.RegisterWithRemoteIdentity( |
|
tt.email, |
|
tt.emailVerified, |
|
tt.rid) |
|
|
|
if tt.err != nil { |
|
if tt.err != err { |
|
t.Errorf("case %d: want=%q, got=%q", i, tt.err, err) |
|
} |
|
continue |
|
} |
|
|
|
usr, err := f.ur.Get(nil, userID) |
|
if err != nil { |
|
t.Errorf("case %d: err != nil: %q", i, err) |
|
} |
|
|
|
if usr.Email != tt.email { |
|
t.Errorf("case %d: user.Email: want=%q, got=%q", i, tt.email, usr.Email) |
|
} |
|
if usr.EmailVerified != tt.emailVerified { |
|
t.Errorf("case %d: user.EmailVerified: want=%v, got=%v", i, tt.emailVerified, usr.EmailVerified) |
|
} |
|
|
|
ridUSR, err := f.ur.GetByRemoteIdentity(nil, tt.rid) |
|
if err != nil { |
|
t.Errorf("case %d: err != nil: %q", i, err) |
|
} |
|
if diff := pretty.Compare(usr, ridUSR); diff != "" { |
|
t.Errorf("case %d: Compare(want, got) = %v", i, diff) |
|
} |
|
} |
|
} |
|
|
|
func TestRegisterWithPassword(t *testing.T) { |
|
tests := []struct { |
|
email string |
|
plaintext string |
|
err error |
|
}{ |
|
{ |
|
email: "email@example.com", |
|
plaintext: "secretpassword123", |
|
err: nil, |
|
}, |
|
{ |
|
plaintext: "secretpassword123", |
|
err: user.ErrorInvalidEmail, |
|
}, |
|
{ |
|
email: "email@example.com", |
|
err: user.ErrorInvalidPassword, |
|
}, |
|
} |
|
|
|
for i, tt := range tests { |
|
f := makeTestFixtures() |
|
connID := "local" |
|
userID, err := f.mgr.RegisterWithPassword( |
|
tt.email, |
|
tt.plaintext, |
|
connID) |
|
|
|
if tt.err != nil { |
|
if tt.err != err { |
|
t.Errorf("case %d: want=%q, got=%q", i, tt.err, err) |
|
} |
|
continue |
|
} |
|
|
|
usr, err := f.ur.Get(nil, userID) |
|
if err != nil { |
|
t.Errorf("case %d: err != nil: %q", i, err) |
|
} |
|
|
|
if usr.Email != tt.email { |
|
t.Errorf("case %d: user.Email: want=%q, got=%q", i, tt.email, usr.Email) |
|
} |
|
if usr.EmailVerified != false { |
|
t.Errorf("case %d: user.EmailVerified: want=%v, got=%v", i, false, usr.EmailVerified) |
|
} |
|
|
|
ridUSR, err := f.ur.GetByRemoteIdentity(nil, user.RemoteIdentity{ |
|
ID: userID, |
|
ConnectorID: connID, |
|
}) |
|
if err != nil { |
|
t.Errorf("case %d: err != nil: %q", i, err) |
|
} |
|
if diff := pretty.Compare(usr, ridUSR); diff != "" { |
|
t.Errorf("case %d: Compare(want, got) = %v", i, diff) |
|
continue |
|
} |
|
|
|
pwi, err := f.pwr.Get(nil, userID) |
|
if err != nil { |
|
t.Errorf("case %d: err != nil: %q", i, err) |
|
continue |
|
} |
|
ident, err := pwi.Authenticate(tt.plaintext) |
|
if err != nil { |
|
t.Errorf("case %d: err != nil: %q", i, err) |
|
continue |
|
} |
|
if ident.ID != userID { |
|
t.Errorf("case %d: ident.ID: want=%q, got=%q", i, userID, ident.ID) |
|
continue |
|
} |
|
|
|
_, err = pwi.Authenticate(tt.plaintext + "WRONG") |
|
if err == nil { |
|
t.Errorf("case %d: want non-nil err", i) |
|
} |
|
} |
|
} |
|
|
|
func TestVerifyEmail(t *testing.T) { |
|
now := time.Now() |
|
issuer, _ := url.Parse("http://example.com") |
|
clientID := "myclient" |
|
callback := "http://client.example.com/callback" |
|
expires := time.Hour * 3 |
|
|
|
makeClaims := func(usr user.User) jose.Claims { |
|
return map[string]interface{}{ |
|
"iss": issuer.String(), |
|
"aud": clientID, |
|
user.ClaimEmailVerificationCallback: callback, |
|
user.ClaimEmailVerificationEmail: usr.Email, |
|
"exp": float64(now.Add(expires).Unix()), |
|
"sub": usr.ID, |
|
"iat": float64(now.Unix()), |
|
} |
|
} |
|
|
|
tests := []struct { |
|
evClaims jose.Claims |
|
wantErr bool |
|
}{ |
|
{ |
|
// happy path |
|
evClaims: makeClaims(user.User{ID: "ID-1", Email: "Email-1@example.com"}), |
|
}, |
|
{ |
|
// non-matching email |
|
evClaims: makeClaims(user.User{ID: "ID-1", Email: "Email-2@example.com"}), |
|
wantErr: true, |
|
}, |
|
{ |
|
// already verified email |
|
evClaims: makeClaims(user.User{ID: "ID-2", Email: "Email-2@example.com"}), |
|
wantErr: true, |
|
}, |
|
{ |
|
// non-existent user. |
|
evClaims: makeClaims(user.User{ID: "ID-UNKNOWN", Email: "noone@example.com"}), |
|
wantErr: true, |
|
}, |
|
} |
|
|
|
for i, tt := range tests { |
|
f := makeTestFixtures() |
|
cb, err := f.mgr.VerifyEmail(user.EmailVerification{Claims: tt.evClaims}) |
|
if tt.wantErr { |
|
if err == nil { |
|
t.Errorf("case %d: want non-nil err", i) |
|
} |
|
continue |
|
} |
|
|
|
if err != nil { |
|
t.Errorf("case %d: want err=nil got=%q", i, err) |
|
continue |
|
} |
|
|
|
if cb.String() != tt.evClaims[user.ClaimEmailVerificationCallback] { |
|
t.Errorf("case %d: want=%q, got=%q", i, cb.String(), |
|
tt.evClaims[user.ClaimEmailVerificationCallback]) |
|
} |
|
} |
|
} |
|
|
|
func TestChangePassword(t *testing.T) { |
|
now := time.Now() |
|
issuer, _ := url.Parse("http://example.com") |
|
clientID := "myclient" |
|
callback := "http://client.example.com/callback" |
|
expires := time.Hour * 3 |
|
password := "password-1" |
|
|
|
makeClaims := func(usrID, callback string) jose.Claims { |
|
return map[string]interface{}{ |
|
"iss": issuer.String(), |
|
"aud": clientID, |
|
user.ClaimPasswordResetCallback: callback, |
|
user.ClaimPasswordResetPassword: password, |
|
"exp": float64(now.Add(expires).Unix()), |
|
"sub": usrID, |
|
"iat": float64(now.Unix()), |
|
} |
|
} |
|
|
|
tests := []struct { |
|
pwrClaims jose.Claims |
|
newPassword string |
|
wantErr bool |
|
}{ |
|
{ |
|
// happy path |
|
pwrClaims: makeClaims("ID-1", callback), |
|
newPassword: "password-1.1", |
|
}, |
|
{ |
|
// happy path with no callback |
|
pwrClaims: makeClaims("ID-1", ""), |
|
newPassword: "password-1.1", |
|
}, |
|
{ |
|
// passwords don't match changed |
|
pwrClaims: makeClaims("ID-2", callback), |
|
newPassword: "password-1.1", |
|
wantErr: true, |
|
}, |
|
{ |
|
// user doesn't exist |
|
pwrClaims: makeClaims("ID-123", callback), |
|
newPassword: "password-1.1", |
|
wantErr: true, |
|
}, |
|
} |
|
|
|
for i, tt := range tests { |
|
f := makeTestFixtures() |
|
cb, err := f.mgr.ChangePassword(user.PasswordReset{Claims: tt.pwrClaims}, tt.newPassword) |
|
if tt.wantErr { |
|
if err == nil { |
|
t.Errorf("case %d: want non-nil err", i) |
|
} |
|
continue |
|
} |
|
|
|
if err != nil { |
|
t.Errorf("case %d: want err=nil got=%q", i, err) |
|
continue |
|
} |
|
|
|
var cbString string |
|
if cb != nil { |
|
cbString = cb.String() |
|
} |
|
if cbString != tt.pwrClaims[user.ClaimPasswordResetCallback] { |
|
t.Errorf("case %d: want=%q, got=%q", i, cb.String(), |
|
tt.pwrClaims[user.ClaimPasswordResetCallback]) |
|
} |
|
} |
|
} |
|
|
|
func TestCreateUser(t *testing.T) { |
|
tests := []struct { |
|
usr user.User |
|
hashedPW user.Password |
|
localID string // defaults to "local" |
|
|
|
wantErr bool |
|
}{ |
|
{ |
|
usr: user.User{ |
|
DisplayName: "Bob Exampleson", |
|
Email: "bob@example.com", |
|
}, |
|
hashedPW: user.Password("I am a hash"), |
|
}, |
|
{ |
|
usr: user.User{ |
|
DisplayName: "Al Adminson", |
|
Email: "al@example.com", |
|
Admin: true, |
|
}, |
|
hashedPW: user.Password("I am a hash"), |
|
}, |
|
{ |
|
usr: user.User{ |
|
DisplayName: "Ed Emailless", |
|
}, |
|
hashedPW: user.Password("I am a hash"), |
|
wantErr: true, |
|
}, |
|
{ |
|
usr: user.User{ |
|
DisplayName: "Eric Exampleson", |
|
Email: "eric@example.com", |
|
}, |
|
hashedPW: user.Password("I am a hash"), |
|
localID: "abadlocalid", |
|
wantErr: true, |
|
}, |
|
} |
|
|
|
for i, tt := range tests { |
|
f := makeTestFixtures() |
|
localID := "local" |
|
if tt.localID != "" { |
|
localID = tt.localID |
|
} |
|
id, err := f.mgr.CreateUser(tt.usr, tt.hashedPW, localID) |
|
if tt.wantErr { |
|
if err == nil { |
|
t.Errorf("case %d: want non-nil err", i) |
|
} |
|
continue |
|
} |
|
if id == "" { |
|
t.Errorf("case %d: want non-empty id", i) |
|
} |
|
|
|
if err != nil { |
|
t.Errorf("case %d: unexpected err: %v", i, err) |
|
continue |
|
} |
|
|
|
gotUsr, err := f.ur.Get(nil, id) |
|
if err != nil { |
|
t.Errorf("case %d: unexpected err: %v", i, err) |
|
} |
|
|
|
tt.usr.ID = id |
|
tt.usr.CreatedAt = f.clock.Now() |
|
if diff := pretty.Compare(tt.usr, gotUsr); diff != "" { |
|
t.Errorf("case %d: Compare(want, got) = %v", i, diff) |
|
} |
|
|
|
pwi, err := f.pwr.Get(nil, id) |
|
if err != nil { |
|
t.Errorf("case %d: unexpected err: %v", i, err) |
|
} |
|
|
|
if string(pwi.Password) != string(tt.hashedPW) { |
|
t.Errorf("case %d: want=%q, got=%q", i, tt.hashedPW, pwi.Password) |
|
} |
|
|
|
ridUser, err := f.ur.GetByRemoteIdentity(nil, user.RemoteIdentity{ |
|
ID: id, |
|
ConnectorID: "local", |
|
}) |
|
if err != nil { |
|
t.Errorf("case %d: err != nil: %q", i, err) |
|
} |
|
if diff := pretty.Compare(gotUsr, ridUser); diff != "" { |
|
t.Errorf("case %d: Compare(want, got) = %v", i, diff) |
|
} |
|
} |
|
}
|
|
|