|
|
|
|
@ -37,8 +37,7 @@ type app struct {
|
|
|
|
|
// or does it use "access_type=offline" (e.g. Google)?
|
|
|
|
|
offlineAsScope bool |
|
|
|
|
|
|
|
|
|
ctx context.Context |
|
|
|
|
cancel context.CancelFunc |
|
|
|
|
client *http.Client |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
// return an HTTP client which trusts the provided root CAs.
|
|
|
|
|
@ -118,31 +117,31 @@ func cmd() *cobra.Command {
|
|
|
|
|
return fmt.Errorf("parse listen address: %v", err) |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
a.ctx, a.cancel = context.WithCancel(context.Background()) |
|
|
|
|
|
|
|
|
|
if rootCAs != "" { |
|
|
|
|
client, err := httpClientForRootCAs(rootCAs) |
|
|
|
|
if err != nil { |
|
|
|
|
return err |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
// This sets the OAuth2 client and oidc client.
|
|
|
|
|
a.ctx = context.WithValue(a.ctx, oauth2.HTTPClient, client) |
|
|
|
|
a.client = client |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
if debug { |
|
|
|
|
client, ok := a.ctx.Value(oauth2.HTTPClient).(*http.Client) |
|
|
|
|
if ok { |
|
|
|
|
client.Transport = debugTransport{client.Transport} |
|
|
|
|
} else { |
|
|
|
|
a.ctx = context.WithValue(a.ctx, oauth2.HTTPClient, &http.Client{ |
|
|
|
|
if a.client == nil { |
|
|
|
|
a.client = &http.Client{ |
|
|
|
|
Transport: debugTransport{http.DefaultTransport}, |
|
|
|
|
}) |
|
|
|
|
} |
|
|
|
|
} else { |
|
|
|
|
a.client.Transport = debugTransport{a.client.Transport} |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
if a.client == nil { |
|
|
|
|
a.client = http.DefaultClient |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
// TODO(ericchiang): Retry with backoff
|
|
|
|
|
provider, err := oidc.NewProvider(a.ctx, issuerURL) |
|
|
|
|
ctx := oidc.ClientContext(context.Background(), a.client) |
|
|
|
|
provider, err := oidc.NewProvider(ctx, issuerURL) |
|
|
|
|
if err != nil { |
|
|
|
|
return fmt.Errorf("Failed to query provider %q: %v", issuerURL, err) |
|
|
|
|
} |
|
|
|
|
@ -258,6 +257,8 @@ func (a *app) handleCallback(w http.ResponseWriter, r *http.Request) {
|
|
|
|
|
err error |
|
|
|
|
token *oauth2.Token |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
ctx := oidc.ClientContext(r.Context(), a.client) |
|
|
|
|
oauth2Config := a.oauth2Config(nil) |
|
|
|
|
switch r.Method { |
|
|
|
|
case "GET": |
|
|
|
|
@ -275,7 +276,7 @@ func (a *app) handleCallback(w http.ResponseWriter, r *http.Request) {
|
|
|
|
|
http.Error(w, fmt.Sprintf("expected state %q got %q", exampleAppState, state), http.StatusBadRequest) |
|
|
|
|
return |
|
|
|
|
} |
|
|
|
|
token, err = oauth2Config.Exchange(a.ctx, code) |
|
|
|
|
token, err = oauth2Config.Exchange(ctx, code) |
|
|
|
|
case "POST": |
|
|
|
|
// Form request from frontend to refresh a token.
|
|
|
|
|
refresh := r.FormValue("refresh_token") |
|
|
|
|
@ -287,7 +288,7 @@ func (a *app) handleCallback(w http.ResponseWriter, r *http.Request) {
|
|
|
|
|
RefreshToken: refresh, |
|
|
|
|
Expiry: time.Now().Add(-time.Hour), |
|
|
|
|
} |
|
|
|
|
token, err = oauth2Config.TokenSource(r.Context(), t).Token() |
|
|
|
|
token, err = oauth2Config.TokenSource(ctx, t).Token() |
|
|
|
|
default: |
|
|
|
|
http.Error(w, fmt.Sprintf("method not implemented: %s", r.Method), http.StatusBadRequest) |
|
|
|
|
return |
|
|
|
|
|