From 20fdfef99bc9f0dd2872355aa1bf0f417d78307d Mon Sep 17 00:00:00 2001 From: "maksim.nabokikh" Date: Fri, 6 Feb 2026 09:36:14 +0100 Subject: [PATCH] feat: Add Vault signer for JWT Introduce a signer interface with Vault and local implementations for JWT signing. Signed-off-by: maksim.nabokikh --- cmd/dex/config.go | 3 + cmd/dex/serve.go | 1 + examples/config-dev.yaml | 10 + go.mod | 15 ++ go.sum | 41 ++++- server/handlers.go | 29 +-- server/introspectionhandler.go | 2 +- server/oauth2.go | 39 ++-- server/oauth2_test.go | 9 +- server/rotation.go | 30 --- server/server.go | 25 ++- server/signer.go | 22 +++ server/signer_local.go | 105 +++++++++++ server/signer_vault.go | 327 +++++++++++++++++++++++++++++++++ 14 files changed, 585 insertions(+), 73 deletions(-) create mode 100644 server/signer.go create mode 100644 server/signer_local.go create mode 100644 server/signer_vault.go diff --git a/cmd/dex/config.go b/cmd/dex/config.go index c76ff030..cbe7b000 100644 --- a/cmd/dex/config.go +++ b/cmd/dex/config.go @@ -35,6 +35,9 @@ type Config struct { Frontend server.WebConfig `json:"frontend"` + // Signer configuration controls signing of JWT tokens issued by Dex. + Signer server.SignerConfig `json:"signer"` + // StaticConnectors are user defined connectors specified in the ConfigMap // Write operations, like updating a connector, will fail. StaticConnectors []Connector `json:"connectors"` diff --git a/cmd/dex/serve.go b/cmd/dex/serve.go index ac715e60..b0c49dc3 100644 --- a/cmd/dex/serve.go +++ b/cmd/dex/serve.go @@ -307,6 +307,7 @@ func runServe(options serveOptions) error { PrometheusRegistry: prometheusRegistry, HealthChecker: healthChecker, ContinueOnConnectorFailure: featureflags.ContinueOnConnectorFailure.Enabled(), + Signer: c.Signer, } if c.Expiry.SigningKeys != "" { signingKeys, err := time.ParseDuration(c.Expiry.SigningKeys) diff --git a/examples/config-dev.yaml b/examples/config-dev.yaml index 0fdf350c..7857c04f 100644 --- a/examples/config-dev.yaml +++ b/examples/config-dev.yaml @@ -169,3 +169,13 @@ staticPasswords: - "team-a" - "team-a/admins" userID: "08a8684b-db88-4b73-90a9-3cd1661f5466" + +# Settings for signing JWT tokens. Available options: +# - "local": use local keys (only RSA keys supported) +# - "vault": use Vault Transit backend (RSA and EC keys supported) +# signer: +# type: vault +# vault: +# addr: http://127.0.0.1:8200 +# token: root +# keyName: dex-key diff --git a/go.mod b/go.mod index b00f0f76..ffbe54f9 100644 --- a/go.mod +++ b/go.mod @@ -20,6 +20,7 @@ require ( github.com/gorilla/handlers v1.5.2 github.com/gorilla/mux v1.8.1 github.com/grpc-ecosystem/go-grpc-prometheus v1.2.0 + github.com/hashicorp/vault/api v1.22.0 github.com/kylelemons/godebug v1.1.0 github.com/lib/pq v1.11.1 github.com/mattermost/xml-roundtrip-validator v0.1.0 @@ -54,6 +55,7 @@ require ( github.com/apparentlymart/go-textseg/v15 v15.0.0 // indirect github.com/beorn7/perks v1.0.1 // indirect github.com/bmatcuk/doublestar v1.3.4 // indirect + github.com/cenkalti/backoff/v4 v4.3.0 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/coreos/go-semver v0.3.1 // indirect github.com/coreos/go-systemd/v22 v22.5.0 // indirect @@ -70,13 +72,24 @@ require ( github.com/googleapis/enterprise-certificate-proxy v0.3.11 // indirect github.com/googleapis/gax-go/v2 v2.16.0 // indirect github.com/grpc-ecosystem/grpc-gateway/v2 v2.26.3 // indirect + github.com/hashicorp/errwrap v1.1.0 // indirect + github.com/hashicorp/go-cleanhttp v0.5.2 // indirect + github.com/hashicorp/go-multierror v1.1.1 // indirect + github.com/hashicorp/go-retryablehttp v0.7.8 // indirect + github.com/hashicorp/go-rootcerts v1.0.2 // indirect + github.com/hashicorp/go-secure-stdlib/parseutil v0.2.0 // indirect + github.com/hashicorp/go-secure-stdlib/strutil v0.1.2 // indirect + github.com/hashicorp/go-sockaddr v1.0.7 // indirect + github.com/hashicorp/hcl v1.0.1-vault-7 // indirect github.com/hashicorp/hcl/v2 v2.18.1 // indirect github.com/huandu/xstrings v1.5.0 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/jonboulle/clockwork v0.5.0 // indirect github.com/mattn/go-runewidth v0.0.9 // indirect github.com/mitchellh/copystructure v1.2.0 // indirect + github.com/mitchellh/go-homedir v1.1.0 // indirect github.com/mitchellh/go-wordwrap v1.0.1 // indirect + github.com/mitchellh/mapstructure v1.5.0 // indirect github.com/mitchellh/reflectwalk v1.0.2 // indirect github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect github.com/olekukonko/tablewriter v0.0.5 // indirect @@ -84,6 +97,7 @@ require ( github.com/prometheus/client_model v0.6.2 // indirect github.com/prometheus/common v0.66.1 // indirect github.com/prometheus/procfs v0.16.1 // indirect + github.com/ryanuber/go-glob v1.0.0 // indirect github.com/shopspring/decimal v1.4.0 // indirect github.com/spf13/cast v1.7.0 // indirect github.com/spf13/pflag v1.0.9 // indirect @@ -102,6 +116,7 @@ require ( golang.org/x/sync v0.19.0 // indirect golang.org/x/sys v0.40.0 // indirect golang.org/x/text v0.33.0 // indirect + golang.org/x/time v0.14.0 // indirect golang.org/x/tools v0.40.0 // indirect golang.org/x/tools/go/packages/packagestest v0.1.1-deprecated // indirect google.golang.org/genproto/googleapis/api v0.0.0-20251202230838-ff82c1b0f217 // indirect diff --git a/go.sum b/go.sum index c46db6d8..30ba0b0e 100644 --- a/go.sum +++ b/go.sum @@ -38,6 +38,8 @@ github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= github.com/bmatcuk/doublestar v1.3.4 h1:gPypJ5xD31uhX6Tf54sDPUOBXTqKH4c9aPY66CyQrS0= github.com/bmatcuk/doublestar v1.3.4/go.mod h1:wiQtGV+rzVYxB7WIlirSN++5HPtPlXEo9MEoZQC/PmE= +github.com/cenkalti/backoff/v4 v4.3.0 h1:MyRJ/UdXutAwSAT+s3wNd7MfTIcy71VQueUuFK343L8= +github.com/cenkalti/backoff/v4 v4.3.0/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE= github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/coreos/go-oidc/v3 v3.17.0 h1:hWBGaQfbi0iVviX4ibC7bk8OKT5qNr4klBaCHVNvehc= @@ -50,6 +52,8 @@ github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6N github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/fatih/color v1.18.0 h1:S8gINlzdQ840/4pfAwic/ZE0djQEH3wM94VfqLTZcOM= +github.com/fatih/color v1.18.0/go.mod h1:4FelSpRwEGDpQ12mAdzqdOukCy4u8WUtOY6lkT/6HfU= github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg= github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U= github.com/fortytw2/leaktest v1.3.0 h1:u8491cBMTQ8ft8aeV+adlcytMZylmA5nnwwkRZjI8vw= @@ -75,8 +79,8 @@ github.com/go-openapi/inflect v0.19.0 h1:9jCH9scKIbHeV9m12SmPilScz6krDxKRasNNSNP github.com/go-openapi/inflect v0.19.0/go.mod h1:lHpZVlpIQqLyKwJ4N+YSc9hchQy/i12fJykb83CRBH4= github.com/go-sql-driver/mysql v1.9.3 h1:U/N249h2WzJ3Ukj8SowVFjdtZKfu9vlLZxjPXV1aweo= github.com/go-sql-driver/mysql v1.9.3/go.mod h1:qn46aNg1333BRMNU69Lq93t8du/dwxI64Gl8i5p1WMU= -github.com/go-test/deep v1.0.3 h1:ZrJSEWsXzPOxaZnFteGEfooLba+ju3FYIbOrS+rQd68= -github.com/go-test/deep v1.0.3/go.mod h1:wGDj63lr65AM2AQyKZd/NYHGb0R+1RLqB8NKt3aSFNA= +github.com/go-test/deep v1.1.1 h1:0r/53hagsehfO4bzD2Pgr/+RgHqhmf+k1Bpse2cTu1U= +github.com/go-test/deep v1.1.1/go.mod h1:5C2ZWiW0ErCdrYzpqxLbTX7MG14M9iiw8DgHncVwcsE= github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= @@ -100,10 +104,33 @@ github.com/grpc-ecosystem/go-grpc-prometheus v1.2.0 h1:Ovs26xHkKqVztRpIrF/92Bcuy github.com/grpc-ecosystem/go-grpc-prometheus v1.2.0/go.mod h1:8NvIoxWQoOIhqOTXgfV/d3M/q6VIi02HzZEHgUlZvzk= github.com/grpc-ecosystem/grpc-gateway/v2 v2.26.3 h1:5ZPtiqj0JL5oKWmcsq4VMaAW5ukBEgSGXEN89zeH1Jo= github.com/grpc-ecosystem/grpc-gateway/v2 v2.26.3/go.mod h1:ndYquD05frm2vACXE1nsccT4oJzjhw2arTS2cpUD1PI= +github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= +github.com/hashicorp/errwrap v1.1.0 h1:OxrOeh75EUXMY8TBjag2fzXGZ40LB6IKw45YeGUDY2I= +github.com/hashicorp/errwrap v1.1.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= +github.com/hashicorp/go-cleanhttp v0.5.2 h1:035FKYIWjmULyFRBKPs8TBQoi0x6d9G4xc9neXJWAZQ= +github.com/hashicorp/go-cleanhttp v0.5.2/go.mod h1:kO/YDlP8L1346E6Sodw+PrpBSV4/SoxCXGY6BqNFT48= +github.com/hashicorp/go-hclog v1.6.3 h1:Qr2kF+eVWjTiYmU7Y31tYlP1h0q/X3Nl3tPGdaB11/k= +github.com/hashicorp/go-hclog v1.6.3/go.mod h1:W4Qnvbt70Wk/zYJryRzDRU/4r0kIg0PVHBcfoyhpF5M= +github.com/hashicorp/go-multierror v1.1.1 h1:H5DkEtf6CXdFp0N0Em5UCwQpXMWke8IA0+lD48awMYo= +github.com/hashicorp/go-multierror v1.1.1/go.mod h1:iw975J/qwKPdAO1clOe2L8331t/9/fmwbPZ6JB6eMoM= +github.com/hashicorp/go-retryablehttp v0.7.8 h1:ylXZWnqa7Lhqpk0L1P1LzDtGcCR0rPVUrx/c8Unxc48= +github.com/hashicorp/go-retryablehttp v0.7.8/go.mod h1:rjiScheydd+CxvumBsIrFKlx3iS0jrZ7LvzFGFmuKbw= +github.com/hashicorp/go-rootcerts v1.0.2 h1:jzhAVGtqPKbwpyCPELlgNWhE1znq+qwJtW5Oi2viEzc= +github.com/hashicorp/go-rootcerts v1.0.2/go.mod h1:pqUvnprVnM5bf7AOirdbb01K4ccR319Vf4pU3K5EGc8= +github.com/hashicorp/go-secure-stdlib/parseutil v0.2.0 h1:U+kC2dOhMFQctRfhK0gRctKAPTloZdMU5ZJxaesJ/VM= +github.com/hashicorp/go-secure-stdlib/parseutil v0.2.0/go.mod h1:Ll013mhdmsVDuoIXVfBtvgGJsXDYkTw1kooNcoCXuE0= +github.com/hashicorp/go-secure-stdlib/strutil v0.1.2 h1:kes8mmyCpxJsI7FTwtzRqEy9CdjCtrXrXGuOpxEA7Ts= +github.com/hashicorp/go-secure-stdlib/strutil v0.1.2/go.mod h1:Gou2R9+il93BqX25LAKCLuM+y9U2T4hlwvT1yprcna4= +github.com/hashicorp/go-sockaddr v1.0.7 h1:G+pTkSO01HpR5qCxg7lxfsFEZaG+C0VssTy/9dbT+Fw= +github.com/hashicorp/go-sockaddr v1.0.7/go.mod h1:FZQbEYa1pxkQ7WLpyXJ6cbjpT8q0YgQaK/JakXqGyWw= github.com/hashicorp/go-uuid v1.0.3 h1:2gKiV6YVmrJ1i2CKKa9obLvRieoRGviZFL26PcT/Co8= github.com/hashicorp/go-uuid v1.0.3/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro= +github.com/hashicorp/hcl v1.0.1-vault-7 h1:ag5OxFVy3QYTFTJODRzTKVZ6xvdfLLCA1cy/Y6xGI0I= +github.com/hashicorp/hcl v1.0.1-vault-7/go.mod h1:XYhtn6ijBSAj6n4YqAaf7RBPS4I06AItNorpy+MoQNM= github.com/hashicorp/hcl/v2 v2.18.1 h1:6nxnOJFku1EuSawSD81fuviYUV8DxFr3fp2dUi3ZYSo= github.com/hashicorp/hcl/v2 v2.18.1/go.mod h1:ThLC89FV4p9MPW804KVbe/cEXoQ8NZEh+JtMeeGErHE= +github.com/hashicorp/vault/api v1.22.0 h1:+HYFquE35/B74fHoIeXlZIP2YADVboaPjaSicHEZiH0= +github.com/hashicorp/vault/api v1.22.0/go.mod h1:IUZA2cDvr4Ok3+NtK2Oq/r+lJeXkeCrHRmqdyWfpmGM= github.com/huandu/xstrings v1.5.0 h1:2ag3IFq9ZDANvthTwTiqSSZLjDc+BedvHPAp5tJy2TI= github.com/huandu/xstrings v1.5.0/go.mod h1:y5/lhBue+AyNmUVz9RLU9xbLR0o4KIIExikq4ovT0aE= github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= @@ -139,14 +166,22 @@ github.com/lib/pq v1.11.1 h1:wuChtj2hfsGmmx3nf1m7xC2XpK6OtelS2shMY+bGMtI= github.com/lib/pq v1.11.1/go.mod h1:/p+8NSbOcwzAEI7wiMXFlgydTwcgTr3OSKMsD2BitpA= github.com/mattermost/xml-roundtrip-validator v0.1.0 h1:RXbVD2UAl7A7nOTR4u7E3ILa4IbtvKBHw64LDsmu9hU= github.com/mattermost/xml-roundtrip-validator v0.1.0/go.mod h1:qccnGMcpgwcNaBnxqpJpWWUiPNr5H3O8eDgGV9gT5To= +github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE= +github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8= +github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= +github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-runewidth v0.0.9 h1:Lm995f3rfxdpd6TSmuVCHVb/QhupuXlYr8sCI/QdE+0= github.com/mattn/go-runewidth v0.0.9/go.mod h1:H031xJmbD/WCDINGzjvQ9THkh0rPKHF+m2gUSrubnMI= github.com/mattn/go-sqlite3 v1.14.33 h1:A5blZ5ulQo2AtayQ9/limgHEkFreKj1Dv226a1K73s0= github.com/mattn/go-sqlite3 v1.14.33/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= github.com/mitchellh/copystructure v1.2.0 h1:vpKXTN4ewci03Vljg/q9QvCGUDttBOGBIa15WveJJGw= github.com/mitchellh/copystructure v1.2.0/go.mod h1:qLl+cE2AmVv+CoeAwDPye/v+N2HKCj9FbZEVFJRxO9s= +github.com/mitchellh/go-homedir v1.1.0 h1:lukF9ziXFxDFPkA1vsr5zpc1XuPDn/wFntq5mG+4E0Y= +github.com/mitchellh/go-homedir v1.1.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrkLzIz1N1q0pr0= github.com/mitchellh/go-wordwrap v1.0.1 h1:TLuKupo69TCn6TQSyGxwI1EblZZEsQ0vMlAFQflz0v0= github.com/mitchellh/go-wordwrap v1.0.1/go.mod h1:R62XHJLzvMFRBbcrT7m7WgmE1eOyTSsCt+hzestvNj0= +github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY= +github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= github.com/mitchellh/reflectwalk v1.0.2 h1:G2LzWKi524PWgd3mLHV8Y5k7s6XUvT0Gef6zxSIeXaQ= github.com/mitchellh/reflectwalk v1.0.2/go.mod h1:mSTlrgnPZtwu0c4WaC2kGObEpuNDbx0jmZXqmk4esnw= github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA= @@ -173,6 +208,8 @@ github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7 github.com/russellhaering/goxmldsig v1.5.0 h1:AU2UkkYIUOTyZRbe08XMThaOCelArgvNfYapcmSjBNw= github.com/russellhaering/goxmldsig v1.5.0/go.mod h1:x98CjQNFJcWfMxeOrMnMKg70lvDP6tE0nTaeUnjXDmk= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= +github.com/ryanuber/go-glob v1.0.0 h1:iQh3xXAumdQ+4Ufa5b25cRpC5TYKlno6hsv6Cb3pkBk= +github.com/ryanuber/go-glob v1.0.0/go.mod h1:807d1WSdnB0XRJzKNil9Om6lcp/3a0v4qIHxIXzX/Yc= github.com/sergi/go-diff v1.3.1 h1:xkr+Oxo4BOQKmkn/B9eMK0g5Kg/983T9DqqPHwYqD+8= github.com/sergi/go-diff v1.3.1/go.mod h1:aMJSSKb2lpPvRNec0+w3fl7LP9IOFzdc9Pa4NFbPK1I= github.com/shopspring/decimal v1.4.0 h1:bxl37RwXBklmTi0C79JfXCEBD1cqqHt0bbgBAGFp81k= diff --git a/server/handlers.go b/server/handlers.go index e46c7b8f..763b373b 100644 --- a/server/handlers.go +++ b/server/handlers.go @@ -34,25 +34,24 @@ const ( func (s *Server) handlePublicKeys(w http.ResponseWriter, r *http.Request) { ctx := r.Context() // TODO(ericchiang): Cache this. - keys, err := s.storage.GetKeys(ctx) + keys, err := s.signer.ValidationKeys(ctx) if err != nil { s.logger.ErrorContext(r.Context(), "failed to get keys", "err", err) s.renderError(r, w, http.StatusInternalServerError, "Internal server error.") return } - if keys.SigningKeyPub == nil { + if len(keys) == 0 { s.logger.ErrorContext(r.Context(), "no public keys found.") s.renderError(r, w, http.StatusInternalServerError, "Internal server error.") return } jwks := jose.JSONWebKeySet{ - Keys: make([]jose.JSONWebKey, len(keys.VerificationKeys)+1), + Keys: make([]jose.JSONWebKey, len(keys)), } - jwks.Keys[0] = *keys.SigningKeyPub - for i, verificationKey := range keys.VerificationKeys { - jwks.Keys[i+1] = *verificationKey.PublicKey + for i, key := range keys { + jwks.Keys[i] = *key } data, err := json.MarshalIndent(jwks, "", " ") @@ -61,10 +60,10 @@ func (s *Server) handlePublicKeys(w http.ResponseWriter, r *http.Request) { s.renderError(r, w, http.StatusInternalServerError, "Internal server error.") return } - maxAge := keys.NextRotation.Sub(s.now()) - if maxAge < (time.Minute * 2) { - maxAge = time.Minute * 2 - } + + // We don't have NextRotation info from Signer interface easily, + // so we'll just set a reasonable default cache time. + maxAge := time.Minute * 10 w.Header().Set("Cache-Control", fmt.Sprintf("max-age=%d, must-revalidate", int(maxAge.Seconds()))) w.Header().Set("Content-Type", "application/json") @@ -125,6 +124,14 @@ func (s *Server) constructDiscovery() discovery { }, } + // Determine signing algorithm from signer + signingAlg, err := s.signer.Algorithm(context.Background()) + if err != nil { + s.logger.Error("failed to get signing algorithm", "err", err) + } else { + d.IDTokenAlgs = []string{string(signingAlg)} + } + for responseType := range s.supportedResponseTypes { d.ResponseTypes = append(d.ResponseTypes, responseType) } @@ -1099,7 +1106,7 @@ func (s *Server) handleUserInfo(w http.ResponseWriter, r *http.Request) { } rawIDToken := auth[len(prefix):] - verifier := oidc.NewVerifier(s.issuerURL.String(), &storageKeySet{s.storage}, &oidc.Config{SkipClientIDCheck: true}) + verifier := oidc.NewVerifier(s.issuerURL.String(), &signerKeySet{s.signer}, &oidc.Config{SkipClientIDCheck: true}) idToken, err := verifier.Verify(ctx, rawIDToken) if err != nil { s.logger.ErrorContext(r.Context(), "failed to verify ID token", "err", err) diff --git a/server/introspectionhandler.go b/server/introspectionhandler.go index 4b0073db..dd7cce83 100644 --- a/server/introspectionhandler.go +++ b/server/introspectionhandler.go @@ -245,7 +245,7 @@ func (s *Server) introspectRefreshToken(ctx context.Context, token string) (*Int } func (s *Server) introspectAccessToken(ctx context.Context, token string) (*Introspection, error) { - verifier := oidc.NewVerifier(s.issuerURL.String(), &storageKeySet{s.storage}, &oidc.Config{SkipClientIDCheck: true}) + verifier := oidc.NewVerifier(s.issuerURL.String(), &signerKeySet{s.signer}, &oidc.Config{SkipClientIDCheck: true}) idToken, err := verifier.Verify(ctx, token) if err != nil { return nil, newIntrospectInactiveTokenError() diff --git a/server/oauth2.go b/server/oauth2.go index 7268bcfd..6164f5ae 100644 --- a/server/oauth2.go +++ b/server/oauth2.go @@ -351,21 +351,6 @@ func genSubject(userID string, connID string) (string, error) { } func (s *Server) newIDToken(ctx context.Context, clientID string, claims storage.Claims, scopes []string, nonce, accessToken, code, connID string) (idToken string, expiry time.Time, err error) { - keys, err := s.storage.GetKeys(ctx) - if err != nil { - s.logger.ErrorContext(ctx, "failed to get keys", "err", err) - return "", expiry, err - } - - signingKey := keys.SigningKey - if signingKey == nil { - return "", expiry, fmt.Errorf("no key to sign payload with") - } - signingAlg, err := signatureAlgorithm(signingKey) - if err != nil { - return "", expiry, err - } - issuedAt := s.now() expiry = issuedAt.Add(s.idTokensValidFor) @@ -383,6 +368,13 @@ func (s *Server) newIDToken(ctx context.Context, clientID string, claims storage IssuedAt: issuedAt.Unix(), } + // Determine signing algorithm from signer + signingAlg, err := s.signer.Algorithm(ctx) + if err != nil { + s.logger.ErrorContext(ctx, "failed to get signing algorithm", "err", err) + return "", expiry, fmt.Errorf("failed to get signing algorithm: %v", err) + } + if accessToken != "" { atHash, err := accessTokenHash(signingAlg, accessToken) if err != nil { @@ -445,7 +437,7 @@ func (s *Server) newIDToken(ctx context.Context, clientID string, claims storage return "", expiry, fmt.Errorf("could not serialize claims: %v", err) } - if idToken, err = signPayload(signingKey, signingAlg, payload); err != nil { + if idToken, err = s.signer.Sign(ctx, payload); err != nil { return "", expiry, fmt.Errorf("failed to sign payload: %v", err) } return idToken, expiry, nil @@ -705,12 +697,12 @@ func validateConnectorID(connectors []storage.Connector, connectorID string) boo return false } -// storageKeySet implements the oidc.KeySet interface backed by Dex storage -type storageKeySet struct { - storage.Storage +// signerKeySet implements the oidc.KeySet interface backed by the Dex signer +type signerKeySet struct { + signer Signer } -func (s *storageKeySet) VerifySignature(ctx context.Context, jwt string) (payload []byte, err error) { +func (s *signerKeySet) VerifySignature(ctx context.Context, jwt string) (payload []byte, err error) { jws, err := jose.ParseSigned(jwt, []jose.SignatureAlgorithm{jose.RS256, jose.RS384, jose.RS512, jose.ES256, jose.ES384, jose.ES512}) if err != nil { return nil, err @@ -722,16 +714,11 @@ func (s *storageKeySet) VerifySignature(ctx context.Context, jwt string) (payloa break } - skeys, err := s.Storage.GetKeys(ctx) + keys, err := s.signer.ValidationKeys(ctx) if err != nil { return nil, err } - keys := []*jose.JSONWebKey{skeys.SigningKeyPub} - for _, vk := range skeys.VerificationKeys { - keys = append(keys, vk.PublicKey) - } - for _, key := range keys { if keyID == "" || key.KeyID == keyID { if payload, err := jws.Verify(key); err == nil { diff --git a/server/oauth2_test.go b/server/oauth2_test.go index 3dff30d6..ae90eb5b 100644 --- a/server/oauth2_test.go +++ b/server/oauth2_test.go @@ -593,7 +593,7 @@ func TestValidRedirectURI(t *testing.T) { } } -func TestStorageKeySet(t *testing.T) { +func TestSignerKeySet(t *testing.T) { logger := newLogger(t) s := memory.New(logger) if err := s.UpdateKeys(t.Context(), func(keys storage.Keys) (storage.Keys, error) { @@ -668,7 +668,12 @@ func TestStorageKeySet(t *testing.T) { t.Fatal(err) } - keySet := &storageKeySet{s} + // We use a localSigner here to bridge the storage to the Signer interface. + // Since VerifySignature only needs ValidationKeys (which only needs storage), + // we don't need to initialize the rotator or other fields. + keySet := &signerKeySet{ + signer: &localSigner{storage: s}, + } _, err = keySet.VerifySignature(t.Context(), jwt) if (err != nil && !tc.wantErr) || (err == nil && tc.wantErr) { diff --git a/server/rotation.go b/server/rotation.go index 286b4b57..70d7a9bf 100644 --- a/server/rotation.go +++ b/server/rotation.go @@ -64,36 +64,6 @@ type keyRotator struct { logger *slog.Logger } -// startKeyRotation begins key rotation in a new goroutine, closing once the context is canceled. -// -// The method blocks until after the first attempt to rotate keys has completed. That way -// healthy storages will return from this call with valid keys. -func (s *Server) startKeyRotation(ctx context.Context, strategy rotationStrategy, now func() time.Time) { - rotator := keyRotator{s.storage, strategy, now, s.logger} - - // Try to rotate immediately so properly configured storages will have keys. - if err := rotator.rotate(); err != nil { - if err == errAlreadyRotated { - s.logger.Info("key rotation not needed", "err", err) - } else { - s.logger.Error("failed to rotate keys", "err", err) - } - } - - go func() { - for { - select { - case <-ctx.Done(): - return - case <-time.After(time.Second * 30): - if err := rotator.rotate(); err != nil { - s.logger.Error("failed to rotate keys", "err", err) - } - } - } - }() -} - func (k keyRotator) rotate() error { keys, err := k.GetKeys(context.Background()) if err != nil && err != storage.ErrNotFound { diff --git a/server/server.go b/server/server.go index d81a0f71..5d66c8a8 100644 --- a/server/server.go +++ b/server/server.go @@ -116,6 +116,8 @@ type Config struct { Logger *slog.Logger + Signer SignerConfig + PrometheusRegistry *prometheus.Registry HealthChecker gosundheit.Health @@ -156,6 +158,12 @@ type WebConfig struct { Extra map[string]string } +// SignerConfig holds the server's signer configuration. +type SignerConfig struct { + Type string `json:"type"` + Vault VaultSignerConfig `json:"vault"` +} + func value(val, defaultValue time.Duration) time.Duration { if val == 0 { return defaultValue @@ -200,6 +208,8 @@ type Server struct { refreshTokenPolicy *RefreshTokenPolicy logger *slog.Logger + + signer Signer } // NewServer constructs a server from the provided config. @@ -318,6 +328,19 @@ func newServer(ctx context.Context, c Config, rotationStrategy rotationStrategy) logger: c.Logger, } + // Initialize signer + if c.Signer.Type == "vault" { + s.signer, err = newVaultSigner(c.Signer.Vault) + if err != nil { + return nil, fmt.Errorf("failed to initialize vault signer: %v", err) + } + s.logger.Info("signer configured", "type", "vault") + } else { + // Default to local signer + s.signer = newLocalSigner(c.Storage, rotationStrategy, now, c.Logger) + s.logger.Info("signer configured", "type", "local") + } + // Retrieves connector objects in backend storage. This list includes the static connectors // defined in the ConfigMap and dynamic connectors retrieved from the storage. storageConnectors, err := c.Storage.ListConnectors(ctx) @@ -514,7 +537,7 @@ func newServer(ctx context.Context, c Config, rotationStrategy rotationStrategy) s.mux = r - s.startKeyRotation(ctx, rotationStrategy, now) + s.signer.Start(ctx) s.startGarbageCollection(ctx, value(c.GCFrequency, 5*time.Minute), now) return s, nil diff --git a/server/signer.go b/server/signer.go new file mode 100644 index 00000000..d38abb1f --- /dev/null +++ b/server/signer.go @@ -0,0 +1,22 @@ +package server + +import ( + "context" + + "github.com/go-jose/go-jose/v4" +) + +// Signer is an interface for signing payloads and retrieving validation keys. +type Signer interface { + // Sign signs the provided payload. + Sign(ctx context.Context, payload []byte) (string, error) + + // ValidationKeys returns the current public keys used for signature validation. + ValidationKeys(ctx context.Context) ([]*jose.JSONWebKey, error) + + // Algorithm returns the signing algorithm used by this signer. + Algorithm(ctx context.Context) (jose.SignatureAlgorithm, error) + + // Start starts any background tasks required by the signer (e.g., key rotation). + Start(ctx context.Context) +} diff --git a/server/signer_local.go b/server/signer_local.go new file mode 100644 index 00000000..3d9f5e0d --- /dev/null +++ b/server/signer_local.go @@ -0,0 +1,105 @@ +package server + +import ( + "context" + "fmt" + "log/slog" + "time" + + "github.com/go-jose/go-jose/v4" + + "github.com/dexidp/dex/storage" +) + +// localSigner signs payloads using keys stored in the Dex storage. +// It manages key rotation and storage using the existing keyRotator logic. +type localSigner struct { + storage storage.Storage + rotator *keyRotator + logger *slog.Logger +} + +// newLocalSigner creates a new local signer and starts the key rotation loop. +func newLocalSigner(s storage.Storage, strategy rotationStrategy, now func() time.Time, logger *slog.Logger) *localSigner { + r := &keyRotator{s, strategy, now, logger} + return &localSigner{ + storage: s, + rotator: r, + logger: logger, + } +} + +// Start begins key rotation in a new goroutine, closing once the context is canceled. +// +// The method blocks until after the first attempt to rotate keys has completed. That way +// healthy storages will return from this call with valid keys. +func (l *localSigner) Start(ctx context.Context) { + // Try to rotate immediately so properly configured storages will have keys. + if err := l.rotator.rotate(); err != nil { + if err == errAlreadyRotated { + l.logger.Info("key rotation not needed", "err", err) + } else { + l.logger.Error("failed to rotate keys", "err", err) + } + } + + go func() { + for { + select { + case <-ctx.Done(): + return + case <-time.After(time.Second * 30): + if err := l.rotator.rotate(); err != nil { + l.logger.Error("failed to rotate keys", "err", err) + } + } + } + }() +} + +func (l *localSigner) Sign(ctx context.Context, payload []byte) (string, error) { + keys, err := l.storage.GetKeys(ctx) + if err != nil { + return "", fmt.Errorf("failed to get keys: %v", err) + } + + signingKey := keys.SigningKey + if signingKey == nil { + return "", fmt.Errorf("no key to sign payload with") + } + signingAlg, err := signatureAlgorithm(signingKey) + if err != nil { + return "", err + } + + return signPayload(signingKey, signingAlg, payload) +} + +func (l *localSigner) ValidationKeys(ctx context.Context) ([]*jose.JSONWebKey, error) { + keys, err := l.storage.GetKeys(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get keys: %v", err) + } + + if keys.SigningKeyPub == nil { + return nil, fmt.Errorf("no public keys found") + } + + jwks := make([]*jose.JSONWebKey, len(keys.VerificationKeys)+1) + jwks[0] = keys.SigningKeyPub + for i, verificationKey := range keys.VerificationKeys { + jwks[i+1] = verificationKey.PublicKey + } + return jwks, nil +} + +func (l *localSigner) Algorithm(ctx context.Context) (jose.SignatureAlgorithm, error) { + keys, err := l.storage.GetKeys(ctx) + if err != nil { + return "", fmt.Errorf("failed to get keys: %v", err) + } + if keys.SigningKey == nil { + return "", fmt.Errorf("no signing key found") + } + return signatureAlgorithm(keys.SigningKey) +} diff --git a/server/signer_vault.go b/server/signer_vault.go new file mode 100644 index 00000000..61221a85 --- /dev/null +++ b/server/signer_vault.go @@ -0,0 +1,327 @@ +package server + +import ( + "context" + "crypto" + "crypto/ecdsa" + "crypto/ed25519" + "crypto/elliptic" + "crypto/rsa" + "crypto/sha256" + "crypto/sha512" + "crypto/x509" + "encoding/base64" + "encoding/json" + "encoding/pem" + "fmt" + "hash" + + "github.com/go-jose/go-jose/v4" + vault "github.com/hashicorp/vault/api" +) + +// VaultSignerConfig holds configuration for the Vault signer. +type VaultSignerConfig struct { + Addr string `json:"addr"` + Token string `json:"token"` + KeyName string `json:"keyName"` +} + +// vaultSigner signs payloads using HashiCorp Vault's Transit backend. +type vaultSigner struct { + client *vault.Client + keyName string +} + +// newVaultSigner creates a new Vault signer that uses Transit backend for signing. +func newVaultSigner(c VaultSignerConfig) (*vaultSigner, error) { + config := vault.DefaultConfig() + config.Address = c.Addr + + client, err := vault.NewClient(config) + if err != nil { + return nil, fmt.Errorf("failed to create vault client: %v", err) + } + + if c.Token != "" { + client.SetToken(c.Token) + } + + return &vaultSigner{ + client: client, + keyName: c.KeyName, + }, nil +} + +func (v *vaultSigner) Start(ctx context.Context) { + // Vault signer does not need background rotation tasks +} + +func (v *vaultSigner) Sign(ctx context.Context, payload []byte) (string, error) { + // 1. Fetch keys to determine the key to use (latest version) and its ID. + keysMap, latestVersion, err := v.getTransitKeysMap(ctx) + if err != nil { + return "", fmt.Errorf("failed to get keys for signing context: %v", err) + } + + // Determine the key version and ID to use + // We use the latest version by default + signingJWK, ok := keysMap[latestVersion] + if !ok { + return "", fmt.Errorf("latest key version %d not found in public keys", latestVersion) + } + + // 2. Construct JWS Header and Payload first (Signing Input) + header := map[string]interface{}{ + "alg": signingJWK.Algorithm, + "kid": signingJWK.KeyID, + } + + headerBytes, err := json.Marshal(header) + if err != nil { + return "", fmt.Errorf("failed to marshal header: %v", err) + } + + headerB64 := base64.RawURLEncoding.EncodeToString(headerBytes) + payloadB64 := base64.RawURLEncoding.EncodeToString(payload) + + // The input to the signature is "header.payload" + signingInput := fmt.Sprintf("%s.%s", headerB64, payloadB64) + + // 3. Sign the signingInput using Vault + var vaultInput string + data := map[string]interface{}{} + + // Determine Vault params based on JWS algorithm + params, err := getVaultParams(signingJWK.Algorithm) + if err != nil { + return "", err + } + + // Apply params to data map + for k, v := range params.extraParams { + data[k] = v + } + + // Hash input if needed + if params.hasher != nil { + params.hasher.Write([]byte(signingInput)) + hash := params.hasher.Sum(nil) + vaultInput = base64.StdEncoding.EncodeToString(hash) + } else { + // No pre-hashing (EdDSA) + vaultInput = base64.StdEncoding.EncodeToString([]byte(signingInput)) + } + data["input"] = vaultInput + + signPath := fmt.Sprintf("transit/sign/%s", v.keyName) + signSecret, err := v.client.Logical().WriteWithContext(ctx, signPath, data) + if err != nil { + return "", fmt.Errorf("vault sign: %v", err) + } + + signatureString, ok := signSecret.Data["signature"].(string) + if !ok { + return "", fmt.Errorf("vault response missing signature") + } + + // Parse vault signature: "vault:v1:base64sig" + var signatureB64 []byte + if len(signatureString) > 8 && signatureString[:6] == "vault:" { + parts := splitVaultSignature(signatureString) + if len(parts) == 3 { + // part 1 is "vault", part 2 is "v1", part 3 is signature + // The signature is already base64 encoded, decoding it is not needed and + // will make the code failing. + signatureB64 = []byte(parts[2]) + } + } else { + return "", fmt.Errorf("unexpected signature format: %s", signatureString) + } + + return fmt.Sprintf("%s.%s.%s", headerB64, payloadB64, signatureB64), nil +} + +func (v *vaultSigner) ValidationKeys(ctx context.Context) ([]*jose.JSONWebKey, error) { + keysMap, _, err := v.getTransitKeysMap(ctx) + if err != nil { + return nil, err + } + + keys := make([]*jose.JSONWebKey, 0, len(keysMap)) + for _, k := range keysMap { + keys = append(keys, k) + } + return keys, nil +} + +// getTransitKeysMap returns a map of key_version -> JWK and the latest version number +func (v *vaultSigner) getTransitKeysMap(ctx context.Context) (map[int64]*jose.JSONWebKey, int64, error) { + path := fmt.Sprintf("transit/keys/%s", v.keyName) + secret, err := v.client.Logical().ReadWithContext(ctx, path) + if err != nil { + return nil, 0, fmt.Errorf("failed to read key from vault: %v", err) + } + if secret == nil { + return nil, 0, fmt.Errorf("key %q not found in vault", v.keyName) + } + + latestVersion, ok := secret.Data["latest_version"].(json.Number) + if !ok { + // Try float64 which is default for unmarshal interface{} + if lv, ok := secret.Data["latest_version"].(float64); ok { + latestVersion = json.Number(fmt.Sprintf("%d", int(lv))) + } else if lv, ok := secret.Data["latest_version"].(int); ok { + latestVersion = json.Number(fmt.Sprintf("%d", lv)) + } + } + latestVerInt, err := latestVersion.Int64() + if err != nil { + return nil, 0, fmt.Errorf("failed to get latest version: %v", err) + } + + keysObj, ok := secret.Data["keys"].(map[string]interface{}) + if !ok { + return nil, 0, fmt.Errorf("invalid response from vault") + } + + jwksMap := make(map[int64]*jose.JSONWebKey) + + for verStr, data := range keysObj { + d, ok := data.(map[string]interface{}) + if !ok { + continue + } + + var ver int64 + fmt.Sscanf(verStr, "%d", &ver) + + pemStr, ok := d["public_key"].(string) + if !ok { + continue + } + + jwk, err := parsePEMToJWK(pemStr) + if err != nil { + continue + } + + jwksMap[ver] = jwk + } + + return jwksMap, latestVerInt, nil +} + +func parsePEMToJWK(pemStr string) (*jose.JSONWebKey, error) { + block, _ := pem.Decode([]byte(pemStr)) + if block == nil { + return nil, fmt.Errorf("failed to parse PEM block") + } + + pub, err := x509.ParsePKIXPublicKey(block.Bytes) + if err != nil { + return nil, fmt.Errorf("failed to parse public key: %v", err) + } + + alg := "" + switch k := pub.(type) { + case *rsa.PublicKey: + alg = "RS256" + case *ecdsa.PublicKey: + switch k.Curve { + case elliptic.P256(): + alg = "ES256" + case elliptic.P384(): + alg = "ES384" + case elliptic.P521(): + alg = "ES512" + default: + return nil, fmt.Errorf("unsupported ECDSA curve") + } + case ed25519.PublicKey: + alg = "EdDSA" + default: + return nil, fmt.Errorf("unsupported key type %T", pub) + } + + jwk := &jose.JSONWebKey{ + Key: pub, + Algorithm: alg, + Use: "sig", + } + + thumbprint, err := jwk.Thumbprint(crypto.SHA256) + if err != nil { + return nil, err + } + jwk.KeyID = base64.RawURLEncoding.EncodeToString(thumbprint) + + return jwk, nil +} + +func splitVaultSignature(sig string) []string { + // Basic split implementation + // "vault:v1:signature" + var parts []string + start := 0 + for i := 0; i < len(sig); i++ { + if sig[i] == ':' { + parts = append(parts, sig[start:i]) + start = i + 1 + } + } + parts = append(parts, sig[start:]) + return parts +} + +func (v *vaultSigner) Algorithm(ctx context.Context) (jose.SignatureAlgorithm, error) { + keysMap, latestVersion, err := v.getTransitKeysMap(ctx) + if err != nil { + return "", fmt.Errorf("failed to get keys: %v", err) + } + + signingJWK, ok := keysMap[latestVersion] + if !ok { + return "", fmt.Errorf("latest key version %d not found", latestVersion) + } + return jose.SignatureAlgorithm(signingJWK.Algorithm), nil +} + +type vaultAlgoParams struct { + hasher hash.Hash + extraParams map[string]interface{} +} + +func getVaultParams(alg string) (vaultAlgoParams, error) { + params := vaultAlgoParams{ + extraParams: map[string]interface{}{ + "marshaling_algorithm": "jws", + "signature_algorithm": "pkcs1v15", + }, + } + + switch alg { + case "RS256": + params.hasher = sha256.New() + params.extraParams["prehashed"] = true + params.extraParams["hash_algorithm"] = "sha2-256" + case "ES256": + params.hasher = sha256.New() + params.extraParams["prehashed"] = true + params.extraParams["hash_algorithm"] = "sha2-256" + case "ES384": + params.hasher = sha512.New384() + params.extraParams["prehashed"] = true + params.extraParams["hash_algorithm"] = "sha2-384" + case "ES512": + params.hasher = sha512.New() + params.extraParams["prehashed"] = true + params.extraParams["hash_algorithm"] = "sha2-512" + case "EdDSA": + // No hashing + params.hasher = nil + default: + return params, fmt.Errorf("unsupported signing algorithm: %s", alg) + } + return params, nil +}