Browse Source

Allow server startup with partial connector failures (#4159)

Signed-off-by: Manoj Vivek <p.manoj.vivek@gmail.com>
pull/4175/head
Manoj Vivek 9 months ago committed by GitHub
parent
commit
87ec9e077e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 35
      cmd/dex/serve.go
  2. 3
      pkg/featureflags/set.go
  3. 4
      server/handlers.go
  4. 14
      server/server.go
  5. 202
      server/server_test.go

35
cmd/dex/serve.go

@ -35,6 +35,7 @@ import (
"google.golang.org/grpc/reflection"
"github.com/dexidp/dex/api/v2"
"github.com/dexidp/dex/pkg/featureflags"
"github.com/dexidp/dex/server"
"github.com/dexidp/dex/storage"
)
@ -280,6 +281,9 @@ func runServe(options serveOptions) error {
if len(c.Web.AllowedOrigins) > 0 {
logger.Info("config allowed origins", "origins", c.Web.AllowedOrigins)
}
if featureflags.ContinueOnConnectorFailure.Enabled() {
logger.Info("continue on connector failure feature flag enabled")
}
// explicitly convert to UTC.
now := func() time.Time { return time.Now().UTC() }
@ -287,21 +291,22 @@ func runServe(options serveOptions) error {
healthChecker := gosundheit.New()
serverConfig := server.Config{
AllowedGrantTypes: c.OAuth2.GrantTypes,
SupportedResponseTypes: c.OAuth2.ResponseTypes,
SkipApprovalScreen: c.OAuth2.SkipApprovalScreen,
AlwaysShowLoginScreen: c.OAuth2.AlwaysShowLoginScreen,
PasswordConnector: c.OAuth2.PasswordConnector,
Headers: c.Web.Headers.ToHTTPHeader(),
AllowedOrigins: c.Web.AllowedOrigins,
AllowedHeaders: c.Web.AllowedHeaders,
Issuer: c.Issuer,
Storage: s,
Web: c.Frontend,
Logger: logger,
Now: now,
PrometheusRegistry: prometheusRegistry,
HealthChecker: healthChecker,
AllowedGrantTypes: c.OAuth2.GrantTypes,
SupportedResponseTypes: c.OAuth2.ResponseTypes,
SkipApprovalScreen: c.OAuth2.SkipApprovalScreen,
AlwaysShowLoginScreen: c.OAuth2.AlwaysShowLoginScreen,
PasswordConnector: c.OAuth2.PasswordConnector,
Headers: c.Web.Headers.ToHTTPHeader(),
AllowedOrigins: c.Web.AllowedOrigins,
AllowedHeaders: c.Web.AllowedHeaders,
Issuer: c.Issuer,
Storage: s,
Web: c.Frontend,
Logger: logger,
Now: now,
PrometheusRegistry: prometheusRegistry,
HealthChecker: healthChecker,
ContinueOnConnectorFailure: featureflags.ContinueOnConnectorFailure.Enabled(),
}
if c.Expiry.SigningKeys != "" {
signingKeys, err := time.ParseDuration(c.Expiry.SigningKeys)

3
pkg/featureflags/set.go

@ -11,4 +11,7 @@ var (
// APIConnectorsCRUD allows CRUD operations on connectors through the gRPC API
APIConnectorsCRUD = newFlag("api_connectors_crud", false)
// ContinueOnConnectorFailure allows the server to start even if some connectors fail to initialize.
ContinueOnConnectorFailure = newFlag("continue_on_connector_failure", false)
)

4
server/handlers.go

@ -223,7 +223,7 @@ func (s *Server) handleConnectorLogin(w http.ResponseWriter, r *http.Request) {
conn, err := s.getConnector(ctx, connID)
if err != nil {
s.logger.ErrorContext(r.Context(), "Failed to get connector", "err", err)
s.renderError(r, w, http.StatusBadRequest, "Requested resource does not exist")
s.renderError(r, w, http.StatusBadRequest, "Connector failed to initialize")
return
}
@ -350,7 +350,7 @@ func (s *Server) handlePasswordLogin(w http.ResponseWriter, r *http.Request) {
conn, err := s.getConnector(ctx, authReq.ConnectorID)
if err != nil {
s.logger.ErrorContext(r.Context(), "failed to get connector", "connector_id", authReq.ConnectorID, "err", err)
s.renderError(r, w, http.StatusInternalServerError, "Requested resource does not exist.")
s.renderError(r, w, http.StatusInternalServerError, "Connector failed to initialize.")
return
}

14
server/server.go

@ -119,6 +119,10 @@ type Config struct {
PrometheusRegistry *prometheus.Registry
HealthChecker gosundheit.Health
// If enabled, the server will continue starting even if some connectors fail to initialize.
// This allows the server to operate with a subset of connectors if some are misconfigured.
ContinueOnConnectorFailure bool
}
// WebConfig holds the server's frontend templates and asset configuration.
@ -325,12 +329,22 @@ func newServer(ctx context.Context, c Config, rotationStrategy rotationStrategy)
return nil, errors.New("server: no connectors specified")
}
var failedCount int
for _, conn := range storageConnectors {
if _, err := s.OpenConnector(conn); err != nil {
failedCount++
if c.ContinueOnConnectorFailure {
s.logger.Error("server: Failed to open connector", "id", conn.ID, "err", err)
continue
}
return nil, fmt.Errorf("server: Failed to open connector %s: %v", conn.ID, err)
}
}
if c.ContinueOnConnectorFailure && failedCount == len(storageConnectors) {
return nil, fmt.Errorf("server: failed to open all connectors (%d/%d)", failedCount, len(storageConnectors))
}
instrumentHandler := func(_ string, handler http.Handler) http.HandlerFunc {
return handler.ServeHTTP
}

202
server/server_test.go

@ -1816,3 +1816,205 @@ func TestHeaders(t *testing.T) {
require.Equal(t, "max-age=31536000; includeSubDomains", resp.Header.Get("Strict-Transport-Security"))
}
func TestConnectorFailureHandling(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
tests := []struct {
name string
connectors []storage.Connector
continueOnConnectorFailure bool
wantErr bool
wantErrContains string
expectConnectors []string // IDs of connectors that should be loaded successfully
}{
{
name: "all connectors succeed with flag enabled",
connectors: []storage.Connector{
{
ID: "mock1",
Type: "mockCallback",
Name: "Mock1",
},
{
ID: "mock2",
Type: "mockCallback",
Name: "Mock2",
},
},
continueOnConnectorFailure: true,
wantErr: false,
expectConnectors: []string{"mock1", "mock2"},
},
{
name: "all connectors succeed with flag disabled",
connectors: []storage.Connector{
{
ID: "mock1",
Type: "mockCallback",
Name: "Mock1",
},
{
ID: "mock2",
Type: "mockCallback",
Name: "Mock2",
},
},
continueOnConnectorFailure: false,
wantErr: false,
expectConnectors: []string{"mock1", "mock2"},
},
{
name: "partial connector failure with flag enabled",
connectors: []storage.Connector{
{
ID: "mock-good",
Type: "mockCallback",
Name: "Good Mock",
},
{
ID: "bad-connector",
Type: "nonexistent",
Name: "Bad Connector",
},
{
ID: "mock-good2",
Type: "mockCallback",
Name: "Good Mock 2",
},
},
continueOnConnectorFailure: true,
wantErr: false,
expectConnectors: []string{"mock-good", "mock-good2"},
},
{
name: "partial connector failure with flag disabled",
connectors: []storage.Connector{
{
ID: "mock-good",
Type: "mockCallback",
Name: "Good Mock",
},
{
ID: "bad-connector",
Type: "nonexistent",
Name: "Bad Connector",
},
{
ID: "mock-good2",
Type: "mockCallback",
Name: "Good Mock 2",
},
},
continueOnConnectorFailure: false,
wantErr: true,
wantErrContains: "Failed to open connector bad-connector",
expectConnectors: []string{}, // Server creation should fail
},
{
name: "all connectors fail with flag enabled",
connectors: []storage.Connector{
{
ID: "bad1",
Type: "nonexistent1",
Name: "Bad 1",
},
{
ID: "bad2",
Type: "nonexistent2",
Name: "Bad 2",
},
},
continueOnConnectorFailure: true,
wantErr: true,
wantErrContains: "failed to open all connectors (2/2)",
},
{
name: "all connectors fail with flag disabled",
connectors: []storage.Connector{
{
ID: "bad1",
Type: "nonexistent1",
Name: "Bad 1",
},
{
ID: "bad2",
Type: "nonexistent2",
Name: "Bad 2",
},
},
continueOnConnectorFailure: false,
wantErr: true,
wantErrContains: "Failed to open connector",
},
{
name: "no connectors",
connectors: []storage.Connector{},
continueOnConnectorFailure: true,
wantErr: true,
wantErrContains: "no connectors specified",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
config := Config{
Issuer: "http://localhost",
Storage: memory.New(logger),
Web: WebConfig{
Dir: "../web",
},
Logger: logger,
PrometheusRegistry: prometheus.NewRegistry(),
HealthChecker: gosundheit.New(),
ContinueOnConnectorFailure: tc.continueOnConnectorFailure,
}
// Create connectors in storage
for _, conn := range tc.connectors {
if err := config.Storage.CreateConnector(ctx, conn); err != nil {
t.Fatalf("failed to create connector: %v", err)
}
}
server, err := newServer(ctx, config, staticRotationStrategy(testKey))
if tc.wantErr {
if err == nil {
t.Errorf("expected error but got none")
} else if tc.wantErrContains != "" && !strings.Contains(err.Error(), tc.wantErrContains) {
t.Errorf("expected error containing %q, got %q", tc.wantErrContains, err.Error())
}
} else {
if err != nil {
t.Errorf("unexpected error: %v", err)
} else {
// Verify expected connectors are loaded
for _, id := range tc.expectConnectors {
if _, exists := server.connectors[id]; !exists {
t.Errorf("expected connector %q to be loaded", id)
}
}
// Verify failed connectors are not loaded
for _, conn := range tc.connectors {
_, shouldExist := false, false
for _, expectedID := range tc.expectConnectors {
if conn.ID == expectedID {
shouldExist = true
break
}
}
_, exists := server.connectors[conn.ID]
if shouldExist && !exists {
t.Errorf("connector %q should have been loaded but wasn't", conn.ID)
} else if !shouldExist && exists {
t.Errorf("connector %q should not have been loaded but was", conn.ID)
}
}
}
}
})
}
}

Loading…
Cancel
Save