Browse Source

*: remove in memory session repos

Move manager to it's own package so it can import db. Move all
references to the in memory session repos to use sqlite3.
pull/304/head
Eric Chiang 10 years ago
parent
commit
7bac93aa20
  1. 12
      db/conn.go
  2. 2
      db/migrate_sqlite3.go
  3. 4
      functional/repo/session_repo_test.go
  4. 7
      integration/oidc_test.go
  5. 16
      server/config.go
  6. 7
      server/http_test.go
  7. 10
      server/password.go
  8. 7
      server/register.go
  9. 7
      server/server.go
  10. 19
      server/server_test.go
  11. 13
      server/testutil.go
  12. 38
      session/manager/manager.go
  13. 25
      session/manager/manager_test.go
  14. 91
      session/repo.go
  15. 2
      test

12
db/conn.go

@ -101,3 +101,15 @@ func rollback(tx *gorp.Transaction) {
log.Errorf("unable to rollback: %v", err)
}
}
// NewMemDB creates a new in memory sqlite3 database.
func NewMemDB() *gorp.DbMap {
dbMap, err := NewConnection(Config{DSN: "sqlite3://:memory:"})
if err != nil {
panic("Failed to create in memory database: " + err.Error())
}
if _, err := MigrateToLatest(dbMap); err != nil {
panic("In memory database migration failed: " + err.Error())
}
return dbMap
}

2
db/migrate_sqlite3.go

