|
|
|
|
@ -10,6 +10,8 @@ import (
|
|
|
|
|
"io" |
|
|
|
|
"log/slog" |
|
|
|
|
"net/http" |
|
|
|
|
"net/url" |
|
|
|
|
"os" |
|
|
|
|
"strings" |
|
|
|
|
"sync" |
|
|
|
|
"time" |
|
|
|
|
@ -42,10 +44,11 @@ const (
|
|
|
|
|
scopeOfflineAccess = "offline_access" |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
// Config holds configuration options for microsoft logins.
|
|
|
|
|
// Config holds configuration options for Microsoft logins.
|
|
|
|
|
type Config struct { |
|
|
|
|
ClientID string `json:"clientID"` |
|
|
|
|
ClientSecret string `json:"clientSecret"` |
|
|
|
|
ClientAssertion string `json:"clientAssertion"` |
|
|
|
|
RedirectURI string `json:"redirectURI"` |
|
|
|
|
Tenant string `json:"tenant"` |
|
|
|
|
OnlySecurityGroups bool `json:"onlySecurityGroups"` |
|
|
|
|
@ -73,6 +76,7 @@ func (c *Config) Open(id string, logger *slog.Logger) (connector.Connector, erro
|
|
|
|
|
redirectURI: c.RedirectURI, |
|
|
|
|
clientID: c.ClientID, |
|
|
|
|
clientSecret: c.ClientSecret, |
|
|
|
|
clientAssertion: c.ClientAssertion, |
|
|
|
|
tenant: c.Tenant, |
|
|
|
|
onlySecurityGroups: c.OnlySecurityGroups, |
|
|
|
|
groups: c.Groups, |
|
|
|
|
@ -128,6 +132,7 @@ type microsoftConnector struct {
|
|
|
|
|
redirectURI string |
|
|
|
|
clientID string |
|
|
|
|
clientSecret string |
|
|
|
|
clientAssertion string |
|
|
|
|
tenant string |
|
|
|
|
onlySecurityGroups bool |
|
|
|
|
groupNameFormat GroupNameFormat |
|
|
|
|
@ -191,6 +196,46 @@ func (c *microsoftConnector) LoginURL(scopes connector.Scopes, callbackURL, stat
|
|
|
|
|
return c.oauth2Config(scopes).AuthCodeURL(state, options...), nil |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
// assertionTransport is an http.RoundTripper that intercepts token endpoint requests
|
|
|
|
|
// and injects client_assertion parameters while removing client_secret.
|
|
|
|
|
type assertionTransport struct { |
|
|
|
|
assertion string |
|
|
|
|
tokenURL string |
|
|
|
|
base http.RoundTripper |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
func (t *assertionTransport) RoundTrip(req *http.Request) (*http.Response, error) { |
|
|
|
|
// Only modify requests to the token endpoint
|
|
|
|
|
if req.URL.String() != t.tokenURL { |
|
|
|
|
return t.base.RoundTrip(req) |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
// Read the original request body
|
|
|
|
|
body, err := io.ReadAll(req.Body) |
|
|
|
|
if err != nil { |
|
|
|
|
return nil, fmt.Errorf("failed to read request body: %v", err) |
|
|
|
|
} |
|
|
|
|
req.Body.Close() |
|
|
|
|
|
|
|
|
|
// Parse the form data
|
|
|
|
|
values, err := url.ParseQuery(string(body)) |
|
|
|
|
if err != nil { |
|
|
|
|
return nil, fmt.Errorf("failed to parse request body: %v", err) |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
// Remove client_secret and add client_assertion parameters
|
|
|
|
|
values.Del("client_secret") |
|
|
|
|
values.Set("client_assertion", t.assertion) |
|
|
|
|
values.Set("client_assertion_type", "urn:ietf:params:oauth:client-assertion-type:jwt-bearer") |
|
|
|
|
|
|
|
|
|
// Create new request with modified body
|
|
|
|
|
newBody := strings.NewReader(values.Encode()) |
|
|
|
|
req.Body = io.NopCloser(newBody) |
|
|
|
|
req.ContentLength = int64(len(values.Encode())) |
|
|
|
|
|
|
|
|
|
return t.base.RoundTrip(req) |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
func (c *microsoftConnector) HandleCallback(s connector.Scopes, r *http.Request) (identity connector.Identity, err error) { |
|
|
|
|
q := r.URL.Query() |
|
|
|
|
if errType := q.Get("error"); errType != "" { |
|
|
|
|
@ -201,6 +246,24 @@ func (c *microsoftConnector) HandleCallback(s connector.Scopes, r *http.Request)
|
|
|
|
|
|
|
|
|
|
ctx := r.Context() |
|
|
|
|
|
|
|
|
|
// If using client assertion, wrap the HTTP client with a custom transport
|
|
|
|
|
if c.clientAssertion != "" { |
|
|
|
|
assertionBytes, err := os.ReadFile(c.clientAssertion) |
|
|
|
|
if err != nil { |
|
|
|
|
return identity, fmt.Errorf("microsoft: failed to read client assertion: %v", err) |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
// Create HTTP client with custom transport that injects client_assertion
|
|
|
|
|
httpClient := &http.Client{ |
|
|
|
|
Transport: &assertionTransport{ |
|
|
|
|
assertion: string(assertionBytes), |
|
|
|
|
tokenURL: oauth2Config.Endpoint.TokenURL, |
|
|
|
|
base: http.DefaultTransport, |
|
|
|
|
}, |
|
|
|
|
} |
|
|
|
|
ctx = context.WithValue(ctx, oauth2.HTTPClient, httpClient) |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
token, err := oauth2Config.Exchange(ctx, q.Get("code")) |
|
|
|
|
if err != nil { |
|
|
|
|
return identity, fmt.Errorf("microsoft: failed to get token: %v", err) |
|
|
|
|
@ -290,6 +353,24 @@ func (c *microsoftConnector) Refresh(ctx context.Context, s connector.Scopes, id
|
|
|
|
|
Expiry: data.Expiry, |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
// If using client assertion, wrap the HTTP client with a custom transport
|
|
|
|
|
if c.clientAssertion != "" { |
|
|
|
|
assertionBytes, err := os.ReadFile(c.clientAssertion) |
|
|
|
|
if err != nil { |
|
|
|
|
return identity, fmt.Errorf("microsoft: failed to read client assertion: %v", err) |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
oauth2Config := c.oauth2Config(s) |
|
|
|
|
httpClient := &http.Client{ |
|
|
|
|
Transport: &assertionTransport{ |
|
|
|
|
assertion: string(assertionBytes), |
|
|
|
|
tokenURL: oauth2Config.Endpoint.TokenURL, |
|
|
|
|
base: http.DefaultTransport, |
|
|
|
|
}, |
|
|
|
|
} |
|
|
|
|
ctx = context.WithValue(ctx, oauth2.HTTPClient, httpClient) |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
client := oauth2.NewClient(ctx, ¬ifyRefreshTokenSource{ |
|
|
|
|
new: c.oauth2Config(s).TokenSource(ctx, tok), |
|
|
|
|
t: tok, |
|
|
|
|
|