Browse Source

test: Add conformance tests for Vault signer integration

Signed-off-by: maksim.nabokikh <max.nabokih@gmail.com>
pull/4520/head
maksim.nabokikh 1 month ago
parent
commit
c882128152
  1. 23
      .github/workflows/ci.yaml
  2. 20
      docker-compose.yaml
  3. 52
      server/signer_vault.go
  4. 540
      server/signer_vault_integration_test.go

23
.github/workflows/ci.yaml

@ -64,6 +64,24 @@ jobs:
- 35357
options: --health-cmd "curl --fail http://localhost:5000/v3" --health-interval 10s --health-timeout 5s --health-retries 5
vault:
image: hashicorp/vault:1.21
ports:
- 8200
env:
VAULT_DEV_ROOT_TOKEN_ID: root-token
VAULT_DEV_LISTEN_ADDRESS: "0.0.0.0:8200"
options: --health-cmd "vault status -address=http://localhost:8200 || exit 1" --health-interval 10s --health-timeout 5s --health-retries 5
openbao:
image: quay.io/openbao/openbao:2.5
ports:
- 8210
env:
BAO_DEV_ROOT_TOKEN_ID: root-token
BAO_DEV_LISTEN_ADDRESS: "0.0.0.0:8210"
options: --health-cmd "bao status -address=http://localhost:8210 || exit 1" --health-interval 10s --health-timeout 5s --health-retries 5
steps:
- name: Checkout repository
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
@ -129,6 +147,11 @@ jobs:
DEX_KEYSTONE_ADMIN_USER: demo
DEX_KEYSTONE_ADMIN_PASS: DEMO_PASS
DEX_VAULT_ADDR: http://localhost:${{ job.services.vault.ports[8200] }}
DEX_VAULT_TOKEN: root-token
DEX_OPENBAO_ADDR: http://localhost:${{ job.services.openbao.ports[8210] }}
DEX_OPENBAO_TOKEN: root-token
DEX_KUBERNETES_CONFIG_PATH: ~/.kube/config
lint:

20
docker-compose.yaml

@ -45,3 +45,23 @@ services:
volumes:
- ./connector/ldap/testdata/certs:/container/service/slapd/assets/certs
- ./connector/ldap/testdata/schema.ldif:/container/service/slapd/assets/config/bootstrap/ldif/99-schema.ldif
vault:
image: hashicorp/vault:1.21
environment:
VAULT_DEV_ROOT_TOKEN_ID: root-token
VAULT_DEV_LISTEN_ADDRESS: "0.0.0.0:8200"
cap_add:
- IPC_LOCK
ports:
- 8200:8200
openbao:
image: quay.io/openbao/openbao:2.5
environment:
BAO_DEV_ROOT_TOKEN_ID: root-token
BAO_DEV_LISTEN_ADDRESS: "0.0.0.0:8200"
cap_add:
- IPC_LOCK
ports:
- 8210:8200

52
server/signer_vault.go

