diff --git a/cmd/dex/serve.go b/cmd/dex/serve.go index cb675869..a01afbb4 100644 --- a/cmd/dex/serve.go +++ b/cmd/dex/serve.go @@ -32,6 +32,8 @@ import ( "github.com/spf13/cobra" "google.golang.org/grpc" "google.golang.org/grpc/credentials" + "google.golang.org/grpc/health" + healthgrpc "google.golang.org/grpc/health/grpc_health_v1" "google.golang.org/grpc/reflection" "github.com/dexidp/dex/api/v2" @@ -61,6 +63,11 @@ var buildInfo = prometheus.NewGaugeVec( []string{"version", "go_version", "platform"}, ) +var ( + healthCheckPeriod = 15 * time.Second + shutdownTimeout = 1 * time.Minute +) + func commandServe() *cobra.Command { options := serveOptions{} @@ -442,7 +449,7 @@ func runServe(options serveOptions) error { CheckName: "storage", CheckFunc: storage.NewCustomHealthCheckFunc(serverConfig.Storage, serverConfig.Now), }, - gosundheit.ExecutionPeriod(15*time.Second), + gosundheit.ExecutionPeriod(healthCheckPeriod), gosundheit.InitiallyPassing(true), ) @@ -471,7 +478,7 @@ func runServe(options serveOptions) error { group.Add(func() error { return server.Serve(l) }, func(err error) { - ctx, cancel := context.WithTimeout(context.Background(), time.Minute) + ctx, cancel := context.WithTimeout(context.Background(), shutdownTimeout) defer cancel() logger.Debug("starting graceful shutdown", "server", name) @@ -500,7 +507,7 @@ func runServe(options serveOptions) error { group.Add(func() error { return server.Serve(l) }, func(err error) { - ctx, cancel := context.WithTimeout(context.Background(), time.Minute) + ctx, cancel := context.WithTimeout(context.Background(), shutdownTimeout) defer cancel() logger.Debug("starting graceful shutdown", "server", name) @@ -551,7 +558,7 @@ func runServe(options serveOptions) error { group.Add(func() error { return server.ServeTLS(l, "", "") }, func(err error) { - ctx, cancel := context.WithTimeout(context.Background(), time.Minute) + ctx, cancel := context.WithTimeout(context.Background(), shutdownTimeout) defer cancel() logger.Debug("starting graceful shutdown", "server", name) @@ -571,6 +578,8 @@ func runServe(options serveOptions) error { } grpcSrv := grpc.NewServer(grpcOptions...) + healthcheck := health.NewServer() + healthgrpc.RegisterHealthServer(grpcSrv, healthcheck) api.RegisterDexServer(grpcSrv, server.NewAPI(serverConfig.Storage, logger, version, serv)) grpcMetrics.InitializeMetrics(grpcSrv) @@ -578,12 +587,44 @@ func runServe(options serveOptions) error { logger.Info("enabling reflection in grpc service") reflection.Register(grpcSrv) } + ctx, cancelHealthcheck := context.WithCancel(context.Background()) + defer cancelHealthcheck() + + group.Add(func() error { + setHealthCheckStatus(healthcheck, healthChecker) + for { + select { + case <-ctx.Done(): + logger.Debug("stopping health check status updater", "server", "grpc") + return nil + case <-time.After(healthCheckPeriod): + setHealthCheckStatus(healthcheck, healthChecker) + } + } + }, func(err error) { + logger.Debug("stopped health check status updater", "server", "grpc") + }) group.Add(func() error { return grpcSrv.Serve(grpcListener) }, func(err error) { logger.Debug("starting graceful shutdown", "server", "grpc") - grpcSrv.GracefulStop() + ctx, cancel := context.WithTimeout(context.Background(), shutdownTimeout) + + go func() { + healthcheck.Shutdown() + cancelHealthcheck() + grpcSrv.GracefulStop() + cancel() + }() + + <-ctx.Done() + if ctx.Err() != nil { + logger.Debug("Graceful shutdown timed out. forcing shutdown", "server", "grpc") + grpcSrv.Stop() + } else { + logger.Debug("Graceful shutdown completed", "server", "grpc") + } }) } @@ -597,6 +638,16 @@ func runServe(options serveOptions) error { return nil } +func setHealthCheckStatus(healthServer *health.Server, healthChecker gosundheit.Health) { + var status healthgrpc.HealthCheckResponse_ServingStatus + if healthChecker.IsHealthy() { + status = healthgrpc.HealthCheckResponse_SERVING + } else { + status = healthgrpc.HealthCheckResponse_NOT_SERVING + } + healthServer.SetServingStatus("", status) +} + func applyConfigOverrides(options serveOptions, config *Config) { if options.webHTTPAddr != "" { config.Web.HTTP = options.webHTTPAddr