mirror of https://github.com/dexidp/dex.git
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
166 lines
4.0 KiB
166 lines
4.0 KiB
package refresh |
|
|
|
import ( |
|
"crypto/rand" |
|
"encoding/base64" |
|
"errors" |
|
"fmt" |
|
"strconv" |
|
"strings" |
|
) |
|
|
|
const ( |
|
DefaultRefreshTokenPayloadLength = 64 |
|
TokenDelimer = "/" |
|
) |
|
|
|
var ( |
|
ErrorInvalidUserID = errors.New("invalid user ID") |
|
ErrorInvalidClientID = errors.New("invalid client ID") |
|
ErrorInvalidToken = errors.New("invalid token") |
|
) |
|
|
|
type RefreshTokenGenerator func() (string, error) |
|
|
|
func (g RefreshTokenGenerator) Generate() (string, error) { |
|
return g() |
|
} |
|
|
|
func DefaultRefreshTokenGenerator() (string, error) { |
|
// TODO(yifan) Remove this duplicated token generate function. |
|
b := make([]byte, DefaultRefreshTokenPayloadLength) |
|
n, err := rand.Read(b) |
|
if err != nil { |
|
return "", err |
|
} else if n != DefaultRefreshTokenPayloadLength { |
|
return "", errors.New("unable to read enough random bytes") |
|
} |
|
return base64.URLEncoding.EncodeToString(b), nil |
|
} |
|
|
|
type RefreshTokenRepo interface { |
|
// Create generates and returns a new refresh token for the given client-user pair. |
|
// On success the token will be return. |
|
Create(userID, clientID string) (string, error) |
|
|
|
// Verify verifies that a token belongs to the client, and returns the corresponding user ID. |
|
// Note that this assumes the client validation is currently done in the application layer, |
|
Verify(clientID, token string) (string, error) |
|
|
|
// Revoke deletes the refresh token if the token belongs to the given userID. |
|
Revoke(userID, token string) error |
|
} |
|
|
|
type refreshToken struct { |
|
payload string |
|
userID string |
|
clientID string |
|
} |
|
|
|
type memRefreshTokenRepo struct { |
|
store map[int]refreshToken |
|
tokenGenerator RefreshTokenGenerator |
|
} |
|
|
|
// buildToken combines the token ID and token payload to create a new token. |
|
func buildToken(tokenID int, tokenPayload string) string { |
|
return fmt.Sprintf("%d%s%s", tokenID, TokenDelimer, tokenPayload) |
|
} |
|
|
|
// parseToken parses a token and returns the token ID and token payload. |
|
func parseToken(token string) (int, string, error) { |
|
parts := strings.SplitN(token, TokenDelimer, 2) |
|
if len(parts) != 2 { |
|
return -1, "", ErrorInvalidToken |
|
} |
|
id, err := strconv.Atoi(parts[0]) |
|
if err != nil { |
|
return -1, "", ErrorInvalidToken |
|
} |
|
return id, parts[1], nil |
|
} |
|
|
|
// NewRefreshTokenRepo returns an in-memory RefreshTokenRepo useful for development. |
|
func NewRefreshTokenRepo() RefreshTokenRepo { |
|
return NewRefreshTokenRepoWithTokenGenerator(DefaultRefreshTokenGenerator) |
|
} |
|
|
|
func NewRefreshTokenRepoWithTokenGenerator(tokenGenerator RefreshTokenGenerator) RefreshTokenRepo { |
|
repo := &memRefreshTokenRepo{} |
|
repo.store = make(map[int]refreshToken) |
|
repo.tokenGenerator = tokenGenerator |
|
return repo |
|
} |
|
|
|
func (r *memRefreshTokenRepo) Create(userID, clientID string) (string, error) { |
|
// Validate userID. |
|
if userID == "" { |
|
return "", ErrorInvalidUserID |
|
} |
|
|
|
// Validate clientID. |
|
if clientID == "" { |
|
return "", ErrorInvalidClientID |
|
} |
|
|
|
// Generate and store token. |
|
tokenPayload, err := r.tokenGenerator.Generate() |
|
if err != nil { |
|
return "", err |
|
} |
|
|
|
tokenID := len(r.store) // Should only be used in single threaded tests. |
|
|
|
// No limits on the number of tokens per user/client for this in-memory repo. |
|
r.store[tokenID] = refreshToken{ |
|
payload: tokenPayload, |
|
userID: userID, |
|
clientID: clientID, |
|
} |
|
return buildToken(tokenID, tokenPayload), nil |
|
} |
|
|
|
func (r *memRefreshTokenRepo) Verify(clientID, token string) (string, error) { |
|
tokenID, tokenPayload, err := parseToken(token) |
|
if err != nil { |
|
return "", err |
|
} |
|
|
|
record, ok := r.store[tokenID] |
|
if !ok { |
|
return "", ErrorInvalidToken |
|
} |
|
|
|
if record.payload != tokenPayload { |
|
return "", ErrorInvalidToken |
|
} |
|
|
|
if record.clientID != clientID { |
|
return "", ErrorInvalidClientID |
|
} |
|
|
|
return record.userID, nil |
|
} |
|
|
|
func (r *memRefreshTokenRepo) Revoke(userID, token string) error { |
|
tokenID, tokenPayload, err := parseToken(token) |
|
if err != nil { |
|
return err |
|
} |
|
|
|
record, ok := r.store[tokenID] |
|
if !ok { |
|
return ErrorInvalidToken |
|
} |
|
|
|
if record.payload != tokenPayload { |
|
return ErrorInvalidToken |
|
} |
|
|
|
if record.userID != userID { |
|
return ErrorInvalidUserID |
|
} |
|
|
|
delete(r.store, tokenID) |
|
return nil |
|
}
|
|
|