Browse Source

feat(microsoft): support using client_assertion

Signed-off-by: Ben Dronen <dronenb@users.noreply.github.com>
pull/4521/head
Ben Dronen 1 month ago
parent
commit
3e11ec4595
No known key found for this signature in database
GPG Key ID: 96D0948A3EFFABFC
  1. 83
      connector/microsoft/microsoft.go
  2. 51
      connector/microsoft/microsoft_test.go

83
connector/microsoft/microsoft.go

@ -10,6 +10,8 @@ import (
"io"
"log/slog"
"net/http"
"net/url"
"os"
"strings"
"sync"
"time"
@ -42,10 +44,11 @@ const (
scopeOfflineAccess = "offline_access"
)
// Config holds configuration options for microsoft logins.
// Config holds configuration options for Microsoft logins.
type Config struct {
ClientID string `json:"clientID"`
ClientSecret string `json:"clientSecret"`
ClientAssertion string `json:"clientAssertion"`
RedirectURI string `json:"redirectURI"`
Tenant string `json:"tenant"`
OnlySecurityGroups bool `json:"onlySecurityGroups"`
@ -73,6 +76,7 @@ func (c *Config) Open(id string, logger *slog.Logger) (connector.Connector, erro
redirectURI: c.RedirectURI,
clientID: c.ClientID,
clientSecret: c.ClientSecret,
clientAssertion: c.ClientAssertion,
tenant: c.Tenant,
onlySecurityGroups: c.OnlySecurityGroups,
groups: c.Groups,
@ -128,6 +132,7 @@ type microsoftConnector struct {
redirectURI string
clientID string
clientSecret string
clientAssertion string
tenant string
onlySecurityGroups bool
groupNameFormat GroupNameFormat
@ -191,6 +196,46 @@ func (c *microsoftConnector) LoginURL(scopes connector.Scopes, callbackURL, stat
return c.oauth2Config(scopes).AuthCodeURL(state, options...), nil
}
// assertionTransport is an http.RoundTripper that intercepts token endpoint requests
// and injects client_assertion parameters while removing client_secret.
type assertionTransport struct {
assertion string
tokenURL string
base http.RoundTripper
}
func (t *assertionTransport) RoundTrip(req *http.Request) (*http.Response, error) {
// Only modify requests to the token endpoint
if req.URL.String() != t.tokenURL {
return t.base.RoundTrip(req)
}
// Read the original request body
body, err := io.ReadAll(req.Body)
if err != nil {
return nil, fmt.Errorf("failed to read request body: %v", err)
}
req.Body.Close()
// Parse the form data
values, err := url.ParseQuery(string(body))
if err != nil {
return nil, fmt.Errorf("failed to parse request body: %v", err)
}
// Remove client_secret and add client_assertion parameters
values.Del("client_secret")
values.Set("client_assertion", t.assertion)
values.Set("client_assertion_type", "urn:ietf:params:oauth:client-assertion-type:jwt-bearer")
// Create new request with modified body
newBody := strings.NewReader(values.Encode())
req.Body = io.NopCloser(newBody)
req.ContentLength = int64(len(values.Encode()))
return t.base.RoundTrip(req)
}
func (c *microsoftConnector) HandleCallback(s connector.Scopes, r *http.Request) (identity connector.Identity, err error) {
q := r.URL.Query()
if errType := q.Get("error"); errType != "" {
@ -201,6 +246,24 @@ func (c *microsoftConnector) HandleCallback(s connector.Scopes, r *http.Request)
ctx := r.Context()
// If using client assertion, wrap the HTTP client with a custom transport
if c.clientAssertion != "" {
assertionBytes, err := os.ReadFile(c.clientAssertion)
if err != nil {
return identity, fmt.Errorf("microsoft: failed to read client assertion: %v", err)
}
// Create HTTP client with custom transport that injects client_assertion
httpClient := &http.Client{
Transport: &assertionTransport{
assertion: string(assertionBytes),
tokenURL: oauth2Config.Endpoint.TokenURL,
base: http.DefaultTransport,
},
}
ctx = context.WithValue(ctx, oauth2.HTTPClient, httpClient)
}
token, err := oauth2Config.Exchange(ctx, q.Get("code"))
if err != nil {
return identity, fmt.Errorf("microsoft: failed to get token: %v", err)
@ -290,6 +353,24 @@ func (c *microsoftConnector) Refresh(ctx context.Context, s connector.Scopes, id
Expiry: data.Expiry,
}
// If using client assertion, wrap the HTTP client with a custom transport
if c.clientAssertion != "" {
assertionBytes, err := os.ReadFile(c.clientAssertion)
if err != nil {
return identity, fmt.Errorf("microsoft: failed to read client assertion: %v", err)
}
oauth2Config := c.oauth2Config(s)
httpClient := &http.Client{
Transport: &assertionTransport{
assertion: string(assertionBytes),
tokenURL: oauth2Config.Endpoint.TokenURL,
base: http.DefaultTransport,
},
}
ctx = context.WithValue(ctx, oauth2.HTTPClient, httpClient)
}
client := oauth2.NewClient(ctx, &notifyRefreshTokenSource{
new: c.oauth2Config(s).TokenSource(ctx, tok),
t: tok,

51
connector/microsoft/microsoft_test.go

@ -3,6 +3,7 @@ package microsoft
import (
"encoding/json"
"fmt"
"io"
"net/http"
"net/http/httptest"
"net/url"
@ -119,6 +120,56 @@ func TestUserGroupsFromGraphAPI(t *testing.T) {
expectEquals(t, identity.Groups, []string{"a", "b"})
}
func TestClientAssertionTokenExchange(t *testing.T) {
assertion := "dummy-jwt-assertion"
file, err := os.CreateTemp("", "assertion.jwt")
if err != nil {
t.Fatalf("failed to create temp file: %v", err)
}
defer os.Remove(file.Name())
file.WriteString(assertion)
file.Close()
tokenCalled := false
var receivedAssertion, receivedSecret string
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method == "POST" && r.URL.Path == "/testtenant/oauth2/v2.0/token" {
bodyBytes, _ := io.ReadAll(r.Body)
r.Body.Close()
form, _ := url.ParseQuery(string(bodyBytes))
receivedAssertion = form.Get("client_assertion")
receivedSecret = form.Get("client_secret")
tokenCalled = true
w.Header().Set("Content-Type", "application/json")
w.Write([]byte(`{"access_token": "token", "expires_in": 3600}`))
}
}))
defer ts.Close()
conn := microsoftConnector{
apiURL: ts.URL,
graphURL: ts.URL,
redirectURI: "https://test.com",
clientID: clientID,
clientSecret: "should-not-be-used",
tenant: "testtenant",
clientAssertion: file.Name(),
}
req, _ := http.NewRequest("GET", ts.URL, nil)
conn.HandleCallback(connector.Scopes{}, req)
if !tokenCalled {
t.Errorf("Token endpoint was not called")
}
if receivedAssertion != assertion {
t.Errorf("Expected client_assertion to be %q, got %q", assertion, receivedAssertion)
}
if receivedSecret != "" {
t.Errorf("Expected client_secret to be empty, got %q", receivedSecret)
}
}
func newTestServer(responses map[string]testResponse) *httptest.Server {
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
response, found := responses[r.RequestURI]

Loading…
Cancel
Save