diff --git a/server/signer_vault.go b/server/signer_vault.go index c16e25d1..7071d2eb 100644 --- a/server/signer_vault.go +++ b/server/signer_vault.go @@ -15,6 +15,7 @@ import ( "encoding/pem" "fmt" "hash" + "os" "github.com/go-jose/go-jose/v4" vault "github.com/openbao/openbao/api/v2" @@ -27,6 +28,37 @@ type VaultSignerConfig struct { KeyName string `json:"keyName"` } +// UnmarshalJSON unmarshals a VaultSignerConfig and applies environment variables. +// If Addr or Token are not provided in the config, they are read from VAULT_ADDR +// and VAULT_TOKEN environment variables respectively. +func (c *VaultSignerConfig) UnmarshalJSON(data []byte) error { + type Alias VaultSignerConfig + aux := &struct { + *Alias + }{ + Alias: (*Alias)(c), + } + + if err := json.Unmarshal(data, &aux); err != nil { + return err + } + + // Apply environment variables if config values are empty + if c.Addr == "" { + if addr := os.Getenv("VAULT_ADDR"); addr != "" { + c.Addr = addr + } + } + + if c.Token == "" { + if token := os.Getenv("VAULT_TOKEN"); token != "" { + c.Token = token + } + } + + return nil +} + // vaultSigner signs payloads using HashiCorp Vault's Transit backend. type vaultSigner struct { client *vault.Client diff --git a/server/signer_vault_test.go b/server/signer_vault_test.go new file mode 100644 index 00000000..050a672c --- /dev/null +++ b/server/signer_vault_test.go @@ -0,0 +1,196 @@ +package server + +import ( + "encoding/json" + "os" + "testing" +) + +func TestVaultSignerConfigUnmarshalJSON_WithEnvVars(t *testing.T) { + // Save original environment variables + originalAddr := os.Getenv("VAULT_ADDR") + originalToken := os.Getenv("VAULT_TOKEN") + defer func() { + os.Setenv("VAULT_ADDR", originalAddr) + os.Setenv("VAULT_TOKEN", originalToken) + }() + + // Set environment variables + os.Setenv("VAULT_ADDR", "http://vault.example.com:8200") + os.Setenv("VAULT_TOKEN", "s.xxxxxxxxxxxxxxxx") + + tests := []struct { + name string + json string + want VaultSignerConfig + wantErr bool + }{ + { + name: "empty config uses env vars", + json: `{"keyName": "signing-key"}`, + want: VaultSignerConfig{ + Addr: "http://vault.example.com:8200", + Token: "s.xxxxxxxxxxxxxxxx", + KeyName: "signing-key", + }, + wantErr: false, + }, + { + name: "config values override env vars", + json: `{"addr": "http://custom.vault.com:8200", "token": "s.custom", "keyName": "signing-key"}`, + want: VaultSignerConfig{ + Addr: "http://custom.vault.com:8200", + Token: "s.custom", + KeyName: "signing-key", + }, + wantErr: false, + }, + { + name: "partial config uses env vars for missing values", + json: `{"addr": "http://custom.vault.com:8200", "keyName": "signing-key"}`, + want: VaultSignerConfig{ + Addr: "http://custom.vault.com:8200", + Token: "s.xxxxxxxxxxxxxxxx", + KeyName: "signing-key", + }, + wantErr: false, + }, + { + name: "empty token in config uses env var", + json: `{"addr": "http://custom.vault.com:8200", "token": "", "keyName": "signing-key"}`, + want: VaultSignerConfig{ + Addr: "http://custom.vault.com:8200", + Token: "s.xxxxxxxxxxxxxxxx", + KeyName: "signing-key", + }, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var got VaultSignerConfig + err := json.Unmarshal([]byte(tt.json), &got) + + if (err != nil) != tt.wantErr { + t.Errorf("UnmarshalJSON() error = %v, wantErr %v", err, tt.wantErr) + return + } + + if got.Addr != tt.want.Addr { + t.Errorf("Addr: got %q, want %q", got.Addr, tt.want.Addr) + } + if got.Token != tt.want.Token { + t.Errorf("Token: got %q, want %q", got.Token, tt.want.Token) + } + if got.KeyName != tt.want.KeyName { + t.Errorf("KeyName: got %q, want %q", got.KeyName, tt.want.KeyName) + } + }) + } +} + +func TestVaultSignerConfigUnmarshalJSON_WithoutEnvVars(t *testing.T) { + // Save original environment variables + originalAddr := os.Getenv("VAULT_ADDR") + originalToken := os.Getenv("VAULT_TOKEN") + defer func() { + os.Setenv("VAULT_ADDR", originalAddr) + os.Setenv("VAULT_TOKEN", originalToken) + }() + + // Unset environment variables + os.Unsetenv("VAULT_ADDR") + os.Unsetenv("VAULT_TOKEN") + + tests := []struct { + name string + json string + want VaultSignerConfig + wantErr bool + }{ + { + name: "config values used when env vars not set", + json: `{"addr": "http://vault.example.com:8200", "token": "s.xxxxxxxxxxxxxxxx", "keyName": "signing-key"}`, + want: VaultSignerConfig{ + Addr: "http://vault.example.com:8200", + Token: "s.xxxxxxxxxxxxxxxx", + KeyName: "signing-key", + }, + wantErr: false, + }, + { + name: "empty config when env vars not set", + json: `{"keyName": "signing-key"}`, + want: VaultSignerConfig{ + Addr: "", + Token: "", + KeyName: "signing-key", + }, + wantErr: false, + }, + { + name: "only keyName required in config", + json: `{"keyName": "my-key"}`, + want: VaultSignerConfig{ + Addr: "", + Token: "", + KeyName: "my-key", + }, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var got VaultSignerConfig + err := json.Unmarshal([]byte(tt.json), &got) + + if (err != nil) != tt.wantErr { + t.Errorf("UnmarshalJSON() error = %v, wantErr %v", err, tt.wantErr) + return + } + + if got.Addr != tt.want.Addr { + t.Errorf("Addr: got %q, want %q", got.Addr, tt.want.Addr) + } + if got.Token != tt.want.Token { + t.Errorf("Token: got %q, want %q", got.Token, tt.want.Token) + } + if got.KeyName != tt.want.KeyName { + t.Errorf("KeyName: got %q, want %q", got.KeyName, tt.want.KeyName) + } + }) + } +} + +func TestVaultSignerConfigUnmarshalJSON_InvalidJSON(t *testing.T) { + tests := []struct { + name string + json string + wantErr bool + }{ + { + name: "invalid json", + json: `{invalid json}`, + wantErr: true, + }, + { + name: "empty json", + json: `{}`, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var got VaultSignerConfig + err := json.Unmarshal([]byte(tt.json), &got) + + if (err != nil) != tt.wantErr { + t.Errorf("UnmarshalJSON() error = %v, wantErr %v", err, tt.wantErr) + return + } + }) + } +}