From 87ec9e077ea541bfba9f1c2160aac109537c15f4 Mon Sep 17 00:00:00 2001 From: Manoj Vivek Date: Mon, 16 Jun 2025 19:23:20 +0530 Subject: [PATCH] Allow server startup with partial connector failures (#4159) Signed-off-by: Manoj Vivek --- cmd/dex/serve.go | 35 ++++--- pkg/featureflags/set.go | 3 + server/handlers.go | 4 +- server/server.go | 14 +++ server/server_test.go | 202 ++++++++++++++++++++++++++++++++++++++++ 5 files changed, 241 insertions(+), 17 deletions(-) diff --git a/cmd/dex/serve.go b/cmd/dex/serve.go index 8a69c7ee..ac715e60 100644 --- a/cmd/dex/serve.go +++ b/cmd/dex/serve.go @@ -35,6 +35,7 @@ import ( "google.golang.org/grpc/reflection" "github.com/dexidp/dex/api/v2" + "github.com/dexidp/dex/pkg/featureflags" "github.com/dexidp/dex/server" "github.com/dexidp/dex/storage" ) @@ -280,6 +281,9 @@ func runServe(options serveOptions) error { if len(c.Web.AllowedOrigins) > 0 { logger.Info("config allowed origins", "origins", c.Web.AllowedOrigins) } + if featureflags.ContinueOnConnectorFailure.Enabled() { + logger.Info("continue on connector failure feature flag enabled") + } // explicitly convert to UTC. now := func() time.Time { return time.Now().UTC() } @@ -287,21 +291,22 @@ func runServe(options serveOptions) error { healthChecker := gosundheit.New() serverConfig := server.Config{ - AllowedGrantTypes: c.OAuth2.GrantTypes, - SupportedResponseTypes: c.OAuth2.ResponseTypes, - SkipApprovalScreen: c.OAuth2.SkipApprovalScreen, - AlwaysShowLoginScreen: c.OAuth2.AlwaysShowLoginScreen, - PasswordConnector: c.OAuth2.PasswordConnector, - Headers: c.Web.Headers.ToHTTPHeader(), - AllowedOrigins: c.Web.AllowedOrigins, - AllowedHeaders: c.Web.AllowedHeaders, - Issuer: c.Issuer, - Storage: s, - Web: c.Frontend, - Logger: logger, - Now: now, - PrometheusRegistry: prometheusRegistry, - HealthChecker: healthChecker, + AllowedGrantTypes: c.OAuth2.GrantTypes, + SupportedResponseTypes: c.OAuth2.ResponseTypes, + SkipApprovalScreen: c.OAuth2.SkipApprovalScreen, + AlwaysShowLoginScreen: c.OAuth2.AlwaysShowLoginScreen, + PasswordConnector: c.OAuth2.PasswordConnector, + Headers: c.Web.Headers.ToHTTPHeader(), + AllowedOrigins: c.Web.AllowedOrigins, + AllowedHeaders: c.Web.AllowedHeaders, + Issuer: c.Issuer, + Storage: s, + Web: c.Frontend, + Logger: logger, + Now: now, + PrometheusRegistry: prometheusRegistry, + HealthChecker: healthChecker, + ContinueOnConnectorFailure: featureflags.ContinueOnConnectorFailure.Enabled(), } if c.Expiry.SigningKeys != "" { signingKeys, err := time.ParseDuration(c.Expiry.SigningKeys) diff --git a/pkg/featureflags/set.go b/pkg/featureflags/set.go index a86c4fa9..bd86f1e7 100644 --- a/pkg/featureflags/set.go +++ b/pkg/featureflags/set.go @@ -11,4 +11,7 @@ var ( // APIConnectorsCRUD allows CRUD operations on connectors through the gRPC API APIConnectorsCRUD = newFlag("api_connectors_crud", false) + + // ContinueOnConnectorFailure allows the server to start even if some connectors fail to initialize. + ContinueOnConnectorFailure = newFlag("continue_on_connector_failure", false) ) diff --git a/server/handlers.go b/server/handlers.go index c87038cc..f8d0ed64 100644 --- a/server/handlers.go +++ b/server/handlers.go @@ -223,7 +223,7 @@ func (s *Server) handleConnectorLogin(w http.ResponseWriter, r *http.Request) { conn, err := s.getConnector(ctx, connID) if err != nil { s.logger.ErrorContext(r.Context(), "Failed to get connector", "err", err) - s.renderError(r, w, http.StatusBadRequest, "Requested resource does not exist") + s.renderError(r, w, http.StatusBadRequest, "Connector failed to initialize") return } @@ -350,7 +350,7 @@ func (s *Server) handlePasswordLogin(w http.ResponseWriter, r *http.Request) { conn, err := s.getConnector(ctx, authReq.ConnectorID) if err != nil { s.logger.ErrorContext(r.Context(), "failed to get connector", "connector_id", authReq.ConnectorID, "err", err) - s.renderError(r, w, http.StatusInternalServerError, "Requested resource does not exist.") + s.renderError(r, w, http.StatusInternalServerError, "Connector failed to initialize.") return } diff --git a/server/server.go b/server/server.go index 8c046296..0f48fc11 100644 --- a/server/server.go +++ b/server/server.go @@ -119,6 +119,10 @@ type Config struct { PrometheusRegistry *prometheus.Registry HealthChecker gosundheit.Health + + // If enabled, the server will continue starting even if some connectors fail to initialize. + // This allows the server to operate with a subset of connectors if some are misconfigured. + ContinueOnConnectorFailure bool } // WebConfig holds the server's frontend templates and asset configuration. @@ -325,12 +329,22 @@ func newServer(ctx context.Context, c Config, rotationStrategy rotationStrategy) return nil, errors.New("server: no connectors specified") } + var failedCount int for _, conn := range storageConnectors { if _, err := s.OpenConnector(conn); err != nil { + failedCount++ + if c.ContinueOnConnectorFailure { + s.logger.Error("server: Failed to open connector", "id", conn.ID, "err", err) + continue + } return nil, fmt.Errorf("server: Failed to open connector %s: %v", conn.ID, err) } } + if c.ContinueOnConnectorFailure && failedCount == len(storageConnectors) { + return nil, fmt.Errorf("server: failed to open all connectors (%d/%d)", failedCount, len(storageConnectors)) + } + instrumentHandler := func(_ string, handler http.Handler) http.HandlerFunc { return handler.ServeHTTP } diff --git a/server/server_test.go b/server/server_test.go index cb4f491b..c414eb88 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -1816,3 +1816,205 @@ func TestHeaders(t *testing.T) { require.Equal(t, "max-age=31536000; includeSubDomains", resp.Header.Get("Strict-Transport-Security")) } + +func TestConnectorFailureHandling(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + tests := []struct { + name string + connectors []storage.Connector + continueOnConnectorFailure bool + wantErr bool + wantErrContains string + expectConnectors []string // IDs of connectors that should be loaded successfully + }{ + { + name: "all connectors succeed with flag enabled", + connectors: []storage.Connector{ + { + ID: "mock1", + Type: "mockCallback", + Name: "Mock1", + }, + { + ID: "mock2", + Type: "mockCallback", + Name: "Mock2", + }, + }, + continueOnConnectorFailure: true, + wantErr: false, + expectConnectors: []string{"mock1", "mock2"}, + }, + { + name: "all connectors succeed with flag disabled", + connectors: []storage.Connector{ + { + ID: "mock1", + Type: "mockCallback", + Name: "Mock1", + }, + { + ID: "mock2", + Type: "mockCallback", + Name: "Mock2", + }, + }, + continueOnConnectorFailure: false, + wantErr: false, + expectConnectors: []string{"mock1", "mock2"}, + }, + { + name: "partial connector failure with flag enabled", + connectors: []storage.Connector{ + { + ID: "mock-good", + Type: "mockCallback", + Name: "Good Mock", + }, + { + ID: "bad-connector", + Type: "nonexistent", + Name: "Bad Connector", + }, + { + ID: "mock-good2", + Type: "mockCallback", + Name: "Good Mock 2", + }, + }, + continueOnConnectorFailure: true, + wantErr: false, + expectConnectors: []string{"mock-good", "mock-good2"}, + }, + { + name: "partial connector failure with flag disabled", + connectors: []storage.Connector{ + { + ID: "mock-good", + Type: "mockCallback", + Name: "Good Mock", + }, + { + ID: "bad-connector", + Type: "nonexistent", + Name: "Bad Connector", + }, + { + ID: "mock-good2", + Type: "mockCallback", + Name: "Good Mock 2", + }, + }, + continueOnConnectorFailure: false, + wantErr: true, + wantErrContains: "Failed to open connector bad-connector", + expectConnectors: []string{}, // Server creation should fail + }, + { + name: "all connectors fail with flag enabled", + connectors: []storage.Connector{ + { + ID: "bad1", + Type: "nonexistent1", + Name: "Bad 1", + }, + { + ID: "bad2", + Type: "nonexistent2", + Name: "Bad 2", + }, + }, + continueOnConnectorFailure: true, + wantErr: true, + wantErrContains: "failed to open all connectors (2/2)", + }, + { + name: "all connectors fail with flag disabled", + connectors: []storage.Connector{ + { + ID: "bad1", + Type: "nonexistent1", + Name: "Bad 1", + }, + { + ID: "bad2", + Type: "nonexistent2", + Name: "Bad 2", + }, + }, + continueOnConnectorFailure: false, + wantErr: true, + wantErrContains: "Failed to open connector", + }, + { + name: "no connectors", + connectors: []storage.Connector{}, + continueOnConnectorFailure: true, + wantErr: true, + wantErrContains: "no connectors specified", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + config := Config{ + Issuer: "http://localhost", + Storage: memory.New(logger), + Web: WebConfig{ + Dir: "../web", + }, + Logger: logger, + PrometheusRegistry: prometheus.NewRegistry(), + HealthChecker: gosundheit.New(), + ContinueOnConnectorFailure: tc.continueOnConnectorFailure, + } + + // Create connectors in storage + for _, conn := range tc.connectors { + if err := config.Storage.CreateConnector(ctx, conn); err != nil { + t.Fatalf("failed to create connector: %v", err) + } + } + + server, err := newServer(ctx, config, staticRotationStrategy(testKey)) + + if tc.wantErr { + if err == nil { + t.Errorf("expected error but got none") + } else if tc.wantErrContains != "" && !strings.Contains(err.Error(), tc.wantErrContains) { + t.Errorf("expected error containing %q, got %q", tc.wantErrContains, err.Error()) + } + } else { + if err != nil { + t.Errorf("unexpected error: %v", err) + } else { + // Verify expected connectors are loaded + for _, id := range tc.expectConnectors { + if _, exists := server.connectors[id]; !exists { + t.Errorf("expected connector %q to be loaded", id) + } + } + + // Verify failed connectors are not loaded + for _, conn := range tc.connectors { + _, shouldExist := false, false + for _, expectedID := range tc.expectConnectors { + if conn.ID == expectedID { + shouldExist = true + break + } + } + _, exists := server.connectors[conn.ID] + if shouldExist && !exists { + t.Errorf("connector %q should have been loaded but wasn't", conn.ID) + } else if !shouldExist && exists { + t.Errorf("connector %q should not have been loaded but was", conn.ID) + } + } + } + } + }) + } +}