Browse Source

feat: Add Vault signer for JWT (#4512)

Signed-off-by: Maksim Nabokikh <max.nabokih@gmail.com>
pull/4522/head
Maksim Nabokikh 1 month ago committed by GitHub
parent
commit
56958b1ad2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 3
      cmd/dex/config.go
  2. 1
      cmd/dex/serve.go
  3. 10
      examples/config-dev.yaml
  4. 18
      go.mod
  5. 45
      go.sum
  6. 2
      server/api.go
  7. 35
      server/handlers.go
  8. 2
      server/introspectionhandler.go
  9. 39
      server/oauth2.go
  10. 9
      server/oauth2_test.go
  11. 30
      server/rotation.go
  12. 27
      server/server.go
  13. 22
      server/signer.go
  14. 105
      server/signer_local.go
  15. 359
      server/signer_vault.go
  16. 196
      server/signer_vault_test.go

3
cmd/dex/config.go

@ -35,6 +35,9 @@ type Config struct {
Frontend server.WebConfig `json:"frontend"`
// Signer configuration controls signing of JWT tokens issued by Dex.
Signer server.SignerConfig `json:"signer"`
// StaticConnectors are user defined connectors specified in the ConfigMap
// Write operations, like updating a connector, will fail.
StaticConnectors []Connector `json:"connectors"`

1
cmd/dex/serve.go

@ -307,6 +307,7 @@ func runServe(options serveOptions) error {
PrometheusRegistry: prometheusRegistry,
HealthChecker: healthChecker,
ContinueOnConnectorFailure: featureflags.ContinueOnConnectorFailure.Enabled(),
Signer: c.Signer,
}
if c.Expiry.SigningKeys != "" {
signingKeys, err := time.ParseDuration(c.Expiry.SigningKeys)

10
examples/config-dev.yaml

@ -178,3 +178,13 @@ staticPasswords:
- "team-a"
- "team-a/admins"
userID: "08a8684b-db88-4b73-90a9-3cd1661f5466"
# Settings for signing JWT tokens. Available options:
# - "local": use local keys (only RSA keys supported)
# - "vault": use Vault Transit backend (RSA and EC keys supported)
# signer:
# type: vault
# vault:
# addr: http://127.0.0.1:8200
# token: root
# keyName: dex-key

18
go.mod

@ -25,6 +25,7 @@ require (
github.com/mattermost/xml-roundtrip-validator v0.1.0
github.com/mattn/go-sqlite3 v1.14.33
github.com/oklog/run v1.2.0
github.com/openbao/openbao/api/v2 v2.5.1
github.com/pkg/errors v0.9.1
github.com/prometheus/client_golang v1.23.2
github.com/russellhaering/goxmldsig v1.5.0
@ -54,15 +55,17 @@ require (
github.com/apparentlymart/go-textseg/v15 v15.0.0 // indirect
github.com/beorn7/perks v1.0.1 // indirect
github.com/bmatcuk/doublestar v1.3.4 // indirect
github.com/cenkalti/backoff/v4 v4.3.0 // indirect
github.com/cespare/xxhash/v2 v2.3.0 // indirect
github.com/coreos/go-semver v0.3.1 // indirect
github.com/coreos/go-systemd/v22 v22.5.0 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect
github.com/felixge/httpsnoop v1.0.4 // indirect
github.com/go-asn1-ber/asn1-ber v1.5.8-0.20250403174932-29230038a667 // indirect
github.com/go-logr/logr v1.4.3 // indirect
github.com/go-logr/stdr v1.2.2 // indirect
github.com/go-openapi/inflect v0.19.0 // indirect
github.com/go-viper/mapstructure/v2 v2.4.0 // indirect
github.com/gogo/protobuf v1.3.2 // indirect
github.com/golang/protobuf v1.5.4 // indirect
github.com/google/go-cmp v0.7.0 // indirect
@ -70,6 +73,14 @@ require (
github.com/googleapis/enterprise-certificate-proxy v0.3.11 // indirect
github.com/googleapis/gax-go/v2 v2.16.0 // indirect
github.com/grpc-ecosystem/grpc-gateway/v2 v2.26.3 // indirect
github.com/hashicorp/errwrap v1.1.0 // indirect
github.com/hashicorp/go-cleanhttp v0.5.2 // indirect
github.com/hashicorp/go-multierror v1.1.1 // indirect
github.com/hashicorp/go-retryablehttp v0.7.8 // indirect
github.com/hashicorp/go-secure-stdlib/parseutil v0.2.0 // indirect
github.com/hashicorp/go-secure-stdlib/strutil v0.1.2 // indirect
github.com/hashicorp/go-sockaddr v1.0.7 // indirect
github.com/hashicorp/hcl v1.0.1-vault-7 // indirect
github.com/hashicorp/hcl/v2 v2.18.1 // indirect
github.com/huandu/xstrings v1.5.0 // indirect
github.com/inconshreveable/mousetrap v1.1.0 // indirect
@ -77,13 +88,15 @@ require (
github.com/mattn/go-runewidth v0.0.9 // indirect
github.com/mitchellh/copystructure v1.2.0 // indirect
github.com/mitchellh/go-wordwrap v1.0.1 // indirect
github.com/mitchellh/mapstructure v1.5.0 // indirect
github.com/mitchellh/reflectwalk v1.0.2 // indirect
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
github.com/olekukonko/tablewriter v0.0.5 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
github.com/prometheus/client_model v0.6.2 // indirect
github.com/prometheus/common v0.66.1 // indirect
github.com/prometheus/procfs v0.16.1 // indirect
github.com/ryanuber/go-glob v1.0.0 // indirect
github.com/shopspring/decimal v1.4.0 // indirect
github.com/spf13/cast v1.7.0 // indirect
github.com/spf13/pflag v1.0.9 // indirect
@ -102,6 +115,7 @@ require (
golang.org/x/sync v0.19.0 // indirect
golang.org/x/sys v0.41.0 // indirect
golang.org/x/text v0.34.0 // indirect
golang.org/x/time v0.14.0 // indirect
golang.org/x/tools v0.41.0 // indirect
golang.org/x/tools/go/packages/packagestest v0.1.1-deprecated // indirect
google.golang.org/genproto/googleapis/api v0.0.0-20251202230838-ff82c1b0f217 // indirect

45
go.sum

@ -38,6 +38,8 @@ github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw=
github.com/bmatcuk/doublestar v1.3.4 h1:gPypJ5xD31uhX6Tf54sDPUOBXTqKH4c9aPY66CyQrS0=
github.com/bmatcuk/doublestar v1.3.4/go.mod h1:wiQtGV+rzVYxB7WIlirSN++5HPtPlXEo9MEoZQC/PmE=
github.com/cenkalti/backoff/v4 v4.3.0 h1:MyRJ/UdXutAwSAT+s3wNd7MfTIcy71VQueUuFK343L8=
github.com/cenkalti/backoff/v4 v4.3.0/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE=
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/coreos/go-oidc/v3 v3.17.0 h1:hWBGaQfbi0iVviX4ibC7bk8OKT5qNr4klBaCHVNvehc=
@ -48,8 +50,11 @@ github.com/coreos/go-systemd/v22 v22.5.0 h1:RrqgGjYQKalulkV8NGVIfkXQf6YYmOyiJKk8
github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc=
github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM=
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/fatih/color v1.18.0 h1:S8gINlzdQ840/4pfAwic/ZE0djQEH3wM94VfqLTZcOM=
github.com/fatih/color v1.18.0/go.mod h1:4FelSpRwEGDpQ12mAdzqdOukCy4u8WUtOY6lkT/6HfU=
github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg=
github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U=
github.com/fortytw2/leaktest v1.3.0 h1:u8491cBMTQ8ft8aeV+adlcytMZylmA5nnwwkRZjI8vw=
@ -75,8 +80,10 @@ github.com/go-openapi/inflect v0.19.0 h1:9jCH9scKIbHeV9m12SmPilScz6krDxKRasNNSNP
github.com/go-openapi/inflect v0.19.0/go.mod h1:lHpZVlpIQqLyKwJ4N+YSc9hchQy/i12fJykb83CRBH4=
github.com/go-sql-driver/mysql v1.9.3 h1:U/N249h2WzJ3Ukj8SowVFjdtZKfu9vlLZxjPXV1aweo=
github.com/go-sql-driver/mysql v1.9.3/go.mod h1:qn46aNg1333BRMNU69Lq93t8du/dwxI64Gl8i5p1WMU=
github.com/go-test/deep v1.0.3 h1:ZrJSEWsXzPOxaZnFteGEfooLba+ju3FYIbOrS+rQd68=
github.com/go-test/deep v1.0.3/go.mod h1:wGDj63lr65AM2AQyKZd/NYHGb0R+1RLqB8NKt3aSFNA=
github.com/go-test/deep v1.1.1 h1:0r/53hagsehfO4bzD2Pgr/+RgHqhmf+k1Bpse2cTu1U=
github.com/go-test/deep v1.1.1/go.mod h1:5C2ZWiW0ErCdrYzpqxLbTX7MG14M9iiw8DgHncVwcsE=
github.com/go-viper/mapstructure/v2 v2.4.0 h1:EBsztssimR/CONLSZZ04E8qAkxNYq4Qp9LvH92wZUgs=
github.com/go-viper/mapstructure/v2 v2.4.0/go.mod h1:oJDH3BJKyqBA2TXFhDsKDGDTlndYOZ6rGS0BRZIxGhM=
github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q=
github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q=
@ -100,8 +107,27 @@ github.com/grpc-ecosystem/go-grpc-prometheus v1.2.0 h1:Ovs26xHkKqVztRpIrF/92Bcuy
github.com/grpc-ecosystem/go-grpc-prometheus v1.2.0/go.mod h1:8NvIoxWQoOIhqOTXgfV/d3M/q6VIi02HzZEHgUlZvzk=
github.com/grpc-ecosystem/grpc-gateway/v2 v2.26.3 h1:5ZPtiqj0JL5oKWmcsq4VMaAW5ukBEgSGXEN89zeH1Jo=
github.com/grpc-ecosystem/grpc-gateway/v2 v2.26.3/go.mod h1:ndYquD05frm2vACXE1nsccT4oJzjhw2arTS2cpUD1PI=
github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4=
github.com/hashicorp/errwrap v1.1.0 h1:OxrOeh75EUXMY8TBjag2fzXGZ40LB6IKw45YeGUDY2I=
github.com/hashicorp/errwrap v1.1.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4=
github.com/hashicorp/go-cleanhttp v0.5.2 h1:035FKYIWjmULyFRBKPs8TBQoi0x6d9G4xc9neXJWAZQ=
github.com/hashicorp/go-cleanhttp v0.5.2/go.mod h1:kO/YDlP8L1346E6Sodw+PrpBSV4/SoxCXGY6BqNFT48=
github.com/hashicorp/go-hclog v1.6.3 h1:Qr2kF+eVWjTiYmU7Y31tYlP1h0q/X3Nl3tPGdaB11/k=
github.com/hashicorp/go-hclog v1.6.3/go.mod h1:W4Qnvbt70Wk/zYJryRzDRU/4r0kIg0PVHBcfoyhpF5M=
github.com/hashicorp/go-multierror v1.1.1 h1:H5DkEtf6CXdFp0N0Em5UCwQpXMWke8IA0+lD48awMYo=
github.com/hashicorp/go-multierror v1.1.1/go.mod h1:iw975J/qwKPdAO1clOe2L8331t/9/fmwbPZ6JB6eMoM=
github.com/hashicorp/go-retryablehttp v0.7.8 h1:ylXZWnqa7Lhqpk0L1P1LzDtGcCR0rPVUrx/c8Unxc48=
github.com/hashicorp/go-retryablehttp v0.7.8/go.mod h1:rjiScheydd+CxvumBsIrFKlx3iS0jrZ7LvzFGFmuKbw=
github.com/hashicorp/go-secure-stdlib/parseutil v0.2.0 h1:U+kC2dOhMFQctRfhK0gRctKAPTloZdMU5ZJxaesJ/VM=
github.com/hashicorp/go-secure-stdlib/parseutil v0.2.0/go.mod h1:Ll013mhdmsVDuoIXVfBtvgGJsXDYkTw1kooNcoCXuE0=
github.com/hashicorp/go-secure-stdlib/strutil v0.1.2 h1:kes8mmyCpxJsI7FTwtzRqEy9CdjCtrXrXGuOpxEA7Ts=
github.com/hashicorp/go-secure-stdlib/strutil v0.1.2/go.mod h1:Gou2R9+il93BqX25LAKCLuM+y9U2T4hlwvT1yprcna4=
github.com/hashicorp/go-sockaddr v1.0.7 h1:G+pTkSO01HpR5qCxg7lxfsFEZaG+C0VssTy/9dbT+Fw=
github.com/hashicorp/go-sockaddr v1.0.7/go.mod h1:FZQbEYa1pxkQ7WLpyXJ6cbjpT8q0YgQaK/JakXqGyWw=
github.com/hashicorp/go-uuid v1.0.3 h1:2gKiV6YVmrJ1i2CKKa9obLvRieoRGviZFL26PcT/Co8=
github.com/hashicorp/go-uuid v1.0.3/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro=
github.com/hashicorp/hcl v1.0.1-vault-7 h1:ag5OxFVy3QYTFTJODRzTKVZ6xvdfLLCA1cy/Y6xGI0I=
github.com/hashicorp/hcl v1.0.1-vault-7/go.mod h1:XYhtn6ijBSAj6n4YqAaf7RBPS4I06AItNorpy+MoQNM=
github.com/hashicorp/hcl/v2 v2.18.1 h1:6nxnOJFku1EuSawSD81fuviYUV8DxFr3fp2dUi3ZYSo=
github.com/hashicorp/hcl/v2 v2.18.1/go.mod h1:ThLC89FV4p9MPW804KVbe/cEXoQ8NZEh+JtMeeGErHE=
github.com/huandu/xstrings v1.5.0 h1:2ag3IFq9ZDANvthTwTiqSSZLjDc+BedvHPAp5tJy2TI=
@ -139,6 +165,10 @@ github.com/lib/pq v1.11.1 h1:wuChtj2hfsGmmx3nf1m7xC2XpK6OtelS2shMY+bGMtI=
github.com/lib/pq v1.11.1/go.mod h1:/p+8NSbOcwzAEI7wiMXFlgydTwcgTr3OSKMsD2BitpA=
github.com/mattermost/xml-roundtrip-validator v0.1.0 h1:RXbVD2UAl7A7nOTR4u7E3ILa4IbtvKBHw64LDsmu9hU=
github.com/mattermost/xml-roundtrip-validator v0.1.0/go.mod h1:qccnGMcpgwcNaBnxqpJpWWUiPNr5H3O8eDgGV9gT5To=
github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE=
github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-runewidth v0.0.9 h1:Lm995f3rfxdpd6TSmuVCHVb/QhupuXlYr8sCI/QdE+0=
github.com/mattn/go-runewidth v0.0.9/go.mod h1:H031xJmbD/WCDINGzjvQ9THkh0rPKHF+m2gUSrubnMI=
github.com/mattn/go-sqlite3 v1.14.33 h1:A5blZ5ulQo2AtayQ9/limgHEkFreKj1Dv226a1K73s0=
@ -147,6 +177,8 @@ github.com/mitchellh/copystructure v1.2.0 h1:vpKXTN4ewci03Vljg/q9QvCGUDttBOGBIa1
github.com/mitchellh/copystructure v1.2.0/go.mod h1:qLl+cE2AmVv+CoeAwDPye/v+N2HKCj9FbZEVFJRxO9s=
github.com/mitchellh/go-wordwrap v1.0.1 h1:TLuKupo69TCn6TQSyGxwI1EblZZEsQ0vMlAFQflz0v0=
github.com/mitchellh/go-wordwrap v1.0.1/go.mod h1:R62XHJLzvMFRBbcrT7m7WgmE1eOyTSsCt+hzestvNj0=
github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY=
github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo=
github.com/mitchellh/reflectwalk v1.0.2 h1:G2LzWKi524PWgd3mLHV8Y5k7s6XUvT0Gef6zxSIeXaQ=
github.com/mitchellh/reflectwalk v1.0.2/go.mod h1:mSTlrgnPZtwu0c4WaC2kGObEpuNDbx0jmZXqmk4esnw=
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA=
@ -155,11 +187,14 @@ github.com/oklog/run v1.2.0 h1:O8x3yXwah4A73hJdlrwo/2X6J62gE5qTMusH0dvz60E=
github.com/oklog/run v1.2.0/go.mod h1:mgDbKRSwPhJfesJ4PntqFUbKQRZ50NgmZTSPlFA0YFk=
github.com/olekukonko/tablewriter v0.0.5 h1:P2Ga83D34wi1o9J6Wh1mRuqd4mF/x/lgBS7N7AbDhec=
github.com/olekukonko/tablewriter v0.0.5/go.mod h1:hPp6KlRPjbx+hW8ykQs1w3UBbZlj6HuIJcUGPhkA7kY=
github.com/openbao/openbao/api/v2 v2.5.1 h1:Br79D6L20SbAa5P7xqENxmvv8LyI4HoKosPy7klhn4o=
github.com/openbao/openbao/api/v2 v2.5.1/go.mod h1:Dh5un77tqGgMbmlVEqjqN+8/dMyUohnkaQVg/wXW0Ig=
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U=
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/prometheus/client_golang v1.23.2 h1:Je96obch5RDVy3FDMndoUsjAhG5Edi49h0RJWRi/o0o=
github.com/prometheus/client_golang v1.23.2/go.mod h1:Tb1a6LWHB3/SPIzCoaDXI4I8UHKeFTEQ1YCr+0Gyqmg=
github.com/prometheus/client_model v0.6.2 h1:oBsgwpGs7iVziMvrGhE53c/GrLUsZdHnqNwqPLxwZyk=
@ -173,6 +208,8 @@ github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7
github.com/russellhaering/goxmldsig v1.5.0 h1:AU2UkkYIUOTyZRbe08XMThaOCelArgvNfYapcmSjBNw=
github.com/russellhaering/goxmldsig v1.5.0/go.mod h1:x98CjQNFJcWfMxeOrMnMKg70lvDP6tE0nTaeUnjXDmk=
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
github.com/ryanuber/go-glob v1.0.0 h1:iQh3xXAumdQ+4Ufa5b25cRpC5TYKlno6hsv6Cb3pkBk=
github.com/ryanuber/go-glob v1.0.0/go.mod h1:807d1WSdnB0XRJzKNil9Om6lcp/3a0v4qIHxIXzX/Yc=
github.com/sergi/go-diff v1.3.1 h1:xkr+Oxo4BOQKmkn/B9eMK0g5Kg/983T9DqqPHwYqD+8=
github.com/sergi/go-diff v1.3.1/go.mod h1:aMJSSKb2lpPvRNec0+w3fl7LP9IOFzdc9Pa4NFbPK1I=
github.com/shopspring/decimal v1.4.0 h1:bxl37RwXBklmTi0C79JfXCEBD1cqqHt0bbgBAGFp81k=

2
server/api.go

@ -279,7 +279,7 @@ func (d dexAPI) GetVersion(ctx context.Context, req *api.VersionReq) (*api.Versi
}
func (d dexAPI) GetDiscovery(ctx context.Context, req *api.DiscoveryReq) (*api.DiscoveryResp, error) {
discoveryDoc := d.server.constructDiscovery()
discoveryDoc := d.server.constructDiscovery(ctx)
data, err := json.Marshal(discoveryDoc)
if err != nil {
return nil, fmt.Errorf("failed to marshal discovery data: %v", err)

35
server/handlers.go

@ -34,25 +34,24 @@ const (
func (s *Server) handlePublicKeys(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
// TODO(ericchiang): Cache this.
keys, err := s.storage.GetKeys(ctx)
keys, err := s.signer.ValidationKeys(ctx)
if err != nil {
s.logger.ErrorContext(r.Context(), "failed to get keys", "err", err)
s.renderError(r, w, http.StatusInternalServerError, "Internal server error.")
return
}
if keys.SigningKeyPub == nil {
if len(keys) == 0 {
s.logger.ErrorContext(r.Context(), "no public keys found.")
s.renderError(r, w, http.StatusInternalServerError, "Internal server error.")
return
}
jwks := jose.JSONWebKeySet{
Keys: make([]jose.JSONWebKey, len(keys.VerificationKeys)+1),
Keys: make([]jose.JSONWebKey, len(keys)),
}
jwks.Keys[0] = *keys.SigningKeyPub
for i, verificationKey := range keys.VerificationKeys {
jwks.Keys[i+1] = *verificationKey.PublicKey
for i, key := range keys {
jwks.Keys[i] = *key
}
data, err := json.MarshalIndent(jwks, "", " ")
@ -61,10 +60,10 @@ func (s *Server) handlePublicKeys(w http.ResponseWriter, r *http.Request) {
s.renderError(r, w, http.StatusInternalServerError, "Internal server error.")
return
}
maxAge := keys.NextRotation.Sub(s.now())
if maxAge < (time.Minute * 2) {
maxAge = time.Minute * 2
}
// We don't have NextRotation info from Signer interface easily,
// so we'll just set a reasonable default cache time.
maxAge := time.Minute * 10
w.Header().Set("Cache-Control", fmt.Sprintf("max-age=%d, must-revalidate", int(maxAge.Seconds())))
w.Header().Set("Content-Type", "application/json")
@ -90,8 +89,8 @@ type discovery struct {
Claims []string `json:"claims_supported"`
}
func (s *Server) discoveryHandler() (http.HandlerFunc, error) {
d := s.constructDiscovery()
func (s *Server) discoveryHandler(ctx context.Context) (http.HandlerFunc, error) {
d := s.constructDiscovery(ctx)
data, err := json.MarshalIndent(d, "", " ")
if err != nil {
@ -105,7 +104,7 @@ func (s *Server) discoveryHandler() (http.HandlerFunc, error) {
}), nil
}
func (s *Server) constructDiscovery() discovery {
func (s *Server) constructDiscovery(ctx context.Context) discovery {
d := discovery{
Issuer: s.issuerURL.String(),
Auth: s.absURL("/auth"),
@ -125,6 +124,14 @@ func (s *Server) constructDiscovery() discovery {
},
}
// Determine signing algorithm from signer
signingAlg, err := s.signer.Algorithm(ctx)
if err != nil {
s.logger.Error("failed to get signing algorithm", "err", err)
} else {
d.IDTokenAlgs = []string{string(signingAlg)}
}
for responseType := range s.supportedResponseTypes {
d.ResponseTypes = append(d.ResponseTypes, responseType)
}
@ -1099,7 +1106,7 @@ func (s *Server) handleUserInfo(w http.ResponseWriter, r *http.Request) {
}
rawIDToken := auth[len(prefix):]
verifier := oidc.NewVerifier(s.issuerURL.String(), &storageKeySet{s.storage}, &oidc.Config{SkipClientIDCheck: true})
verifier := oidc.NewVerifier(s.issuerURL.String(), &signerKeySet{s.signer}, &oidc.Config{SkipClientIDCheck: true})
idToken, err := verifier.Verify(ctx, rawIDToken)
if err != nil {
s.logger.ErrorContext(r.Context(), "failed to verify ID token", "err", err)

2
server/introspectionhandler.go

@ -245,7 +245,7 @@ func (s *Server) introspectRefreshToken(ctx context.Context, token string) (*Int
}
func (s *Server) introspectAccessToken(ctx context.Context, token string) (*Introspection, error) {
verifier := oidc.NewVerifier(s.issuerURL.String(), &storageKeySet{s.storage}, &oidc.Config{SkipClientIDCheck: true})
verifier := oidc.NewVerifier(s.issuerURL.String(), &signerKeySet{s.signer}, &oidc.Config{SkipClientIDCheck: true})
idToken, err := verifier.Verify(ctx, token)
if err != nil {
return nil, newIntrospectInactiveTokenError()

39
server/oauth2.go

@ -351,21 +351,6 @@ func genSubject(userID string, connID string) (string, error) {
}
func (s *Server) newIDToken(ctx context.Context, clientID string, claims storage.Claims, scopes []string, nonce, accessToken, code, connID string) (idToken string, expiry time.Time, err error) {
keys, err := s.storage.GetKeys(ctx)
if err != nil {
s.logger.ErrorContext(ctx, "failed to get keys", "err", err)
return "", expiry, err
}
signingKey := keys.SigningKey
if signingKey == nil {
return "", expiry, fmt.Errorf("no key to sign payload with")
}
signingAlg, err := signatureAlgorithm(signingKey)
if err != nil {
return "", expiry, err
}
issuedAt := s.now()
expiry = issuedAt.Add(s.idTokensValidFor)
@ -383,6 +368,13 @@ func (s *Server) newIDToken(ctx context.Context, clientID string, claims storage
IssuedAt: issuedAt.Unix(),
}
// Determine signing algorithm from signer
signingAlg, err := s.signer.Algorithm(ctx)
if err != nil {
s.logger.ErrorContext(ctx, "failed to get signing algorithm", "err", err)
return "", expiry, fmt.Errorf("failed to get signing algorithm: %v", err)
}
if accessToken != "" {
atHash, err := accessTokenHash(signingAlg, accessToken)
if err != nil {
@ -445,7 +437,7 @@ func (s *Server) newIDToken(ctx context.Context, clientID string, claims storage
return "", expiry, fmt.Errorf("could not serialize claims: %v", err)
}
if idToken, err = signPayload(signingKey, signingAlg, payload); err != nil {
if idToken, err = s.signer.Sign(ctx, payload); err != nil {
return "", expiry, fmt.Errorf("failed to sign payload: %v", err)
}
return idToken, expiry, nil
@ -705,12 +697,12 @@ func validateConnectorID(connectors []storage.Connector, connectorID string) boo
return false
}
// storageKeySet implements the oidc.KeySet interface backed by Dex storage
type storageKeySet struct {
storage.Storage
// signerKeySet implements the oidc.KeySet interface backed by the Dex signer
type signerKeySet struct {
signer Signer
}
func (s *storageKeySet) VerifySignature(ctx context.Context, jwt string) (payload []byte, err error) {
func (s *signerKeySet) VerifySignature(ctx context.Context, jwt string) (payload []byte, err error) {
jws, err := jose.ParseSigned(jwt, []jose.SignatureAlgorithm{jose.RS256, jose.RS384, jose.RS512, jose.ES256, jose.ES384, jose.ES512})
if err != nil {
return nil, err
@ -722,16 +714,11 @@ func (s *storageKeySet) VerifySignature(ctx context.Context, jwt string) (payloa
break
}
skeys, err := s.Storage.GetKeys(ctx)
keys, err := s.signer.ValidationKeys(ctx)
if err != nil {
return nil, err
}
keys := []*jose.JSONWebKey{skeys.SigningKeyPub}
for _, vk := range skeys.VerificationKeys {
keys = append(keys, vk.PublicKey)
}
for _, key := range keys {
if keyID == "" || key.KeyID == keyID {
if payload, err := jws.Verify(key); err == nil {

9
server/oauth2_test.go

@ -593,7 +593,7 @@ func TestValidRedirectURI(t *testing.T) {
}
}
func TestStorageKeySet(t *testing.T) {
func TestSignerKeySet(t *testing.T) {
logger := newLogger(t)
s := memory.New(logger)
if err := s.UpdateKeys(t.Context(), func(keys storage.Keys) (storage.Keys, error) {
@ -668,7 +668,12 @@ func TestStorageKeySet(t *testing.T) {
t.Fatal(err)
}
keySet := &storageKeySet{s}
// We use a localSigner here to bridge the storage to the Signer interface.
// Since VerifySignature only needs ValidationKeys (which only needs storage),
// we don't need to initialize the rotator or other fields.
keySet := &signerKeySet{
signer: &localSigner{storage: s},
}
_, err = keySet.VerifySignature(t.Context(), jwt)
if (err != nil && !tc.wantErr) || (err == nil && tc.wantErr) {

30
server/rotation.go

@ -64,36 +64,6 @@ type keyRotator struct {
logger *slog.Logger
}
// startKeyRotation begins key rotation in a new goroutine, closing once the context is canceled.
//
// The method blocks until after the first attempt to rotate keys has completed. That way
// healthy storages will return from this call with valid keys.
func (s *Server) startKeyRotation(ctx context.Context, strategy rotationStrategy, now func() time.Time) {
rotator := keyRotator{s.storage, strategy, now, s.logger}
// Try to rotate immediately so properly configured storages will have keys.
if err := rotator.rotate(); err != nil {
if err == errAlreadyRotated {
s.logger.Info("key rotation not needed", "err", err)
} else {
s.logger.Error("failed to rotate keys", "err", err)
}
}
go func() {
for {
select {
case <-ctx.Done():
return
case <-time.After(time.Second * 30):
if err := rotator.rotate(); err != nil {
s.logger.Error("failed to rotate keys", "err", err)
}
}
}
}()
}
func (k keyRotator) rotate() error {
keys, err := k.GetKeys(context.Background())
if err != nil && err != storage.ErrNotFound {

27
server/server.go

@ -116,6 +116,8 @@ type Config struct {
Logger *slog.Logger
Signer SignerConfig
PrometheusRegistry *prometheus.Registry
HealthChecker gosundheit.Health
@ -156,6 +158,12 @@ type WebConfig struct {
Extra map[string]string
}
// SignerConfig holds the server's signer configuration.
type SignerConfig struct {
Type string `json:"type"`
Vault VaultSignerConfig `json:"vault"`
}
func value(val, defaultValue time.Duration) time.Duration {
if val == 0 {
return defaultValue
@ -200,6 +208,8 @@ type Server struct {
refreshTokenPolicy *RefreshTokenPolicy
logger *slog.Logger
signer Signer
}
// NewServer constructs a server from the provided config.
@ -318,6 +328,19 @@ func newServer(ctx context.Context, c Config, rotationStrategy rotationStrategy)
logger: c.Logger,
}
// Initialize signer
if c.Signer.Type == "vault" {
s.signer, err = newVaultSigner(c.Signer.Vault)
if err != nil {
return nil, fmt.Errorf("failed to initialize vault signer: %v", err)
}
s.logger.Info("signer configured", "type", "vault")
} else {
// Default to local signer
s.signer = newLocalSigner(c.Storage, rotationStrategy, now, c.Logger)
s.logger.Info("signer configured", "type", "local")
}
// Retrieves connector objects in backend storage. This list includes the static connectors
// defined in the ConfigMap and dynamic connectors retrieved from the storage.
storageConnectors, err := c.Storage.ListConnectors(ctx)
@ -452,7 +475,7 @@ func newServer(ctx context.Context, c Config, rotationStrategy rotationStrategy)
}
r.NotFoundHandler = http.NotFoundHandler()
discoveryHandler, err := s.discoveryHandler()
discoveryHandler, err := s.discoveryHandler(ctx)
if err != nil {
return nil, err
}
@ -514,7 +537,7 @@ func newServer(ctx context.Context, c Config, rotationStrategy rotationStrategy)
s.mux = r
s.startKeyRotation(ctx, rotationStrategy, now)
s.signer.Start(ctx)
s.startGarbageCollection(ctx, value(c.GCFrequency, 5*time.Minute), now)
return s, nil

22
server/signer.go

@ -0,0 +1,22 @@
package server
import (
"context"
"github.com/go-jose/go-jose/v4"
)
// Signer is an interface for signing payloads and retrieving validation keys.
type Signer interface {
// Sign signs the provided payload.
Sign(ctx context.Context, payload []byte) (string, error)
// ValidationKeys returns the current public keys used for signature validation.
ValidationKeys(ctx context.Context) ([]*jose.JSONWebKey, error)
// Algorithm returns the signing algorithm used by this signer.
Algorithm(ctx context.Context) (jose.SignatureAlgorithm, error)
// Start starts any background tasks required by the signer (e.g., key rotation).
Start(ctx context.Context)
}

105
server/signer_local.go

@ -0,0 +1,105 @@
package server
import (
"context"
"fmt"
"log/slog"
"time"
"github.com/go-jose/go-jose/v4"
"github.com/dexidp/dex/storage"
)
// localSigner signs payloads using keys stored in the Dex storage.
// It manages key rotation and storage using the existing keyRotator logic.
type localSigner struct {
storage storage.Storage
rotator *keyRotator
logger *slog.Logger
}
// newLocalSigner creates a new local signer and starts the key rotation loop.
func newLocalSigner(s storage.Storage, strategy rotationStrategy, now func() time.Time, logger *slog.Logger) *localSigner {
r := &keyRotator{s, strategy, now, logger}
return &localSigner{
storage: s,
rotator: r,
logger: logger,
}
}
// Start begins key rotation in a new goroutine, closing once the context is canceled.
//
// The method blocks until after the first attempt to rotate keys has completed. That way
// healthy storages will return from this call with valid keys.
func (l *localSigner) Start(ctx context.Context) {
// Try to rotate immediately so properly configured storages will have keys.
if err := l.rotator.rotate(); err != nil {
if err == errAlreadyRotated {
l.logger.Info("key rotation not needed", "err", err)
} else {
l.logger.Error("failed to rotate keys", "err", err)
}
}
go func() {
for {
select {
case <-ctx.Done():
return
case <-time.After(time.Second * 30):
if err := l.rotator.rotate(); err != nil {
l.logger.Error("failed to rotate keys", "err", err)
}
}
}
}()
}
func (l *localSigner) Sign(ctx context.Context, payload []byte) (string, error) {
keys, err := l.storage.GetKeys(ctx)
if err != nil {
return "", fmt.Errorf("failed to get keys: %v", err)
}
signingKey := keys.SigningKey
if signingKey == nil {
return "", fmt.Errorf("no key to sign payload with")
}
signingAlg, err := signatureAlgorithm(signingKey)
if err != nil {
return "", err
}
return signPayload(signingKey, signingAlg, payload)
}
func (l *localSigner) ValidationKeys(ctx context.Context) ([]*jose.JSONWebKey, error) {
keys, err := l.storage.GetKeys(ctx)
if err != nil {
return nil, fmt.Errorf("failed to get keys: %v", err)
}
if keys.SigningKeyPub == nil {
return nil, fmt.Errorf("no public keys found")
}
jwks := make([]*jose.JSONWebKey, len(keys.VerificationKeys)+1)
jwks[0] = keys.SigningKeyPub
for i, verificationKey := range keys.VerificationKeys {
jwks[i+1] = verificationKey.PublicKey
}
return jwks, nil
}
func (l *localSigner) Algorithm(ctx context.Context) (jose.SignatureAlgorithm, error) {
keys, err := l.storage.GetKeys(ctx)
if err != nil {
return "", fmt.Errorf("failed to get keys: %v", err)
}
if keys.SigningKey == nil {
return "", fmt.Errorf("no signing key found")
}
return signatureAlgorithm(keys.SigningKey)
}

359
server/signer_vault.go

@ -0,0 +1,359 @@
package server
import (
"context"
"crypto"
"crypto/ecdsa"
"crypto/ed25519"
"crypto/elliptic"
"crypto/rsa"
"crypto/sha256"
"crypto/sha512"
"crypto/x509"
"encoding/base64"
"encoding/json"
"encoding/pem"
"fmt"
"hash"
"os"
"github.com/go-jose/go-jose/v4"
vault "github.com/openbao/openbao/api/v2"
)
// VaultSignerConfig holds configuration for the Vault signer.
type VaultSignerConfig struct {
Addr string `json:"addr"`
Token string `json:"token"`
KeyName string `json:"keyName"`
}
// UnmarshalJSON unmarshals a VaultSignerConfig and applies environment variables.
// If Addr or Token are not provided in the config, they are read from VAULT_ADDR
// and VAULT_TOKEN environment variables respectively.
func (c *VaultSignerConfig) UnmarshalJSON(data []byte) error {
type Alias VaultSignerConfig
aux := &struct {
*Alias
}{
Alias: (*Alias)(c),
}
if err := json.Unmarshal(data, &aux); err != nil {
return err
}
// Apply environment variables if config values are empty
if c.Addr == "" {
if addr := os.Getenv("VAULT_ADDR"); addr != "" {
c.Addr = addr
}
}
if c.Token == "" {
if token := os.Getenv("VAULT_TOKEN"); token != "" {
c.Token = token
}
}
return nil
}
// vaultSigner signs payloads using HashiCorp Vault's Transit backend.
type vaultSigner struct {
client *vault.Client
keyName string
}
// newVaultSigner creates a new Vault signer that uses Transit backend for signing.
func newVaultSigner(c VaultSignerConfig) (*vaultSigner, error) {
config := vault.DefaultConfig()
config.Address = c.Addr
client, err := vault.NewClient(config)
if err != nil {
return nil, fmt.Errorf("failed to create vault client: %v", err)
}
if c.Token != "" {
client.SetToken(c.Token)
}
return &vaultSigner{
client: client,
keyName: c.KeyName,
}, nil
}
func (v *vaultSigner) Start(ctx context.Context) {
// Vault signer does not need background rotation tasks
}
func (v *vaultSigner) Sign(ctx context.Context, payload []byte) (string, error) {
// 1. Fetch keys to determine the key to use (latest version) and its ID.
keysMap, latestVersion, err := v.getTransitKeysMap(ctx)
if err != nil {
return "", fmt.Errorf("failed to get keys for signing context: %v", err)
}
// Determine the key version and ID to use
// We use the latest version by default
signingJWK, ok := keysMap[latestVersion]
if !ok {
return "", fmt.Errorf("latest key version %d not found in public keys", latestVersion)
}
// 2. Construct JWS Header and Payload first (Signing Input)
header := map[string]interface{}{
"alg": signingJWK.Algorithm,
"kid": signingJWK.KeyID,
}
headerBytes, err := json.Marshal(header)
if err != nil {
return "", fmt.Errorf("failed to marshal header: %v", err)
}
headerB64 := base64.RawURLEncoding.EncodeToString(headerBytes)
payloadB64 := base64.RawURLEncoding.EncodeToString(payload)
// The input to the signature is "header.payload"
signingInput := fmt.Sprintf("%s.%s", headerB64, payloadB64)
// 3. Sign the signingInput using Vault
var vaultInput string
data := map[string]interface{}{}
// Determine Vault params based on JWS algorithm
params, err := getVaultParams(signingJWK.Algorithm)
if err != nil {
return "", err
}
// Apply params to data map
for k, v := range params.extraParams {
data[k] = v
}
// Hash input if needed
if params.hasher != nil {
params.hasher.Write([]byte(signingInput))
hash := params.hasher.Sum(nil)
vaultInput = base64.StdEncoding.EncodeToString(hash)
} else {
// No pre-hashing (EdDSA)
vaultInput = base64.StdEncoding.EncodeToString([]byte(signingInput))
}
data["input"] = vaultInput
signPath := fmt.Sprintf("transit/sign/%s", v.keyName)
signSecret, err := v.client.Logical().WriteWithContext(ctx, signPath, data)
if err != nil {
return "", fmt.Errorf("vault sign: %v", err)
}
signatureString, ok := signSecret.Data["signature"].(string)
if !ok {
return "", fmt.Errorf("vault response missing signature")
}
// Parse vault signature: "vault:v1:base64sig"
var signatureB64 []byte
if len(signatureString) > 8 && signatureString[:6] == "vault:" {
parts := splitVaultSignature(signatureString)
if len(parts) == 3 {
// part 1 is "vault", part 2 is "v1", part 3 is signature
// The signature is already base64 encoded, decoding it is not needed and
// will make the code failing.
signatureB64 = []byte(parts[2])
}
} else {
return "", fmt.Errorf("unexpected signature format: %s", signatureString)
}
return fmt.Sprintf("%s.%s.%s", headerB64, payloadB64, signatureB64), nil
}
func (v *vaultSigner) ValidationKeys(ctx context.Context) ([]*jose.JSONWebKey, error) {
keysMap, _, err := v.getTransitKeysMap(ctx)
if err != nil {
return nil, err
}
keys := make([]*jose.JSONWebKey, 0, len(keysMap))
for _, k := range keysMap {
keys = append(keys, k)
}
return keys, nil
}
// getTransitKeysMap returns a map of key_version -> JWK and the latest version number
func (v *vaultSigner) getTransitKeysMap(ctx context.Context) (map[int64]*jose.JSONWebKey, int64, error) {
path := fmt.Sprintf("transit/keys/%s", v.keyName)
secret, err := v.client.Logical().ReadWithContext(ctx, path)
if err != nil {
return nil, 0, fmt.Errorf("failed to read key from vault: %v", err)
}
if secret == nil {
return nil, 0, fmt.Errorf("key %q not found in vault", v.keyName)
}
latestVersion, ok := secret.Data["latest_version"].(json.Number)
if !ok {
// Try float64 which is default for unmarshal interface{}
if lv, ok := secret.Data["latest_version"].(float64); ok {
latestVersion = json.Number(fmt.Sprintf("%d", int(lv)))
} else if lv, ok := secret.Data["latest_version"].(int); ok {
latestVersion = json.Number(fmt.Sprintf("%d", lv))
}
}
latestVerInt, err := latestVersion.Int64()
if err != nil {
return nil, 0, fmt.Errorf("failed to get latest version: %v", err)
}
keysObj, ok := secret.Data["keys"].(map[string]interface{})
if !ok {
return nil, 0, fmt.Errorf("invalid response from vault")
}
jwksMap := make(map[int64]*jose.JSONWebKey)
for verStr, data := range keysObj {
d, ok := data.(map[string]interface{})
if !ok {
continue
}
var ver int64
fmt.Sscanf(verStr, "%d", &ver)
pemStr, ok := d["public_key"].(string)
if !ok {
continue
}
jwk, err := parsePEMToJWK(pemStr)
if err != nil {
continue
}
jwksMap[ver] = jwk
}
return jwksMap, latestVerInt, nil
}
func parsePEMToJWK(pemStr string) (*jose.JSONWebKey, error) {
block, _ := pem.Decode([]byte(pemStr))
if block == nil {
return nil, fmt.Errorf("failed to parse PEM block")
}
pub, err := x509.ParsePKIXPublicKey(block.Bytes)
if err != nil {
return nil, fmt.Errorf("failed to parse public key: %v", err)
}
alg := ""
switch k := pub.(type) {
case *rsa.PublicKey:
alg = "RS256"
case *ecdsa.PublicKey:
switch k.Curve {
case elliptic.P256():
alg = "ES256"
case elliptic.P384():
alg = "ES384"
case elliptic.P521():
alg = "ES512"
default:
return nil, fmt.Errorf("unsupported ECDSA curve")
}
case ed25519.PublicKey:
alg = "EdDSA"
default:
return nil, fmt.Errorf("unsupported key type %T", pub)
}
jwk := &jose.JSONWebKey{
Key: pub,
Algorithm: alg,
Use: "sig",
}
thumbprint, err := jwk.Thumbprint(crypto.SHA256)
if err != nil {
return nil, err
}
jwk.KeyID = base64.RawURLEncoding.EncodeToString(thumbprint)
return jwk, nil
}
func splitVaultSignature(sig string) []string {
// Basic split implementation
// "vault:v1:signature"
var parts []string
start := 0
for i := 0; i < len(sig); i++ {
if sig[i] == ':' {
parts = append(parts, sig[start:i])
start = i + 1
}
}
parts = append(parts, sig[start:])
return parts
}
func (v *vaultSigner) Algorithm(ctx context.Context) (jose.SignatureAlgorithm, error) {
keysMap, latestVersion, err := v.getTransitKeysMap(ctx)
if err != nil {
return "", fmt.Errorf("failed to get keys: %v", err)
}
signingJWK, ok := keysMap[latestVersion]
if !ok {
return "", fmt.Errorf("latest key version %d not found", latestVersion)
}
return jose.SignatureAlgorithm(signingJWK.Algorithm), nil
}
type vaultAlgoParams struct {
hasher hash.Hash
extraParams map[string]interface{}
}
func getVaultParams(alg string) (vaultAlgoParams, error) {
params := vaultAlgoParams{
extraParams: map[string]interface{}{
"marshaling_algorithm": "jws",
"signature_algorithm": "pkcs1v15",
},
}
switch alg {
case "RS256":
params.hasher = sha256.New()
params.extraParams["prehashed"] = true
params.extraParams["hash_algorithm"] = "sha2-256"
case "ES256":
params.hasher = sha256.New()
params.extraParams["prehashed"] = true
params.extraParams["hash_algorithm"] = "sha2-256"
case "ES384":
params.hasher = sha512.New384()
params.extraParams["prehashed"] = true
params.extraParams["hash_algorithm"] = "sha2-384"
case "ES512":
params.hasher = sha512.New()
params.extraParams["prehashed"] = true
params.extraParams["hash_algorithm"] = "sha2-512"
case "EdDSA":
// No hashing
params.hasher = nil
default:
return params, fmt.Errorf("unsupported signing algorithm: %s", alg)
}
return params, nil
}

196
server/signer_vault_test.go

@ -0,0 +1,196 @@
package server
import (
"encoding/json"
"os"
"testing"
)
func TestVaultSignerConfigUnmarshalJSON_WithEnvVars(t *testing.T) {
// Save original environment variables
originalAddr := os.Getenv("VAULT_ADDR")
originalToken := os.Getenv("VAULT_TOKEN")
defer func() {
os.Setenv("VAULT_ADDR", originalAddr)
os.Setenv("VAULT_TOKEN", originalToken)
}()
// Set environment variables
os.Setenv("VAULT_ADDR", "http://vault.example.com:8200")
os.Setenv("VAULT_TOKEN", "s.xxxxxxxxxxxxxxxx")
tests := []struct {
name string
json string
want VaultSignerConfig
wantErr bool
}{
{
name: "empty config uses env vars",
json: `{"keyName": "signing-key"}`,
want: VaultSignerConfig{
Addr: "http://vault.example.com:8200",
Token: "s.xxxxxxxxxxxxxxxx",
KeyName: "signing-key",
},
wantErr: false,
},
{
name: "config values override env vars",
json: `{"addr": "http://custom.vault.com:8200", "token": "s.custom", "keyName": "signing-key"}`,
want: VaultSignerConfig{
Addr: "http://custom.vault.com:8200",
Token: "s.custom",
KeyName: "signing-key",
},
wantErr: false,
},
{
name: "partial config uses env vars for missing values",
json: `{"addr": "http://custom.vault.com:8200", "keyName": "signing-key"}`,
want: VaultSignerConfig{
Addr: "http://custom.vault.com:8200",
Token: "s.xxxxxxxxxxxxxxxx",
KeyName: "signing-key",
},
wantErr: false,
},
{
name: "empty token in config uses env var",
json: `{"addr": "http://custom.vault.com:8200", "token": "", "keyName": "signing-key"}`,
want: VaultSignerConfig{
Addr: "http://custom.vault.com:8200",
Token: "s.xxxxxxxxxxxxxxxx",
KeyName: "signing-key",
},
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var got VaultSignerConfig
err := json.Unmarshal([]byte(tt.json), &got)
if (err != nil) != tt.wantErr {
t.Errorf("UnmarshalJSON() error = %v, wantErr %v", err, tt.wantErr)
return
}
if got.Addr != tt.want.Addr {
t.Errorf("Addr: got %q, want %q", got.Addr, tt.want.Addr)
}
if got.Token != tt.want.Token {
t.Errorf("Token: got %q, want %q", got.Token, tt.want.Token)
}
if got.KeyName != tt.want.KeyName {
t.Errorf("KeyName: got %q, want %q", got.KeyName, tt.want.KeyName)
}
})
}
}
func TestVaultSignerConfigUnmarshalJSON_WithoutEnvVars(t *testing.T) {
// Save original environment variables
originalAddr := os.Getenv("VAULT_ADDR")
originalToken := os.Getenv("VAULT_TOKEN")
defer func() {
os.Setenv("VAULT_ADDR", originalAddr)
os.Setenv("VAULT_TOKEN", originalToken)
}()
// Unset environment variables
os.Unsetenv("VAULT_ADDR")
os.Unsetenv("VAULT_TOKEN")
tests := []struct {
name string
json string
want VaultSignerConfig
wantErr bool
}{
{
name: "config values used when env vars not set",
json: `{"addr": "http://vault.example.com:8200", "token": "s.xxxxxxxxxxxxxxxx", "keyName": "signing-key"}`,
want: VaultSignerConfig{
Addr: "http://vault.example.com:8200",
Token: "s.xxxxxxxxxxxxxxxx",
KeyName: "signing-key",
},
wantErr: false,
},
{
name: "empty config when env vars not set",
json: `{"keyName": "signing-key"}`,
want: VaultSignerConfig{
Addr: "",
Token: "",
KeyName: "signing-key",
},
wantErr: false,
},
{
name: "only keyName required in config",
json: `{"keyName": "my-key"}`,
want: VaultSignerConfig{
Addr: "",
Token: "",
KeyName: "my-key",
},
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var got VaultSignerConfig
err := json.Unmarshal([]byte(tt.json), &got)
if (err != nil) != tt.wantErr {
t.Errorf("UnmarshalJSON() error = %v, wantErr %v", err, tt.wantErr)
return
}
if got.Addr != tt.want.Addr {
t.Errorf("Addr: got %q, want %q", got.Addr, tt.want.Addr)
}
if got.Token != tt.want.Token {
t.Errorf("Token: got %q, want %q", got.Token, tt.want.Token)
}
if got.KeyName != tt.want.KeyName {
t.Errorf("KeyName: got %q, want %q", got.KeyName, tt.want.KeyName)
}
})
}
}
func TestVaultSignerConfigUnmarshalJSON_InvalidJSON(t *testing.T) {
tests := []struct {
name string
json string
wantErr bool
}{
{
name: "invalid json",
json: `{invalid json}`,
wantErr: true,
},
{
name: "empty json",
json: `{}`,
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var got VaultSignerConfig
err := json.Unmarshal([]byte(tt.json), &got)
if (err != nil) != tt.wantErr {
t.Errorf("UnmarshalJSON() error = %v, wantErr %v", err, tt.wantErr)
return
}
})
}
}
Loading…
Cancel
Save