@ -247,7 +247,57 @@ func (v *vaultSigner) getTransitKeysMap(ctx context.Context) (map[int64]*jose.JS
func parsePEMToJWK(pemStr string) (*jose.JSONWebKey, error) {
block, _ := pem.Decode([]byte(pemStr))
if block == nil {
return nil, fmt.Errorf("failed to parse PEM block")
// OpenBao may return ED25519 keys as raw base64-encoded strings instead of PEM
// Try to decode as raw base64 ED25519 key
keyBytes, err := base64.StdEncoding.DecodeString(pemStr)
if err != nil {
return nil, fmt.Errorf("failed to parse PEM block or base64: %v", err)
}
// Check if it's a raw 32-byte ED25519 key
if len(keyBytes) == 32 {
ed25519Key := ed25519.PublicKey(keyBytes)
jwk := &jose.JSONWebKey{
Key: ed25519Key,
Algorithm: "EdDSA",
Use: "sig",
}
thumbprint, err := jwk.Thumbprint(crypto.SHA256)
if err != nil {
return nil, err
}
jwk.KeyID = base64.RawURLEncoding.EncodeToString(thumbprint)
return jwk, nil
}
// Try to parse as PKIX public key
pub, err := x509.ParsePKIXPublicKey(keyBytes)
if err != nil {
return nil, fmt.Errorf("failed to parse raw key: %v", err)
}
// Create JWK for ED25519 key
ed25519Key, ok := pub.(ed25519.PublicKey)
if !ok {
return nil, fmt.Errorf("expected ED25519 key, got %T", pub)
}
jwk := &jose.JSONWebKey{
Key: ed25519Key,
Algorithm: "EdDSA",
Use: "sig",
}
thumbprint, err := jwk.Thumbprint(crypto.SHA256)
if err != nil {
return nil, err
}
jwk.KeyID = base64.RawURLEncoding.EncodeToString(thumbprint)
return jwk, nil
}
pub, err := x509.ParsePKIXPublicKey(block.Bytes)

540
server/signer_vault_integration_test.go

@ -0,0 +1,540 @@
package server
import (
"context"
"crypto/rand"
"encoding/base64"
"encoding/json"
"fmt"
"os"
"testing"
"time"
"github.com/go-jose/go-jose/v4"
vault "github.com/openbao/openbao/api/v2"
)
// Conformance tests verify that Vault and OpenBao behave identically with the signer.
// These tests use a single SDK (OpenBao API) that works with both systems.
//
// To run tests for a specific system, set the environment variables:
//
// For Vault:
// DEX_VAULT_ADDR=http://localhost:8200
// DEX_VAULT_TOKEN=root-token
// go test -v -run TestVaultSignerConformance
//
// For OpenBao:
// DEX_OPENBAO_ADDR=http://localhost:8210
// DEX_OPENBAO_TOKEN=root-token
// go test -v -run TestVaultSignerConformance
//
// To test both systems in parallel, set both sets of environment variables.
type conformanceTestConfig struct {
name string
addr string
token string
addrEnv string
tokenEnv string
}
// getTestConfigs returns list of test configs based on environment variables
func getTestConfigs(t *testing.T) []conformanceTestConfig {
var configs []conformanceTestConfig
// Check for Vault
vaultAddr := os.Getenv("DEX_VAULT_ADDR")
vaultToken := os.Getenv("DEX_VAULT_TOKEN")
if vaultAddr != "" && vaultToken != "" {
configs = append(configs, conformanceTestConfig{
name: "Vault",
addr: vaultAddr,
token: vaultToken,
addrEnv: "DEX_VAULT_ADDR",
tokenEnv: "DEX_VAULT_TOKEN",
})
}
// Check for OpenBao
openbaoAddr := os.Getenv("DEX_OPENBAO_ADDR")
openbaoToken := os.Getenv("DEX_OPENBAO_TOKEN")
if openbaoAddr != "" && openbaoToken != "" {
configs = append(configs, conformanceTestConfig{
name: "OpenBao",
addr: openbaoAddr,
token: openbaoToken,
addrEnv: "DEX_OPENBAO_ADDR",
tokenEnv: "DEX_OPENBAO_TOKEN",
})
}
if len(configs) == 0 {
t.Skip("Skipping conformance tests. Set DEX_VAULT_TOKEN+DEX_VAULT_ADDR or DEX_OPENBAO_TOKEN+DEX_OPENBAO_ADDR to run.")
}
return configs
}
// TestVaultSignerConformance_SigningAndVerification tests that signing and verification work the same way
// across Vault and OpenBao implementations.
func TestVaultSignerConformance_SigningAndVerification(t *testing.T) {
configs := getTestConfigs(t)
testCases := []struct {
name string
keyType string
alg string
}{
{
name: "RSA-2048",
keyType: "rsa-2048",
alg: "RS256",
},
{
name: "ECDSA-P256",
keyType: "ecdsa-p256",
alg: "ES256",
},
{
name: "ECDSA-P384",
keyType: "ecdsa-p384",
alg: "ES384",
},
{
name: "ED25519",
keyType: "ed25519",
alg: "EdDSA",
},
}
for _, config := range configs {
t.Run(config.name, func(t *testing.T) {
ctx := context.Background()
// Create client
vaultConfig := vault.DefaultConfig()
vaultConfig.Address = config.addr
client, err := vault.NewClient(vaultConfig)
if err != nil {
t.Fatalf("failed to create client: %v", err)
}
client.SetToken(config.token)
// Enable transit engine
if err := enableTransitEngine(client); err != nil {
t.Fatalf("failed to enable transit engine: %v", err)
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
keyName := fmt.Sprintf("test-key-%s-%s-%d", config.name, tc.keyType, time.Now().Unix())
// Create key
keyData := map[string]interface{}{
"type": tc.keyType,
}
_, err := client.Logical().WriteWithContext(ctx, fmt.Sprintf("transit/keys/%s", keyName), keyData)
if err != nil {
t.Fatalf("failed to create key: %v", err)
}
// Clean up
defer func() {
updateData := map[string]interface{}{
"deletion_allowed": true,
}
client.Logical().WriteWithContext(ctx, fmt.Sprintf("transit/keys/%s/config", keyName), updateData)
client.Logical().DeleteWithContext(ctx, fmt.Sprintf("transit/keys/%s", keyName))
}()
// Create signer
signerConfig := VaultSignerConfig{
Addr: config.addr,
Token: config.token,
KeyName: keyName,
}
signer, err := newVaultSigner(signerConfig)
if err != nil {
t.Fatalf("failed to create signer: %v", err)
}
// Test 1: Verify algorithm
alg, err := signer.Algorithm(ctx)
if err != nil {
t.Fatalf("failed to get algorithm: %v", err)
}
if string(alg) != tc.alg {
t.Errorf("expected algorithm %s, got %s", tc.alg, alg)
}
// Test 2: Get validation keys
keys, err := signer.ValidationKeys(ctx)
if err != nil {
t.Fatalf("failed to get validation keys: %v", err)
}
if len(keys) == 0 {
t.Fatal("expected at least one validation key")
}
if keys[0].Algorithm != tc.alg {
t.Errorf("expected key algorithm %s, got %s", tc.alg, keys[0].Algorithm)
}
if keys[0].Use != "sig" {
t.Errorf("expected key use 'sig', got %s", keys[0].Use)
}
// Test 3: Sign and verify JWT
payload := map[string]interface{}{
"iss": "https://dex.example.com",
"sub": "user123",
"aud": "client-app",
"exp": time.Now().Add(time.Hour).Unix(),
"iat": time.Now().Unix(),
}
payloadBytes, _ := json.Marshal(payload)
jwtString, err := signer.Sign(ctx, payloadBytes)
if err != nil {
t.Fatalf("failed to sign payload: %v", err)
}
// Verify JWT signature
jws, err := jose.ParseSigned(jwtString, []jose.SignatureAlgorithm{jose.SignatureAlgorithm(tc.alg)})
if err != nil {
t.Fatalf("failed to parse signed JWT: %v", err)
}
verifiedPayload, err := jws.Verify(keys[0])
if err != nil {
t.Fatalf("failed to verify JWT signature: %v", err)
}
var decodedPayload map[string]interface{}
if err := json.Unmarshal(verifiedPayload, &decodedPayload); err != nil {
t.Fatalf("failed to unmarshal verified payload: %v", err)
}
if decodedPayload["sub"] != payload["sub"] {
t.Errorf("payload mismatch: expected sub=%s, got %s", payload["sub"], decodedPayload["sub"])
}
// Test 4: Multiple signatures with same key
for i := 0; i < 3; i++ {
randomPayload := make([]byte, 32)
rand.Read(randomPayload)
payloadData := map[string]interface{}{
"data": base64.StdEncoding.EncodeToString(randomPayload),
"iat": time.Now().Unix(),
}
payloadBytes, _ := json.Marshal(payloadData)
jwtString, err := signer.Sign(ctx, payloadBytes)
if err != nil {
t.Fatalf("sign attempt %d failed: %v", i+1, err)
}
jws, err := jose.ParseSigned(jwtString, []jose.SignatureAlgorithm{jose.SignatureAlgorithm(tc.alg)})
if err != nil {
t.Fatalf("parse attempt %d failed: %v", i+1, err)
}
_, err = jws.Verify(keys[0])
if err != nil {
t.Fatalf("verify attempt %d failed: %v", i+1, err)
}
}
})
}
})
}
}
// TestVaultSignerConformance_KeyRotation tests that key rotation works identically
// across Vault and OpenBao implementations.
func TestVaultSignerConformance_KeyRotation(t *testing.T) {
configs := getTestConfigs(t)
for _, config := range configs {
t.Run(config.name, func(t *testing.T) {
ctx := context.Background()
// Create client
vaultConfig := vault.DefaultConfig()
vaultConfig.Address = config.addr
client, err := vault.NewClient(vaultConfig)
if err != nil {
t.Fatalf("failed to create client: %v", err)
}
client.SetToken(config.token)
// Enable transit engine
if err := enableTransitEngine(client); err != nil {
t.Fatalf("failed to enable transit engine: %v", err)
}
keyName := fmt.Sprintf("test-rotation-key-%s-%d", config.name, time.Now().Unix())
// Create initial key
keyData := map[string]interface{}{
"type": "ecdsa-p256",
}
_, err = client.Logical().WriteWithContext(ctx, fmt.Sprintf("transit/keys/%s", keyName), keyData)
if err != nil {
t.Fatalf("failed to create key: %v", err)
}
// Clean up
defer func() {
updateData := map[string]interface{}{
"deletion_allowed": true,
}
client.Logical().WriteWithContext(ctx, fmt.Sprintf("transit/keys/%s/config", keyName), updateData)
client.Logical().DeleteWithContext(ctx, fmt.Sprintf("transit/keys/%s", keyName))
}()
// Create signer
signerConfig := VaultSignerConfig{
Addr: config.addr,
Token: config.token,
KeyName: keyName,
}
signer, err := newVaultSigner(signerConfig)
if err != nil {
t.Fatalf("failed to create signer: %v", err)
}
// Sign with initial key version
payload1 := map[string]interface{}{"version": "v1", "iat": time.Now().Unix()}
payload1Bytes, _ := json.Marshal(payload1)
jwt1, err := signer.Sign(ctx, payload1Bytes)
if err != nil {
t.Fatalf("failed to sign with v1: %v", err)
}
// Get keys before rotation
keysBefore, err := signer.ValidationKeys(ctx)
if err != nil {
t.Fatalf("failed to get keys before rotation: %v", err)
}
if len(keysBefore) != 1 {
t.Errorf("expected 1 key before rotation, got %d", len(keysBefore))
}
// Rotate key
_, err = client.Logical().WriteWithContext(ctx, fmt.Sprintf("transit/keys/%s/rotate", keyName), nil)
if err != nil {
t.Fatalf("failed to rotate key: %v", err)
}
// Sign with new key version
payload2 := map[string]interface{}{"version": "v2", "iat": time.Now().Unix()}
payload2Bytes, _ := json.Marshal(payload2)
jwt2, err := signer.Sign(ctx, payload2Bytes)
if err != nil {
t.Fatalf("failed to sign with v2: %v", err)
}
// Get keys after rotation
keysAfter, err := signer.ValidationKeys(ctx)
if err != nil {
t.Fatalf("failed to get keys after rotation: %v", err)
}
if len(keysAfter) != 2 {
t.Errorf("expected 2 keys after rotation, got %d", len(keysAfter))
}
// Verify both JWTs can be validated with the current keyset
jws1, err := jose.ParseSigned(jwt1, []jose.SignatureAlgorithm{jose.ES256})
if err != nil {
t.Fatalf("failed to parse jwt1: %v", err)
}
jws2, err := jose.ParseSigned(jwt2, []jose.SignatureAlgorithm{jose.ES256})
if err != nil {
t.Fatalf("failed to parse jwt2: %v", err)
}
// Find matching keys and verify
verified1 := false
verified2 := false
for _, key := range keysAfter {
if _, err := jws1.Verify(key); err == nil {
verified1 = true
}
if _, err := jws2.Verify(key); err == nil {
verified2 = true
}
}
if !verified1 {
t.Error("failed to verify JWT signed with version 1")
}
if !verified2 {
t.Error("failed to verify JWT signed with version 2")
}
})
}
}
// TestVaultSignerConformance_PublicKeyDiscovery tests that public key discovery works identically
// across Vault and OpenBao implementations.
func TestVaultSignerConformance_PublicKeyDiscovery(t *testing.T) {
configs := getTestConfigs(t)
for _, config := range configs {
t.Run(config.name, func(t *testing.T) {
ctx := context.Background()
// Create client
vaultConfig := vault.DefaultConfig()
vaultConfig.Address = config.addr
client, err := vault.NewClient(vaultConfig)
if err != nil {
t.Fatalf("failed to create client: %v", err)
}
client.SetToken(config.token)
// Enable transit engine
if err := enableTransitEngine(client); err != nil {
t.Fatalf("failed to enable transit engine: %v", err)
}
keyName := fmt.Sprintf("test-discovery-key-%s-%d", config.name, time.Now().Unix())
// Create key
keyData := map[string]interface{}{
"type": "rsa-2048",
}
_, err = client.Logical().WriteWithContext(ctx, fmt.Sprintf("transit/keys/%s", keyName), keyData)
if err != nil {
t.Fatalf("failed to create key: %v", err)
}
// Clean up
defer func() {
updateData := map[string]interface{}{
"deletion_allowed": true,
}
client.Logical().WriteWithContext(ctx, fmt.Sprintf("transit/keys/%s/config", keyName), updateData)
client.Logical().DeleteWithContext(ctx, fmt.Sprintf("transit/keys/%s", keyName))
}()
// Create signer
signerConfig := VaultSignerConfig{
Addr: config.addr,
Token: config.token,
KeyName: keyName,
}
signer, err := newVaultSigner(signerConfig)
if err != nil {
t.Fatalf("failed to create signer: %v", err)
}
// Get public keys (simulating JWKS endpoint)
keys, err := signer.ValidationKeys(ctx)
if err != nil {
t.Fatalf("failed to get validation keys: %v", err)
}
// Verify keys have required JWKS fields
for i, key := range keys {
if key.KeyID == "" {
t.Errorf("key %d missing KeyID", i)
}
if key.Algorithm == "" {
t.Errorf("key %d missing Algorithm", i)
}
if key.Use != "sig" {
t.Errorf("key %d has wrong Use field: expected 'sig', got '%s'", i, key.Use)
}
if key.Key == nil {
t.Errorf("key %d missing public key", i)
}
// Verify key can be marshaled to JWKS format
jwksData, err := json.Marshal(key)
if err != nil {
t.Errorf("key %d cannot be marshaled to JSON: %v", i, err)
}
var jwksCheck map[string]interface{}
if err := json.Unmarshal(jwksData, &jwksCheck); err != nil {
t.Errorf("key %d JWKS data is invalid: %v", i, err)
}
// Check for standard JWKS fields
requiredFields := []string{"kty", "use", "kid", "alg"}
for _, field := range requiredFields {
if _, ok := jwksCheck[field]; !ok {
t.Errorf("key %d missing required JWKS field: %s", i, field)
}
}
}
// Sign a JWT
payload := map[string]interface{}{
"iss": "https://dex.example.com",
"sub": "test-user",
"aud": "test-client",
"exp": time.Now().Add(time.Hour).Unix(),
"iat": time.Now().Unix(),
}
payloadBytes, _ := json.Marshal(payload)
jwtString, err := signer.Sign(ctx, payloadBytes)
if err != nil {
t.Fatalf("failed to sign JWT: %v", err)
}
// Parse JWT and verify it has correct kid in header
jws, err := jose.ParseSigned(jwtString, []jose.SignatureAlgorithm{jose.RS256})
if err != nil {
t.Fatalf("failed to parse JWT: %v", err)
}
if len(jws.Signatures) == 0 {
t.Fatal("JWT has no signatures")
}
kid := jws.Signatures[0].Header.KeyID
if kid == "" {
t.Error("JWT header missing kid")
}
// Verify kid matches one of the public keys
kidFound := false
for _, key := range keys {
if key.KeyID == kid {
kidFound = true
break
}
}
if !kidFound {
t.Errorf("JWT kid '%s' not found in public keys", kid)
}
})
}
}
// enableTransitEngine enables the transit secrets engine if not already enabled.
func enableTransitEngine(client *vault.Client) error {
// Check if already enabled
mounts, err := client.Sys().ListMounts()
if err != nil {
return fmt.Errorf("failed to list mounts: %v", err)
}
if _, exists := mounts["transit/"]; exists {
return nil
}
// Enable transit engine
mountInput := &vault.MountInput{
Type: "transit",
}
if err := client.Sys().Mount("transit", mountInput); err != nil {
return fmt.Errorf("failed to mount transit: %v", err)
}
return nil
}
Loading…
Cancel
Save