mirror of https://github.com/dexidp/dex.git
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
162 lines
3.3 KiB
162 lines
3.3 KiB
package session |
|
|
|
import ( |
|
"crypto/rand" |
|
"encoding/base64" |
|
"errors" |
|
"fmt" |
|
"net/url" |
|
"time" |
|
|
|
"github.com/jonboulle/clockwork" |
|
|
|
"github.com/coreos/go-oidc/oidc" |
|
) |
|
|
|
type GenerateCodeFunc func() (string, error) |
|
|
|
func DefaultGenerateCode() (string, error) { |
|
b := make([]byte, 8) |
|
n, err := rand.Read(b) |
|
if err != nil { |
|
return "", err |
|
} |
|
if n != 8 { |
|
return "", errors.New("unable to read enough random bytes") |
|
} |
|
return base64.URLEncoding.EncodeToString(b), nil |
|
} |
|
|
|
func NewSessionManager(sRepo SessionRepo, skRepo SessionKeyRepo) *SessionManager { |
|
return &SessionManager{ |
|
GenerateCode: DefaultGenerateCode, |
|
Clock: clockwork.NewRealClock(), |
|
ValidityWindow: DefaultSessionValidityWindow, |
|
sessions: sRepo, |
|
keys: skRepo, |
|
} |
|
} |
|
|
|
type SessionManager struct { |
|
GenerateCode GenerateCodeFunc |
|
Clock clockwork.Clock |
|
ValidityWindow time.Duration |
|
sessions SessionRepo |
|
keys SessionKeyRepo |
|
} |
|
|
|
func (m *SessionManager) NewSession(connectorID, clientID, clientState string, redirectURL url.URL, nonce string, register bool, scope []string) (string, error) { |
|
sID, err := m.GenerateCode() |
|
if err != nil { |
|
return "", err |
|
} |
|
|
|
now := m.Clock.Now() |
|
s := Session{ |
|
ConnectorID: connectorID, |
|
ID: sID, |
|
State: SessionStateNew, |
|
CreatedAt: now, |
|
ExpiresAt: now.Add(m.ValidityWindow), |
|
ClientID: clientID, |
|
ClientState: clientState, |
|
RedirectURL: redirectURL, |
|
Register: register, |
|
Nonce: nonce, |
|
Scope: scope, |
|
} |
|
|
|
err = m.sessions.Create(s) |
|
if err != nil { |
|
return "", err |
|
} |
|
|
|
return sID, nil |
|
} |
|
|
|
func (m *SessionManager) NewSessionKey(sessionID string) (string, error) { |
|
key, err := m.GenerateCode() |
|
if err != nil { |
|
return "", err |
|
} |
|
|
|
k := SessionKey{ |
|
Key: key, |
|
SessionID: sessionID, |
|
} |
|
|
|
err = m.keys.Push(k, sessionKeyValidityWindow) |
|
if err != nil { |
|
return "", err |
|
} |
|
|
|
return k.Key, nil |
|
} |
|
|
|
func (m *SessionManager) ExchangeKey(key string) (string, error) { |
|
return m.keys.Pop(key) |
|
} |
|
|
|
func (m *SessionManager) getSessionInState(sessionID string, state SessionState) (*Session, error) { |
|
s, err := m.sessions.Get(sessionID) |
|
if err != nil { |
|
return nil, err |
|
} |
|
|
|
if s.State != state { |
|
return nil, fmt.Errorf("session state %s, expect %s", s.State, state) |
|
} |
|
|
|
return s, nil |
|
} |
|
|
|
func (m *SessionManager) AttachRemoteIdentity(sessionID string, ident oidc.Identity) (*Session, error) { |
|
s, err := m.getSessionInState(sessionID, SessionStateNew) |
|
if err != nil { |
|
return nil, err |
|
} |
|
|
|
s.Identity = ident |
|
s.State = SessionStateRemoteAttached |
|
|
|
if err = m.sessions.Update(*s); err != nil { |
|
return nil, err |
|
} |
|
|
|
return s, nil |
|
} |
|
|
|
func (m *SessionManager) AttachUser(sessionID string, userID string) (*Session, error) { |
|
s, err := m.getSessionInState(sessionID, SessionStateRemoteAttached) |
|
if err != nil { |
|
return nil, err |
|
} |
|
|
|
s.UserID = userID |
|
s.State = SessionStateIdentified |
|
|
|
if err = m.sessions.Update(*s); err != nil { |
|
return nil, err |
|
} |
|
|
|
return s, nil |
|
} |
|
|
|
func (m *SessionManager) Kill(sessionID string) (*Session, error) { |
|
s, err := m.sessions.Get(sessionID) |
|
if err != nil { |
|
return nil, err |
|
} |
|
|
|
s.State = SessionStateDead |
|
|
|
if err = m.sessions.Update(*s); err != nil { |
|
return nil, err |
|
} |
|
|
|
return s, nil |
|
} |
|
|
|
func (m *SessionManager) Get(sessionID string) (*Session, error) { |
|
return m.sessions.Get(sessionID) |
|
}
|
|
|