mirror of https://github.com/dexidp/dex.git
Browse Source
Co-authored-by: Shash Reddy <sreddy@pivotal.io> Signed-off-by: Joshua Winters <jwinters@pivotal.io>pull/1630/head
3 changed files with 478 additions and 0 deletions
@ -0,0 +1,242 @@
|
||||
package oauth |
||||
|
||||
import ( |
||||
"context" |
||||
"crypto/tls" |
||||
"crypto/x509" |
||||
"encoding/base64" |
||||
"encoding/json" |
||||
"errors" |
||||
"fmt" |
||||
"io/ioutil" |
||||
"net" |
||||
"net/http" |
||||
"strings" |
||||
"time" |
||||
|
||||
"github.com/dexidp/dex/connector" |
||||
"github.com/dexidp/dex/pkg/log" |
||||
"golang.org/x/oauth2" |
||||
) |
||||
|
||||
type oauthConnector struct { |
||||
clientID string |
||||
clientSecret string |
||||
redirectURI string |
||||
tokenURL string |
||||
authorizationURL string |
||||
userInfoURL string |
||||
scopes []string |
||||
groupsKey string |
||||
httpClient *http.Client |
||||
logger log.Logger |
||||
} |
||||
|
||||
type connectorData struct { |
||||
AccessToken string |
||||
} |
||||
|
||||
type Config struct { |
||||
ClientID string `json:"clientID"` |
||||
ClientSecret string `json:"clientSecret"` |
||||
RedirectURI string `json:"redirectURI"` |
||||
TokenURL string `json:"tokenURL"` |
||||
AuthorizationURL string `json:"authorizationURL"` |
||||
UserInfoURL string `json:"userInfoURL"` |
||||
Scopes []string `json:"scopes"` |
||||
GroupsKey string `json:"groupsKey"` |
||||
RootCAs []string `json:"rootCAs"` |
||||
InsecureSkipVerify bool `json:"insecureSkipVerify"` |
||||
} |
||||
|
||||
func (c *Config) Open(id string, logger log.Logger) (connector.Connector, error) { |
||||
var err error |
||||
|
||||
oauthConn := &oauthConnector{ |
||||
clientID: c.ClientID, |
||||
clientSecret: c.ClientSecret, |
||||
tokenURL: c.TokenURL, |
||||
authorizationURL: c.AuthorizationURL, |
||||
userInfoURL: c.UserInfoURL, |
||||
scopes: c.Scopes, |
||||
groupsKey: c.GroupsKey, |
||||
redirectURI: c.RedirectURI, |
||||
logger: logger, |
||||
} |
||||
|
||||
oauthConn.httpClient, err = newHTTPClient(c.RootCAs, c.InsecureSkipVerify) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
return oauthConn, err |
||||
} |
||||
|
||||
func newHTTPClient(rootCAs []string, insecureSkipVerify bool) (*http.Client, error) { |
||||
pool, err := x509.SystemCertPool() |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
tlsConfig := tls.Config{RootCAs: pool, InsecureSkipVerify: insecureSkipVerify} |
||||
for _, rootCA := range rootCAs { |
||||
rootCABytes, err := ioutil.ReadFile(rootCA) |
||||
if err != nil { |
||||
return nil, fmt.Errorf("failed to read root-ca: %v", err) |
||||
} |
||||
if !tlsConfig.RootCAs.AppendCertsFromPEM(rootCABytes) { |
||||
return nil, fmt.Errorf("no certs found in root CA file %q", rootCA) |
||||
} |
||||
} |
||||
|
||||
return &http.Client{ |
||||
Transport: &http.Transport{ |
||||
TLSClientConfig: &tlsConfig, |
||||
Proxy: http.ProxyFromEnvironment, |
||||
DialContext: (&net.Dialer{ |
||||
Timeout: 30 * time.Second, |
||||
KeepAlive: 30 * time.Second, |
||||
DualStack: true, |
||||
}).DialContext, |
||||
MaxIdleConns: 100, |
||||
IdleConnTimeout: 90 * time.Second, |
||||
TLSHandshakeTimeout: 10 * time.Second, |
||||
ExpectContinueTimeout: 1 * time.Second, |
||||
}, |
||||
}, nil |
||||
} |
||||
|
||||
func (c *oauthConnector) LoginURL(scopes connector.Scopes, callbackURL, state string) (string, error) { |
||||
|
||||
if c.redirectURI != callbackURL { |
||||
return "", fmt.Errorf("expected callback URL %q did not match the URL in the config %q", callbackURL, c.redirectURI) |
||||
} |
||||
|
||||
oauth2Config := &oauth2.Config{ |
||||
ClientID: c.clientID, |
||||
ClientSecret: c.clientSecret, |
||||
Endpoint: oauth2.Endpoint{TokenURL: c.tokenURL, AuthURL: c.authorizationURL}, |
||||
RedirectURL: c.redirectURI, |
||||
Scopes: c.scopes, |
||||
} |
||||
|
||||
return oauth2Config.AuthCodeURL(state), nil |
||||
} |
||||
|
||||
func (c *oauthConnector) HandleCallback(s connector.Scopes, r *http.Request) (identity connector.Identity, err error) { |
||||
|
||||
q := r.URL.Query() |
||||
if errType := q.Get("error"); errType != "" { |
||||
return identity, errors.New(q.Get("error_description")) |
||||
} |
||||
|
||||
oauth2Config := &oauth2.Config{ |
||||
ClientID: c.clientID, |
||||
ClientSecret: c.clientSecret, |
||||
Endpoint: oauth2.Endpoint{TokenURL: c.tokenURL, AuthURL: c.authorizationURL}, |
||||
RedirectURL: c.redirectURI, |
||||
Scopes: c.scopes, |
||||
} |
||||
|
||||
ctx := context.WithValue(r.Context(), oauth2.HTTPClient, c.httpClient) |
||||
|
||||
token, err := oauth2Config.Exchange(ctx, q.Get("code")) |
||||
if err != nil { |
||||
return identity, fmt.Errorf("OAuth connector: failed to get token: %v", err) |
||||
} |
||||
|
||||
client := oauth2.NewClient(ctx, oauth2.StaticTokenSource(token)) |
||||
|
||||
userInfoResp, err := client.Get(c.userInfoURL) |
||||
if err != nil { |
||||
return identity, fmt.Errorf("OAuth Connector: failed to execute request to userinfo: %v", err) |
||||
} |
||||
|
||||
if userInfoResp.StatusCode != http.StatusOK { |
||||
return identity, fmt.Errorf("OAuth Connector: failed to execute request to userinfo: status %d", userInfoResp.StatusCode) |
||||
} |
||||
|
||||
defer userInfoResp.Body.Close() |
||||
|
||||
var userInfoResult map[string]interface{} |
||||
err = json.NewDecoder(userInfoResp.Body).Decode(&userInfoResult) |
||||
|
||||
if err != nil { |
||||
return identity, fmt.Errorf("OAuth Connector: failed to parse userinfo: %v", err) |
||||
} |
||||
|
||||
identity.UserID, _ = userInfoResult["user_id"].(string) |
||||
identity.Name, _ = userInfoResult["name"].(string) |
||||
identity.Username, _ = userInfoResult["user_name"].(string) |
||||
identity.Email, _ = userInfoResult["email"].(string) |
||||
identity.EmailVerified, _ = userInfoResult["email_verified"].(bool) |
||||
|
||||
if s.Groups { |
||||
if c.groupsKey == "" { |
||||
c.groupsKey = "groups" |
||||
} |
||||
|
||||
groups := map[string]bool{} |
||||
|
||||
c.addGroupsFromMap(groups, userInfoResult) |
||||
c.addGroupsFromToken(groups, token.AccessToken) |
||||
|
||||
for groupName, _ := range groups { |
||||
identity.Groups = append(identity.Groups, groupName) |
||||
} |
||||
} |
||||
|
||||
if s.OfflineAccess { |
||||
data := connectorData{AccessToken: token.AccessToken} |
||||
connData, err := json.Marshal(data) |
||||
if err != nil { |
||||
return identity, fmt.Errorf("OAuth Connector: failed to parse connector data for offline access: %v", err) |
||||
} |
||||
identity.ConnectorData = connData |
||||
} |
||||
|
||||
return identity, nil |
||||
} |
||||
|
||||
func (c *oauthConnector) addGroupsFromMap(groups map[string]bool, result map[string]interface{}) error { |
||||
groupsClaim, ok := result[c.groupsKey].([]interface{}) |
||||
if !ok { |
||||
return errors.New("Cant convert to array") |
||||
} |
||||
|
||||
for _, group := range groupsClaim { |
||||
if groupString, ok := group.(string); ok { |
||||
groups[groupString] = true |
||||
} |
||||
} |
||||
|
||||
return nil |
||||
} |
||||
|
||||
func (c *oauthConnector) addGroupsFromToken(groups map[string]bool, token string) error { |
||||
parts := strings.Split(token, ".") |
||||
if len(parts) < 2 { |
||||
return errors.New("Invalid token") |
||||
} |
||||
|
||||
decoded, err := decode(parts[1]) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
|
||||
var claimsMap map[string]interface{} |
||||
err = json.Unmarshal(decoded, &claimsMap) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
|
||||
return c.addGroupsFromMap(groups, claimsMap) |
||||
} |
||||
|
||||
func decode(seg string) ([]byte, error) { |
||||
if l := len(seg) % 4; l > 0 { |
||||
seg += strings.Repeat("=", 4-l) |
||||
} |
||||
|
||||
return base64.URLEncoding.DecodeString(seg) |
||||
} |
||||
@ -0,0 +1,234 @@
|
||||
package oauth |
||||
|
||||
import ( |
||||
"crypto/rand" |
||||
"crypto/rsa" |
||||
"encoding/json" |
||||
"errors" |
||||
"fmt" |
||||
"net/http" |
||||
"net/http/httptest" |
||||
"net/url" |
||||
"reflect" |
||||
"sort" |
||||
"testing" |
||||
|
||||
"github.com/dexidp/dex/connector" |
||||
"github.com/sirupsen/logrus" |
||||
jose "gopkg.in/square/go-jose.v2" |
||||
) |
||||
|
||||
func TestOpen(t *testing.T) { |
||||
tokenClaims := map[string]interface{}{} |
||||
userInfoClaims := map[string]interface{}{} |
||||
|
||||
testServer := testSetup(t, tokenClaims, userInfoClaims) |
||||
defer testServer.Close() |
||||
|
||||
conn := newConnector(t, testServer.URL) |
||||
|
||||
sort.Strings(conn.scopes) |
||||
|
||||
expectEqual(t, conn.clientID, "testClient") |
||||
expectEqual(t, conn.clientSecret, "testSecret") |
||||
expectEqual(t, conn.redirectURI, testServer.URL+"/callback") |
||||
expectEqual(t, conn.tokenURL, testServer.URL+"/token") |
||||
expectEqual(t, conn.authorizationURL, testServer.URL+"/authorize") |
||||
expectEqual(t, conn.userInfoURL, testServer.URL+"/userinfo") |
||||
expectEqual(t, len(conn.scopes), 2) |
||||
expectEqual(t, conn.scopes[0], "groups") |
||||
expectEqual(t, conn.scopes[1], "openid") |
||||
} |
||||
|
||||
func TestLoginURL(t *testing.T) { |
||||
tokenClaims := map[string]interface{}{} |
||||
userInfoClaims := map[string]interface{}{} |
||||
|
||||
testServer := testSetup(t, tokenClaims, userInfoClaims) |
||||
defer testServer.Close() |
||||
|
||||
conn := newConnector(t, testServer.URL) |
||||
|
||||
loginURL, err := conn.LoginURL(connector.Scopes{}, conn.redirectURI, "some-state") |
||||
expectEqual(t, err, nil) |
||||
|
||||
expectedURL, err := url.Parse(testServer.URL + "/authorize") |
||||
expectEqual(t, err, nil) |
||||
|
||||
values := url.Values{} |
||||
values.Add("client_id", "testClient") |
||||
values.Add("redirect_uri", conn.redirectURI) |
||||
values.Add("response_type", "code") |
||||
values.Add("scope", "openid groups") |
||||
values.Add("state", "some-state") |
||||
expectedURL.RawQuery = values.Encode() |
||||
|
||||
expectEqual(t, loginURL, expectedURL.String()) |
||||
} |
||||
|
||||
func TestHandleCallBackForGroupsInUserInfo(t *testing.T) { |
||||
|
||||
tokenClaims := map[string]interface{}{} |
||||
|
||||
userInfoClaims := map[string]interface{}{ |
||||
"name": "test-name", |
||||
"user_name": "test-username", |
||||
"user_id": "test-user-id", |
||||
"email": "test-email", |
||||
"email_verified": true, |
||||
"groups_key": []string{"admin-group", "user-group"}, |
||||
} |
||||
|
||||
testServer := testSetup(t, tokenClaims, userInfoClaims) |
||||
defer testServer.Close() |
||||
|
||||
conn := newConnector(t, testServer.URL) |
||||
req := newRequestWithAuthCode(t, testServer.URL, "some-code") |
||||
|
||||
identity, err := conn.HandleCallback(connector.Scopes{Groups: true}, req) |
||||
expectEqual(t, err, nil) |
||||
|
||||
sort.Strings(identity.Groups) |
||||
expectEqual(t, len(identity.Groups), 2) |
||||
expectEqual(t, identity.Groups[0], "admin-group") |
||||
expectEqual(t, identity.Groups[1], "user-group") |
||||
expectEqual(t, identity.Name, "test-name") |
||||
expectEqual(t, identity.Username, "test-username") |
||||
expectEqual(t, identity.Email, "test-email") |
||||
expectEqual(t, identity.EmailVerified, true) |
||||
} |
||||
|
||||
func TestHandleCallBackForGroupsInToken(t *testing.T) { |
||||
|
||||
tokenClaims := map[string]interface{}{ |
||||
"groups_key": []string{"test-group"}, |
||||
} |
||||
|
||||
userInfoClaims := map[string]interface{}{ |
||||
"name": "test-name", |
||||
"user_name": "test-username", |
||||
"user_id": "test-user-id", |
||||
"email": "test-email", |
||||
"email_verified": true, |
||||
} |
||||
|
||||
testServer := testSetup(t, tokenClaims, userInfoClaims) |
||||
defer testServer.Close() |
||||
|
||||
conn := newConnector(t, testServer.URL) |
||||
req := newRequestWithAuthCode(t, testServer.URL, "some-code") |
||||
|
||||
identity, err := conn.HandleCallback(connector.Scopes{Groups: true}, req) |
||||
expectEqual(t, err, nil) |
||||
|
||||
expectEqual(t, len(identity.Groups), 1) |
||||
expectEqual(t, identity.Groups[0], "test-group") |
||||
expectEqual(t, identity.Name, "test-name") |
||||
expectEqual(t, identity.Username, "test-username") |
||||
expectEqual(t, identity.Email, "test-email") |
||||
expectEqual(t, identity.EmailVerified, true) |
||||
} |
||||
|
||||
func testSetup(t *testing.T, tokenClaims map[string]interface{}, userInfoClaims map[string]interface{}) *httptest.Server { |
||||
|
||||
key, err := rsa.GenerateKey(rand.Reader, 1024) |
||||
if err != nil { |
||||
t.Fatal("Failed to generate rsa key", err) |
||||
} |
||||
|
||||
jwk := jose.JSONWebKey{ |
||||
Key: key, |
||||
KeyID: "some-key", |
||||
Algorithm: "RSA", |
||||
} |
||||
|
||||
mux := http.NewServeMux() |
||||
|
||||
mux.HandleFunc("/token", func(w http.ResponseWriter, r *http.Request) { |
||||
token, err := newToken(&jwk, tokenClaims) |
||||
if err != nil { |
||||
t.Fatal("unable to generate token", err) |
||||
} |
||||
|
||||
w.Header().Add("Content-Type", "application/json") |
||||
json.NewEncoder(w).Encode(&map[string]string{ |
||||
"access_token": token, |
||||
"id_token": token, |
||||
"token_type": "Bearer", |
||||
}) |
||||
}) |
||||
|
||||
mux.HandleFunc("/userinfo", func(w http.ResponseWriter, r *http.Request) { |
||||
w.Header().Add("Content-Type", "application/json") |
||||
json.NewEncoder(w).Encode(userInfoClaims) |
||||
}) |
||||
|
||||
return httptest.NewServer(mux) |
||||
} |
||||
|
||||
func newToken(key *jose.JSONWebKey, claims map[string]interface{}) (string, error) { |
||||
signingKey := jose.SigningKey{Key: key, Algorithm: jose.RS256} |
||||
|
||||
signer, err := jose.NewSigner(signingKey, &jose.SignerOptions{}) |
||||
if err != nil { |
||||
return "", fmt.Errorf("new signer: %v", err) |
||||
} |
||||
|
||||
payload, err := json.Marshal(claims) |
||||
if err != nil { |
||||
return "", fmt.Errorf("marshaling claims: %v", err) |
||||
} |
||||
|
||||
signature, err := signer.Sign(payload) |
||||
if err != nil { |
||||
return "", fmt.Errorf("signing payload: %v", err) |
||||
} |
||||
|
||||
return signature.CompactSerialize() |
||||
} |
||||
|
||||
func newConnector(t *testing.T, serverURL string) *oauthConnector { |
||||
testConfig := Config{ |
||||
ClientID: "testClient", |
||||
ClientSecret: "testSecret", |
||||
RedirectURI: serverURL + "/callback", |
||||
TokenURL: serverURL + "/token", |
||||
AuthorizationURL: serverURL + "/authorize", |
||||
UserInfoURL: serverURL + "/userinfo", |
||||
Scopes: []string{"openid", "groups"}, |
||||
GroupsKey: "groups_key", |
||||
} |
||||
|
||||
log := logrus.New() |
||||
|
||||
conn, err := testConfig.Open("id", log) |
||||
if err != nil { |
||||
t.Fatal(err) |
||||
} |
||||
|
||||
oauthConn, ok := conn.(*oauthConnector) |
||||
if !ok { |
||||
t.Fatal(errors.New("failed to convert to oauthConnector")) |
||||
} |
||||
|
||||
return oauthConn |
||||
} |
||||
|
||||
func newRequestWithAuthCode(t *testing.T, serverURL string, code string) *http.Request { |
||||
req, err := http.NewRequest("GET", serverURL, nil) |
||||
if err != nil { |
||||
t.Fatal("failed to create request", err) |
||||
} |
||||
|
||||
values := req.URL.Query() |
||||
values.Add("code", code) |
||||
req.URL.RawQuery = values.Encode() |
||||
|
||||
return req |
||||
} |
||||
|
||||
func expectEqual(t *testing.T, a interface{}, b interface{}) { |
||||
if !reflect.DeepEqual(a, b) { |
||||
t.Fatalf("Expected %+v to equal %+v", a, b) |
||||
} |
||||
} |
||||
Loading…
Reference in new issue