@ -65,7 +65,7 @@ CREATE TABLE session (
);
CREATE TABLE session_key (
key text NOT NULL UNIQUE,
key text NOT NULL,
session_id text,
expires_at bigint,
stale integer

4
functional/repo/session_repo_test.go

@ -15,7 +15,7 @@ import (
func newSessionRepo(t *testing.T) (session.SessionRepo, clockwork.FakeClock) {
clock := clockwork.NewFakeClock()
if os.Getenv("DEX_TEST_DSN") == "" {
return session.NewSessionRepoWithClock(clock), clock
return db.NewSessionRepoWithClock(db.NewMemDB(), clock), clock
}
dbMap := connect(t)
return db.NewSessionRepoWithClock(dbMap, clock), clock
@ -24,7 +24,7 @@ func newSessionRepo(t *testing.T) (session.SessionRepo, clockwork.FakeClock) {
func newSessionKeyRepo(t *testing.T) (session.SessionKeyRepo, clockwork.FakeClock) {
clock := clockwork.NewFakeClock()
if os.Getenv("DEX_TEST_DSN") == "" {
return session.NewSessionKeyRepoWithClock(clock), clock
return db.NewSessionKeyRepoWithClock(db.NewMemDB(), clock), clock
}
dbMap := connect(t)
return db.NewSessionKeyRepoWithClock(dbMap, clock), clock

7
integration/oidc_test.go

@ -10,10 +10,11 @@ import (
"github.com/coreos/dex/client"
"github.com/coreos/dex/connector"
"github.com/coreos/dex/db"
phttp "github.com/coreos/dex/pkg/http"
"github.com/coreos/dex/refresh/refreshtest"
"github.com/coreos/dex/server"
"github.com/coreos/dex/session"
"github.com/coreos/dex/session/manager"
"github.com/coreos/dex/user"
"github.com/coreos/go-oidc/jose"
"github.com/coreos/go-oidc/key"
@ -33,7 +34,7 @@ func mockServer(cis []oidc.ClientIdentity) (*server.Server, error) {
return nil, err
}
sm := session.NewSessionManager(session.NewSessionRepo(), session.NewSessionKeyRepo())
sm := manager.NewSessionManager(db.NewSessionRepo(db.NewMemDB()), db.NewSessionKeyRepo(db.NewMemDB()))
srv := &server.Server{
IssuerURL: url.URL{Scheme: "http", Host: "server.example.com"},
KeyManager: km,
@ -120,7 +121,7 @@ func TestHTTPExchangeTokenRefreshToken(t *testing.T) {
cir := client.NewClientIdentityRepo([]oidc.ClientIdentity{ci})
issuerURL := url.URL{Scheme: "http", Host: "server.example.com"}
sm := session.NewSessionManager(session.NewSessionRepo(), session.NewSessionKeyRepo())
sm := manager.NewSessionManager(db.NewSessionRepo(db.NewMemDB()), db.NewSessionKeyRepo(db.NewMemDB()))
k, err := key.GeneratePrivateKey()
if err != nil {

16
server/config.go

@ -19,10 +19,10 @@ import (
"github.com/coreos/dex/email"
"github.com/coreos/dex/refresh"
"github.com/coreos/dex/repo"
"github.com/coreos/dex/session"
sessionmanager "github.com/coreos/dex/session/manager"
"github.com/coreos/dex/user"
useremail "github.com/coreos/dex/user/email"
"github.com/coreos/dex/user/manager"
usermanager "github.com/coreos/dex/user/manager"
)
type ServerConfig struct {
@ -128,9 +128,9 @@ func (cfg *SingleServerConfig) Configure(srv *Server) error {
}
cfgRepo := connector.NewConnectorConfigRepoFromConfigs(cfgs)
sRepo := session.NewSessionRepo()
skRepo := session.NewSessionKeyRepo()
sm := session.NewSessionManager(sRepo, skRepo)
sRepo := db.NewSessionRepo(db.NewMemDB())
skRepo := db.NewSessionKeyRepo(db.NewMemDB())
sm := sessionmanager.NewSessionManager(sRepo, skRepo)
userRepo, err := user.NewUserRepoFromFile(cfg.UsersFile)
if err != nil {
@ -142,7 +142,7 @@ func (cfg *SingleServerConfig) Configure(srv *Server) error {
refTokRepo := refresh.NewRefreshTokenRepo()
txnFactory := repo.InMemTransactionFactory
userManager := manager.NewUserManager(userRepo, pwiRepo, cfgRepo, txnFactory, manager.ManagerOptions{})
userManager := usermanager.NewUserManager(userRepo, pwiRepo, cfgRepo, txnFactory, usermanager.ManagerOptions{})
srv.ClientIdentityRepo = ciRepo
srv.KeySetRepo = kRepo
srv.ConnectorConfigRepo = cfgRepo
@ -180,10 +180,10 @@ func (cfg *MultiServerConfig) Configure(srv *Server) error {
cfgRepo := db.NewConnectorConfigRepo(dbc)
userRepo := db.NewUserRepo(dbc)
pwiRepo := db.NewPasswordInfoRepo(dbc)
userManager := manager.NewUserManager(userRepo, pwiRepo, cfgRepo, db.TransactionFactory(dbc), manager.ManagerOptions{})
userManager := usermanager.NewUserManager(userRepo, pwiRepo, cfgRepo, db.TransactionFactory(dbc), usermanager.ManagerOptions{})
refreshTokenRepo := db.NewRefreshTokenRepo(dbc)
sm := session.NewSessionManager(sRepo, skRepo)
sm := sessionmanager.NewSessionManager(sRepo, skRepo)
srv.ClientIdentityRepo = ciRepo
srv.KeySetRepo = kRepo

7
server/http_test.go

@ -17,7 +17,8 @@ import (
"github.com/coreos/dex/client"
"github.com/coreos/dex/connector"
"github.com/coreos/dex/session"
"github.com/coreos/dex/db"
"github.com/coreos/dex/session/manager"
"github.com/coreos/go-oidc/jose"
"github.com/coreos/go-oidc/oauth2"
"github.com/coreos/go-oidc/oidc"
@ -75,7 +76,7 @@ func TestHandleAuthFuncResponsesSingleRedirectURL(t *testing.T) {
}
srv := &Server{
IssuerURL: url.URL{Scheme: "http", Host: "server.example.com"},
SessionManager: session.NewSessionManager(session.NewSessionRepo(), session.NewSessionKeyRepo()),
SessionManager: manager.NewSessionManager(db.NewSessionRepo(db.NewMemDB()), db.NewSessionKeyRepo(db.NewMemDB())),
ClientIdentityRepo: client.NewClientIdentityRepo([]oidc.ClientIdentity{
oidc.ClientIdentity{
Credentials: oidc.ClientCredentials{
@ -198,7 +199,7 @@ func TestHandleAuthFuncResponsesMultipleRedirectURLs(t *testing.T) {
}
srv := &Server{
IssuerURL: url.URL{Scheme: "http", Host: "server.example.com"},
SessionManager: session.NewSessionManager(session.NewSessionRepo(), session.NewSessionKeyRepo()),
SessionManager: manager.NewSessionManager(db.NewSessionRepo(db.NewMemDB()), db.NewSessionKeyRepo(db.NewMemDB())),
ClientIdentityRepo: client.NewClientIdentityRepo([]oidc.ClientIdentity{
oidc.ClientIdentity{
Credentials: oidc.ClientCredentials{

10
server/password.go

@ -9,10 +9,10 @@ import (
"github.com/coreos/dex/client"
"github.com/coreos/dex/pkg/log"
"github.com/coreos/dex/session"
sessionmanager "github.com/coreos/dex/session/manager"
"github.com/coreos/dex/user"
useremail "github.com/coreos/dex/user/email"
"github.com/coreos/dex/user/manager"
usermanager "github.com/coreos/dex/user/manager"
)
type sendResetPasswordEmailData struct {
@ -28,7 +28,7 @@ type sendResetPasswordEmailData struct {
type SendResetPasswordEmailHandler struct {
tpl *template.Template
emailer *useremail.UserEmailer
sm *session.SessionManager
sm *sessionmanager.SessionManager
cr client.ClientIdentityRepo
}
@ -182,7 +182,7 @@ type resetPasswordTemplateData struct {
type ResetPasswordHandler struct {
tpl *template.Template
issuerURL url.URL
um *manager.UserManager
um *usermanager.UserManager
keysFunc func() ([]key.PublicKey, error)
}
@ -238,7 +238,7 @@ func (r *resetPasswordRequest) handlePOST() {
cbURL, err := r.h.um.ChangePassword(r.pwReset, plaintext)
if err != nil {
switch err {
case manager.ErrorPasswordAlreadyChanged:
case usermanager.ErrorPasswordAlreadyChanged:
r.data.Error = "Link Expired"
r.data.Message = "The link in your email is no longer valid. If you need to change your password, generate a new email."
r.data.DontShowForm = true

7
server/register.go

@ -10,8 +10,9 @@ import (
"github.com/coreos/dex/connector"
"github.com/coreos/dex/pkg/log"
"github.com/coreos/dex/session"
sessionmanager "github.com/coreos/dex/session/manager"
"github.com/coreos/dex/user"
"github.com/coreos/dex/user/manager"
usermanager "github.com/coreos/dex/user/manager"
"github.com/coreos/go-oidc/oidc"
)
@ -274,7 +275,7 @@ func makeClientRedirectURL(baseRedirURL url.URL, code, clientState string) *url.
return &ru
}
func registerFromLocalConnector(userManager *manager.UserManager, sessionManager *session.SessionManager, ses *session.Session, email, password string) (string, error) {
func registerFromLocalConnector(userManager *usermanager.UserManager, sessionManager *sessionmanager.SessionManager, ses *session.Session, email, password string) (string, error) {
userID, err := userManager.RegisterWithPassword(email, password, ses.ConnectorID)
if err != nil {
return "", err
@ -289,7 +290,7 @@ func registerFromLocalConnector(userManager *manager.UserManager, sessionManager
return userID, nil
}
func registerFromRemoteConnector(userManager *manager.UserManager, ses *session.Session, email string, emailVerified bool) (string, error) {
func registerFromRemoteConnector(userManager *usermanager.UserManager, ses *session.Session, email string, emailVerified bool) (string, error) {
if ses.Identity.ID == "" {
return "", errors.New("No Identity found in session.")
}

7
server/server.go

@ -22,10 +22,11 @@ import (
"github.com/coreos/dex/pkg/log"
"github.com/coreos/dex/refresh"
"github.com/coreos/dex/session"
sessionmanager "github.com/coreos/dex/session/manager"
"github.com/coreos/dex/user"
usersapi "github.com/coreos/dex/user/api"
useremail "github.com/coreos/dex/user/email"
"github.com/coreos/dex/user/manager"
usermanager "github.com/coreos/dex/user/manager"
)
const (
@ -57,7 +58,7 @@ type Server struct {
IssuerURL url.URL
KeyManager key.PrivateKeyManager
KeySetRepo key.PrivateKeySetRepo
SessionManager *session.SessionManager
SessionManager *sessionmanager.SessionManager
ClientIdentityRepo client.ClientIdentityRepo
ConnectorConfigRepo connector.ConnectorConfigRepo
Templates *template.Template
@ -69,7 +70,7 @@ type Server struct {
HealthChecks []health.Checkable
Connectors []connector.Connector
UserRepo user.UserRepo
UserManager *manager.UserManager
UserManager *usermanager.UserManager
PasswordInfoRepo user.PasswordInfoRepo
RefreshTokenRepo refresh.RefreshTokenRepo
UserEmailer *useremail.UserEmailer

19
server/server_test.go

@ -10,8 +10,9 @@ import (
"time"
"github.com/coreos/dex/client"
"github.com/coreos/dex/db"
"github.com/coreos/dex/refresh/refreshtest"
"github.com/coreos/dex/session"
"github.com/coreos/dex/session/manager"
"github.com/coreos/dex/user"
"github.com/coreos/go-oidc/jose"
"github.com/coreos/go-oidc/key"
@ -68,7 +69,7 @@ func (ss *StaticSigner) JWK() jose.JWK {
return jose.JWK{}
}
func staticGenerateCodeFunc(code string) session.GenerateCodeFunc {
func staticGenerateCodeFunc(code string) manager.GenerateCodeFunc {
return func() (string, error) {
return code, nil
}
@ -120,7 +121,7 @@ func TestServerProviderConfig(t *testing.T) {
}
func TestServerNewSession(t *testing.T) {
sm := session.NewSessionManager(session.NewSessionRepo(), session.NewSessionKeyRepo())
sm := manager.NewSessionManager(db.NewSessionRepo(db.NewMemDB()), db.NewSessionKeyRepo(db.NewMemDB()))
srv := &Server{
SessionManager: sm,
}
@ -197,7 +198,7 @@ func TestServerLogin(t *testing.T) {
signer: &StaticSigner{sig: []byte("beer"), err: nil},
}
sm := session.NewSessionManager(session.NewSessionRepo(), session.NewSessionKeyRepo())
sm := manager.NewSessionManager(db.NewSessionRepo(db.NewMemDB()), db.NewSessionKeyRepo(db.NewMemDB()))
sm.GenerateCode = staticGenerateCodeFunc("fakecode")
sessionID, err := sm.NewSession("test_connector_id", ci.Credentials.ID, "bogus", ci.Metadata.RedirectURIs[0], "", false, []string{"openid"})
if err != nil {
@ -245,7 +246,7 @@ func TestServerLoginUnrecognizedSessionKey(t *testing.T) {
km := &StaticKeyManager{
signer: &StaticSigner{sig: nil, err: errors.New("fail")},
}
sm := session.NewSessionManager(session.NewSessionRepo(), session.NewSessionKeyRepo())
sm := manager.NewSessionManager(db.NewSessionRepo(db.NewMemDB()), db.NewSessionKeyRepo(db.NewMemDB()))
srv := &Server{
IssuerURL: url.URL{Scheme: "http", Host: "server.example.com"},
KeyManager: km,
@ -286,7 +287,7 @@ func TestServerLoginDisabledUser(t *testing.T) {
signer: &StaticSigner{sig: []byte("beer"), err: nil},
}
sm := session.NewSessionManager(session.NewSessionRepo(), session.NewSessionKeyRepo())
sm := manager.NewSessionManager(db.NewSessionRepo(db.NewMemDB()), db.NewSessionKeyRepo(db.NewMemDB()))
sm.GenerateCode = staticGenerateCodeFunc("fakecode")
sessionID, err := sm.NewSession("test_connector_id", ci.Credentials.ID, "bogus", ci.Metadata.RedirectURIs[0], "", false, []string{"openid"})
if err != nil {
@ -343,7 +344,7 @@ func TestServerCodeToken(t *testing.T) {
km := &StaticKeyManager{
signer: &StaticSigner{sig: []byte("beer"), err: nil},
}
sm := session.NewSessionManager(session.NewSessionRepo(), session.NewSessionKeyRepo())
sm := manager.NewSessionManager(db.NewSessionRepo(db.NewMemDB()), db.NewSessionKeyRepo(db.NewMemDB()))
userRepo, err := makeNewUserRepo()
if err != nil {
@ -424,7 +425,7 @@ func TestServerTokenUnrecognizedKey(t *testing.T) {
km := &StaticKeyManager{
signer: &StaticSigner{sig: []byte("beer"), err: nil},
}
sm := session.NewSessionManager(session.NewSessionRepo(), session.NewSessionKeyRepo())
sm := manager.NewSessionManager(db.NewSessionRepo(db.NewMemDB()), db.NewSessionKeyRepo(db.NewMemDB()))
srv := &Server{
IssuerURL: url.URL{Scheme: "http", Host: "server.example.com"},
@ -518,7 +519,7 @@ func TestServerTokenFail(t *testing.T) {
}
for i, tt := range tests {
sm := session.NewSessionManager(session.NewSessionRepo(), session.NewSessionKeyRepo())
sm := manager.NewSessionManager(db.NewSessionRepo(db.NewMemDB()), db.NewSessionKeyRepo(db.NewMemDB()))
sm.GenerateCode = func() (string, error) { return keyFixture, nil }
sessionID, err := sm.NewSession("connector_id", ccFixture.ID, "bogus", url.URL{}, "", false, tt.scope)

13
server/testutil.go

@ -10,12 +10,13 @@ import (
"github.com/coreos/dex/client"
"github.com/coreos/dex/connector"
"github.com/coreos/dex/db"
"github.com/coreos/dex/email"
"github.com/coreos/dex/repo"
"github.com/coreos/dex/session"
sessionmanager "github.com/coreos/dex/session/manager"
"github.com/coreos/dex/user"
useremail "github.com/coreos/dex/user/email"
"github.com/coreos/dex/user/manager"
usermanager "github.com/coreos/dex/user/manager"
)
const (
@ -75,13 +76,13 @@ var (
type testFixtures struct {
srv *Server
userRepo user.UserRepo
sessionManager *session.SessionManager
sessionManager *sessionmanager.SessionManager
emailer *email.TemplatizedEmailer
redirectURL url.URL
clientIdentityRepo client.ClientIdentityRepo
}
func sequentialGenerateCodeFunc() session.GenerateCodeFunc {
func sequentialGenerateCodeFunc() sessionmanager.GenerateCodeFunc {
x := 0
return func() (string, error) {
x += 1
@ -113,9 +114,9 @@ func makeTestFixtures() (*testFixtures, error) {
}
connCfgRepo := connector.NewConnectorConfigRepoFromConfigs(connConfigs)
manager := manager.NewUserManager(userRepo, pwRepo, connCfgRepo, repo.InMemTransactionFactory, manager.ManagerOptions{})
manager := usermanager.NewUserManager(userRepo, pwRepo, connCfgRepo, repo.InMemTransactionFactory, usermanager.ManagerOptions{})
sessionManager := session.NewSessionManager(session.NewSessionRepo(), session.NewSessionKeyRepo())
sessionManager := sessionmanager.NewSessionManager(db.NewSessionRepo(db.NewMemDB()), db.NewSessionKeyRepo(db.NewMemDB()))
sessionManager.GenerateCode = sequentialGenerateCodeFunc()
emailer, err := email.NewTemplatizedEmailerFromGlobs(

38
session/manager.go → session/manager/manager.go

@ -1,4 +1,4 @@
package session
package manager
import (
"crypto/rand"
@ -10,6 +10,7 @@ import (
"github.com/jonboulle/clockwork"
"github.com/coreos/dex/session"
"github.com/coreos/go-oidc/oidc"
)
@ -27,11 +28,11 @@ func DefaultGenerateCode() (string, error) {
return base64.URLEncoding.EncodeToString(b), nil
}
func NewSessionManager(sRepo SessionRepo, skRepo SessionKeyRepo) *SessionManager {
func NewSessionManager(sRepo session.SessionRepo, skRepo session.SessionKeyRepo) *SessionManager {
return &SessionManager{
GenerateCode: DefaultGenerateCode,
Clock: clockwork.NewRealClock(),
ValidityWindow: DefaultSessionValidityWindow,
ValidityWindow: session.DefaultSessionValidityWindow,
sessions: sRepo,
keys: skRepo,
}
@ -41,8 +42,8 @@ type SessionManager struct {
GenerateCode GenerateCodeFunc
Clock clockwork.Clock
ValidityWindow time.Duration
sessions SessionRepo
keys SessionKeyRepo
sessions session.SessionRepo
keys session.SessionKeyRepo
}
func (m *SessionManager) NewSession(connectorID, clientID, clientState string, redirectURL url.URL, nonce string, register bool, scope []string) (string, error) {
@ -52,10 +53,10 @@ func (m *SessionManager) NewSession(connectorID, clientID, clientState string, r
}
now := m.Clock.Now()
s := Session{
s := session.Session{
ConnectorID: connectorID,
ID: sID,
State: SessionStateNew,
State: session.SessionStateNew,
CreatedAt: now,
ExpiresAt: now.Add(m.ValidityWindow),
ClientID: clientID,
@ -80,11 +81,12 @@ func (m *SessionManager) NewSessionKey(sessionID string) (string, error) {
return "", err
}
k := SessionKey{
k := session.SessionKey{
Key: key,
SessionID: sessionID,
}
sessionKeyValidityWindow := 10 * time.Minute //RFC6749
err = m.keys.Push(k, sessionKeyValidityWindow)
if err != nil {
return "", err
@ -97,7 +99,7 @@ func (m *SessionManager) ExchangeKey(key string) (string, error) {
return m.keys.Pop(key)
}
func (m *SessionManager) getSessionInState(sessionID string, state SessionState) (*Session, error) {
func (m *SessionManager) getSessionInState(sessionID string, state session.SessionState) (*session.Session, error) {
s, err := m.sessions.Get(sessionID)
if err != nil {
return nil, err
@ -110,14 +112,14 @@ func (m *SessionManager) getSessionInState(sessionID string, state SessionState)
return s, nil
}
func (m *SessionManager) AttachRemoteIdentity(sessionID string, ident oidc.Identity) (*Session, error) {
s, err := m.getSessionInState(sessionID, SessionStateNew)
func (m *SessionManager) AttachRemoteIdentity(sessionID string, ident oidc.Identity) (*session.Session, error) {
s, err := m.getSessionInState(sessionID, session.SessionStateNew)
if err != nil {
return nil, err
}
s.Identity = ident
s.State = SessionStateRemoteAttached
s.State = session.SessionStateRemoteAttached
if err = m.sessions.Update(*s); err != nil {
return nil, err
@ -126,14 +128,14 @@ func (m *SessionManager) AttachRemoteIdentity(sessionID string, ident oidc.Ident
return s, nil
}
func (m *SessionManager) AttachUser(sessionID string, userID string) (*Session, error) {
s, err := m.getSessionInState(sessionID, SessionStateRemoteAttached)
func (m *SessionManager) AttachUser(sessionID string, userID string) (*session.Session, error) {
s, err := m.getSessionInState(sessionID, session.SessionStateRemoteAttached)
if err != nil {
return nil, err
}
s.UserID = userID
s.State = SessionStateIdentified
s.State = session.SessionStateIdentified
if err = m.sessions.Update(*s); err != nil {
return nil, err
@ -142,13 +144,13 @@ func (m *SessionManager) AttachUser(sessionID string, userID string) (*Session,
return s, nil
}
func (m *SessionManager) Kill(sessionID string) (*Session, error) {
func (m *SessionManager) Kill(sessionID string) (*session.Session, error) {
s, err := m.sessions.Get(sessionID)
if err != nil {
return nil, err
}
s.State = SessionStateDead
s.State = session.SessionStateDead
if err = m.sessions.Update(*s); err != nil {
return nil, err
@ -157,6 +159,6 @@ func (m *SessionManager) Kill(sessionID string) (*Session, error) {
return s, nil
}
func (m *SessionManager) Get(sessionID string) (*Session, error) {
func (m *SessionManager) Get(sessionID string) (*session.Session, error) {
return m.sessions.Get(sessionID)
}

25
session/manager_test.go → session/manager/manager_test.go

@ -1,9 +1,11 @@
package session
package manager
import (
"net/url"
"testing"
"github.com/coreos/dex/db"
"github.com/coreos/dex/session"
"github.com/coreos/go-oidc/oidc"
)
@ -13,8 +15,13 @@ func staticGenerateCodeFunc(code string) GenerateCodeFunc {
}
}
func newManager(t *testing.T) *SessionManager {
dbMap := db.NewMemDB()
return NewSessionManager(db.NewSessionRepo(dbMap), db.NewSessionKeyRepo(dbMap))
}
func TestSessionManagerNewSession(t *testing.T) {
sm := NewSessionManager(NewSessionRepo(), NewSessionKeyRepo())
sm := newManager(t)
sm.GenerateCode = staticGenerateCodeFunc("boo")
got, err := sm.NewSession("bogus_idpc", "XXX", "bogus", url.URL{}, "", false, []string{"openid"})
if err != nil {
@ -26,7 +33,7 @@ func TestSessionManagerNewSession(t *testing.T) {
}
func TestSessionAttachRemoteIdentityTwice(t *testing.T) {
sm := NewSessionManager(NewSessionRepo(), NewSessionKeyRepo())
sm := newManager(t)
sessionID, err := sm.NewSession("bogus_idpc", "XXX", "bogus", url.URL{}, "", false, []string{"openid"})
if err != nil {
t.Fatalf("Unexpected error: %v", err)
@ -43,7 +50,7 @@ func TestSessionAttachRemoteIdentityTwice(t *testing.T) {
}
func TestSessionManagerExchangeKey(t *testing.T) {
sm := NewSessionManager(NewSessionRepo(), NewSessionKeyRepo())
sm := newManager(t)
sessionID, err := sm.NewSession("connector_id", "XXX", "bogus", url.URL{}, "", false, []string{"openid"})
if err != nil {
t.Fatalf("Unexpected error: %v", err)
@ -68,8 +75,8 @@ func TestSessionManagerExchangeKey(t *testing.T) {
}
func TestSessionManagerGetSessionInStateNoExist(t *testing.T) {
sm := NewSessionManager(NewSessionRepo(), NewSessionKeyRepo())
ses, err := sm.getSessionInState("123", SessionStateNew)
sm := newManager(t)
ses, err := sm.getSessionInState("123", session.SessionStateNew)
if err == nil {
t.Errorf("Expected non-nil error")
}
@ -79,12 +86,12 @@ func TestSessionManagerGetSessionInStateNoExist(t *testing.T) {
}
func TestSessionManagerGetSessionInStateWrongState(t *testing.T) {
sm := NewSessionManager(NewSessionRepo(), NewSessionKeyRepo())
sm := newManager(t)
sessionID, err := sm.NewSession("connector_id", "XXX", "bogus", url.URL{}, "", false, []string{"openid"})
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
ses, err := sm.getSessionInState(sessionID, SessionStateDead)
ses, err := sm.getSessionInState(sessionID, session.SessionStateDead)
if err == nil {
t.Errorf("Expected non-nil error")
}
@ -94,7 +101,7 @@ func TestSessionManagerGetSessionInStateWrongState(t *testing.T) {
}
func TestSessionManagerKill(t *testing.T) {
sm := NewSessionManager(NewSessionRepo(), NewSessionKeyRepo())
sm := newManager(t)
sessionID, err := sm.NewSession("connector_id", "XXX", "bogus", url.URL{}, "", false, []string{"openid"})
if err != nil {
t.Fatalf("Unexpected error: %v", err)

91
session/repo.go

@ -1,11 +1,6 @@
package session
import (
"errors"
"time"
"github.com/jonboulle/clockwork"
)
import "time"
type SessionRepo interface {
Get(string) (*Session, error)
@ -17,87 +12,3 @@ type SessionKeyRepo interface {
Push(SessionKey, time.Duration) error
Pop(string) (string, error)
}
func NewSessionRepo() SessionRepo {
return NewSessionRepoWithClock(clockwork.NewRealClock())
}
func NewSessionRepoWithClock(clock clockwork.Clock) SessionRepo {
return &memSessionRepo{
store: make(map[string]Session),
clock: clock,
}
}
type memSessionRepo struct {
store map[string]Session
clock clockwork.Clock
}
func (m *memSessionRepo) Get(sessionID string) (*Session, error) {
s, ok := m.store[sessionID]
if !ok || s.ExpiresAt.Before(m.clock.Now()) {
return nil, errors.New("unrecognized ID")
}
return &s, nil
}
func (m *memSessionRepo) Create(s Session) error {
if _, ok := m.store[s.ID]; ok {
return errors.New("ID exists")
}
m.store[s.ID] = s
return nil
}
func (m *memSessionRepo) Update(s Session) error {
if _, ok := m.store[s.ID]; !ok {
return errors.New("unrecognized ID")
}
m.store[s.ID] = s
return nil
}
type expiringSessionKey struct {
SessionKey
expiresAt time.Time
}
func NewSessionKeyRepo() SessionKeyRepo {
return NewSessionKeyRepoWithClock(clockwork.NewRealClock())
}
func NewSessionKeyRepoWithClock(clock clockwork.Clock) SessionKeyRepo {
return &memSessionKeyRepo{
store: make(map[string]expiringSessionKey),
clock: clock,
}
}
type memSessionKeyRepo struct {
store map[string]expiringSessionKey
clock clockwork.Clock
}
func (m *memSessionKeyRepo) Pop(key string) (string, error) {
esk, ok := m.store[key]
if !ok {
return "", errors.New("unrecognized key")
}
defer delete(m.store, key)
if esk.expiresAt.Before(m.clock.Now()) {
return "", errors.New("expired key")
}
return esk.SessionKey.SessionID, nil
}
func (m *memSessionKeyRepo) Push(sk SessionKey, ttl time.Duration) error {
m.store[sk.Key] = expiringSessionKey{
SessionKey: sk,
expiresAt: m.clock.Now().Add(ttl),
}
return nil
}

2
test

@ -14,7 +14,7 @@ COVER=${COVER:-"-cover"}
source ./build
TESTABLE="connector db integration pkg/crypto pkg/flag pkg/http pkg/net pkg/time pkg/html functional/repo server session user user/api user/manager email admin"
TESTABLE="connector db integration pkg/crypto pkg/flag pkg/http pkg/net pkg/time pkg/html functional/repo server session session/manager user user/api user/manager email admin"
FORMATTABLE="$TESTABLE cmd/dexctl cmd/dex-worker cmd/dex-overlord examples/app functional pkg/log"
# user has not provided PKG override

Loading…
Cancel
Save