From 3e11ec45954a193fbbe0f59522d30d9300fb7671 Mon Sep 17 00:00:00 2001 From: Ben Dronen Date: Tue, 10 Feb 2026 20:35:22 -0500 Subject: [PATCH] feat(microsoft): support using client_assertion Signed-off-by: Ben Dronen --- connector/microsoft/microsoft.go | 83 ++++++++++++++++++++++++++- connector/microsoft/microsoft_test.go | 51 ++++++++++++++++ 2 files changed, 133 insertions(+), 1 deletion(-) diff --git a/connector/microsoft/microsoft.go b/connector/microsoft/microsoft.go index 2fcf6a75..e3b2c234 100644 --- a/connector/microsoft/microsoft.go +++ b/connector/microsoft/microsoft.go @@ -10,6 +10,8 @@ import ( "io" "log/slog" "net/http" + "net/url" + "os" "strings" "sync" "time" @@ -42,10 +44,11 @@ const ( scopeOfflineAccess = "offline_access" ) -// Config holds configuration options for microsoft logins. +// Config holds configuration options for Microsoft logins. type Config struct { ClientID string `json:"clientID"` ClientSecret string `json:"clientSecret"` + ClientAssertion string `json:"clientAssertion"` RedirectURI string `json:"redirectURI"` Tenant string `json:"tenant"` OnlySecurityGroups bool `json:"onlySecurityGroups"` @@ -73,6 +76,7 @@ func (c *Config) Open(id string, logger *slog.Logger) (connector.Connector, erro redirectURI: c.RedirectURI, clientID: c.ClientID, clientSecret: c.ClientSecret, + clientAssertion: c.ClientAssertion, tenant: c.Tenant, onlySecurityGroups: c.OnlySecurityGroups, groups: c.Groups, @@ -128,6 +132,7 @@ type microsoftConnector struct { redirectURI string clientID string clientSecret string + clientAssertion string tenant string onlySecurityGroups bool groupNameFormat GroupNameFormat @@ -191,6 +196,46 @@ func (c *microsoftConnector) LoginURL(scopes connector.Scopes, callbackURL, stat return c.oauth2Config(scopes).AuthCodeURL(state, options...), nil } +// assertionTransport is an http.RoundTripper that intercepts token endpoint requests +// and injects client_assertion parameters while removing client_secret. +type assertionTransport struct { + assertion string + tokenURL string + base http.RoundTripper +} + +func (t *assertionTransport) RoundTrip(req *http.Request) (*http.Response, error) { + // Only modify requests to the token endpoint + if req.URL.String() != t.tokenURL { + return t.base.RoundTrip(req) + } + + // Read the original request body + body, err := io.ReadAll(req.Body) + if err != nil { + return nil, fmt.Errorf("failed to read request body: %v", err) + } + req.Body.Close() + + // Parse the form data + values, err := url.ParseQuery(string(body)) + if err != nil { + return nil, fmt.Errorf("failed to parse request body: %v", err) + } + + // Remove client_secret and add client_assertion parameters + values.Del("client_secret") + values.Set("client_assertion", t.assertion) + values.Set("client_assertion_type", "urn:ietf:params:oauth:client-assertion-type:jwt-bearer") + + // Create new request with modified body + newBody := strings.NewReader(values.Encode()) + req.Body = io.NopCloser(newBody) + req.ContentLength = int64(len(values.Encode())) + + return t.base.RoundTrip(req) +} + func (c *microsoftConnector) HandleCallback(s connector.Scopes, r *http.Request) (identity connector.Identity, err error) { q := r.URL.Query() if errType := q.Get("error"); errType != "" { @@ -201,6 +246,24 @@ func (c *microsoftConnector) HandleCallback(s connector.Scopes, r *http.Request) ctx := r.Context() + // If using client assertion, wrap the HTTP client with a custom transport + if c.clientAssertion != "" { + assertionBytes, err := os.ReadFile(c.clientAssertion) + if err != nil { + return identity, fmt.Errorf("microsoft: failed to read client assertion: %v", err) + } + + // Create HTTP client with custom transport that injects client_assertion + httpClient := &http.Client{ + Transport: &assertionTransport{ + assertion: string(assertionBytes), + tokenURL: oauth2Config.Endpoint.TokenURL, + base: http.DefaultTransport, + }, + } + ctx = context.WithValue(ctx, oauth2.HTTPClient, httpClient) + } + token, err := oauth2Config.Exchange(ctx, q.Get("code")) if err != nil { return identity, fmt.Errorf("microsoft: failed to get token: %v", err) @@ -290,6 +353,24 @@ func (c *microsoftConnector) Refresh(ctx context.Context, s connector.Scopes, id Expiry: data.Expiry, } + // If using client assertion, wrap the HTTP client with a custom transport + if c.clientAssertion != "" { + assertionBytes, err := os.ReadFile(c.clientAssertion) + if err != nil { + return identity, fmt.Errorf("microsoft: failed to read client assertion: %v", err) + } + + oauth2Config := c.oauth2Config(s) + httpClient := &http.Client{ + Transport: &assertionTransport{ + assertion: string(assertionBytes), + tokenURL: oauth2Config.Endpoint.TokenURL, + base: http.DefaultTransport, + }, + } + ctx = context.WithValue(ctx, oauth2.HTTPClient, httpClient) + } + client := oauth2.NewClient(ctx, ¬ifyRefreshTokenSource{ new: c.oauth2Config(s).TokenSource(ctx, tok), t: tok, diff --git a/connector/microsoft/microsoft_test.go b/connector/microsoft/microsoft_test.go index 67be660f..97e6c0af 100644 --- a/connector/microsoft/microsoft_test.go +++ b/connector/microsoft/microsoft_test.go @@ -3,6 +3,7 @@ package microsoft import ( "encoding/json" "fmt" + "io" "net/http" "net/http/httptest" "net/url" @@ -119,6 +120,56 @@ func TestUserGroupsFromGraphAPI(t *testing.T) { expectEquals(t, identity.Groups, []string{"a", "b"}) } +func TestClientAssertionTokenExchange(t *testing.T) { + assertion := "dummy-jwt-assertion" + file, err := os.CreateTemp("", "assertion.jwt") + if err != nil { + t.Fatalf("failed to create temp file: %v", err) + } + defer os.Remove(file.Name()) + file.WriteString(assertion) + file.Close() + + tokenCalled := false + var receivedAssertion, receivedSecret string + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method == "POST" && r.URL.Path == "/testtenant/oauth2/v2.0/token" { + bodyBytes, _ := io.ReadAll(r.Body) + r.Body.Close() + form, _ := url.ParseQuery(string(bodyBytes)) + receivedAssertion = form.Get("client_assertion") + receivedSecret = form.Get("client_secret") + tokenCalled = true + w.Header().Set("Content-Type", "application/json") + w.Write([]byte(`{"access_token": "token", "expires_in": 3600}`)) + } + })) + defer ts.Close() + + conn := microsoftConnector{ + apiURL: ts.URL, + graphURL: ts.URL, + redirectURI: "https://test.com", + clientID: clientID, + clientSecret: "should-not-be-used", + tenant: "testtenant", + clientAssertion: file.Name(), + } + + req, _ := http.NewRequest("GET", ts.URL, nil) + conn.HandleCallback(connector.Scopes{}, req) + + if !tokenCalled { + t.Errorf("Token endpoint was not called") + } + if receivedAssertion != assertion { + t.Errorf("Expected client_assertion to be %q, got %q", assertion, receivedAssertion) + } + if receivedSecret != "" { + t.Errorf("Expected client_secret to be empty, got %q", receivedSecret) + } +} + func newTestServer(responses map[string]testResponse) *httptest.Server { s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { response, found := responses[r.RequestURI]