Browse Source

use slog for structured logging (#3502)

Signed-off-by: Sean Liao <sean+git@liao.dev>
pull/3506/head
Sean Liao 2 years ago committed by GitHub
parent
commit
0b6a78397e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 6
      cmd/dex/config.go
  2. 5
      cmd/dex/config_test.go
  3. 127
      cmd/dex/serve.go
  4. 14
      connector/atlassiancrowd/atlassiancrowd.go
  5. 9
      connector/atlassiancrowd/atlassiancrowd_test.go
  6. 8
      connector/authproxy/authproxy.go
  7. 5
      connector/authproxy/authproxy_test.go
  8. 8
      connector/bitbucketcloud/bitbucketcloud.go
  9. 8
      connector/gitea/gitea.go
  10. 12
      connector/github/github.go
  11. 5
      connector/github/github_test.go
  12. 8
      connector/gitlab/gitlab.go
  13. 19
      connector/google/google.go
  14. 5
      connector/google/google_test.go
  15. 10
      connector/keystone/keystone.go
  16. 37
      connector/ldap/ldap.go
  17. 4
      connector/ldap/ldap_test.go
  18. 8
      connector/linkedin/linkedin.go
  19. 8
      connector/microsoft/microsoft.go
  20. 13
      connector/mock/connectortest.go
  21. 8
      connector/oauth/oauth.go
  22. 5
      connector/oauth/oauth_test.go
  23. 10
      connector/oidc/oidc.go
  24. 5
      connector/oidc/oidc_test.go
  25. 10
      connector/openshift/openshift.go
  26. 5
      connector/openshift/openshift_test.go
  27. 12
      connector/saml/saml.go
  28. 7
      connector/saml/saml_test.go
  29. 1
      go.mod
  30. 3
      go.sum
  31. 5
      pkg/log/deprecated.go
  32. 18
      pkg/log/logger.go
  33. 38
      server/api.go
  34. 31
      server/api_test.go
  35. 45
      server/deviceflowhandlers.go
  36. 175
      server/handlers.go
  37. 20
      server/introspectionhandler.go
  38. 14
      server/oauth2.go
  39. 30
      server/refreshhandlers.go
  40. 26
      server/rotation.go
  41. 16
      server/rotation_test.go
  42. 17
      server/server.go
  43. 9
      server/server_test.go
  44. 4
      storage/ent/mysql.go
  45. 9
      storage/ent/mysql_test.go
  46. 5
      storage/ent/postgres.go
  47. 9
      storage/ent/postgres_test.go
  48. 4
      storage/ent/sqlite.go
  49. 11
      storage/ent/sqlite_test.go
  50. 6
      storage/etcd/config.go
  51. 12
      storage/etcd/etcd.go
  52. 9
      storage/etcd/etcd_test.go
  53. 10
      storage/kubernetes/client.go
  54. 9
      storage/kubernetes/client_test.go
  55. 6
      storage/kubernetes/lock.go
  56. 34
      storage/kubernetes/storage.go
  57. 21
      storage/kubernetes/storage_test.go
  58. 8
      storage/memory/memory.go
  59. 11
      storage/memory/memory_test.go
  60. 23
      storage/memory/static_test.go
  61. 12
      storage/sql/config.go
  62. 13
      storage/sql/config_test.go
  63. 10
      storage/sql/migrate_test.go
  64. 5
      storage/sql/sql.go
  65. 6
      storage/sql/sqlite.go
  66. 9
      storage/static.go

6
cmd/dex/config.go

@ -4,6 +4,7 @@ import (
"encoding/base64" "encoding/base64"
"encoding/json" "encoding/json"
"fmt" "fmt"
"log/slog"
"net/http" "net/http"
"os" "os"
"strings" "strings"
@ -11,7 +12,6 @@ import (
"golang.org/x/crypto/bcrypt" "golang.org/x/crypto/bcrypt"
"github.com/dexidp/dex/pkg/featureflags" "github.com/dexidp/dex/pkg/featureflags"
"github.com/dexidp/dex/pkg/log"
"github.com/dexidp/dex/server" "github.com/dexidp/dex/server"
"github.com/dexidp/dex/storage" "github.com/dexidp/dex/storage"
"github.com/dexidp/dex/storage/ent" "github.com/dexidp/dex/storage/ent"
@ -236,7 +236,7 @@ type Storage struct {
// StorageConfig is a configuration that can create a storage. // StorageConfig is a configuration that can create a storage.
type StorageConfig interface { type StorageConfig interface {
Open(logger log.Logger) (storage.Storage, error) Open(logger *slog.Logger) (storage.Storage, error)
} }
var ( var (
@ -386,7 +386,7 @@ type Expiry struct {
// Logger holds configuration required to customize logging for dex. // Logger holds configuration required to customize logging for dex.
type Logger struct { type Logger struct {
// Level sets logging level severity. // Level sets logging level severity.
Level string `json:"level"` Level slog.Level `json:"level"`
// Format specifies the format to be used for logging. // Format specifies the format to be used for logging.
Format string `json:"format"` Format string `json:"format"`

5
cmd/dex/config_test.go

@ -1,6 +1,7 @@
package main package main
import ( import (
"log/slog"
"os" "os"
"testing" "testing"
@ -219,7 +220,7 @@ logger:
DeviceRequests: "10m", DeviceRequests: "10m",
}, },
Logger: Logger{ Logger: Logger{
Level: "debug", Level: slog.LevelDebug,
Format: "json", Format: "json",
}, },
} }
@ -426,7 +427,7 @@ logger:
AuthRequests: "25h", AuthRequests: "25h",
}, },
Logger: Logger{ Logger: Logger{
Level: "debug", Level: slog.LevelDebug,
Format: "json", Format: "json",
}, },
} }

127
cmd/dex/serve.go

@ -6,6 +6,7 @@ import (
"crypto/x509" "crypto/x509"
"errors" "errors"
"fmt" "fmt"
"log/slog"
"net" "net"
"net/http" "net/http"
"net/http/pprof" "net/http/pprof"
@ -28,14 +29,12 @@ import (
"github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/collectors" "github.com/prometheus/client_golang/prometheus/collectors"
"github.com/prometheus/client_golang/prometheus/promhttp" "github.com/prometheus/client_golang/prometheus/promhttp"
"github.com/sirupsen/logrus"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"google.golang.org/grpc" "google.golang.org/grpc"
"google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials"
"google.golang.org/grpc/reflection" "google.golang.org/grpc/reflection"
"github.com/dexidp/dex/api/v2" "github.com/dexidp/dex/api/v2"
"github.com/dexidp/dex/pkg/log"
"github.com/dexidp/dex/server" "github.com/dexidp/dex/server"
"github.com/dexidp/dex/storage" "github.com/dexidp/dex/storage"
) )
@ -98,22 +97,24 @@ func runServe(options serveOptions) error {
return fmt.Errorf("invalid config: %v", err) return fmt.Errorf("invalid config: %v", err)
} }
logger.Infof( logger.Info(
"Dex Version: %s, Go Version: %s, Go OS/ARCH: %s %s", "Version info",
version, "dex_version", version,
runtime.Version(), slog.Group("go",
runtime.GOOS, "version", runtime.Version(),
runtime.GOARCH, "os", runtime.GOOS,
"arch", runtime.GOARCH,
),
) )
if c.Logger.Level != "" { if c.Logger.Level != slog.LevelInfo {
logger.Infof("config using log level: %s", c.Logger.Level) logger.Info("config using log level", "level", c.Logger.Level)
} }
if err := c.Validate(); err != nil { if err := c.Validate(); err != nil {
return err return err
} }
logger.Infof("config issuer: %s", c.Issuer) logger.Info("config issuer", "issuer", c.Issuer)
prometheusRegistry := prometheus.NewRegistry() prometheusRegistry := prometheus.NewRegistry()
err = prometheusRegistry.Register(collectors.NewGoCollector()) err = prometheusRegistry.Register(collectors.NewGoCollector())
@ -188,7 +189,7 @@ func runServe(options serveOptions) error {
} }
defer s.Close() defer s.Close()
logger.Infof("config storage: %s", c.Storage.Type) logger.Info("config storage", "storage_type", c.Storage.Type)
if len(c.StaticClients) > 0 { if len(c.StaticClients) > 0 {
for i, client := range c.StaticClients { for i, client := range c.StaticClients {
@ -213,7 +214,7 @@ func runServe(options serveOptions) error {
} }
c.StaticClients[i].Secret = os.Getenv(client.SecretEnv) c.StaticClients[i].Secret = os.Getenv(client.SecretEnv)
} }
logger.Infof("config static client: %s", client.Name) logger.Info("config static client", "client_name", client.Name)
} }
s = storage.WithStaticClients(s, c.StaticClients) s = storage.WithStaticClients(s, c.StaticClients)
} }
@ -233,7 +234,7 @@ func runServe(options serveOptions) error {
if c.Config == nil { if c.Config == nil {
return fmt.Errorf("invalid config: no config field for connector %q", c.ID) return fmt.Errorf("invalid config: no config field for connector %q", c.ID)
} }
logger.Infof("config connector: %s", c.ID) logger.Info("config connector", "connector_id", c.ID)
// convert to a storage connector object // convert to a storage connector object
conn, err := ToStorageConnector(c) conn, err := ToStorageConnector(c)
@ -249,22 +250,22 @@ func runServe(options serveOptions) error {
Name: "Email", Name: "Email",
Type: server.LocalConnector, Type: server.LocalConnector,
}) })
logger.Infof("config connector: local passwords enabled") logger.Info("config connector: local passwords enabled")
} }
s = storage.WithStaticConnectors(s, storageConnectors) s = storage.WithStaticConnectors(s, storageConnectors)
if len(c.OAuth2.ResponseTypes) > 0 { if len(c.OAuth2.ResponseTypes) > 0 {
logger.Infof("config response types accepted: %s", c.OAuth2.ResponseTypes) logger.Info("config response types accepted", "response_types", c.OAuth2.ResponseTypes)
} }
if c.OAuth2.SkipApprovalScreen { if c.OAuth2.SkipApprovalScreen {
logger.Infof("config skipping approval screen") logger.Info("config skipping approval screen")
} }
if c.OAuth2.PasswordConnector != "" { if c.OAuth2.PasswordConnector != "" {
logger.Infof("config using password grant connector: %s", c.OAuth2.PasswordConnector) logger.Info("config using password grant connector", "password_connector", c.OAuth2.PasswordConnector)
} }
if len(c.Web.AllowedOrigins) > 0 { if len(c.Web.AllowedOrigins) > 0 {
logger.Infof("config allowed origins: %s", c.Web.AllowedOrigins) logger.Info("config allowed origins", "origins", c.Web.AllowedOrigins)
} }
// explicitly convert to UTC. // explicitly convert to UTC.
@ -294,7 +295,7 @@ func runServe(options serveOptions) error {
if err != nil { if err != nil {
return fmt.Errorf("invalid config value %q for signing keys expiry: %v", c.Expiry.SigningKeys, err) return fmt.Errorf("invalid config value %q for signing keys expiry: %v", c.Expiry.SigningKeys, err)
} }
logger.Infof("config signing keys expire after: %v", signingKeys) logger.Info("config signing keys", "expire_after", signingKeys)
serverConfig.RotateKeysAfter = signingKeys serverConfig.RotateKeysAfter = signingKeys
} }
if c.Expiry.IDTokens != "" { if c.Expiry.IDTokens != "" {
@ -302,7 +303,7 @@ func runServe(options serveOptions) error {
if err != nil { if err != nil {
return fmt.Errorf("invalid config value %q for id token expiry: %v", c.Expiry.IDTokens, err) return fmt.Errorf("invalid config value %q for id token expiry: %v", c.Expiry.IDTokens, err)
} }
logger.Infof("config id tokens valid for: %v", idTokens) logger.Info("config id tokens", "valid_for", idTokens)
serverConfig.IDTokensValidFor = idTokens serverConfig.IDTokensValidFor = idTokens
} }
if c.Expiry.AuthRequests != "" { if c.Expiry.AuthRequests != "" {
@ -310,7 +311,7 @@ func runServe(options serveOptions) error {
if err != nil { if err != nil {
return fmt.Errorf("invalid config value %q for auth request expiry: %v", c.Expiry.AuthRequests, err) return fmt.Errorf("invalid config value %q for auth request expiry: %v", c.Expiry.AuthRequests, err)
} }
logger.Infof("config auth requests valid for: %v", authRequests) logger.Info("config auth requests", "valid_for", authRequests)
serverConfig.AuthRequestsValidFor = authRequests serverConfig.AuthRequestsValidFor = authRequests
} }
if c.Expiry.DeviceRequests != "" { if c.Expiry.DeviceRequests != "" {
@ -318,7 +319,7 @@ func runServe(options serveOptions) error {
if err != nil { if err != nil {
return fmt.Errorf("invalid config value %q for device request expiry: %v", c.Expiry.AuthRequests, err) return fmt.Errorf("invalid config value %q for device request expiry: %v", c.Expiry.AuthRequests, err)
} }
logger.Infof("config device requests valid for: %v", deviceRequests) logger.Info("config device requests", "valid_for", deviceRequests)
serverConfig.DeviceRequestsValidFor = deviceRequests serverConfig.DeviceRequestsValidFor = deviceRequests
} }
refreshTokenPolicy, err := server.NewRefreshTokenPolicy( refreshTokenPolicy, err := server.NewRefreshTokenPolicy(
@ -368,7 +369,7 @@ func runServe(options serveOptions) error {
if c.Telemetry.HTTP != "" { if c.Telemetry.HTTP != "" {
const name = "telemetry" const name = "telemetry"
logger.Infof("listening (%s) on %s", name, c.Telemetry.HTTP) logger.Info("listening on", "server", name, "address", c.Telemetry.HTTP)
l, err := net.Listen("tcp", c.Telemetry.HTTP) l, err := net.Listen("tcp", c.Telemetry.HTTP)
if err != nil { if err != nil {
@ -390,9 +391,9 @@ func runServe(options serveOptions) error {
ctx, cancel := context.WithTimeout(context.Background(), time.Minute) ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
defer cancel() defer cancel()
logger.Debugf("starting graceful shutdown (%s)", name) logger.Debug("starting graceful shutdown", "server", name)
if err := server.Shutdown(ctx); err != nil { if err := server.Shutdown(ctx); err != nil {
logger.Errorf("graceful shutdown (%s): %v", name, err) logger.Error("graceful shutdown", "server", name, "err", err)
} }
}) })
} }
@ -401,7 +402,7 @@ func runServe(options serveOptions) error {
if c.Web.HTTP != "" { if c.Web.HTTP != "" {
const name = "http" const name = "http"
logger.Infof("listening (%s) on %s", name, c.Web.HTTP) logger.Info("listening on", "server", name, "address", c.Web.HTTP)
l, err := net.Listen("tcp", c.Web.HTTP) l, err := net.Listen("tcp", c.Web.HTTP)
if err != nil { if err != nil {
@ -419,9 +420,9 @@ func runServe(options serveOptions) error {
ctx, cancel := context.WithTimeout(context.Background(), time.Minute) ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
defer cancel() defer cancel()
logger.Debugf("starting graceful shutdown (%s)", name) logger.Debug("starting graceful shutdown", "server", name)
if err := server.Shutdown(ctx); err != nil { if err := server.Shutdown(ctx); err != nil {
logger.Errorf("graceful shutdown (%s): %v", name, err) logger.Error("graceful shutdown", "server", name, "err", err)
} }
}) })
} }
@ -430,7 +431,7 @@ func runServe(options serveOptions) error {
if c.Web.HTTPS != "" { if c.Web.HTTPS != "" {
const name = "https" const name = "https"
logger.Infof("listening (%s) on %s", name, c.Web.HTTPS) logger.Info("listening on", "server", name, "address", c.Web.HTTPS)
l, err := net.Listen("tcp", c.Web.HTTPS) l, err := net.Listen("tcp", c.Web.HTTPS)
if err != nil { if err != nil {
@ -470,16 +471,16 @@ func runServe(options serveOptions) error {
ctx, cancel := context.WithTimeout(context.Background(), time.Minute) ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
defer cancel() defer cancel()
logger.Debugf("starting graceful shutdown (%s)", name) logger.Debug("starting graceful shutdown", "server", name)
if err := server.Shutdown(ctx); err != nil { if err := server.Shutdown(ctx); err != nil {
logger.Errorf("graceful shutdown (%s): %v", name, err) logger.Error("graceful shutdown", "server", name, "err", err)
} }
}) })
} }
// Set up grpc server // Set up grpc server
if c.GRPC.Addr != "" { if c.GRPC.Addr != "" {
logger.Infof("listening (grpc) on %s", c.GRPC.Addr) logger.Info("listening on", "server", "grpc", "address", c.GRPC.Addr)
grpcListener, err := net.Listen("tcp", c.GRPC.Addr) grpcListener, err := net.Listen("tcp", c.GRPC.Addr)
if err != nil { if err != nil {
@ -498,7 +499,7 @@ func runServe(options serveOptions) error {
group.Add(func() error { group.Add(func() error {
return grpcSrv.Serve(grpcListener) return grpcSrv.Serve(grpcListener)
}, func(err error) { }, func(err error) {
logger.Debugf("starting graceful shutdown (grpc)") logger.Debug("starting graceful shutdown", "server", "grpc")
grpcSrv.GracefulStop() grpcSrv.GracefulStop()
}) })
} }
@ -508,53 +509,29 @@ func runServe(options serveOptions) error {
if _, ok := err.(run.SignalError); !ok { if _, ok := err.(run.SignalError); !ok {
return fmt.Errorf("run groups: %w", err) return fmt.Errorf("run groups: %w", err)
} }
logger.Infof("%v, shutdown now", err) logger.Info("shutdown now", "err", err)
} }
return nil return nil
} }
var ( var logFormats = []string{"json", "text"}
logLevels = []string{"debug", "info", "error"}
logFormats = []string{"json", "text"}
)
type utcFormatter struct {
f logrus.Formatter
}
func (f *utcFormatter) Format(e *logrus.Entry) ([]byte, error) {
e.Time = e.Time.UTC()
return f.f.Format(e)
}
func newLogger(level string, format string) (log.Logger, error) { func newLogger(level slog.Level, format string) (*slog.Logger, error) {
var logLevel logrus.Level var handler slog.Handler
switch strings.ToLower(level) {
case "debug":
logLevel = logrus.DebugLevel
case "", "info":
logLevel = logrus.InfoLevel
case "error":
logLevel = logrus.ErrorLevel
default:
return nil, fmt.Errorf("log level is not one of the supported values (%s): %s", strings.Join(logLevels, ", "), level)
}
var formatter utcFormatter
switch strings.ToLower(format) { switch strings.ToLower(format) {
case "", "text": case "", "text":
formatter.f = &logrus.TextFormatter{DisableColors: true} slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{
Level: level,
})
case "json": case "json":
formatter.f = &logrus.JSONFormatter{} slog.NewJSONHandler(os.Stderr, &slog.HandlerOptions{
Level: level,
})
default: default:
return nil, fmt.Errorf("log format is not one of the supported values (%s): %s", strings.Join(logFormats, ", "), format) return nil, fmt.Errorf("log format is not one of the supported values (%s): %s", strings.Join(logFormats, ", "), format)
} }
return &logrus.Logger{ return slog.New(handler), nil
Out: os.Stderr,
Formatter: &formatter,
Level: logLevel,
}, nil
} }
func applyConfigOverrides(options serveOptions, config *Config) { func applyConfigOverrides(options serveOptions, config *Config) {
@ -600,7 +577,7 @@ func pprofHandler(router *http.ServeMux) {
// newTLSReloader returns a [tls.Config] with GetCertificate or GetConfigForClient set // newTLSReloader returns a [tls.Config] with GetCertificate or GetConfigForClient set
// to reload certificates from the given paths on SIGHUP or on file creates (atomic update via rename). // to reload certificates from the given paths on SIGHUP or on file creates (atomic update via rename).
func newTLSReloader(logger log.Logger, certFile, keyFile, caFile string, baseConfig *tls.Config) (*tls.Config, error) { func newTLSReloader(logger *slog.Logger, certFile, keyFile, caFile string, baseConfig *tls.Config) (*tls.Config, error) {
// trigger reload on channel // trigger reload on channel
sigc := make(chan os.Signal, 1) sigc := make(chan os.Signal, 1)
signal.Notify(sigc, syscall.SIGHUP) signal.Notify(sigc, syscall.SIGHUP)
@ -631,7 +608,7 @@ func newTLSReloader(logger log.Logger, certFile, keyFile, caFile string, baseCon
// recommended by fsnotify: watch the dir to handle renames // recommended by fsnotify: watch the dir to handle renames
// https://pkg.go.dev/github.com/fsnotify/fsnotify#hdr-Watching_files // https://pkg.go.dev/github.com/fsnotify/fsnotify#hdr-Watching_files
for dir := range watchDirs { for dir := range watchDirs {
logger.Debugf("watching dir: %v", dir) logger.Debug("watching dir", "dir", dir)
err := watcher.Add(dir) err := watcher.Add(dir)
if err != nil { if err != nil {
return nil, fmt.Errorf("watch dir for TLS reloader: %v", err) return nil, fmt.Errorf("watch dir for TLS reloader: %v", err)
@ -654,19 +631,19 @@ func newTLSReloader(logger log.Logger, certFile, keyFile, caFile string, baseCon
for { for {
select { select {
case sig := <-sigc: case sig := <-sigc:
logger.Debug("reloading cert from signal: %v", sig) logger.Debug("reloading cert from signal", "signal", sig)
case evt := <-watcher.Events: case evt := <-watcher.Events:
if _, ok := watchFiles[evt.Name]; !ok || !evt.Has(fsnotify.Create) { if _, ok := watchFiles[evt.Name]; !ok || !evt.Has(fsnotify.Create) {
continue loop continue loop
} }
logger.Debug("reloading cert from fsnotify: %v %v", evt.Name, evt.Op.String()) logger.Debug("reloading cert from fsnotify", "event", evt.Name, "operation", evt.Op.String())
case err := <-watcher.Errors: case err := <-watcher.Errors:
logger.Errorf("TLS reloader watch: %v", err) logger.Error("TLS reloader watch", "err", err)
} }
loaded, err := loadTLSConfig(certFile, keyFile, caFile, baseConfig) loaded, err := loadTLSConfig(certFile, keyFile, caFile, baseConfig)
if err != nil { if err != nil {
logger.Errorf("reload TLS config: %v", err) logger.Error("reload TLS config", "err", err)
} }
ptr.Store(loaded) ptr.Store(loaded)
} }

14
connector/atlassiancrowd/atlassiancrowd.go

@ -7,6 +7,7 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io"
"log/slog"
"net" "net"
"net/http" "net/http"
"strings" "strings"
@ -14,7 +15,6 @@ import (
"github.com/dexidp/dex/connector" "github.com/dexidp/dex/connector"
"github.com/dexidp/dex/pkg/groups" "github.com/dexidp/dex/pkg/groups"
"github.com/dexidp/dex/pkg/log"
) )
// Config holds configuration options for Atlassian Crowd connector. // Config holds configuration options for Atlassian Crowd connector.
@ -80,16 +80,16 @@ type crowdAuthenticationError struct {
} }
// Open returns a strategy for logging in through Atlassian Crowd // Open returns a strategy for logging in through Atlassian Crowd
func (c *Config) Open(_ string, logger log.Logger) (connector.Connector, error) { func (c *Config) Open(id string, logger *slog.Logger) (connector.Connector, error) {
if c.BaseURL == "" { if c.BaseURL == "" {
return nil, fmt.Errorf("crowd: no baseURL provided for crowd connector") return nil, fmt.Errorf("crowd: no baseURL provided for crowd connector")
} }
return &crowdConnector{Config: *c, logger: logger}, nil return &crowdConnector{Config: *c, logger: logger.With(slog.Group("connector", "type", "atlassiancrowd", "id", id))}, nil
} }
type crowdConnector struct { type crowdConnector struct {
Config Config
logger log.Logger logger *slog.Logger
} }
var ( var (
@ -375,7 +375,7 @@ func (c *crowdConnector) identityFromCrowdUser(user crowdUser) connector.Identit
identity.PreferredUsername = user.Email identity.PreferredUsername = user.Email
default: default:
if c.PreferredUsernameField != "" { if c.PreferredUsernameField != "" {
c.logger.Warnf("preferred_username left empty. Invalid crowd field mapped to preferred_username: %s", c.PreferredUsernameField) c.logger.Warn("preferred_username left empty. Invalid crowd field mapped to preferred_username", "field", c.PreferredUsernameField)
} }
} }
@ -436,12 +436,12 @@ func (c *crowdConnector) validateCrowdResponse(resp *http.Response) ([]byte, err
} }
if resp.StatusCode == http.StatusForbidden && strings.Contains(string(body), "The server understood the request but refuses to authorize it.") { if resp.StatusCode == http.StatusForbidden && strings.Contains(string(body), "The server understood the request but refuses to authorize it.") {
c.logger.Debugf("crowd response validation failed: %s", string(body)) c.logger.Debug("crowd response validation failed", "response", string(body))
return nil, fmt.Errorf("dex is forbidden from making requests to the Atlassian Crowd application by URL %q", c.BaseURL) return nil, fmt.Errorf("dex is forbidden from making requests to the Atlassian Crowd application by URL %q", c.BaseURL)
} }
if resp.StatusCode == http.StatusUnauthorized && string(body) == "Application failed to authenticate" { if resp.StatusCode == http.StatusUnauthorized && string(body) == "Application failed to authenticate" {
c.logger.Debugf("crowd response validation failed: %s", string(body)) c.logger.Debug("crowd response validation failed", "response", string(body))
return nil, fmt.Errorf("dex failed to authenticate Crowd Application with ID %q", c.ClientID) return nil, fmt.Errorf("dex failed to authenticate Crowd Application with ID %q", c.ClientID)
} }
return body, nil return body, nil

9
connector/atlassiancrowd/atlassiancrowd_test.go

@ -7,12 +7,11 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io"
"log/slog"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"reflect" "reflect"
"testing" "testing"
"github.com/sirupsen/logrus"
) )
func TestUserGroups(t *testing.T) { func TestUserGroups(t *testing.T) {
@ -151,11 +150,7 @@ type TestServerResponse struct {
func newTestCrowdConnector(baseURL string) crowdConnector { func newTestCrowdConnector(baseURL string) crowdConnector {
connector := crowdConnector{} connector := crowdConnector{}
connector.BaseURL = baseURL connector.BaseURL = baseURL
connector.logger = &logrus.Logger{ connector.logger = slog.New(slog.NewTextHandler(io.Discard, &slog.HandlerOptions{}))
Out: io.Discard,
Level: logrus.DebugLevel,
Formatter: &logrus.TextFormatter{DisableColors: true},
}
return connector return connector
} }

8
connector/authproxy/authproxy.go

@ -5,12 +5,12 @@ package authproxy
import ( import (
"fmt" "fmt"
"log/slog"
"net/http" "net/http"
"net/url" "net/url"
"strings" "strings"
"github.com/dexidp/dex/connector" "github.com/dexidp/dex/connector"
"github.com/dexidp/dex/pkg/log"
) )
// Config holds the configuration parameters for a connector which returns an // Config holds the configuration parameters for a connector which returns an
@ -27,7 +27,7 @@ type Config struct {
} }
// Open returns an authentication strategy which requires no user interaction. // Open returns an authentication strategy which requires no user interaction.
func (c *Config) Open(id string, logger log.Logger) (connector.Connector, error) { func (c *Config) Open(id string, logger *slog.Logger) (connector.Connector, error) {
userIDHeader := c.UserIDHeader userIDHeader := c.UserIDHeader
if userIDHeader == "" { if userIDHeader == "" {
userIDHeader = "X-Remote-User-Id" userIDHeader = "X-Remote-User-Id"
@ -51,7 +51,7 @@ func (c *Config) Open(id string, logger log.Logger) (connector.Connector, error)
emailHeader: emailHeader, emailHeader: emailHeader,
groupHeader: groupHeader, groupHeader: groupHeader,
groups: c.Groups, groups: c.Groups,
logger: logger, logger: logger.With(slog.Group("connector", "type", "authproxy", "id", id)),
pathSuffix: "/" + id, pathSuffix: "/" + id,
}, nil }, nil
} }
@ -64,7 +64,7 @@ type callback struct {
emailHeader string emailHeader string
groupHeader string groupHeader string
groups []string groups []string
logger log.Logger logger *slog.Logger
pathSuffix string pathSuffix string
} }

5
connector/authproxy/authproxy_test.go

@ -2,12 +2,11 @@ package authproxy
import ( import (
"io" "io"
"log/slog"
"net/http" "net/http"
"reflect" "reflect"
"testing" "testing"
"github.com/sirupsen/logrus"
"github.com/dexidp/dex/connector" "github.com/dexidp/dex/connector"
) )
@ -23,7 +22,7 @@ const (
testUserID = "1234567890" testUserID = "1234567890"
) )
var logger = &logrus.Logger{Out: io.Discard, Formatter: &logrus.TextFormatter{}} var logger = slog.New(slog.NewTextHandler(io.Discard, &slog.HandlerOptions{}))
func TestUser(t *testing.T) { func TestUser(t *testing.T) {
config := Config{ config := Config{

8
connector/bitbucketcloud/bitbucketcloud.go

@ -7,6 +7,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"io" "io"
"log/slog"
"net/http" "net/http"
"sync" "sync"
"time" "time"
@ -16,7 +17,6 @@ import (
"github.com/dexidp/dex/connector" "github.com/dexidp/dex/connector"
"github.com/dexidp/dex/pkg/groups" "github.com/dexidp/dex/pkg/groups"
"github.com/dexidp/dex/pkg/log"
) )
const ( const (
@ -42,7 +42,7 @@ type Config struct {
} }
// Open returns a strategy for logging in through Bitbucket. // Open returns a strategy for logging in through Bitbucket.
func (c *Config) Open(_ string, logger log.Logger) (connector.Connector, error) { func (c *Config) Open(id string, logger *slog.Logger) (connector.Connector, error) {
b := bitbucketConnector{ b := bitbucketConnector{
redirectURI: c.RedirectURI, redirectURI: c.RedirectURI,
teams: c.Teams, teams: c.Teams,
@ -51,7 +51,7 @@ func (c *Config) Open(_ string, logger log.Logger) (connector.Connector, error)
includeTeamGroups: c.IncludeTeamGroups, includeTeamGroups: c.IncludeTeamGroups,
apiURL: apiURL, apiURL: apiURL,
legacyAPIURL: legacyAPIURL, legacyAPIURL: legacyAPIURL,
logger: logger, logger: logger.With(slog.Group("connector", "type", "bitbucketcloud", "id", id)),
} }
return &b, nil return &b, nil
@ -73,7 +73,7 @@ type bitbucketConnector struct {
teams []string teams []string
clientID string clientID string
clientSecret string clientSecret string
logger log.Logger logger *slog.Logger
apiURL string apiURL string
legacyAPIURL string legacyAPIURL string

8
connector/gitea/gitea.go

@ -7,6 +7,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"io" "io"
"log/slog"
"net/http" "net/http"
"strconv" "strconv"
"sync" "sync"
@ -15,7 +16,6 @@ import (
"golang.org/x/oauth2" "golang.org/x/oauth2"
"github.com/dexidp/dex/connector" "github.com/dexidp/dex/connector"
"github.com/dexidp/dex/pkg/log"
) )
// Config holds configuration options for gitea logins. // Config holds configuration options for gitea logins.
@ -51,7 +51,7 @@ type giteaUser struct {
} }
// Open returns a strategy for logging in through Gitea // Open returns a strategy for logging in through Gitea
func (c *Config) Open(id string, logger log.Logger) (connector.Connector, error) { func (c *Config) Open(id string, logger *slog.Logger) (connector.Connector, error) {
if c.BaseURL == "" { if c.BaseURL == "" {
c.BaseURL = "https://gitea.com" c.BaseURL = "https://gitea.com"
} }
@ -61,7 +61,7 @@ func (c *Config) Open(id string, logger log.Logger) (connector.Connector, error)
orgs: c.Orgs, orgs: c.Orgs,
clientID: c.ClientID, clientID: c.ClientID,
clientSecret: c.ClientSecret, clientSecret: c.ClientSecret,
logger: logger, logger: logger.With(slog.Group("connector", "type", "gitea", "id", id)),
loadAllGroups: c.LoadAllGroups, loadAllGroups: c.LoadAllGroups,
useLoginAsID: c.UseLoginAsID, useLoginAsID: c.UseLoginAsID,
}, nil }, nil
@ -84,7 +84,7 @@ type giteaConnector struct {
orgs []Org orgs []Org
clientID string clientID string
clientSecret string clientSecret string
logger log.Logger logger *slog.Logger
httpClient *http.Client httpClient *http.Client
// if set to true and no orgs are configured then connector loads all user claims (all orgs and team) // if set to true and no orgs are configured then connector loads all user claims (all orgs and team)
loadAllGroups bool loadAllGroups bool

12
connector/github/github.go

@ -7,6 +7,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"io" "io"
"log/slog"
"net/http" "net/http"
"regexp" "regexp"
"strconv" "strconv"
@ -18,7 +19,6 @@ import (
"github.com/dexidp/dex/connector" "github.com/dexidp/dex/connector"
groups_pkg "github.com/dexidp/dex/pkg/groups" groups_pkg "github.com/dexidp/dex/pkg/groups"
"github.com/dexidp/dex/pkg/httpclient" "github.com/dexidp/dex/pkg/httpclient"
"github.com/dexidp/dex/pkg/log"
) )
const ( const (
@ -66,7 +66,7 @@ type Org struct {
} }
// Open returns a strategy for logging in through GitHub. // Open returns a strategy for logging in through GitHub.
func (c *Config) Open(id string, logger log.Logger) (connector.Connector, error) { func (c *Config) Open(id string, logger *slog.Logger) (connector.Connector, error) {
if c.Org != "" { if c.Org != "" {
// Return error if both 'org' and 'orgs' fields are used. // Return error if both 'org' and 'orgs' fields are used.
if len(c.Orgs) > 0 { if len(c.Orgs) > 0 {
@ -82,7 +82,7 @@ func (c *Config) Open(id string, logger log.Logger) (connector.Connector, error)
clientID: c.ClientID, clientID: c.ClientID,
clientSecret: c.ClientSecret, clientSecret: c.ClientSecret,
apiURL: apiURL, apiURL: apiURL,
logger: logger, logger: logger.With(slog.Group("connector", "type", "github", "id", id)),
useLoginAsID: c.UseLoginAsID, useLoginAsID: c.UseLoginAsID,
preferredEmailDomain: c.PreferredEmailDomain, preferredEmailDomain: c.PreferredEmailDomain,
} }
@ -142,7 +142,7 @@ type githubConnector struct {
orgs []Org orgs []Org
clientID string clientID string
clientSecret string clientSecret string
logger log.Logger logger *slog.Logger
// apiURL defaults to "https://api.github.com" // apiURL defaults to "https://api.github.com"
apiURL string apiURL string
// hostName of the GitHub enterprise account. // hostName of the GitHub enterprise account.
@ -362,7 +362,7 @@ func (c *githubConnector) groupsForOrgs(ctx context.Context, client *http.Client
if len(org.Teams) == 0 { if len(org.Teams) == 0 {
inOrgNoTeams = true inOrgNoTeams = true
} else if teams = groups_pkg.Filter(teams, org.Teams); len(teams) == 0 { } else if teams = groups_pkg.Filter(teams, org.Teams); len(teams) == 0 {
c.logger.Infof("github: user %q in org %q but no teams", userName, org.Name) c.logger.Info("user in org but no teams", "user", userName, "org", org.Name)
} }
for _, teamName := range teams { for _, teamName := range teams {
@ -667,7 +667,7 @@ func (c *githubConnector) userInOrg(ctx context.Context, client *http.Client, us
switch resp.StatusCode { switch resp.StatusCode {
case http.StatusNoContent: case http.StatusNoContent:
case http.StatusFound, http.StatusNotFound: case http.StatusFound, http.StatusNotFound:
c.logger.Infof("github: user %q not in org %q or application not authorized to read org data", userName, orgName) c.logger.Info("user not in org or application not authorized to read org data", "user", userName, "org", orgName)
default: default:
err = fmt.Errorf("github: unexpected return status: %q", resp.Status) err = fmt.Errorf("github: unexpected return status: %q", resp.Status)
} }

5
connector/github/github_test.go

@ -6,6 +6,8 @@ import (
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"io"
"log/slog"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"net/url" "net/url"
@ -449,6 +451,7 @@ func Test_isPreferredEmailDomain(t *testing.T) {
} }
func Test_Open_PreferredDomainConfig(t *testing.T) { func Test_Open_PreferredDomainConfig(t *testing.T) {
log := slog.New(slog.NewTextHandler(io.Discard, &slog.HandlerOptions{}))
tests := []struct { tests := []struct {
preferredEmailDomain string preferredEmailDomain string
email string email string
@ -476,7 +479,7 @@ func Test_Open_PreferredDomainConfig(t *testing.T) {
c := Config{ c := Config{
PreferredEmailDomain: test.preferredEmailDomain, PreferredEmailDomain: test.preferredEmailDomain,
} }
_, err := c.Open("id", nil) _, err := c.Open("id", log)
expectEquals(t, err, test.expected) expectEquals(t, err, test.expected)
}) })

8
connector/gitlab/gitlab.go

@ -7,6 +7,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"io" "io"
"log/slog"
"net/http" "net/http"
"strconv" "strconv"
"time" "time"
@ -15,7 +16,6 @@ import (
"github.com/dexidp/dex/connector" "github.com/dexidp/dex/connector"
"github.com/dexidp/dex/pkg/groups" "github.com/dexidp/dex/pkg/groups"
"github.com/dexidp/dex/pkg/log"
) )
const ( const (
@ -46,7 +46,7 @@ type gitlabUser struct {
} }
// Open returns a strategy for logging in through GitLab. // Open returns a strategy for logging in through GitLab.
func (c *Config) Open(id string, logger log.Logger) (connector.Connector, error) { func (c *Config) Open(id string, logger *slog.Logger) (connector.Connector, error) {
if c.BaseURL == "" { if c.BaseURL == "" {
c.BaseURL = "https://gitlab.com" c.BaseURL = "https://gitlab.com"
} }
@ -55,7 +55,7 @@ func (c *Config) Open(id string, logger log.Logger) (connector.Connector, error)
redirectURI: c.RedirectURI, redirectURI: c.RedirectURI,
clientID: c.ClientID, clientID: c.ClientID,
clientSecret: c.ClientSecret, clientSecret: c.ClientSecret,
logger: logger, logger: logger.With(slog.Group("connector", "type", "gitlab", "id", id)),
groups: c.Groups, groups: c.Groups,
useLoginAsID: c.UseLoginAsID, useLoginAsID: c.UseLoginAsID,
}, nil }, nil
@ -78,7 +78,7 @@ type gitlabConnector struct {
groups []string groups []string
clientID string clientID string
clientSecret string clientSecret string
logger log.Logger logger *slog.Logger
httpClient *http.Client httpClient *http.Client
// if set to true will use the user's handle rather than their numeric id as the ID // if set to true will use the user's handle rather than their numeric id as the ID
useLoginAsID bool useLoginAsID bool

19
connector/google/google.go

@ -5,6 +5,7 @@ import (
"context" "context"
"errors" "errors"
"fmt" "fmt"
"log/slog"
"net/http" "net/http"
"os" "os"
"strings" "strings"
@ -21,7 +22,6 @@ import (
"github.com/dexidp/dex/connector" "github.com/dexidp/dex/connector"
pkg_groups "github.com/dexidp/dex/pkg/groups" pkg_groups "github.com/dexidp/dex/pkg/groups"
"github.com/dexidp/dex/pkg/log"
) )
const ( const (
@ -67,9 +67,10 @@ type Config struct {
} }
// Open returns a connector which can be used to login users through Google. // Open returns a connector which can be used to login users through Google.
func (c *Config) Open(id string, logger log.Logger) (conn connector.Connector, err error) { func (c *Config) Open(id string, logger *slog.Logger) (conn connector.Connector, err error) {
logger = logger.With(slog.Group("connector", "type", "google", "id", id))
if c.AdminEmail != "" { if c.AdminEmail != "" {
log.Deprecated(logger, `google: use "domainToAdminEmail.*: %s" option instead of "adminEmail: %s".`, c.AdminEmail, c.AdminEmail) logger.Warn(`use "domainToAdminEmail.*" option instead of "adminEmail"`, "deprecated", true)
if c.DomainToAdminEmail == nil { if c.DomainToAdminEmail == nil {
c.DomainToAdminEmail = make(map[string]string) c.DomainToAdminEmail = make(map[string]string)
} }
@ -152,7 +153,7 @@ type googleConnector struct {
oauth2Config *oauth2.Config oauth2Config *oauth2.Config
verifier *oidc.IDTokenVerifier verifier *oidc.IDTokenVerifier
cancel context.CancelFunc cancel context.CancelFunc
logger log.Logger logger *slog.Logger
hostedDomains []string hostedDomains []string
groups []string groups []string
serviceAccountFilePath string serviceAccountFilePath string
@ -340,7 +341,7 @@ func (c *googleConnector) findAdminService(domain string) (*admin.Service, error
adminSrv, ok := c.adminSrv[domain] adminSrv, ok := c.adminSrv[domain]
if !ok { if !ok {
adminSrv, ok = c.adminSrv[wildcardDomainToAdminEmail] adminSrv, ok = c.adminSrv[wildcardDomainToAdminEmail]
c.logger.Debugf("using wildcard (%s) admin email to fetch groups", c.domainToAdminEmail[wildcardDomainToAdminEmail]) c.logger.Debug("using wildcard admin email to fetch groups", "admin_email", c.domainToAdminEmail[wildcardDomainToAdminEmail])
} }
if !ok { if !ok {
@ -377,7 +378,7 @@ func getCredentialsFromFilePath(serviceAccountFilePath string) ([]byte, error) {
// If the default credential is empty, it attempts to create a new service with metadata credentials. // If the default credential is empty, it attempts to create a new service with metadata credentials.
// If successful, it returns the service and nil error. // If successful, it returns the service and nil error.
// If unsuccessful, it returns the error and a nil service. // If unsuccessful, it returns the error and a nil service.
func getCredentialsFromDefault(ctx context.Context, email string, logger log.Logger) ([]byte, *admin.Service, error) { func getCredentialsFromDefault(ctx context.Context, email string, logger *slog.Logger) ([]byte, *admin.Service, error) {
credential, err := google.FindDefaultCredentials(ctx) credential, err := google.FindDefaultCredentials(ctx)
if err != nil { if err != nil {
return nil, nil, fmt.Errorf("failed to fetch application default credentials: %w", err) return nil, nil, fmt.Errorf("failed to fetch application default credentials: %w", err)
@ -397,9 +398,9 @@ func getCredentialsFromDefault(ctx context.Context, email string, logger log.Log
// createServiceWithMetadataServer creates a new service using metadata server. // createServiceWithMetadataServer creates a new service using metadata server.
// If an error occurs during the process, it is returned along with a nil service. // If an error occurs during the process, it is returned along with a nil service.
func createServiceWithMetadataServer(ctx context.Context, adminEmail string, logger log.Logger) (*admin.Service, error) { func createServiceWithMetadataServer(ctx context.Context, adminEmail string, logger *slog.Logger) (*admin.Service, error) {
serviceAccountEmail, err := metadata.Email("default") serviceAccountEmail, err := metadata.Email("default")
logger.Infof("discovered serviceAccountEmail: %s", serviceAccountEmail) logger.Info("discovered serviceAccountEmail", "email", serviceAccountEmail)
if err != nil { if err != nil {
return nil, fmt.Errorf("unable to get service account email from metadata server: %v", err) return nil, fmt.Errorf("unable to get service account email from metadata server: %v", err)
@ -423,7 +424,7 @@ func createServiceWithMetadataServer(ctx context.Context, adminEmail string, log
// createDirectoryService sets up super user impersonation and creates an admin client for calling // createDirectoryService sets up super user impersonation and creates an admin client for calling
// the google admin api. If no serviceAccountFilePath is defined, the application default credential // the google admin api. If no serviceAccountFilePath is defined, the application default credential
// is used. // is used.
func createDirectoryService(serviceAccountFilePath, email string, logger log.Logger) (service *admin.Service, err error) { func createDirectoryService(serviceAccountFilePath, email string, logger *slog.Logger) (service *admin.Service, err error) {
var jsonCredentials []byte var jsonCredentials []byte
ctx := context.Background() ctx := context.Background()

5
connector/google/google_test.go

@ -4,6 +4,8 @@ import (
"context" "context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io"
"log/slog"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"net/url" "net/url"
@ -11,7 +13,6 @@ import (
"strings" "strings"
"testing" "testing"
"github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
admin "google.golang.org/api/admin/directory/v1" admin "google.golang.org/api/admin/directory/v1"
"google.golang.org/api/option" "google.golang.org/api/option"
@ -51,7 +52,7 @@ func testSetup() *httptest.Server {
} }
func newConnector(config *Config) (*googleConnector, error) { func newConnector(config *Config) (*googleConnector, error) {
log := logrus.New() log := slog.New(slog.NewTextHandler(io.Discard, &slog.HandlerOptions{}))
conn, err := config.Open("id", log) conn, err := config.Open("id", log)
if err != nil { if err != nil {
return nil, err return nil, err

10
connector/keystone/keystone.go

@ -7,10 +7,10 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io"
"log/slog"
"net/http" "net/http"
"github.com/dexidp/dex/connector" "github.com/dexidp/dex/connector"
"github.com/dexidp/dex/pkg/log"
) )
type conn struct { type conn struct {
@ -19,7 +19,7 @@ type conn struct {
AdminUsername string AdminUsername string
AdminPassword string AdminPassword string
client *http.Client client *http.Client
Logger log.Logger Logger *slog.Logger
} }
type userKeystone struct { type userKeystone struct {
@ -111,13 +111,13 @@ var (
) )
// Open returns an authentication strategy using Keystone. // Open returns an authentication strategy using Keystone.
func (c *Config) Open(id string, logger log.Logger) (connector.Connector, error) { func (c *Config) Open(id string, logger *slog.Logger) (connector.Connector, error) {
return &conn{ return &conn{
Domain: c.Domain, Domain: c.Domain,
Host: c.Host, Host: c.Host,
AdminUsername: c.AdminUsername, AdminUsername: c.AdminUsername,
AdminPassword: c.AdminPassword, AdminPassword: c.AdminPassword,
Logger: logger, Logger: logger.With(slog.Group("connector", "type", "keystone", "id", id)),
client: http.DefaultClient, client: http.DefaultClient,
}, nil }, nil
} }
@ -287,7 +287,7 @@ func (p *conn) getUserGroups(ctx context.Context, userID string, token string) (
req = req.WithContext(ctx) req = req.WithContext(ctx)
resp, err := p.client.Do(req) resp, err := p.client.Do(req)
if err != nil { if err != nil {
p.Logger.Errorf("keystone: error while fetching user %q groups\n", userID) p.Logger.Error("error while fetching user groups", "user_id", userID, "err", err)
return nil, err return nil, err
} }

37
connector/ldap/ldap.go

@ -7,6 +7,7 @@ import (
"crypto/x509" "crypto/x509"
"encoding/json" "encoding/json"
"fmt" "fmt"
"log/slog"
"net" "net"
"os" "os"
"strings" "strings"
@ -14,7 +15,6 @@ import (
"github.com/go-ldap/ldap/v3" "github.com/go-ldap/ldap/v3"
"github.com/dexidp/dex/connector" "github.com/dexidp/dex/connector"
"github.com/dexidp/dex/pkg/log"
) )
// Config holds the configuration parameters for the LDAP connector. The LDAP // Config holds the configuration parameters for the LDAP connector. The LDAP
@ -188,12 +188,12 @@ func parseScope(s string) (int, bool) {
// Function exists here to allow backward compatibility between old and new // Function exists here to allow backward compatibility between old and new
// group to user matching implementations. // group to user matching implementations.
// See "Config.GroupSearch.UserMatchers" comments for the details // See "Config.GroupSearch.UserMatchers" comments for the details
func userMatchers(c *Config, logger log.Logger) []UserMatcher { func userMatchers(c *Config, logger *slog.Logger) []UserMatcher {
if len(c.GroupSearch.UserMatchers) > 0 && c.GroupSearch.UserMatchers[0].UserAttr != "" { if len(c.GroupSearch.UserMatchers) > 0 && c.GroupSearch.UserMatchers[0].UserAttr != "" {
return c.GroupSearch.UserMatchers return c.GroupSearch.UserMatchers
} }
log.Deprecated(logger, `LDAP: use groupSearch.userMatchers option instead of "userAttr/groupAttr" fields.`) logger.Warn(`use "groupSearch.userMatchers" option instead of "userAttr/groupAttr" fields`, "deprecated", true)
return []UserMatcher{ return []UserMatcher{
{ {
UserAttr: c.GroupSearch.UserAttr, UserAttr: c.GroupSearch.UserAttr,
@ -203,7 +203,8 @@ func userMatchers(c *Config, logger log.Logger) []UserMatcher {
} }
// Open returns an authentication strategy using LDAP. // Open returns an authentication strategy using LDAP.
func (c *Config) Open(id string, logger log.Logger) (connector.Connector, error) { func (c *Config) Open(id string, logger *slog.Logger) (connector.Connector, error) {
logger = logger.With(slog.Group("connector", "type", "ldap", "id", id))
conn, err := c.OpenConnector(logger) conn, err := c.OpenConnector(logger)
if err != nil { if err != nil {
return nil, err return nil, err
@ -217,7 +218,7 @@ type refreshData struct {
} }
// OpenConnector is the same as Open but returns a type with all implemented connector interfaces. // OpenConnector is the same as Open but returns a type with all implemented connector interfaces.
func (c *Config) OpenConnector(logger log.Logger) (interface { func (c *Config) OpenConnector(logger *slog.Logger) (interface {
connector.Connector connector.Connector
connector.PasswordConnector connector.PasswordConnector
connector.RefreshConnector connector.RefreshConnector
@ -226,7 +227,7 @@ func (c *Config) OpenConnector(logger log.Logger) (interface {
return c.openConnector(logger) return c.openConnector(logger)
} }
func (c *Config) openConnector(logger log.Logger) (*ldapConnector, error) { func (c *Config) openConnector(logger *slog.Logger) (*ldapConnector, error) {
requiredFields := []struct { requiredFields := []struct {
name string name string
val string val string
@ -300,7 +301,7 @@ type ldapConnector struct {
tlsConfig *tls.Config tlsConfig *tls.Config
logger log.Logger logger *slog.Logger
} }
var ( var (
@ -359,7 +360,7 @@ func (c *ldapConnector) getAttrs(e ldap.Entry, name string) []string {
return []string{e.DN} return []string{e.DN}
} }
c.logger.Debugf("%q attribute is not fround in entry", name) c.logger.Debug("attribute is not fround in entry", "attribute", name)
return nil return nil
} }
@ -438,8 +439,8 @@ func (c *ldapConnector) userEntry(conn *ldap.Conn, username string) (user ldap.E
req.Attributes = append(req.Attributes, c.UserSearch.PreferredUsernameAttrAttr) req.Attributes = append(req.Attributes, c.UserSearch.PreferredUsernameAttrAttr)
} }
c.logger.Infof("performing ldap search %s %s %s", c.logger.Info("performing ldap search",
req.BaseDN, scopeString(req.Scope), req.Filter) "base_dn", req.BaseDN, "scope", scopeString(req.Scope), "filter", req.Filter)
resp, err := conn.Search(req) resp, err := conn.Search(req)
if err != nil { if err != nil {
return ldap.Entry{}, false, fmt.Errorf("ldap: search with filter %q failed: %v", req.Filter, err) return ldap.Entry{}, false, fmt.Errorf("ldap: search with filter %q failed: %v", req.Filter, err)
@ -447,11 +448,11 @@ func (c *ldapConnector) userEntry(conn *ldap.Conn, username string) (user ldap.E
switch n := len(resp.Entries); n { switch n := len(resp.Entries); n {
case 0: case 0:
c.logger.Errorf("ldap: no results returned for filter: %q", filter) c.logger.Error("no results returned for filter", "filter", filter)
return ldap.Entry{}, false, nil return ldap.Entry{}, false, nil
case 1: case 1:
user = *resp.Entries[0] user = *resp.Entries[0]
c.logger.Infof("username %q mapped to entry %s", username, user.DN) c.logger.Info("username mapped to entry", "username", username, "user_dn", user.DN)
return user, true, nil return user, true, nil
default: default:
return ldap.Entry{}, false, fmt.Errorf("ldap: filter returned multiple (%d) results: %q", n, filter) return ldap.Entry{}, false, fmt.Errorf("ldap: filter returned multiple (%d) results: %q", n, filter)
@ -491,11 +492,11 @@ func (c *ldapConnector) Login(ctx context.Context, s connector.Scopes, username,
if ldapErr, ok := err.(*ldap.Error); ok { if ldapErr, ok := err.(*ldap.Error); ok {
switch ldapErr.ResultCode { switch ldapErr.ResultCode {
case ldap.LDAPResultInvalidCredentials: case ldap.LDAPResultInvalidCredentials:
c.logger.Errorf("ldap: invalid password for user %q", user.DN) c.logger.Error("invalid password for user", "user_dn", user.DN)
incorrectPass = true incorrectPass = true
return nil return nil
case ldap.LDAPResultConstraintViolation: case ldap.LDAPResultConstraintViolation:
c.logger.Errorf("ldap: constraint violation for user %q: %s", user.DN, ldapErr.Error()) c.logger.Error("constraint violation for user", "user_dn", user.DN, "err", ldapErr.Error())
incorrectPass = true incorrectPass = true
return nil return nil
} }
@ -581,7 +582,7 @@ func (c *ldapConnector) Refresh(ctx context.Context, s connector.Scopes, ident c
func (c *ldapConnector) groups(ctx context.Context, user ldap.Entry) ([]string, error) { func (c *ldapConnector) groups(ctx context.Context, user ldap.Entry) ([]string, error) {
if c.GroupSearch.BaseDN == "" { if c.GroupSearch.BaseDN == "" {
c.logger.Debugf("No groups returned for %q because no groups baseDN has been configured.", c.getAttr(user, c.UserSearch.NameAttr)) c.logger.Debug("No groups returned because no groups baseDN has been configured.", "base_dn", c.getAttr(user, c.UserSearch.NameAttr))
return nil, nil return nil, nil
} }
@ -602,8 +603,8 @@ func (c *ldapConnector) groups(ctx context.Context, user ldap.Entry) ([]string,
gotGroups := false gotGroups := false
if err := c.do(ctx, func(conn *ldap.Conn) error { if err := c.do(ctx, func(conn *ldap.Conn) error {
c.logger.Infof("performing ldap search %s %s %s", c.logger.Info("performing ldap search",
req.BaseDN, scopeString(req.Scope), req.Filter) "base_dn", req.BaseDN, "scope", scopeString(req.Scope), "filter", req.Filter)
resp, err := conn.Search(req) resp, err := conn.Search(req)
if err != nil { if err != nil {
return fmt.Errorf("ldap: search failed: %v", err) return fmt.Errorf("ldap: search failed: %v", err)
@ -616,7 +617,7 @@ func (c *ldapConnector) groups(ctx context.Context, user ldap.Entry) ([]string,
} }
if !gotGroups { if !gotGroups {
// TODO(ericchiang): Is this going to spam the logs? // TODO(ericchiang): Is this going to spam the logs?
c.logger.Errorf("ldap: groups search with filter %q returned no groups", filter) c.logger.Error("groups search returned no groups", "filter", filter)
} }
} }
} }

4
connector/ldap/ldap_test.go

@ -4,11 +4,11 @@ import (
"context" "context"
"fmt" "fmt"
"io" "io"
"log/slog"
"os" "os"
"testing" "testing"
"github.com/kylelemons/godebug/pretty" "github.com/kylelemons/godebug/pretty"
"github.com/sirupsen/logrus"
"github.com/dexidp/dex/connector" "github.com/dexidp/dex/connector"
) )
@ -567,7 +567,7 @@ func runTests(t *testing.T, connMethod connectionMethod, config *Config, tests [
c.BindDN = "cn=admin,dc=example,dc=org" c.BindDN = "cn=admin,dc=example,dc=org"
c.BindPW = "admin" c.BindPW = "admin"
l := &logrus.Logger{Out: io.Discard, Formatter: &logrus.TextFormatter{}} l := slog.New(slog.NewTextHandler(io.Discard, &slog.HandlerOptions{}))
conn, err := c.openConnector(l) conn, err := c.openConnector(l)
if err != nil { if err != nil {

8
connector/linkedin/linkedin.go

@ -6,13 +6,13 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io"
"log/slog"
"net/http" "net/http"
"strings" "strings"
"golang.org/x/oauth2" "golang.org/x/oauth2"
"github.com/dexidp/dex/connector" "github.com/dexidp/dex/connector"
"github.com/dexidp/dex/pkg/log"
) )
const ( const (
@ -29,7 +29,7 @@ type Config struct {
} }
// Open returns a strategy for logging in through LinkedIn // Open returns a strategy for logging in through LinkedIn
func (c *Config) Open(id string, logger log.Logger) (connector.Connector, error) { func (c *Config) Open(id string, logger *slog.Logger) (connector.Connector, error) {
return &linkedInConnector{ return &linkedInConnector{
oauth2Config: &oauth2.Config{ oauth2Config: &oauth2.Config{
ClientID: c.ClientID, ClientID: c.ClientID,
@ -41,7 +41,7 @@ func (c *Config) Open(id string, logger log.Logger) (connector.Connector, error)
Scopes: []string{"r_liteprofile", "r_emailaddress"}, Scopes: []string{"r_liteprofile", "r_emailaddress"},
RedirectURL: c.RedirectURI, RedirectURL: c.RedirectURI,
}, },
logger: logger, logger: logger.With(slog.Group("connector", "type", "linkedin", "id", id)),
}, nil }, nil
} }
@ -51,7 +51,7 @@ type connectorData struct {
type linkedInConnector struct { type linkedInConnector struct {
oauth2Config *oauth2.Config oauth2Config *oauth2.Config
logger log.Logger logger *slog.Logger
} }
// LinkedIn doesn't provide refresh tokens, so refresh tokens issued by Dex // LinkedIn doesn't provide refresh tokens, so refresh tokens issued by Dex

8
connector/microsoft/microsoft.go

@ -8,6 +8,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"io" "io"
"log/slog"
"net/http" "net/http"
"strings" "strings"
"sync" "sync"
@ -17,7 +18,6 @@ import (
"github.com/dexidp/dex/connector" "github.com/dexidp/dex/connector"
groups_pkg "github.com/dexidp/dex/pkg/groups" groups_pkg "github.com/dexidp/dex/pkg/groups"
"github.com/dexidp/dex/pkg/log"
) )
// GroupNameFormat represents the format of the group identifier // GroupNameFormat represents the format of the group identifier
@ -66,7 +66,7 @@ type Config struct {
} }
// Open returns a strategy for logging in through Microsoft. // Open returns a strategy for logging in through Microsoft.
func (c *Config) Open(id string, logger log.Logger) (connector.Connector, error) { func (c *Config) Open(id string, logger *slog.Logger) (connector.Connector, error) {
m := microsoftConnector{ m := microsoftConnector{
apiURL: strings.TrimSuffix(c.APIURL, "/"), apiURL: strings.TrimSuffix(c.APIURL, "/"),
graphURL: strings.TrimSuffix(c.GraphURL, "/"), graphURL: strings.TrimSuffix(c.GraphURL, "/"),
@ -78,7 +78,7 @@ func (c *Config) Open(id string, logger log.Logger) (connector.Connector, error)
groups: c.Groups, groups: c.Groups,
groupNameFormat: c.GroupNameFormat, groupNameFormat: c.GroupNameFormat,
useGroupsAsWhitelist: c.UseGroupsAsWhitelist, useGroupsAsWhitelist: c.UseGroupsAsWhitelist,
logger: logger, logger: logger.With(slog.Group("connector", "type", "microsoft", "id", id)),
emailToLowercase: c.EmailToLowercase, emailToLowercase: c.EmailToLowercase,
promptType: c.PromptType, promptType: c.PromptType,
domainHint: c.DomainHint, domainHint: c.DomainHint,
@ -133,7 +133,7 @@ type microsoftConnector struct {
groupNameFormat GroupNameFormat groupNameFormat GroupNameFormat
groups []string groups []string
useGroupsAsWhitelist bool useGroupsAsWhitelist bool
logger log.Logger logger *slog.Logger
emailToLowercase bool emailToLowercase bool
promptType string promptType string
domainHint string domainHint string

13
connector/mock/connectortest.go

@ -5,16 +5,16 @@ import (
"context" "context"
"errors" "errors"
"fmt" "fmt"
"log/slog"
"net/http" "net/http"
"net/url" "net/url"
"github.com/dexidp/dex/connector" "github.com/dexidp/dex/connector"
"github.com/dexidp/dex/pkg/log"
) )
// NewCallbackConnector returns a mock connector which requires no user interaction. It always returns // NewCallbackConnector returns a mock connector which requires no user interaction. It always returns
// the same (fake) identity. // the same (fake) identity.
func NewCallbackConnector(logger log.Logger) connector.Connector { func NewCallbackConnector(logger *slog.Logger) connector.Connector {
return &Callback{ return &Callback{
Identity: connector.Identity{ Identity: connector.Identity{
UserID: "0-385-28089-0", UserID: "0-385-28089-0",
@ -39,7 +39,7 @@ var (
type Callback struct { type Callback struct {
// The returned identity. // The returned identity.
Identity connector.Identity Identity connector.Identity
Logger log.Logger Logger *slog.Logger
} }
// LoginURL returns the URL to redirect the user to login with. // LoginURL returns the URL to redirect the user to login with.
@ -74,7 +74,8 @@ func (m *Callback) TokenIdentity(ctx context.Context, subjectTokenType, subjectT
type CallbackConfig struct{} type CallbackConfig struct{}
// Open returns an authentication strategy which requires no user interaction. // Open returns an authentication strategy which requires no user interaction.
func (c *CallbackConfig) Open(id string, logger log.Logger) (connector.Connector, error) { func (c *CallbackConfig) Open(id string, logger *slog.Logger) (connector.Connector, error) {
logger = logger.With(slog.Group("connector", "type", "callback", "id", id))
return NewCallbackConnector(logger), nil return NewCallbackConnector(logger), nil
} }
@ -86,7 +87,7 @@ type PasswordConfig struct {
} }
// Open returns an authentication strategy which prompts for a predefined username and password. // Open returns an authentication strategy which prompts for a predefined username and password.
func (c *PasswordConfig) Open(id string, logger log.Logger) (connector.Connector, error) { func (c *PasswordConfig) Open(id string, logger *slog.Logger) (connector.Connector, error) {
if c.Username == "" { if c.Username == "" {
return nil, errors.New("no username supplied") return nil, errors.New("no username supplied")
} }
@ -99,7 +100,7 @@ func (c *PasswordConfig) Open(id string, logger log.Logger) (connector.Connector
type passwordConnector struct { type passwordConnector struct {
username string username string
password string password string
logger log.Logger logger *slog.Logger
} }
func (p passwordConnector) Close() error { return nil } func (p passwordConnector) Close() error { return nil }

8
connector/oauth/oauth.go

@ -6,6 +6,7 @@ import (
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"log/slog"
"net/http" "net/http"
"strings" "strings"
@ -13,7 +14,6 @@ import (
"github.com/dexidp/dex/connector" "github.com/dexidp/dex/connector"
"github.com/dexidp/dex/pkg/httpclient" "github.com/dexidp/dex/pkg/httpclient"
"github.com/dexidp/dex/pkg/log"
) )
type oauthConnector struct { type oauthConnector struct {
@ -31,7 +31,7 @@ type oauthConnector struct {
emailVerifiedKey string emailVerifiedKey string
groupsKey string groupsKey string
httpClient *http.Client httpClient *http.Client
logger log.Logger logger *slog.Logger
} }
type connectorData struct { type connectorData struct {
@ -58,7 +58,7 @@ type Config struct {
} `json:"claimMapping"` } `json:"claimMapping"`
} }
func (c *Config) Open(id string, logger log.Logger) (connector.Connector, error) { func (c *Config) Open(id string, logger *slog.Logger) (connector.Connector, error) {
var err error var err error
userIDKey := c.UserIDKey userIDKey := c.UserIDKey
@ -99,7 +99,7 @@ func (c *Config) Open(id string, logger log.Logger) (connector.Connector, error)
userInfoURL: c.UserInfoURL, userInfoURL: c.UserInfoURL,
scopes: c.Scopes, scopes: c.Scopes,
redirectURI: c.RedirectURI, redirectURI: c.RedirectURI,
logger: logger, logger: logger.With(slog.Group("connector", "type", "oauth", "id", id)),
userIDKey: userIDKey, userIDKey: userIDKey,
userNameKey: userNameKey, userNameKey: userNameKey,
preferredUsernameKey: preferredUsernameKey, preferredUsernameKey: preferredUsernameKey,

5
connector/oauth/oauth_test.go

@ -6,6 +6,8 @@ import (
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"io"
"log/slog"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"net/url" "net/url"
@ -13,7 +15,6 @@ import (
"testing" "testing"
"github.com/go-jose/go-jose/v4" "github.com/go-jose/go-jose/v4"
"github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/dexidp/dex/connector" "github.com/dexidp/dex/connector"
@ -270,7 +271,7 @@ func newConnector(t *testing.T, serverURL string) *oauthConnector {
testConfig.ClaimMapping.EmailKey = "mail" testConfig.ClaimMapping.EmailKey = "mail"
testConfig.ClaimMapping.EmailVerifiedKey = "has_verified_email" testConfig.ClaimMapping.EmailVerifiedKey = "has_verified_email"
log := logrus.New() log := slog.New(slog.NewTextHandler(io.Discard, &slog.HandlerOptions{}))
conn, err := testConfig.Open("id", log) conn, err := testConfig.Open("id", log)
if err != nil { if err != nil {

10
connector/oidc/oidc.go

@ -6,6 +6,7 @@ import (
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"log/slog"
"net/http" "net/http"
"net/url" "net/url"
"strings" "strings"
@ -17,7 +18,6 @@ import (
"github.com/dexidp/dex/connector" "github.com/dexidp/dex/connector"
groups_pkg "github.com/dexidp/dex/pkg/groups" groups_pkg "github.com/dexidp/dex/pkg/groups"
"github.com/dexidp/dex/pkg/httpclient" "github.com/dexidp/dex/pkg/httpclient"
"github.com/dexidp/dex/pkg/log"
) )
// Config holds configuration options for OpenID Connect logins. // Config holds configuration options for OpenID Connect logins.
@ -206,7 +206,7 @@ func knownBrokenAuthHeaderProvider(issuerURL string) bool {
// Open returns a connector which can be used to login users through an upstream // Open returns a connector which can be used to login users through an upstream
// OpenID Connect provider. // OpenID Connect provider.
func (c *Config) Open(id string, logger log.Logger) (conn connector.Connector, err error) { func (c *Config) Open(id string, logger *slog.Logger) (conn connector.Connector, err error) {
if len(c.HostedDomains) > 0 { if len(c.HostedDomains) > 0 {
return nil, fmt.Errorf("support for the Hosted domains option had been deprecated and removed, consider switching to the Google connector") return nil, fmt.Errorf("support for the Hosted domains option had been deprecated and removed, consider switching to the Google connector")
} }
@ -225,7 +225,7 @@ func (c *Config) Open(id string, logger log.Logger) (conn connector.Connector, e
return nil, err return nil, err
} }
if !c.ProviderDiscoveryOverrides.Empty() { if !c.ProviderDiscoveryOverrides.Empty() {
logger.Warnf("overrides for connector %q are set, this can be a vulnerability when not properly configured", id) logger.Warn("overrides for connector are set, this can be a vulnerability when not properly configured", "connector_id", id)
} }
endpoint := provider.Endpoint() endpoint := provider.Endpoint()
@ -266,7 +266,7 @@ func (c *Config) Open(id string, logger log.Logger) (conn connector.Connector, e
verifier: provider.Verifier( verifier: provider.Verifier(
&oidc.Config{ClientID: clientID}, &oidc.Config{ClientID: clientID},
), ),
logger: logger, logger: logger.With(slog.Group("connector", "type", "oidc", "id", id)),
cancel: cancel, cancel: cancel,
httpClient: httpClient, httpClient: httpClient,
insecureSkipEmailVerified: c.InsecureSkipEmailVerified, insecureSkipEmailVerified: c.InsecureSkipEmailVerified,
@ -296,7 +296,7 @@ type oidcConnector struct {
oauth2Config *oauth2.Config oauth2Config *oauth2.Config
verifier *oidc.IDTokenVerifier verifier *oidc.IDTokenVerifier
cancel context.CancelFunc cancel context.CancelFunc
logger log.Logger logger *slog.Logger
httpClient *http.Client httpClient *http.Client
insecureSkipEmailVerified bool insecureSkipEmailVerified bool
insecureEnableGroups bool insecureEnableGroups bool

5
connector/oidc/oidc_test.go

@ -10,6 +10,8 @@ import (
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"io"
"log/slog"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"reflect" "reflect"
@ -18,7 +20,6 @@ import (
"time" "time"
"github.com/go-jose/go-jose/v4" "github.com/go-jose/go-jose/v4"
"github.com/sirupsen/logrus"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/dexidp/dex/connector" "github.com/dexidp/dex/connector"
@ -765,7 +766,7 @@ func newToken(key *jose.JSONWebKey, claims map[string]interface{}) (string, erro
} }
func newConnector(config Config) (*oidcConnector, error) { func newConnector(config Config) (*oidcConnector, error) {
logger := logrus.New() logger := slog.New(slog.NewTextHandler(io.Discard, &slog.HandlerOptions{}))
conn, err := config.Open("id", logger) conn, err := config.Open("id", logger)
if err != nil { if err != nil {
return nil, fmt.Errorf("unable to open: %v", err) return nil, fmt.Errorf("unable to open: %v", err)

10
connector/openshift/openshift.go

@ -5,6 +5,7 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io"
"log/slog"
"net/http" "net/http"
"strings" "strings"
@ -13,7 +14,6 @@ import (
"github.com/dexidp/dex/connector" "github.com/dexidp/dex/connector"
"github.com/dexidp/dex/pkg/groups" "github.com/dexidp/dex/pkg/groups"
"github.com/dexidp/dex/pkg/httpclient" "github.com/dexidp/dex/pkg/httpclient"
"github.com/dexidp/dex/pkg/log"
"github.com/dexidp/dex/storage/kubernetes/k8sapi" "github.com/dexidp/dex/storage/kubernetes/k8sapi"
) )
@ -44,7 +44,7 @@ type openshiftConnector struct {
clientID string clientID string
clientSecret string clientSecret string
cancel context.CancelFunc cancel context.CancelFunc
logger log.Logger logger *slog.Logger
httpClient *http.Client httpClient *http.Client
oauth2Config *oauth2.Config oauth2Config *oauth2.Config
insecureCA bool insecureCA bool
@ -62,7 +62,7 @@ type user struct {
// Open returns a connector which can be used to login users through an upstream // Open returns a connector which can be used to login users through an upstream
// OpenShift OAuth2 provider. // OpenShift OAuth2 provider.
func (c *Config) Open(id string, logger log.Logger) (conn connector.Connector, err error) { func (c *Config) Open(id string, logger *slog.Logger) (conn connector.Connector, err error) {
var rootCAs []string var rootCAs []string
if c.RootCA != "" { if c.RootCA != "" {
rootCAs = append(rootCAs, c.RootCA) rootCAs = append(rootCAs, c.RootCA)
@ -78,7 +78,7 @@ func (c *Config) Open(id string, logger log.Logger) (conn connector.Connector, e
// OpenWithHTTPClient returns a connector which can be used to login users through an upstream // OpenWithHTTPClient returns a connector which can be used to login users through an upstream
// OpenShift OAuth2 provider. It provides the ability to inject a http.Client. // OpenShift OAuth2 provider. It provides the ability to inject a http.Client.
func (c *Config) OpenWithHTTPClient(id string, logger log.Logger, func (c *Config) OpenWithHTTPClient(id string, logger *slog.Logger,
httpClient *http.Client, httpClient *http.Client,
) (conn connector.Connector, err error) { ) (conn connector.Connector, err error) {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
@ -96,7 +96,7 @@ func (c *Config) OpenWithHTTPClient(id string, logger log.Logger,
clientID: c.ClientID, clientID: c.ClientID,
clientSecret: c.ClientSecret, clientSecret: c.ClientSecret,
insecureCA: c.InsecureCA, insecureCA: c.InsecureCA,
logger: logger, logger: logger.With(slog.Group("connector", "type", "openshift", "id", id)),
redirectURI: c.RedirectURI, redirectURI: c.RedirectURI,
rootCA: c.RootCA, rootCA: c.RootCA,
groups: c.Groups, groups: c.Groups,

5
connector/openshift/openshift_test.go

@ -4,6 +4,8 @@ import (
"context" "context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io"
"log/slog"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"net/url" "net/url"
@ -11,7 +13,6 @@ import (
"testing" "testing"
"time" "time"
"github.com/sirupsen/logrus"
"golang.org/x/oauth2" "golang.org/x/oauth2"
"github.com/dexidp/dex/connector" "github.com/dexidp/dex/connector"
@ -37,7 +38,7 @@ func TestOpen(t *testing.T) {
InsecureCA: true, InsecureCA: true,
} }
logger := logrus.New() logger := slog.New(slog.NewTextHandler(io.Discard, &slog.HandlerOptions{}))
oconfig, err := c.Open("id", logger) oconfig, err := c.Open("id", logger)

12
connector/saml/saml.go

@ -8,6 +8,7 @@ import (
"encoding/pem" "encoding/pem"
"encoding/xml" "encoding/xml"
"fmt" "fmt"
"log/slog"
"os" "os"
"strings" "strings"
"sync" "sync"
@ -21,10 +22,8 @@ import (
"github.com/dexidp/dex/connector" "github.com/dexidp/dex/connector"
"github.com/dexidp/dex/pkg/groups" "github.com/dexidp/dex/pkg/groups"
"github.com/dexidp/dex/pkg/log"
) )
//nolint
const ( const (
bindingRedirect = "urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect" bindingRedirect = "urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect"
bindingPOST = "urn:oasis:names:tc:SAML:2.0:bindings:HTTP-POST" bindingPOST = "urn:oasis:names:tc:SAML:2.0:bindings:HTTP-POST"
@ -120,11 +119,12 @@ func (c certStore) Certificates() (roots []*x509.Certificate, err error) {
// Open validates the config and returns a connector. It does not actually // Open validates the config and returns a connector. It does not actually
// validate connectivity with the provider. // validate connectivity with the provider.
func (c *Config) Open(id string, logger log.Logger) (connector.Connector, error) { func (c *Config) Open(id string, logger *slog.Logger) (connector.Connector, error) {
logger = logger.With(slog.Group("connector", "type", "saml", "id", id))
return c.openConnector(logger) return c.openConnector(logger)
} }
func (c *Config) openConnector(logger log.Logger) (*provider, error) { func (c *Config) openConnector(logger *slog.Logger) (*provider, error) {
requiredFields := []struct { requiredFields := []struct {
name, val string name, val string
}{ }{
@ -252,7 +252,7 @@ type provider struct {
nameIDPolicyFormat string nameIDPolicyFormat string
logger log.Logger logger *slog.Logger
} }
func (p *provider) POSTData(s connector.Scopes, id string) (action, value string, err error) { func (p *provider) POSTData(s connector.Scopes, id string) (action, value string, err error) {
@ -389,7 +389,7 @@ func (p *provider) HandlePOST(s connector.Scopes, samlResponse, inResponseTo str
// Log the actual attributes we got back from the server. This helps debug // Log the actual attributes we got back from the server. This helps debug
// configuration errors on the server side, where the SAML server doesn't // configuration errors on the server side, where the SAML server doesn't
// send us the correct attributes. // send us the correct attributes.
p.logger.Infof("parsed and verified saml response attributes %s", attributes) p.logger.Info("parsed and verified saml response attributes", "attributes", attributes)
// Grab the email. // Grab the email.
if ident.Email, _ = attributes.get(p.emailAttr); ident.Email == "" { if ident.Email, _ = attributes.get(p.emailAttr); ident.Email == "" {

7
connector/saml/saml_test.go

@ -5,6 +5,8 @@ import (
"encoding/base64" "encoding/base64"
"encoding/pem" "encoding/pem"
"errors" "errors"
"io"
"log/slog"
"os" "os"
"sort" "sort"
"testing" "testing"
@ -12,7 +14,6 @@ import (
"github.com/kylelemons/godebug/pretty" "github.com/kylelemons/godebug/pretty"
dsig "github.com/russellhaering/goxmldsig" dsig "github.com/russellhaering/goxmldsig"
"github.com/sirupsen/logrus"
"github.com/dexidp/dex/connector" "github.com/dexidp/dex/connector"
) )
@ -420,7 +421,7 @@ func (r responseTest) run(t *testing.T) {
t.Fatalf("parse test time: %v", err) t.Fatalf("parse test time: %v", err)
} }
conn, err := c.openConnector(logrus.New()) conn, err := c.openConnector(slog.New(slog.NewTextHandler(io.Discard, &slog.HandlerOptions{})))
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -454,7 +455,7 @@ func (r responseTest) run(t *testing.T) {
} }
func TestConfigCAData(t *testing.T) { func TestConfigCAData(t *testing.T) {
logger := logrus.New() logger := slog.New(slog.NewTextHandler(io.Discard, &slog.HandlerOptions{}))
validPEM, err := os.ReadFile("testdata/ca.crt") validPEM, err := os.ReadFile("testdata/ca.crt")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)

1
go.mod

@ -28,7 +28,6 @@ require (
github.com/pkg/errors v0.9.1 github.com/pkg/errors v0.9.1
github.com/prometheus/client_golang v1.19.1 github.com/prometheus/client_golang v1.19.1
github.com/russellhaering/goxmldsig v1.4.0 github.com/russellhaering/goxmldsig v1.4.0
github.com/sirupsen/logrus v1.9.3
github.com/spf13/cobra v1.8.0 github.com/spf13/cobra v1.8.0
github.com/stretchr/testify v1.9.0 github.com/stretchr/testify v1.9.0
go.etcd.io/etcd/client/pkg/v3 v3.5.14 go.etcd.io/etcd/client/pkg/v3 v3.5.14

3
go.sum

@ -191,8 +191,6 @@ github.com/sergi/go-diff v1.0.0 h1:Kpca3qRNrduNnOQeazBd0ysaKrUJiIuISHxogkT9RPQ=
github.com/sergi/go-diff v1.0.0/go.mod h1:0CfEIISq7TuYL3j771MWULgwwjU+GofnZX9QAmXWZgo= github.com/sergi/go-diff v1.0.0/go.mod h1:0CfEIISq7TuYL3j771MWULgwwjU+GofnZX9QAmXWZgo=
github.com/shopspring/decimal v1.2.0 h1:abSATXmQEYyShuxI4/vyW3tV1MrKAJzCZ/0zLUXYbsQ= github.com/shopspring/decimal v1.2.0 h1:abSATXmQEYyShuxI4/vyW3tV1MrKAJzCZ/0zLUXYbsQ=
github.com/shopspring/decimal v1.2.0/go.mod h1:DKyhrW/HYNuLGql+MJL6WCR6knT2jwCFRcu2hWCYk4o= github.com/shopspring/decimal v1.2.0/go.mod h1:DKyhrW/HYNuLGql+MJL6WCR6knT2jwCFRcu2hWCYk4o=
github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ=
github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
github.com/spf13/cast v1.3.1/go.mod h1:Qx5cxh0v+4UWYiBimWS+eyWzqEqokIECu5etghLkUJE= github.com/spf13/cast v1.3.1/go.mod h1:Qx5cxh0v+4UWYiBimWS+eyWzqEqokIECu5etghLkUJE=
github.com/spf13/cast v1.4.1 h1:s0hze+J0196ZfEMTs80N7UlFt0BDuQ7Q+JDnHiMWKdA= github.com/spf13/cast v1.4.1 h1:s0hze+J0196ZfEMTs80N7UlFt0BDuQ7Q+JDnHiMWKdA=
github.com/spf13/cast v1.4.1/go.mod h1:Qx5cxh0v+4UWYiBimWS+eyWzqEqokIECu5etghLkUJE= github.com/spf13/cast v1.4.1/go.mod h1:Qx5cxh0v+4UWYiBimWS+eyWzqEqokIECu5etghLkUJE=
@ -304,7 +302,6 @@ golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7w
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=

5
pkg/log/deprecated.go

@ -1,5 +0,0 @@
package log
func Deprecated(logger Logger, f string, args ...interface{}) {
logger.Warnf("Deprecated: "+f, args...)
}

18
pkg/log/logger.go

@ -1,18 +0,0 @@
// Package log provides a logger interface for logger libraries
// so that dex does not depend on any of them directly.
// It also includes a default implementation using Logrus (used by dex previously).
package log
// Logger serves as an adapter interface for logger libraries
// so that dex does not depend on any of them directly.
type Logger interface {
Debug(args ...interface{})
Info(args ...interface{})
Warn(args ...interface{})
Error(args ...interface{})
Debugf(format string, args ...interface{})
Infof(format string, args ...interface{})
Warnf(format string, args ...interface{})
Errorf(format string, args ...interface{})
}

38
server/api.go

@ -4,11 +4,11 @@ import (
"context" "context"
"errors" "errors"
"fmt" "fmt"
"log/slog"
"golang.org/x/crypto/bcrypt" "golang.org/x/crypto/bcrypt"
"github.com/dexidp/dex/api/v2" "github.com/dexidp/dex/api/v2"
"github.com/dexidp/dex/pkg/log"
"github.com/dexidp/dex/server/internal" "github.com/dexidp/dex/server/internal"
"github.com/dexidp/dex/storage" "github.com/dexidp/dex/storage"
) )
@ -29,10 +29,10 @@ const (
) )
// NewAPI returns a server which implements the gRPC API interface. // NewAPI returns a server which implements the gRPC API interface.
func NewAPI(s storage.Storage, logger log.Logger, version string) api.DexServer { func NewAPI(s storage.Storage, logger *slog.Logger, version string) api.DexServer {
return dexAPI{ return dexAPI{
s: s, s: s,
logger: logger, logger: logger.With("component", "api"),
version: version, version: version,
} }
} }
@ -41,7 +41,7 @@ type dexAPI struct {
api.UnimplementedDexServer api.UnimplementedDexServer
s storage.Storage s storage.Storage
logger log.Logger logger *slog.Logger
version string version string
} }
@ -89,7 +89,7 @@ func (d dexAPI) CreateClient(ctx context.Context, req *api.CreateClientReq) (*ap
if err == storage.ErrAlreadyExists { if err == storage.ErrAlreadyExists {
return &api.CreateClientResp{AlreadyExists: true}, nil return &api.CreateClientResp{AlreadyExists: true}, nil
} }
d.logger.Errorf("api: failed to create client: %v", err) d.logger.Error("failed to create client", "err", err)
return nil, fmt.Errorf("create client: %v", err) return nil, fmt.Errorf("create client: %v", err)
} }
@ -122,7 +122,7 @@ func (d dexAPI) UpdateClient(ctx context.Context, req *api.UpdateClientReq) (*ap
if err == storage.ErrNotFound { if err == storage.ErrNotFound {
return &api.UpdateClientResp{NotFound: true}, nil return &api.UpdateClientResp{NotFound: true}, nil
} }
d.logger.Errorf("api: failed to update the client: %v", err) d.logger.Error("failed to update the client", "err", err)
return nil, fmt.Errorf("update client: %v", err) return nil, fmt.Errorf("update client: %v", err)
} }
return &api.UpdateClientResp{}, nil return &api.UpdateClientResp{}, nil
@ -134,7 +134,7 @@ func (d dexAPI) DeleteClient(ctx context.Context, req *api.DeleteClientReq) (*ap
if err == storage.ErrNotFound { if err == storage.ErrNotFound {
return &api.DeleteClientResp{NotFound: true}, nil return &api.DeleteClientResp{NotFound: true}, nil
} }
d.logger.Errorf("api: failed to delete client: %v", err) d.logger.Error("failed to delete client", "err", err)
return nil, fmt.Errorf("delete client: %v", err) return nil, fmt.Errorf("delete client: %v", err)
} }
return &api.DeleteClientResp{}, nil return &api.DeleteClientResp{}, nil
@ -181,7 +181,7 @@ func (d dexAPI) CreatePassword(ctx context.Context, req *api.CreatePasswordReq)
if err == storage.ErrAlreadyExists { if err == storage.ErrAlreadyExists {
return &api.CreatePasswordResp{AlreadyExists: true}, nil return &api.CreatePasswordResp{AlreadyExists: true}, nil
} }
d.logger.Errorf("api: failed to create password: %v", err) d.logger.Error("failed to create password", "err", err)
return nil, fmt.Errorf("create password: %v", err) return nil, fmt.Errorf("create password: %v", err)
} }
@ -218,7 +218,7 @@ func (d dexAPI) UpdatePassword(ctx context.Context, req *api.UpdatePasswordReq)
if err == storage.ErrNotFound { if err == storage.ErrNotFound {
return &api.UpdatePasswordResp{NotFound: true}, nil return &api.UpdatePasswordResp{NotFound: true}, nil
} }
d.logger.Errorf("api: failed to update password: %v", err) d.logger.Error("failed to update password", "err", err)
return nil, fmt.Errorf("update password: %v", err) return nil, fmt.Errorf("update password: %v", err)
} }
@ -235,7 +235,7 @@ func (d dexAPI) DeletePassword(ctx context.Context, req *api.DeletePasswordReq)
if err == storage.ErrNotFound { if err == storage.ErrNotFound {
return &api.DeletePasswordResp{NotFound: true}, nil return &api.DeletePasswordResp{NotFound: true}, nil
} }
d.logger.Errorf("api: failed to delete password: %v", err) d.logger.Error("failed to delete password", "err", err)
return nil, fmt.Errorf("delete password: %v", err) return nil, fmt.Errorf("delete password: %v", err)
} }
return &api.DeletePasswordResp{}, nil return &api.DeletePasswordResp{}, nil
@ -251,7 +251,7 @@ func (d dexAPI) GetVersion(ctx context.Context, req *api.VersionReq) (*api.Versi
func (d dexAPI) ListPasswords(ctx context.Context, req *api.ListPasswordReq) (*api.ListPasswordResp, error) { func (d dexAPI) ListPasswords(ctx context.Context, req *api.ListPasswordReq) (*api.ListPasswordResp, error) {
passwordList, err := d.s.ListPasswords() passwordList, err := d.s.ListPasswords()
if err != nil { if err != nil {
d.logger.Errorf("api: failed to list passwords: %v", err) d.logger.Error("failed to list passwords", "err", err)
return nil, fmt.Errorf("list passwords: %v", err) return nil, fmt.Errorf("list passwords: %v", err)
} }
@ -286,12 +286,12 @@ func (d dexAPI) VerifyPassword(ctx context.Context, req *api.VerifyPasswordReq)
NotFound: true, NotFound: true,
}, nil }, nil
} }
d.logger.Errorf("api: there was an error retrieving the password: %v", err) d.logger.Error("there was an error retrieving the password", "err", err)
return nil, fmt.Errorf("verify password: %v", err) return nil, fmt.Errorf("verify password: %v", err)
} }
if err := bcrypt.CompareHashAndPassword(password.Hash, []byte(req.Password)); err != nil { if err := bcrypt.CompareHashAndPassword(password.Hash, []byte(req.Password)); err != nil {
d.logger.Infof("api: password check failed: %v", err) d.logger.Info("password check failed", "err", err)
return &api.VerifyPasswordResp{ return &api.VerifyPasswordResp{
Verified: false, Verified: false,
}, nil }, nil
@ -304,7 +304,7 @@ func (d dexAPI) VerifyPassword(ctx context.Context, req *api.VerifyPasswordReq)
func (d dexAPI) ListRefresh(ctx context.Context, req *api.ListRefreshReq) (*api.ListRefreshResp, error) { func (d dexAPI) ListRefresh(ctx context.Context, req *api.ListRefreshReq) (*api.ListRefreshResp, error) {
id := new(internal.IDTokenSubject) id := new(internal.IDTokenSubject)
if err := internal.Unmarshal(req.UserId, id); err != nil { if err := internal.Unmarshal(req.UserId, id); err != nil {
d.logger.Errorf("api: failed to unmarshal ID Token subject: %v", err) d.logger.Error("failed to unmarshal ID Token subject", "err", err)
return nil, err return nil, err
} }
@ -315,7 +315,7 @@ func (d dexAPI) ListRefresh(ctx context.Context, req *api.ListRefreshReq) (*api.
// An empty list should be returned instead of an error. // An empty list should be returned instead of an error.
return &api.ListRefreshResp{}, nil return &api.ListRefreshResp{}, nil
} }
d.logger.Errorf("api: failed to list refresh tokens %t here : %v", err == storage.ErrNotFound, err) d.logger.Error("failed to list refresh tokens here", "err", err)
return nil, err return nil, err
} }
@ -338,7 +338,7 @@ func (d dexAPI) ListRefresh(ctx context.Context, req *api.ListRefreshReq) (*api.
func (d dexAPI) RevokeRefresh(ctx context.Context, req *api.RevokeRefreshReq) (*api.RevokeRefreshResp, error) { func (d dexAPI) RevokeRefresh(ctx context.Context, req *api.RevokeRefreshReq) (*api.RevokeRefreshResp, error) {
id := new(internal.IDTokenSubject) id := new(internal.IDTokenSubject)
if err := internal.Unmarshal(req.UserId, id); err != nil { if err := internal.Unmarshal(req.UserId, id); err != nil {
d.logger.Errorf("api: failed to unmarshal ID Token subject: %v", err) d.logger.Error("failed to unmarshal ID Token subject", "err", err)
return nil, err return nil, err
} }
@ -349,7 +349,7 @@ func (d dexAPI) RevokeRefresh(ctx context.Context, req *api.RevokeRefreshReq) (*
updater := func(old storage.OfflineSessions) (storage.OfflineSessions, error) { updater := func(old storage.OfflineSessions) (storage.OfflineSessions, error) {
refreshRef := old.Refresh[req.ClientId] refreshRef := old.Refresh[req.ClientId]
if refreshRef == nil || refreshRef.ID == "" { if refreshRef == nil || refreshRef.ID == "" {
d.logger.Errorf("api: refresh token issued to client %q for user %q not found for deletion", req.ClientId, id.UserId) d.logger.Error("refresh token issued to client not found for deletion", "client_id", req.ClientId, "user_id", id.UserId)
notFound = true notFound = true
return old, storage.ErrNotFound return old, storage.ErrNotFound
} }
@ -366,7 +366,7 @@ func (d dexAPI) RevokeRefresh(ctx context.Context, req *api.RevokeRefreshReq) (*
if err == storage.ErrNotFound { if err == storage.ErrNotFound {
return &api.RevokeRefreshResp{NotFound: true}, nil return &api.RevokeRefreshResp{NotFound: true}, nil
} }
d.logger.Errorf("api: failed to update offline session object: %v", err) d.logger.Error("failed to update offline session object", "err", err)
return nil, err return nil, err
} }
@ -379,7 +379,7 @@ func (d dexAPI) RevokeRefresh(ctx context.Context, req *api.RevokeRefreshReq) (*
// TODO(ericchiang): we don't have any good recourse if this call fails. // TODO(ericchiang): we don't have any good recourse if this call fails.
// Consider garbage collection of refresh tokens with no associated ref. // Consider garbage collection of refresh tokens with no associated ref.
if err := d.s.DeleteRefresh(refreshID); err != nil { if err := d.s.DeleteRefresh(refreshID); err != nil {
d.logger.Errorf("failed to delete refresh token: %v", err) d.logger.Error("failed to delete refresh token", "err", err)
return nil, err return nil, err
} }

31
server/api_test.go

@ -2,17 +2,16 @@ package server
import ( import (
"context" "context"
"io"
"log/slog"
"net" "net"
"os"
"testing" "testing"
"time" "time"
"github.com/sirupsen/logrus"
"google.golang.org/grpc" "google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/credentials/insecure"
"github.com/dexidp/dex/api/v2" "github.com/dexidp/dex/api/v2"
"github.com/dexidp/dex/pkg/log"
"github.com/dexidp/dex/server/internal" "github.com/dexidp/dex/server/internal"
"github.com/dexidp/dex/storage" "github.com/dexidp/dex/storage"
"github.com/dexidp/dex/storage/memory" "github.com/dexidp/dex/storage/memory"
@ -30,7 +29,7 @@ type apiClient struct {
} }
// newAPI constructs a gRCP client connected to a backing server. // newAPI constructs a gRCP client connected to a backing server.
func newAPI(s storage.Storage, logger log.Logger, t *testing.T) *apiClient { func newAPI(s storage.Storage, logger *slog.Logger, t *testing.T) *apiClient {
l, err := net.Listen("tcp", "127.0.0.1:0") l, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@ -59,11 +58,7 @@ func newAPI(s storage.Storage, logger log.Logger, t *testing.T) *apiClient {
// Attempts to create, update and delete a test Password // Attempts to create, update and delete a test Password
func TestPassword(t *testing.T) { func TestPassword(t *testing.T) {
logger := &logrus.Logger{ logger := slog.New(slog.NewTextHandler(io.Discard, &slog.HandlerOptions{}))
Out: os.Stderr,
Formatter: &logrus.TextFormatter{DisableColors: true},
Level: logrus.DebugLevel,
}
s := memory.New(logger) s := memory.New(logger)
client := newAPI(s, logger, t) client := newAPI(s, logger, t)
@ -172,11 +167,7 @@ func TestPassword(t *testing.T) {
// Ensures checkCost returns expected values // Ensures checkCost returns expected values
func TestCheckCost(t *testing.T) { func TestCheckCost(t *testing.T) {
logger := &logrus.Logger{ logger := slog.New(slog.NewTextHandler(io.Discard, &slog.HandlerOptions{}))
Out: os.Stderr,
Formatter: &logrus.TextFormatter{DisableColors: true},
Level: logrus.DebugLevel,
}
s := memory.New(logger) s := memory.New(logger)
client := newAPI(s, logger, t) client := newAPI(s, logger, t)
@ -229,11 +220,7 @@ func TestCheckCost(t *testing.T) {
// Attempts to list and revoke an existing refresh token. // Attempts to list and revoke an existing refresh token.
func TestRefreshToken(t *testing.T) { func TestRefreshToken(t *testing.T) {
logger := &logrus.Logger{ logger := slog.New(slog.NewTextHandler(io.Discard, &slog.HandlerOptions{}))
Out: os.Stderr,
Formatter: &logrus.TextFormatter{DisableColors: true},
Level: logrus.DebugLevel,
}
s := memory.New(logger) s := memory.New(logger)
client := newAPI(s, logger, t) client := newAPI(s, logger, t)
@ -342,11 +329,7 @@ func TestRefreshToken(t *testing.T) {
} }
func TestUpdateClient(t *testing.T) { func TestUpdateClient(t *testing.T) {
logger := &logrus.Logger{ logger := slog.New(slog.NewTextHandler(io.Discard, &slog.HandlerOptions{}))
Out: os.Stderr,
Formatter: &logrus.TextFormatter{DisableColors: true},
Level: logrus.DebugLevel,
}
s := memory.New(logger) s := memory.New(logger)
client := newAPI(s, logger, t) client := newAPI(s, logger, t)

45
server/deviceflowhandlers.go

@ -13,7 +13,6 @@ import (
"golang.org/x/net/html" "golang.org/x/net/html"
"github.com/dexidp/dex/pkg/log"
"github.com/dexidp/dex/storage" "github.com/dexidp/dex/storage"
) )
@ -49,7 +48,7 @@ func (s *Server) handleDeviceExchange(w http.ResponseWriter, r *http.Request) {
invalidAttempt = false invalidAttempt = false
} }
if err := s.templates.device(r, w, s.getDeviceVerificationURI(), userCode, invalidAttempt); err != nil { if err := s.templates.device(r, w, s.getDeviceVerificationURI(), userCode, invalidAttempt); err != nil {
s.logger.Errorf("Server template error: %v", err) s.logger.Error("server template error", "err", err)
s.renderError(r, w, http.StatusNotFound, "Page not found") s.renderError(r, w, http.StatusNotFound, "Page not found")
} }
default: default:
@ -65,7 +64,7 @@ func (s *Server) handleDeviceCode(w http.ResponseWriter, r *http.Request) {
case http.MethodPost: case http.MethodPost:
err := r.ParseForm() err := r.ParseForm()
if err != nil { if err != nil {
s.logger.Errorf("Could not parse Device Request body: %v", err) s.logger.Error("could not parse Device Request body", "err", err)
s.tokenErrHelper(w, errInvalidRequest, "", http.StatusNotFound) s.tokenErrHelper(w, errInvalidRequest, "", http.StatusNotFound)
return return
} }
@ -86,7 +85,7 @@ func (s *Server) handleDeviceCode(w http.ResponseWriter, r *http.Request) {
return return
} }
s.logger.Infof("Received device request for client %v with scopes %v", clientID, scopes) s.logger.Info("received device request", "client_id", clientID, "scoped", scopes)
// Make device code // Make device code
deviceCode := storage.NewDeviceCode() deviceCode := storage.NewDeviceCode()
@ -108,7 +107,7 @@ func (s *Server) handleDeviceCode(w http.ResponseWriter, r *http.Request) {
} }
if err := s.storage.CreateDeviceRequest(ctx, deviceReq); err != nil { if err := s.storage.CreateDeviceRequest(ctx, deviceReq); err != nil {
s.logger.Errorf("Failed to store device request; %v", err) s.logger.Error("failed to store device request", "err", err)
s.tokenErrHelper(w, errInvalidRequest, "", http.StatusInternalServerError) s.tokenErrHelper(w, errInvalidRequest, "", http.StatusInternalServerError)
return return
} }
@ -127,14 +126,14 @@ func (s *Server) handleDeviceCode(w http.ResponseWriter, r *http.Request) {
} }
if err := s.storage.CreateDeviceToken(ctx, deviceToken); err != nil { if err := s.storage.CreateDeviceToken(ctx, deviceToken); err != nil {
s.logger.Errorf("Failed to store device token %v", err) s.logger.Error("failed to store device token", "err", err)
s.tokenErrHelper(w, errInvalidRequest, "", http.StatusInternalServerError) s.tokenErrHelper(w, errInvalidRequest, "", http.StatusInternalServerError)
return return
} }
u, err := url.Parse(s.issuerURL.String()) u, err := url.Parse(s.issuerURL.String())
if err != nil { if err != nil {
s.logger.Errorf("Could not parse issuer URL %v", err) s.logger.Error("could not parse issuer URL", "err", err)
s.tokenErrHelper(w, errInvalidRequest, "", http.StatusInternalServerError) s.tokenErrHelper(w, errInvalidRequest, "", http.StatusInternalServerError)
return return
} }
@ -175,14 +174,14 @@ func (s *Server) handleDeviceCode(w http.ResponseWriter, r *http.Request) {
} }
func (s *Server) handleDeviceTokenDeprecated(w http.ResponseWriter, r *http.Request) { func (s *Server) handleDeviceTokenDeprecated(w http.ResponseWriter, r *http.Request) {
log.Deprecated(s.logger, `The /device/token endpoint was called. It will be removed, use /token instead.`) s.logger.Warn(`the /device/token endpoint was called. It will be removed, use /token instead.`, "deprecated", true)
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
switch r.Method { switch r.Method {
case http.MethodPost: case http.MethodPost:
err := r.ParseForm() err := r.ParseForm()
if err != nil { if err != nil {
s.logger.Warnf("Could not parse Device Token Request body: %v", err) s.logger.Warn("could not parse Device Token Request body", "err", err)
s.tokenErrHelper(w, errInvalidRequest, "", http.StatusBadRequest) s.tokenErrHelper(w, errInvalidRequest, "", http.StatusBadRequest)
return return
} }
@ -212,7 +211,7 @@ func (s *Server) handleDeviceToken(w http.ResponseWriter, r *http.Request) {
deviceToken, err := s.storage.GetDeviceToken(deviceCode) deviceToken, err := s.storage.GetDeviceToken(deviceCode)
if err != nil { if err != nil {
if err != storage.ErrNotFound { if err != storage.ErrNotFound {
s.logger.Errorf("failed to get device code: %v", err) s.logger.Error("failed to get device code", "err", err)
} }
s.tokenErrHelper(w, errInvalidRequest, "Invalid Device code.", http.StatusBadRequest) s.tokenErrHelper(w, errInvalidRequest, "Invalid Device code.", http.StatusBadRequest)
return return
@ -242,7 +241,7 @@ func (s *Server) handleDeviceToken(w http.ResponseWriter, r *http.Request) {
} }
// Update device token last request time in storage // Update device token last request time in storage
if err := s.storage.UpdateDeviceToken(deviceCode, updater); err != nil { if err := s.storage.UpdateDeviceToken(deviceCode, updater); err != nil {
s.logger.Errorf("failed to update device token: %v", err) s.logger.Error("failed to update device token", "err", err)
s.renderError(r, w, http.StatusInternalServerError, "") s.renderError(r, w, http.StatusInternalServerError, "")
return return
} }
@ -259,7 +258,7 @@ func (s *Server) handleDeviceToken(w http.ResponseWriter, r *http.Request) {
case providedCodeVerifier != "" && codeChallengeFromStorage != "": case providedCodeVerifier != "" && codeChallengeFromStorage != "":
calculatedCodeChallenge, err := s.calculateCodeChallenge(providedCodeVerifier, deviceToken.PKCE.CodeChallengeMethod) calculatedCodeChallenge, err := s.calculateCodeChallenge(providedCodeVerifier, deviceToken.PKCE.CodeChallengeMethod)
if err != nil { if err != nil {
s.logger.Error(err) s.logger.Error("failed to calculate code challenge", "err", err)
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
return return
} }
@ -304,7 +303,7 @@ func (s *Server) handleDeviceCallback(w http.ResponseWriter, r *http.Request) {
if err != nil || s.now().After(authCode.Expiry) { if err != nil || s.now().After(authCode.Expiry) {
errCode := http.StatusBadRequest errCode := http.StatusBadRequest
if err != nil && err != storage.ErrNotFound { if err != nil && err != storage.ErrNotFound {
s.logger.Errorf("failed to get auth code: %v", err) s.logger.Error("failed to get auth code", "err", err)
errCode = http.StatusInternalServerError errCode = http.StatusInternalServerError
} }
s.renderError(r, w, errCode, "Invalid or expired auth code.") s.renderError(r, w, errCode, "Invalid or expired auth code.")
@ -316,7 +315,7 @@ func (s *Server) handleDeviceCallback(w http.ResponseWriter, r *http.Request) {
if err != nil || s.now().After(deviceReq.Expiry) { if err != nil || s.now().After(deviceReq.Expiry) {
errCode := http.StatusBadRequest errCode := http.StatusBadRequest
if err != nil && err != storage.ErrNotFound { if err != nil && err != storage.ErrNotFound {
s.logger.Errorf("failed to get device code: %v", err) s.logger.Error("failed to get device code", "err", err)
errCode = http.StatusInternalServerError errCode = http.StatusInternalServerError
} }
s.renderError(r, w, errCode, "Invalid or expired user code.") s.renderError(r, w, errCode, "Invalid or expired user code.")
@ -326,7 +325,7 @@ func (s *Server) handleDeviceCallback(w http.ResponseWriter, r *http.Request) {
client, err := s.storage.GetClient(deviceReq.ClientID) client, err := s.storage.GetClient(deviceReq.ClientID)
if err != nil { if err != nil {
if err != storage.ErrNotFound { if err != storage.ErrNotFound {
s.logger.Errorf("failed to get client: %v", err) s.logger.Error("failed to get client", "err", err)
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
} else { } else {
s.tokenErrHelper(w, errInvalidClient, "Invalid client credentials.", http.StatusUnauthorized) s.tokenErrHelper(w, errInvalidClient, "Invalid client credentials.", http.StatusUnauthorized)
@ -340,7 +339,7 @@ func (s *Server) handleDeviceCallback(w http.ResponseWriter, r *http.Request) {
resp, err := s.exchangeAuthCode(ctx, w, authCode, client) resp, err := s.exchangeAuthCode(ctx, w, authCode, client)
if err != nil { if err != nil {
s.logger.Errorf("Could not exchange auth code for client %q: %v", deviceReq.ClientID, err) s.logger.Error("could not exchange auth code for clien", "client_id", deviceReq.ClientID, "err", err)
s.renderError(r, w, http.StatusInternalServerError, "Failed to exchange auth code.") s.renderError(r, w, http.StatusInternalServerError, "Failed to exchange auth code.")
return return
} }
@ -350,7 +349,7 @@ func (s *Server) handleDeviceCallback(w http.ResponseWriter, r *http.Request) {
if err != nil || s.now().After(old.Expiry) { if err != nil || s.now().After(old.Expiry) {
errCode := http.StatusBadRequest errCode := http.StatusBadRequest
if err != nil && err != storage.ErrNotFound { if err != nil && err != storage.ErrNotFound {
s.logger.Errorf("failed to get device token: %v", err) s.logger.Error("failed to get device token", "err", err)
errCode = http.StatusInternalServerError errCode = http.StatusInternalServerError
} }
s.renderError(r, w, errCode, "Invalid or expired device code.") s.renderError(r, w, errCode, "Invalid or expired device code.")
@ -363,7 +362,7 @@ func (s *Server) handleDeviceCallback(w http.ResponseWriter, r *http.Request) {
} }
respStr, err := json.MarshalIndent(resp, "", " ") respStr, err := json.MarshalIndent(resp, "", " ")
if err != nil { if err != nil {
s.logger.Errorf("failed to marshal device token response: %v", err) s.logger.Error("failed to marshal device token response", "err", err)
s.renderError(r, w, http.StatusInternalServerError, "") s.renderError(r, w, http.StatusInternalServerError, "")
return old, err return old, err
} }
@ -375,13 +374,13 @@ func (s *Server) handleDeviceCallback(w http.ResponseWriter, r *http.Request) {
// Update refresh token in the storage, store the token and mark as complete // Update refresh token in the storage, store the token and mark as complete
if err := s.storage.UpdateDeviceToken(deviceReq.DeviceCode, updater); err != nil { if err := s.storage.UpdateDeviceToken(deviceReq.DeviceCode, updater); err != nil {
s.logger.Errorf("failed to update device token: %v", err) s.logger.Error("failed to update device token", "err", err)
s.renderError(r, w, http.StatusBadRequest, "") s.renderError(r, w, http.StatusBadRequest, "")
return return
} }
if err := s.templates.deviceSuccess(r, w, client.Name); err != nil { if err := s.templates.deviceSuccess(r, w, client.Name); err != nil {
s.logger.Errorf("Server template error: %v", err) s.logger.Error("Server template error", "err", err)
s.renderError(r, w, http.StatusNotFound, "Page not found") s.renderError(r, w, http.StatusNotFound, "Page not found")
} }
@ -396,7 +395,7 @@ func (s *Server) verifyUserCode(w http.ResponseWriter, r *http.Request) {
case http.MethodPost: case http.MethodPost:
err := r.ParseForm() err := r.ParseForm()
if err != nil { if err != nil {
s.logger.Warnf("Could not parse user code verification request body : %v", err) s.logger.Warn("could not parse user code verification request body", "err", err)
s.renderError(r, w, http.StatusBadRequest, "") s.renderError(r, w, http.StatusBadRequest, "")
return return
} }
@ -413,10 +412,10 @@ func (s *Server) verifyUserCode(w http.ResponseWriter, r *http.Request) {
deviceRequest, err := s.storage.GetDeviceRequest(userCode) deviceRequest, err := s.storage.GetDeviceRequest(userCode)
if err != nil || s.now().After(deviceRequest.Expiry) { if err != nil || s.now().After(deviceRequest.Expiry) {
if err != nil && err != storage.ErrNotFound { if err != nil && err != storage.ErrNotFound {
s.logger.Errorf("failed to get device request: %v", err) s.logger.Error("failed to get device request", "err", err)
} }
if err := s.templates.device(r, w, s.getDeviceVerificationURI(), userCode, true); err != nil { if err := s.templates.device(r, w, s.getDeviceVerificationURI(), userCode, true); err != nil {
s.logger.Errorf("Server template error: %v", err) s.logger.Error("Server template error", "err", err)
s.renderError(r, w, http.StatusNotFound, "Page not found") s.renderError(r, w, http.StatusNotFound, "Page not found")
} }
return return

175
server/handlers.go

@ -35,13 +35,13 @@ func (s *Server) handlePublicKeys(w http.ResponseWriter, r *http.Request) {
// TODO(ericchiang): Cache this. // TODO(ericchiang): Cache this.
keys, err := s.storage.GetKeys() keys, err := s.storage.GetKeys()
if err != nil { if err != nil {
s.logger.Errorf("failed to get keys: %v", err) s.logger.Error("failed to get keys", "err", err)
s.renderError(r, w, http.StatusInternalServerError, "Internal server error.") s.renderError(r, w, http.StatusInternalServerError, "Internal server error.")
return return
} }
if keys.SigningKeyPub == nil { if keys.SigningKeyPub == nil {
s.logger.Errorf("No public keys found.") s.logger.Error("no public keys found.")
s.renderError(r, w, http.StatusInternalServerError, "Internal server error.") s.renderError(r, w, http.StatusInternalServerError, "Internal server error.")
return return
} }
@ -56,7 +56,7 @@ func (s *Server) handlePublicKeys(w http.ResponseWriter, r *http.Request) {
data, err := json.MarshalIndent(jwks, "", " ") data, err := json.MarshalIndent(jwks, "", " ")
if err != nil { if err != nil {
s.logger.Errorf("failed to marshal discovery data: %v", err) s.logger.Error("failed to marshal discovery data", "err", err)
s.renderError(r, w, http.StatusInternalServerError, "Internal server error.") s.renderError(r, w, http.StatusInternalServerError, "Internal server error.")
return return
} }
@ -132,7 +132,7 @@ func (s *Server) discoveryHandler() (http.HandlerFunc, error) {
func (s *Server) handleAuthorization(w http.ResponseWriter, r *http.Request) { func (s *Server) handleAuthorization(w http.ResponseWriter, r *http.Request) {
// Extract the arguments // Extract the arguments
if err := r.ParseForm(); err != nil { if err := r.ParseForm(); err != nil {
s.logger.Errorf("Failed to parse arguments: %v", err) s.logger.Error("failed to parse arguments", "err", err)
s.renderError(r, w, http.StatusBadRequest, err.Error()) s.renderError(r, w, http.StatusBadRequest, err.Error())
return return
@ -142,7 +142,7 @@ func (s *Server) handleAuthorization(w http.ResponseWriter, r *http.Request) {
connectors, err := s.storage.ListConnectors() connectors, err := s.storage.ListConnectors()
if err != nil { if err != nil {
s.logger.Errorf("Failed to get list of connectors: %v", err) s.logger.Error("failed to get list of connectors", "err", err)
s.renderError(r, w, http.StatusInternalServerError, "Failed to retrieve connector list.") s.renderError(r, w, http.StatusInternalServerError, "Failed to retrieve connector list.")
return return
} }
@ -185,7 +185,7 @@ func (s *Server) handleAuthorization(w http.ResponseWriter, r *http.Request) {
} }
if err := s.templates.login(r, w, connectorInfos); err != nil { if err := s.templates.login(r, w, connectorInfos); err != nil {
s.logger.Errorf("Server template error: %v", err) s.logger.Error("server template error", "err", err)
} }
} }
@ -193,7 +193,7 @@ func (s *Server) handleConnectorLogin(w http.ResponseWriter, r *http.Request) {
ctx := r.Context() ctx := r.Context()
authReq, err := s.parseAuthorizationRequest(r) authReq, err := s.parseAuthorizationRequest(r)
if err != nil { if err != nil {
s.logger.Errorf("Failed to parse authorization request: %v", err) s.logger.Error("failed to parse authorization request", "err", err)
switch authErr := err.(type) { switch authErr := err.(type) {
case *redirectedAuthErr: case *redirectedAuthErr:
@ -209,22 +209,22 @@ func (s *Server) handleConnectorLogin(w http.ResponseWriter, r *http.Request) {
connID, err := url.PathUnescape(mux.Vars(r)["connector"]) connID, err := url.PathUnescape(mux.Vars(r)["connector"])
if err != nil { if err != nil {
s.logger.Errorf("Failed to parse connector: %v", err) s.logger.Error("failed to parse connector", "err", err)
s.renderError(r, w, http.StatusBadRequest, "Requested resource does not exist") s.renderError(r, w, http.StatusBadRequest, "Requested resource does not exist")
return return
} }
conn, err := s.getConnector(connID) conn, err := s.getConnector(connID)
if err != nil { if err != nil {
s.logger.Errorf("Failed to get connector: %v", err) s.logger.Error("Failed to get connector", "err", err)
s.renderError(r, w, http.StatusBadRequest, "Requested resource does not exist") s.renderError(r, w, http.StatusBadRequest, "Requested resource does not exist")
return return
} }
// Set the connector being used for the login. // Set the connector being used for the login.
if authReq.ConnectorID != "" && authReq.ConnectorID != connID { if authReq.ConnectorID != "" && authReq.ConnectorID != connID {
s.logger.Errorf("Mismatched connector ID in auth request: %s vs %s", s.logger.Error("mismatched connector ID in auth request",
authReq.ConnectorID, connID) "auth_request_connector_id", authReq.ConnectorID, "connector_id", connID)
s.renderError(r, w, http.StatusBadRequest, "Bad connector ID") s.renderError(r, w, http.StatusBadRequest, "Bad connector ID")
return return
} }
@ -234,7 +234,7 @@ func (s *Server) handleConnectorLogin(w http.ResponseWriter, r *http.Request) {
// Actually create the auth request // Actually create the auth request
authReq.Expiry = s.now().Add(s.authRequestsValidFor) authReq.Expiry = s.now().Add(s.authRequestsValidFor)
if err := s.storage.CreateAuthRequest(ctx, *authReq); err != nil { if err := s.storage.CreateAuthRequest(ctx, *authReq); err != nil {
s.logger.Errorf("Failed to create authorization request: %v", err) s.logger.Error("failed to create authorization request", "err", err)
s.renderError(r, w, http.StatusInternalServerError, "Failed to connect to the database.") s.renderError(r, w, http.StatusInternalServerError, "Failed to connect to the database.")
return return
} }
@ -260,7 +260,7 @@ func (s *Server) handleConnectorLogin(w http.ResponseWriter, r *http.Request) {
// TODO(ericchiang): Is this appropriate or should we also be using a nonce? // TODO(ericchiang): Is this appropriate or should we also be using a nonce?
callbackURL, err := conn.LoginURL(scopes, s.absURL("/callback"), authReq.ID) callbackURL, err := conn.LoginURL(scopes, s.absURL("/callback"), authReq.ID)
if err != nil { if err != nil {
s.logger.Errorf("Connector %q returned error when creating callback: %v", connID, err) s.logger.Error("connector returned error when creating callback", "connector_id", connID, "err", err)
s.renderError(r, w, http.StatusInternalServerError, "Login error.") s.renderError(r, w, http.StatusInternalServerError, "Login error.")
return return
} }
@ -278,7 +278,7 @@ func (s *Server) handleConnectorLogin(w http.ResponseWriter, r *http.Request) {
case connector.SAMLConnector: case connector.SAMLConnector:
action, value, err := conn.POSTData(scopes, authReq.ID) action, value, err := conn.POSTData(scopes, authReq.ID)
if err != nil { if err != nil {
s.logger.Errorf("Creating SAML data: %v", err) s.logger.Error("creating SAML data", "err", err)
s.renderError(r, w, http.StatusInternalServerError, "Connector Login Error") s.renderError(r, w, http.StatusInternalServerError, "Connector Login Error")
return return
} }
@ -321,36 +321,36 @@ func (s *Server) handlePasswordLogin(w http.ResponseWriter, r *http.Request) {
authReq, err := s.storage.GetAuthRequest(authID) authReq, err := s.storage.GetAuthRequest(authID)
if err != nil { if err != nil {
if err == storage.ErrNotFound { if err == storage.ErrNotFound {
s.logger.Errorf("Invalid 'state' parameter provided: %v", err) s.logger.Error("invalid 'state' parameter provided", "err", err)
s.renderError(r, w, http.StatusBadRequest, "Requested resource does not exist.") s.renderError(r, w, http.StatusBadRequest, "Requested resource does not exist.")
return return
} }
s.logger.Errorf("Failed to get auth request: %v", err) s.logger.Error("failed to get auth request", "err", err)
s.renderError(r, w, http.StatusInternalServerError, "Database error.") s.renderError(r, w, http.StatusInternalServerError, "Database error.")
return return
} }
connID, err := url.PathUnescape(mux.Vars(r)["connector"]) connID, err := url.PathUnescape(mux.Vars(r)["connector"])
if err != nil { if err != nil {
s.logger.Errorf("Failed to parse connector: %v", err) s.logger.Error("failed to parse connector", "err", err)
s.renderError(r, w, http.StatusBadRequest, "Requested resource does not exist") s.renderError(r, w, http.StatusBadRequest, "Requested resource does not exist")
return return
} else if connID != "" && connID != authReq.ConnectorID { } else if connID != "" && connID != authReq.ConnectorID {
s.logger.Errorf("Connector mismatch: authentication started with id %q, but password login for id %q was triggered", authReq.ConnectorID, connID) s.logger.Error("connector mismatch: password login triggered for different connector from authentication start", "start_connector_id", authReq.ConnectorID, "password_connector_id", connID)
s.renderError(r, w, http.StatusInternalServerError, "Requested resource does not exist.") s.renderError(r, w, http.StatusInternalServerError, "Requested resource does not exist.")
return return
} }
conn, err := s.getConnector(authReq.ConnectorID) conn, err := s.getConnector(authReq.ConnectorID)
if err != nil { if err != nil {
s.logger.Errorf("Failed to get connector with id %q : %v", authReq.ConnectorID, err) s.logger.Error("failed to get connector", "connector_id", authReq.ConnectorID, "err", err)
s.renderError(r, w, http.StatusInternalServerError, "Requested resource does not exist.") s.renderError(r, w, http.StatusInternalServerError, "Requested resource does not exist.")
return return
} }
pwConn, ok := conn.Connector.(connector.PasswordConnector) pwConn, ok := conn.Connector.(connector.PasswordConnector)
if !ok { if !ok {
s.logger.Errorf("Expected password connector in handlePasswordLogin(), but got %v", pwConn) s.logger.Error("expected password connector in handlePasswordLogin()", "password_connector", pwConn)
s.renderError(r, w, http.StatusInternalServerError, "Requested resource does not exist.") s.renderError(r, w, http.StatusInternalServerError, "Requested resource does not exist.")
return return
} }
@ -358,7 +358,7 @@ func (s *Server) handlePasswordLogin(w http.ResponseWriter, r *http.Request) {
switch r.Method { switch r.Method {
case http.MethodGet: case http.MethodGet:
if err := s.templates.password(r, w, r.URL.String(), "", usernamePrompt(pwConn), false, backLink); err != nil { if err := s.templates.password(r, w, r.URL.String(), "", usernamePrompt(pwConn), false, backLink); err != nil {
s.logger.Errorf("Server template error: %v", err) s.logger.Error("server template error", "err", err)
} }
case http.MethodPost: case http.MethodPost:
username := r.FormValue("login") username := r.FormValue("login")
@ -367,20 +367,20 @@ func (s *Server) handlePasswordLogin(w http.ResponseWriter, r *http.Request) {
identity, ok, err := pwConn.Login(ctx, scopes, username, password) identity, ok, err := pwConn.Login(ctx, scopes, username, password)
if err != nil { if err != nil {
s.logger.Errorf("Failed to login user: %v", err) s.logger.Error("failed to login user", "err", err)
s.renderError(r, w, http.StatusInternalServerError, fmt.Sprintf("Login error: %v", err)) s.renderError(r, w, http.StatusInternalServerError, fmt.Sprintf("Login error: %v", err))
return return
} }
if !ok { if !ok {
if err := s.templates.password(r, w, r.URL.String(), username, usernamePrompt(pwConn), true, backLink); err != nil { if err := s.templates.password(r, w, r.URL.String(), username, usernamePrompt(pwConn), true, backLink); err != nil {
s.logger.Errorf("Server template error: %v", err) s.logger.Error("server template error", "err", err)
} }
s.logger.Errorf("Failed login attempt for user: %q. Invalid credentials.", username) s.logger.Error("failed login attempt: Invalid credentials.", "user", username)
return return
} }
redirectURL, canSkipApproval, err := s.finalizeLogin(ctx, identity, authReq, conn.Connector) redirectURL, canSkipApproval, err := s.finalizeLogin(ctx, identity, authReq, conn.Connector)
if err != nil { if err != nil {
s.logger.Errorf("Failed to finalize login: %v", err) s.logger.Error("failed to finalize login", "err", err)
s.renderError(r, w, http.StatusInternalServerError, "Login error.") s.renderError(r, w, http.StatusInternalServerError, "Login error.")
return return
} }
@ -388,7 +388,7 @@ func (s *Server) handlePasswordLogin(w http.ResponseWriter, r *http.Request) {
if canSkipApproval { if canSkipApproval {
authReq, err = s.storage.GetAuthRequest(authReq.ID) authReq, err = s.storage.GetAuthRequest(authReq.ID)
if err != nil { if err != nil {
s.logger.Errorf("Failed to get finalized auth request: %v", err) s.logger.Error("failed to get finalized auth request", "err", err)
s.renderError(r, w, http.StatusInternalServerError, "Login error.") s.renderError(r, w, http.StatusInternalServerError, "Login error.")
return return
} }
@ -424,29 +424,29 @@ func (s *Server) handleConnectorCallback(w http.ResponseWriter, r *http.Request)
authReq, err := s.storage.GetAuthRequest(authID) authReq, err := s.storage.GetAuthRequest(authID)
if err != nil { if err != nil {
if err == storage.ErrNotFound { if err == storage.ErrNotFound {
s.logger.Errorf("Invalid 'state' parameter provided: %v", err) s.logger.Error("invalid 'state' parameter provided", "err", err)
s.renderError(r, w, http.StatusBadRequest, "Requested resource does not exist.") s.renderError(r, w, http.StatusBadRequest, "Requested resource does not exist.")
return return
} }
s.logger.Errorf("Failed to get auth request: %v", err) s.logger.Error("failed to get auth request", "err", err)
s.renderError(r, w, http.StatusInternalServerError, "Database error.") s.renderError(r, w, http.StatusInternalServerError, "Database error.")
return return
} }
connID, err := url.PathUnescape(mux.Vars(r)["connector"]) connID, err := url.PathUnescape(mux.Vars(r)["connector"])
if err != nil { if err != nil {
s.logger.Errorf("Failed to get connector with id %q : %v", authReq.ConnectorID, err) s.logger.Error("failed to get connector", "connector_id", authReq.ConnectorID, "err", err)
s.renderError(r, w, http.StatusInternalServerError, "Requested resource does not exist.") s.renderError(r, w, http.StatusInternalServerError, "Requested resource does not exist.")
return return
} else if connID != "" && connID != authReq.ConnectorID { } else if connID != "" && connID != authReq.ConnectorID {
s.logger.Errorf("Connector mismatch: authentication started with id %q, but callback for id %q was triggered", authReq.ConnectorID, connID) s.logger.Error("connector mismatch: callback triggered for different connector than authentication start", "authentication_start_connector_id", authReq.ConnectorID, "connector_id", connID)
s.renderError(r, w, http.StatusInternalServerError, "Requested resource does not exist.") s.renderError(r, w, http.StatusInternalServerError, "Requested resource does not exist.")
return return
} }
conn, err := s.getConnector(authReq.ConnectorID) conn, err := s.getConnector(authReq.ConnectorID)
if err != nil { if err != nil {
s.logger.Errorf("Failed to get connector with id %q : %v", authReq.ConnectorID, err) s.logger.Error("failed to get connector", "connector_id", authReq.ConnectorID, "err", err)
s.renderError(r, w, http.StatusInternalServerError, "Requested resource does not exist.") s.renderError(r, w, http.StatusInternalServerError, "Requested resource does not exist.")
return return
} }
@ -455,14 +455,14 @@ func (s *Server) handleConnectorCallback(w http.ResponseWriter, r *http.Request)
switch conn := conn.Connector.(type) { switch conn := conn.Connector.(type) {
case connector.CallbackConnector: case connector.CallbackConnector:
if r.Method != http.MethodGet { if r.Method != http.MethodGet {
s.logger.Errorf("SAML request mapped to OAuth2 connector") s.logger.Error("SAML request mapped to OAuth2 connector")
s.renderError(r, w, http.StatusBadRequest, "Invalid request") s.renderError(r, w, http.StatusBadRequest, "Invalid request")
return return
} }
identity, err = conn.HandleCallback(parseScopes(authReq.Scopes), r) identity, err = conn.HandleCallback(parseScopes(authReq.Scopes), r)
case connector.SAMLConnector: case connector.SAMLConnector:
if r.Method != http.MethodPost { if r.Method != http.MethodPost {
s.logger.Errorf("OAuth2 request mapped to SAML connector") s.logger.Error("OAuth2 request mapped to SAML connector")
s.renderError(r, w, http.StatusBadRequest, "Invalid request") s.renderError(r, w, http.StatusBadRequest, "Invalid request")
return return
} }
@ -473,14 +473,14 @@ func (s *Server) handleConnectorCallback(w http.ResponseWriter, r *http.Request)
} }
if err != nil { if err != nil {
s.logger.Errorf("Failed to authenticate: %v", err) s.logger.Error("failed to authenticate", "err", err)
s.renderError(r, w, http.StatusInternalServerError, fmt.Sprintf("Failed to authenticate: %v", err)) s.renderError(r, w, http.StatusInternalServerError, fmt.Sprintf("Failed to authenticate: %v", err))
return return
} }
redirectURL, canSkipApproval, err := s.finalizeLogin(ctx, identity, authReq, conn.Connector) redirectURL, canSkipApproval, err := s.finalizeLogin(ctx, identity, authReq, conn.Connector)
if err != nil { if err != nil {
s.logger.Errorf("Failed to finalize login: %v", err) s.logger.Error("failed to finalize login", "err", err)
s.renderError(r, w, http.StatusInternalServerError, "Login error.") s.renderError(r, w, http.StatusInternalServerError, "Login error.")
return return
} }
@ -488,7 +488,7 @@ func (s *Server) handleConnectorCallback(w http.ResponseWriter, r *http.Request)
if canSkipApproval { if canSkipApproval {
authReq, err = s.storage.GetAuthRequest(authReq.ID) authReq, err = s.storage.GetAuthRequest(authReq.ID)
if err != nil { if err != nil {
s.logger.Errorf("Failed to get finalized auth request: %v", err) s.logger.Error("failed to get finalized auth request", "err", err)
s.renderError(r, w, http.StatusInternalServerError, "Login error.") s.renderError(r, w, http.StatusInternalServerError, "Login error.")
return return
} }
@ -526,8 +526,9 @@ func (s *Server) finalizeLogin(ctx context.Context, identity connector.Identity,
email += " (unverified)" email += " (unverified)"
} }
s.logger.Infof("login successful: connector %q, username=%q, preferred_username=%q, email=%q, groups=%q", s.logger.Info("login successful",
authReq.ConnectorID, claims.Username, claims.PreferredUsername, email, claims.Groups) "connector_id", authReq.ConnectorID, "username", claims.Username,
"preferred_username", claims.PreferredUsername, "email", email, "groups", claims.Groups)
// we can skip the redirect to /approval and go ahead and send code if it's not required // we can skip the redirect to /approval and go ahead and send code if it's not required
if s.skipApproval && !authReq.ForceApprovalPrompt { if s.skipApproval && !authReq.ForceApprovalPrompt {
@ -561,7 +562,7 @@ func (s *Server) finalizeLogin(ctx context.Context, identity connector.Identity,
session, err := s.storage.GetOfflineSessions(identity.UserID, authReq.ConnectorID) session, err := s.storage.GetOfflineSessions(identity.UserID, authReq.ConnectorID)
if err != nil { if err != nil {
if err != storage.ErrNotFound { if err != storage.ErrNotFound {
s.logger.Errorf("failed to get offline session: %v", err) s.logger.Error("failed to get offline session", "err", err)
return "", false, err return "", false, err
} }
offlineSessions := storage.OfflineSessions{ offlineSessions := storage.OfflineSessions{
@ -574,7 +575,7 @@ func (s *Server) finalizeLogin(ctx context.Context, identity connector.Identity,
// Create a new OfflineSession object for the user and add a reference object for // Create a new OfflineSession object for the user and add a reference object for
// the newly received refreshtoken. // the newly received refreshtoken.
if err := s.storage.CreateOfflineSessions(ctx, offlineSessions); err != nil { if err := s.storage.CreateOfflineSessions(ctx, offlineSessions); err != nil {
s.logger.Errorf("failed to create offline session: %v", err) s.logger.Error("failed to create offline session", "err", err)
return "", false, err return "", false, err
} }
@ -588,7 +589,7 @@ func (s *Server) finalizeLogin(ctx context.Context, identity connector.Identity,
} }
return old, nil return old, nil
}); err != nil { }); err != nil {
s.logger.Errorf("failed to update offline session: %v", err) s.logger.Error("failed to update offline session", "err", err)
return "", false, err return "", false, err
} }
@ -609,12 +610,12 @@ func (s *Server) handleApproval(w http.ResponseWriter, r *http.Request) {
authReq, err := s.storage.GetAuthRequest(r.FormValue("req")) authReq, err := s.storage.GetAuthRequest(r.FormValue("req"))
if err != nil { if err != nil {
s.logger.Errorf("Failed to get auth request: %v", err) s.logger.Error("failed to get auth request", "err", err)
s.renderError(r, w, http.StatusInternalServerError, "Database error.") s.renderError(r, w, http.StatusInternalServerError, "Database error.")
return return
} }
if !authReq.LoggedIn { if !authReq.LoggedIn {
s.logger.Errorf("Auth request does not have an identity for approval") s.logger.Error("auth request does not have an identity for approval")
s.renderError(r, w, http.StatusInternalServerError, "Login process not yet finalized.") s.renderError(r, w, http.StatusInternalServerError, "Login process not yet finalized.")
return return
} }
@ -633,12 +634,12 @@ func (s *Server) handleApproval(w http.ResponseWriter, r *http.Request) {
case http.MethodGet: case http.MethodGet:
client, err := s.storage.GetClient(authReq.ClientID) client, err := s.storage.GetClient(authReq.ClientID)
if err != nil { if err != nil {
s.logger.Errorf("Failed to get client %q: %v", authReq.ClientID, err) s.logger.Error("Failed to get client", "client_id", authReq.ClientID, "err", err)
s.renderError(r, w, http.StatusInternalServerError, "Failed to retrieve client.") s.renderError(r, w, http.StatusInternalServerError, "Failed to retrieve client.")
return return
} }
if err := s.templates.approval(r, w, authReq.ID, authReq.Claims.Username, client.Name, authReq.Scopes); err != nil { if err := s.templates.approval(r, w, authReq.ID, authReq.Claims.Username, client.Name, authReq.Scopes); err != nil {
s.logger.Errorf("Server template error: %v", err) s.logger.Error("server template error", "err", err)
} }
case http.MethodPost: case http.MethodPost:
if r.FormValue("approval") != "approve" { if r.FormValue("approval") != "approve" {
@ -658,7 +659,7 @@ func (s *Server) sendCodeResponse(w http.ResponseWriter, r *http.Request, authRe
if err := s.storage.DeleteAuthRequest(authReq.ID); err != nil { if err := s.storage.DeleteAuthRequest(authReq.ID); err != nil {
if err != storage.ErrNotFound { if err != storage.ErrNotFound {
s.logger.Errorf("Failed to delete authorization request: %v", err) s.logger.Error("Failed to delete authorization request", "err", err)
s.renderError(r, w, http.StatusInternalServerError, "Internal server error.") s.renderError(r, w, http.StatusInternalServerError, "Internal server error.")
} else { } else {
s.renderError(r, w, http.StatusBadRequest, "User session error.") s.renderError(r, w, http.StatusBadRequest, "User session error.")
@ -704,7 +705,7 @@ func (s *Server) sendCodeResponse(w http.ResponseWriter, r *http.Request, authRe
PKCE: authReq.PKCE, PKCE: authReq.PKCE,
} }
if err := s.storage.CreateAuthCode(ctx, code); err != nil { if err := s.storage.CreateAuthCode(ctx, code); err != nil {
s.logger.Errorf("Failed to create auth code: %v", err) s.logger.Error("Failed to create auth code", "err", err)
s.renderError(r, w, http.StatusInternalServerError, "Internal server error.") s.renderError(r, w, http.StatusInternalServerError, "Internal server error.")
return return
} }
@ -713,7 +714,7 @@ func (s *Server) sendCodeResponse(w http.ResponseWriter, r *http.Request, authRe
// rejected earlier. If we got here we're using the code flow. // rejected earlier. If we got here we're using the code flow.
if authReq.RedirectURI == redirectURIOOB { if authReq.RedirectURI == redirectURIOOB {
if err := s.templates.oob(r, w, code.ID); err != nil { if err := s.templates.oob(r, w, code.ID); err != nil {
s.logger.Errorf("Server template error: %v", err) s.logger.Error("server template error", "err", err)
} }
return return
} }
@ -725,14 +726,14 @@ func (s *Server) sendCodeResponse(w http.ResponseWriter, r *http.Request, authRe
accessToken, _, err = s.newAccessToken(authReq.ClientID, authReq.Claims, authReq.Scopes, authReq.Nonce, authReq.ConnectorID) accessToken, _, err = s.newAccessToken(authReq.ClientID, authReq.Claims, authReq.Scopes, authReq.Nonce, authReq.ConnectorID)
if err != nil { if err != nil {
s.logger.Errorf("failed to create new access token: %v", err) s.logger.Error("failed to create new access token", "err", err)
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
return return
} }
idToken, idTokenExpiry, err = s.newIDToken(authReq.ClientID, authReq.Claims, authReq.Scopes, authReq.Nonce, accessToken, code.ID, authReq.ConnectorID) idToken, idTokenExpiry, err = s.newIDToken(authReq.ClientID, authReq.Claims, authReq.Scopes, authReq.Nonce, accessToken, code.ID, authReq.ConnectorID)
if err != nil { if err != nil {
s.logger.Errorf("failed to create ID token: %v", err) s.logger.Error("failed to create ID token", "err", err)
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
return return
} }
@ -807,7 +808,7 @@ func (s *Server) withClientFromStorage(w http.ResponseWriter, r *http.Request, h
client, err := s.storage.GetClient(clientID) client, err := s.storage.GetClient(clientID)
if err != nil { if err != nil {
if err != storage.ErrNotFound { if err != storage.ErrNotFound {
s.logger.Errorf("failed to get client: %v", err) s.logger.Error("failed to get client", "err", err)
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
} else { } else {
s.tokenErrHelper(w, errInvalidClient, "Invalid client credentials.", http.StatusUnauthorized) s.tokenErrHelper(w, errInvalidClient, "Invalid client credentials.", http.StatusUnauthorized)
@ -817,9 +818,9 @@ func (s *Server) withClientFromStorage(w http.ResponseWriter, r *http.Request, h
if subtle.ConstantTimeCompare([]byte(client.Secret), []byte(clientSecret)) != 1 { if subtle.ConstantTimeCompare([]byte(client.Secret), []byte(clientSecret)) != 1 {
if clientSecret == "" { if clientSecret == "" {
s.logger.Infof("missing client_secret on token request for client: %s", client.ID) s.logger.Info("missing client_secret on token request", "client_id", client.ID)
} else { } else {
s.logger.Infof("invalid client_secret on token request for client: %s", client.ID) s.logger.Info("invalid client_secret on token request", "client_id", client.ID)
} }
s.tokenErrHelper(w, errInvalidClient, "Invalid client credentials.", http.StatusUnauthorized) s.tokenErrHelper(w, errInvalidClient, "Invalid client credentials.", http.StatusUnauthorized)
return return
@ -837,14 +838,14 @@ func (s *Server) handleToken(w http.ResponseWriter, r *http.Request) {
err := r.ParseForm() err := r.ParseForm()
if err != nil { if err != nil {
s.logger.Errorf("Could not parse request body: %v", err) s.logger.Error("could not parse request body", "err", err)
s.tokenErrHelper(w, errInvalidRequest, "", http.StatusBadRequest) s.tokenErrHelper(w, errInvalidRequest, "", http.StatusBadRequest)
return return
} }
grantType := r.PostFormValue("grant_type") grantType := r.PostFormValue("grant_type")
if !contains(s.supportedGrantTypes, grantType) { if !contains(s.supportedGrantTypes, grantType) {
s.logger.Errorf("unsupported grant type: %v", grantType) s.logger.Error("unsupported grant type", "grant_type", grantType)
s.tokenErrHelper(w, errUnsupportedGrantType, "", http.StatusBadRequest) s.tokenErrHelper(w, errUnsupportedGrantType, "", http.StatusBadRequest)
return return
} }
@ -890,7 +891,7 @@ func (s *Server) handleAuthCode(w http.ResponseWriter, r *http.Request, client s
authCode, err := s.storage.GetAuthCode(code) authCode, err := s.storage.GetAuthCode(code)
if err != nil || s.now().After(authCode.Expiry) || authCode.ClientID != client.ID { if err != nil || s.now().After(authCode.Expiry) || authCode.ClientID != client.ID {
if err != storage.ErrNotFound { if err != storage.ErrNotFound {
s.logger.Errorf("failed to get auth code: %v", err) s.logger.Error("failed to get auth code", "err", err)
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
} else { } else {
s.tokenErrHelper(w, errInvalidGrant, "Invalid or expired code parameter.", http.StatusBadRequest) s.tokenErrHelper(w, errInvalidGrant, "Invalid or expired code parameter.", http.StatusBadRequest)
@ -906,7 +907,7 @@ func (s *Server) handleAuthCode(w http.ResponseWriter, r *http.Request, client s
case providedCodeVerifier != "" && codeChallengeFromStorage != "": case providedCodeVerifier != "" && codeChallengeFromStorage != "":
calculatedCodeChallenge, err := s.calculateCodeChallenge(providedCodeVerifier, authCode.PKCE.CodeChallengeMethod) calculatedCodeChallenge, err := s.calculateCodeChallenge(providedCodeVerifier, authCode.PKCE.CodeChallengeMethod)
if err != nil { if err != nil {
s.logger.Error(err) s.logger.Error("failed to calculate code challenge", "err", err)
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
return return
} }
@ -940,20 +941,20 @@ func (s *Server) handleAuthCode(w http.ResponseWriter, r *http.Request, client s
func (s *Server) exchangeAuthCode(ctx context.Context, w http.ResponseWriter, authCode storage.AuthCode, client storage.Client) (*accessTokenResponse, error) { func (s *Server) exchangeAuthCode(ctx context.Context, w http.ResponseWriter, authCode storage.AuthCode, client storage.Client) (*accessTokenResponse, error) {
accessToken, _, err := s.newAccessToken(client.ID, authCode.Claims, authCode.Scopes, authCode.Nonce, authCode.ConnectorID) accessToken, _, err := s.newAccessToken(client.ID, authCode.Claims, authCode.Scopes, authCode.Nonce, authCode.ConnectorID)
if err != nil { if err != nil {
s.logger.Errorf("failed to create new access token: %v", err) s.logger.Error("failed to create new access token", "err", err)
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
return nil, err return nil, err
} }
idToken, expiry, err := s.newIDToken(client.ID, authCode.Claims, authCode.Scopes, authCode.Nonce, accessToken, authCode.ID, authCode.ConnectorID) idToken, expiry, err := s.newIDToken(client.ID, authCode.Claims, authCode.Scopes, authCode.Nonce, accessToken, authCode.ID, authCode.ConnectorID)
if err != nil { if err != nil {
s.logger.Errorf("failed to create ID token: %v", err) s.logger.Error("failed to create ID token", "err", err)
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
return nil, err return nil, err
} }
if err := s.storage.DeleteAuthCode(authCode.ID); err != nil { if err := s.storage.DeleteAuthCode(authCode.ID); err != nil {
s.logger.Errorf("failed to delete auth code: %v", err) s.logger.Error("failed to delete auth code", "err", err)
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
return nil, err return nil, err
} }
@ -964,7 +965,7 @@ func (s *Server) exchangeAuthCode(ctx context.Context, w http.ResponseWriter, au
// Connectors like `saml` do not implement RefreshConnector. // Connectors like `saml` do not implement RefreshConnector.
conn, err := s.getConnector(authCode.ConnectorID) conn, err := s.getConnector(authCode.ConnectorID)
if err != nil { if err != nil {
s.logger.Errorf("connector with ID %q not found: %v", authCode.ConnectorID, err) s.logger.Error("connector not found", "connector_id", authCode.ConnectorID, "err", err)
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
return false return false
} }
@ -1000,13 +1001,13 @@ func (s *Server) exchangeAuthCode(ctx context.Context, w http.ResponseWriter, au
Token: refresh.Token, Token: refresh.Token,
} }
if refreshToken, err = internal.Marshal(token); err != nil { if refreshToken, err = internal.Marshal(token); err != nil {
s.logger.Errorf("failed to marshal refresh token: %v", err) s.logger.Error("failed to marshal refresh token", "err", err)
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
return nil, err return nil, err
} }
if err := s.storage.CreateRefresh(ctx, refresh); err != nil { if err := s.storage.CreateRefresh(ctx, refresh); err != nil {
s.logger.Errorf("failed to create refresh token: %v", err) s.logger.Error("failed to create refresh token", "err", err)
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
return nil, err return nil, err
} }
@ -1019,7 +1020,7 @@ func (s *Server) exchangeAuthCode(ctx context.Context, w http.ResponseWriter, au
if deleteToken { if deleteToken {
// Delete newly created refresh token from storage. // Delete newly created refresh token from storage.
if err := s.storage.DeleteRefresh(refresh.ID); err != nil { if err := s.storage.DeleteRefresh(refresh.ID); err != nil {
s.logger.Errorf("failed to delete refresh token: %v", err) s.logger.Error("failed to delete refresh token", "err", err)
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
return return
} }
@ -1036,7 +1037,7 @@ func (s *Server) exchangeAuthCode(ctx context.Context, w http.ResponseWriter, au
// Try to retrieve an existing OfflineSession object for the corresponding user. // Try to retrieve an existing OfflineSession object for the corresponding user.
if session, err := s.storage.GetOfflineSessions(refresh.Claims.UserID, refresh.ConnectorID); err != nil { if session, err := s.storage.GetOfflineSessions(refresh.Claims.UserID, refresh.ConnectorID); err != nil {
if err != storage.ErrNotFound { if err != storage.ErrNotFound {
s.logger.Errorf("failed to get offline session: %v", err) s.logger.Error("failed to get offline session", "err", err)
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
deleteToken = true deleteToken = true
return nil, err return nil, err
@ -1051,7 +1052,7 @@ func (s *Server) exchangeAuthCode(ctx context.Context, w http.ResponseWriter, au
// Create a new OfflineSession object for the user and add a reference object for // Create a new OfflineSession object for the user and add a reference object for
// the newly received refreshtoken. // the newly received refreshtoken.
if err := s.storage.CreateOfflineSessions(ctx, offlineSessions); err != nil { if err := s.storage.CreateOfflineSessions(ctx, offlineSessions); err != nil {
s.logger.Errorf("failed to create offline session: %v", err) s.logger.Error("failed to create offline session", "err", err)
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
deleteToken = true deleteToken = true
return nil, err return nil, err
@ -1060,7 +1061,7 @@ func (s *Server) exchangeAuthCode(ctx context.Context, w http.ResponseWriter, au
if oldTokenRef, ok := session.Refresh[tokenRef.ClientID]; ok { if oldTokenRef, ok := session.Refresh[tokenRef.ClientID]; ok {
// Delete old refresh token from storage. // Delete old refresh token from storage.
if err := s.storage.DeleteRefresh(oldTokenRef.ID); err != nil && err != storage.ErrNotFound { if err := s.storage.DeleteRefresh(oldTokenRef.ID); err != nil && err != storage.ErrNotFound {
s.logger.Errorf("failed to delete refresh token: %v", err) s.logger.Error("failed to delete refresh token", "err", err)
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
deleteToken = true deleteToken = true
return nil, err return nil, err
@ -1072,7 +1073,7 @@ func (s *Server) exchangeAuthCode(ctx context.Context, w http.ResponseWriter, au
old.Refresh[tokenRef.ClientID] = &tokenRef old.Refresh[tokenRef.ClientID] = &tokenRef
return old, nil return old, nil
}); err != nil { }); err != nil {
s.logger.Errorf("failed to update offline session: %v", err) s.logger.Error("failed to update offline session", "err", err)
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
deleteToken = true deleteToken = true
return nil, err return nil, err
@ -1184,7 +1185,7 @@ func (s *Server) handlePasswordGrant(w http.ResponseWriter, r *http.Request, cli
password := q.Get("password") password := q.Get("password")
identity, ok, err := passwordConnector.Login(ctx, parseScopes(scopes), username, password) identity, ok, err := passwordConnector.Login(ctx, parseScopes(scopes), username, password)
if err != nil { if err != nil {
s.logger.Errorf("Failed to login user: %v", err) s.logger.Error("failed to login user", "err", err)
s.tokenErrHelper(w, errInvalidRequest, "Could not login user", http.StatusBadRequest) s.tokenErrHelper(w, errInvalidRequest, "Could not login user", http.StatusBadRequest)
return return
} }
@ -1205,14 +1206,14 @@ func (s *Server) handlePasswordGrant(w http.ResponseWriter, r *http.Request, cli
accessToken, _, err := s.newAccessToken(client.ID, claims, scopes, nonce, connID) accessToken, _, err := s.newAccessToken(client.ID, claims, scopes, nonce, connID)
if err != nil { if err != nil {
s.logger.Errorf("password grant failed to create new access token: %v", err) s.logger.Error("password grant failed to create new access token", "err", err)
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
return return
} }
idToken, expiry, err := s.newIDToken(client.ID, claims, scopes, nonce, accessToken, "", connID) idToken, expiry, err := s.newIDToken(client.ID, claims, scopes, nonce, accessToken, "", connID)
if err != nil { if err != nil {
s.logger.Errorf("password grant failed to create new ID token: %v", err) s.logger.Error("password grant failed to create new ID token", "err", err)
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
return return
} }
@ -1252,13 +1253,13 @@ func (s *Server) handlePasswordGrant(w http.ResponseWriter, r *http.Request, cli
Token: refresh.Token, Token: refresh.Token,
} }
if refreshToken, err = internal.Marshal(token); err != nil { if refreshToken, err = internal.Marshal(token); err != nil {
s.logger.Errorf("failed to marshal refresh token: %v", err) s.logger.Error("failed to marshal refresh token", "err", err)
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
return return
} }
if err := s.storage.CreateRefresh(ctx, refresh); err != nil { if err := s.storage.CreateRefresh(ctx, refresh); err != nil {
s.logger.Errorf("failed to create refresh token: %v", err) s.logger.Error("failed to create refresh token", "err", err)
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
return return
} }
@ -1271,7 +1272,7 @@ func (s *Server) handlePasswordGrant(w http.ResponseWriter, r *http.Request, cli
if deleteToken { if deleteToken {
// Delete newly created refresh token from storage. // Delete newly created refresh token from storage.
if err := s.storage.DeleteRefresh(refresh.ID); err != nil { if err := s.storage.DeleteRefresh(refresh.ID); err != nil {
s.logger.Errorf("failed to delete refresh token: %v", err) s.logger.Error("failed to delete refresh token", "err", err)
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
return return
} }
@ -1288,7 +1289,7 @@ func (s *Server) handlePasswordGrant(w http.ResponseWriter, r *http.Request, cli
// Try to retrieve an existing OfflineSession object for the corresponding user. // Try to retrieve an existing OfflineSession object for the corresponding user.
if session, err := s.storage.GetOfflineSessions(refresh.Claims.UserID, refresh.ConnectorID); err != nil { if session, err := s.storage.GetOfflineSessions(refresh.Claims.UserID, refresh.ConnectorID); err != nil {
if err != storage.ErrNotFound { if err != storage.ErrNotFound {
s.logger.Errorf("failed to get offline session: %v", err) s.logger.Error("failed to get offline session", "err", err)
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
deleteToken = true deleteToken = true
return return
@ -1304,7 +1305,7 @@ func (s *Server) handlePasswordGrant(w http.ResponseWriter, r *http.Request, cli
// Create a new OfflineSession object for the user and add a reference object for // Create a new OfflineSession object for the user and add a reference object for
// the newly received refreshtoken. // the newly received refreshtoken.
if err := s.storage.CreateOfflineSessions(ctx, offlineSessions); err != nil { if err := s.storage.CreateOfflineSessions(ctx, offlineSessions); err != nil {
s.logger.Errorf("failed to create offline session: %v", err) s.logger.Error("failed to create offline session", "err", err)
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
deleteToken = true deleteToken = true
return return
@ -1314,9 +1315,9 @@ func (s *Server) handlePasswordGrant(w http.ResponseWriter, r *http.Request, cli
// Delete old refresh token from storage. // Delete old refresh token from storage.
if err := s.storage.DeleteRefresh(oldTokenRef.ID); err != nil { if err := s.storage.DeleteRefresh(oldTokenRef.ID); err != nil {
if err == storage.ErrNotFound { if err == storage.ErrNotFound {
s.logger.Warnf("database inconsistent, refresh token missing: %v", oldTokenRef.ID) s.logger.Warn("database inconsistent, refresh token missing", "token_id", oldTokenRef.ID)
} else { } else {
s.logger.Errorf("failed to delete refresh token: %v", err) s.logger.Error("failed to delete refresh token", "err", err)
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
deleteToken = true deleteToken = true
return return
@ -1330,7 +1331,7 @@ func (s *Server) handlePasswordGrant(w http.ResponseWriter, r *http.Request, cli
old.ConnectorData = identity.ConnectorData old.ConnectorData = identity.ConnectorData
return old, nil return old, nil
}); err != nil { }); err != nil {
s.logger.Errorf("failed to update offline session: %v", err) s.logger.Error("failed to update offline session", "err", err)
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
deleteToken = true deleteToken = true
return return
@ -1346,7 +1347,7 @@ func (s *Server) handleTokenExchange(w http.ResponseWriter, r *http.Request, cli
ctx := r.Context() ctx := r.Context()
if err := r.ParseForm(); err != nil { if err := r.ParseForm(); err != nil {
s.logger.Errorf("could not parse request body: %v", err) s.logger.Error("could not parse request body", "err", err)
s.tokenErrHelper(w, errInvalidRequest, "", http.StatusBadRequest) s.tokenErrHelper(w, errInvalidRequest, "", http.StatusBadRequest)
return return
} }
@ -1375,19 +1376,19 @@ func (s *Server) handleTokenExchange(w http.ResponseWriter, r *http.Request, cli
conn, err := s.getConnector(connID) conn, err := s.getConnector(connID)
if err != nil { if err != nil {
s.logger.Errorf("failed to get connector: %v", err) s.logger.Error("failed to get connector", "err", err)
s.tokenErrHelper(w, errInvalidRequest, "Requested connector does not exist.", http.StatusBadRequest) s.tokenErrHelper(w, errInvalidRequest, "Requested connector does not exist.", http.StatusBadRequest)
return return
} }
teConn, ok := conn.Connector.(connector.TokenIdentityConnector) teConn, ok := conn.Connector.(connector.TokenIdentityConnector)
if !ok { if !ok {
s.logger.Errorf("connector doesn't implement token exchange: %v", connID) s.logger.Error("connector doesn't implement token exchange", "connector_id", connID)
s.tokenErrHelper(w, errInvalidRequest, "Requested connector does not exist.", http.StatusBadRequest) s.tokenErrHelper(w, errInvalidRequest, "Requested connector does not exist.", http.StatusBadRequest)
return return
} }
identity, err := teConn.TokenIdentity(ctx, subjectTokenType, subjectToken) identity, err := teConn.TokenIdentity(ctx, subjectTokenType, subjectToken)
if err != nil { if err != nil {
s.logger.Errorf("failed to verify subject token: %v", err) s.logger.Error("failed to verify subject token", "err", err)
s.tokenErrHelper(w, errAccessDenied, "", http.StatusUnauthorized) s.tokenErrHelper(w, errAccessDenied, "", http.StatusUnauthorized)
return return
} }
@ -1415,7 +1416,7 @@ func (s *Server) handleTokenExchange(w http.ResponseWriter, r *http.Request, cli
return return
} }
if err != nil { if err != nil {
s.logger.Errorf("token exchange failed to create new %v token: %v", requestedTokenType, err) s.logger.Error("token exchange failed to create new token", "requested_token_type", requestedTokenType, "err", err)
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
return return
} }
@ -1451,7 +1452,7 @@ func (s *Server) toAccessTokenResponse(idToken, accessToken, refreshToken string
func (s *Server) writeAccessToken(w http.ResponseWriter, resp *accessTokenResponse) { func (s *Server) writeAccessToken(w http.ResponseWriter, resp *accessTokenResponse) {
data, err := json.Marshal(resp) data, err := json.Marshal(resp)
if err != nil { if err != nil {
s.logger.Errorf("failed to marshal access token response: %v", err) s.logger.Error("failed to marshal access token response", "err", err)
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError) s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
return return
} }
@ -1466,13 +1467,13 @@ func (s *Server) writeAccessToken(w http.ResponseWriter, resp *accessTokenRespon
func (s *Server) renderError(r *http.Request, w http.ResponseWriter, status int, description string) { func (s *Server) renderError(r *http.Request, w http.ResponseWriter, status int, description string) {
if err := s.templates.err(r, w, status, description); err != nil { if err := s.templates.err(r, w, status, description); err != nil {
s.logger.Errorf("server template error: %v", err) s.logger.Error("server template error", "err", err)
} }
} }
func (s *Server) tokenErrHelper(w http.ResponseWriter, typ string, description string, statusCode int) { func (s *Server) tokenErrHelper(w http.ResponseWriter, typ string, description string, statusCode int) {
if err := tokenErr(w, typ, description, statusCode); err != nil { if err := tokenErr(w, typ, description, statusCode); err != nil {
s.logger.Errorf("token error response: %v", err) s.logger.Error("token error response", "err", err)
} }
} }

20
server/introspectionhandler.go

@ -179,14 +179,14 @@ func (s *Server) getTokenFromRequest(r *http.Request) (string, TokenTypeEnum, er
token := r.PostForm.Get("token") token := r.PostForm.Get("token")
tokenType, err := s.guessTokenType(r.Context(), token) tokenType, err := s.guessTokenType(r.Context(), token)
if err != nil { if err != nil {
s.logger.Error(err) s.logger.Error("failed to guess token type", "err", err)
return "", 0, newIntrospectInternalServerError() return "", 0, newIntrospectInternalServerError()
} }
requestTokenType := r.PostForm.Get("token_type_hint") requestTokenType := r.PostForm.Get("token_type_hint")
if requestTokenType != "" { if requestTokenType != "" {
if tokenType.String() != requestTokenType { if tokenType.String() != requestTokenType {
s.logger.Warnf("Token type hint doesn't match token type: %s != %s", requestTokenType, tokenType) s.logger.Warn("token type hint doesn't match token type", "request_token_type", requestTokenType, "token_type", tokenType)
} }
} }
@ -211,13 +211,13 @@ func (s *Server) introspectRefreshToken(_ context.Context, token string) (*Intro
return nil, newIntrospectInactiveTokenError() return nil, newIntrospectInactiveTokenError()
} }
s.logger.Errorf("failed to get refresh token: %v", err) s.logger.Error("failed to get refresh token", "err", err)
return nil, newIntrospectInternalServerError() return nil, newIntrospectInternalServerError()
} }
subjectString, sErr := genSubject(rCtx.storageToken.Claims.UserID, rCtx.storageToken.ConnectorID) subjectString, sErr := genSubject(rCtx.storageToken.Claims.UserID, rCtx.storageToken.ConnectorID)
if sErr != nil { if sErr != nil {
s.logger.Errorf("failed to marshal offline session ID: %v", err) s.logger.Error("failed to marshal offline session ID", "err", err)
return nil, newIntrospectInternalServerError() return nil, newIntrospectInternalServerError()
} }
@ -253,19 +253,19 @@ func (s *Server) introspectAccessToken(ctx context.Context, token string) (*Intr
var claims IntrospectionExtra var claims IntrospectionExtra
if err := idToken.Claims(&claims); err != nil { if err := idToken.Claims(&claims); err != nil {
s.logger.Errorf("Error while fetching token claims: %s", err.Error()) s.logger.Error("error while fetching token claims", "err", err.Error())
return nil, newIntrospectInternalServerError() return nil, newIntrospectInternalServerError()
} }
clientID, err := getClientID(idToken.Audience, claims.AuthorizingParty) clientID, err := getClientID(idToken.Audience, claims.AuthorizingParty)
if err != nil { if err != nil {
s.logger.Error("Error while fetching client_id from token: %s", err.Error()) s.logger.Error("error while fetching client_id from token:", "err", err.Error())
return nil, newIntrospectInternalServerError() return nil, newIntrospectInternalServerError()
} }
client, err := s.storage.GetClient(clientID) client, err := s.storage.GetClient(clientID)
if err != nil { if err != nil {
s.logger.Error("Error while fetching client from storage: %s", err.Error()) s.logger.Error("error while fetching client from storage", "err", err.Error())
return nil, newIntrospectInternalServerError() return nil, newIntrospectInternalServerError()
} }
@ -299,7 +299,7 @@ func (s *Server) handleIntrospect(w http.ResponseWriter, r *http.Request) {
introspect, err = s.introspectRefreshToken(ctx, token) introspect, err = s.introspectRefreshToken(ctx, token)
default: default:
// Token type is neither handled token types. // Token type is neither handled token types.
s.logger.Errorf("Unknown token type: %s", tokenType) s.logger.Error("unknown token type", "token_type", tokenType)
introspectInactiveErr(w) introspectInactiveErr(w)
return return
} }
@ -309,7 +309,7 @@ func (s *Server) handleIntrospect(w http.ResponseWriter, r *http.Request) {
if intErr, ok := err.(*introspectionError); ok { if intErr, ok := err.(*introspectionError); ok {
s.introspectErrHelper(w, intErr.typ, intErr.desc, intErr.code) s.introspectErrHelper(w, intErr.typ, intErr.desc, intErr.code)
} else { } else {
s.logger.Errorf("An unknown error occurred: %s", err.Error()) s.logger.Error("an unknown error occurred", "err", err.Error())
s.introspectErrHelper(w, errServerError, "An unknown error occurred", http.StatusInternalServerError) s.introspectErrHelper(w, errServerError, "An unknown error occurred", http.StatusInternalServerError)
} }
@ -332,7 +332,7 @@ func (s *Server) introspectErrHelper(w http.ResponseWriter, typ string, descript
} }
if err := tokenErr(w, typ, description, statusCode); err != nil { if err := tokenErr(w, typ, description, statusCode); err != nil {
s.logger.Errorf("introspect error response: %v", err) s.logger.Error("introspect error response", "err", err)
} }
} }

14
server/oauth2.go

@ -353,7 +353,7 @@ func genSubject(userID string, connID string) (string, error) {
func (s *Server) newIDToken(clientID string, claims storage.Claims, scopes []string, nonce, accessToken, code, connID string) (idToken string, expiry time.Time, err error) { func (s *Server) newIDToken(clientID string, claims storage.Claims, scopes []string, nonce, accessToken, code, connID string) (idToken string, expiry time.Time, err error) {
keys, err := s.storage.GetKeys() keys, err := s.storage.GetKeys()
if err != nil { if err != nil {
s.logger.Errorf("Failed to get keys: %v", err) s.logger.Error("failed to get keys", "err", err)
return "", expiry, err return "", expiry, err
} }
@ -371,7 +371,7 @@ func (s *Server) newIDToken(clientID string, claims storage.Claims, scopes []str
subjectString, err := genSubject(claims.UserID, connID) subjectString, err := genSubject(claims.UserID, connID)
if err != nil { if err != nil {
s.logger.Errorf("failed to marshal offline session ID: %v", err) s.logger.Error("failed to marshal offline session ID", "err", err)
return "", expiry, fmt.Errorf("failed to marshal offline session ID: %v", err) return "", expiry, fmt.Errorf("failed to marshal offline session ID: %v", err)
} }
@ -386,7 +386,7 @@ func (s *Server) newIDToken(clientID string, claims storage.Claims, scopes []str
if accessToken != "" { if accessToken != "" {
atHash, err := accessTokenHash(signingAlg, accessToken) atHash, err := accessTokenHash(signingAlg, accessToken)
if err != nil { if err != nil {
s.logger.Errorf("error computing at_hash: %v", err) s.logger.Error("error computing at_hash", "err", err)
return "", expiry, fmt.Errorf("error computing at_hash: %v", err) return "", expiry, fmt.Errorf("error computing at_hash: %v", err)
} }
tok.AccessTokenHash = atHash tok.AccessTokenHash = atHash
@ -395,7 +395,7 @@ func (s *Server) newIDToken(clientID string, claims storage.Claims, scopes []str
if code != "" { if code != "" {
cHash, err := accessTokenHash(signingAlg, code) cHash, err := accessTokenHash(signingAlg, code)
if err != nil { if err != nil {
s.logger.Errorf("error computing c_hash: %v", err) s.logger.Error("error computing c_hash", "err", err)
return "", expiry, fmt.Errorf("error computing c_hash: #{err}") return "", expiry, fmt.Errorf("error computing c_hash: #{err}")
} }
tok.CodeHash = cHash tok.CodeHash = cHash
@ -482,7 +482,7 @@ func (s *Server) parseAuthorizationRequest(r *http.Request) (*storage.AuthReques
if err == storage.ErrNotFound { if err == storage.ErrNotFound {
return nil, newDisplayedErr(http.StatusNotFound, "Invalid client_id (%q).", clientID) return nil, newDisplayedErr(http.StatusNotFound, "Invalid client_id (%q).", clientID)
} }
s.logger.Errorf("Failed to get client: %v", err) s.logger.Error("failed to get client", "err", err)
return nil, newDisplayedErr(http.StatusInternalServerError, "Database error.") return nil, newDisplayedErr(http.StatusInternalServerError, "Database error.")
} }
@ -501,7 +501,7 @@ func (s *Server) parseAuthorizationRequest(r *http.Request) (*storage.AuthReques
if connectorID != "" { if connectorID != "" {
connectors, err := s.storage.ListConnectors() connectors, err := s.storage.ListConnectors()
if err != nil { if err != nil {
s.logger.Errorf("Failed to list connectors: %v", err) s.logger.Error("failed to list connectors", "err", err)
return nil, newRedirectedErr(errServerError, "Unable to retrieve connectors") return nil, newRedirectedErr(errServerError, "Unable to retrieve connectors")
} }
if !validateConnectorID(connectors, connectorID) { if !validateConnectorID(connectors, connectorID) {
@ -637,7 +637,7 @@ func (s *Server) validateCrossClientTrust(clientID, peerID string) (trusted bool
peer, err := s.storage.GetClient(peerID) peer, err := s.storage.GetClient(peerID)
if err != nil { if err != nil {
if err != storage.ErrNotFound { if err != storage.ErrNotFound {
s.logger.Errorf("Failed to get client: %v", err) s.logger.Error("failed to get client", "err", err)
return false, err return false, err
} }
return false, nil return false, nil

30
server/refreshhandlers.go

@ -87,7 +87,7 @@ func (s *Server) getRefreshTokenFromStorage(clientID *string, token *internal.Re
refresh, err := s.storage.GetRefresh(token.RefreshId) refresh, err := s.storage.GetRefresh(token.RefreshId)
if err != nil { if err != nil {
if err != storage.ErrNotFound { if err != storage.ErrNotFound {
s.logger.Errorf("failed to get refresh token: %v", err) s.logger.Error("failed to get refresh token", "err", err)
return nil, newInternalServerError() return nil, newInternalServerError()
} }
return nil, invalidErr return nil, invalidErr
@ -95,7 +95,7 @@ func (s *Server) getRefreshTokenFromStorage(clientID *string, token *internal.Re
// Only check ClientID if it was provided; // Only check ClientID if it was provided;
if clientID != nil && (refresh.ClientID != *clientID) { if clientID != nil && (refresh.ClientID != *clientID) {
s.logger.Errorf("client %s trying to claim token for client %s", clientID, refresh.ClientID) s.logger.Error("trying to claim token for different client", "client_id", clientID, "refresh_client_id", refresh.ClientID)
// According to https://datatracker.ietf.org/doc/html/rfc6749#section-5.2 Dex should respond with an // According to https://datatracker.ietf.org/doc/html/rfc6749#section-5.2 Dex should respond with an
// invalid grant error if token has already been claimed by another client. // invalid grant error if token has already been claimed by another client.
return nil, &refreshError{msg: errInvalidGrant, desc: invalidErr.desc, code: http.StatusBadRequest} return nil, &refreshError{msg: errInvalidGrant, desc: invalidErr.desc, code: http.StatusBadRequest}
@ -108,18 +108,18 @@ func (s *Server) getRefreshTokenFromStorage(clientID *string, token *internal.Re
case refresh.ObsoleteToken != token.Token: case refresh.ObsoleteToken != token.Token:
fallthrough fallthrough
case refresh.ObsoleteToken == "": case refresh.ObsoleteToken == "":
s.logger.Errorf("refresh token with id %s claimed twice", refresh.ID) s.logger.Error("refresh token claimed twice", "token_id", refresh.ID)
return nil, invalidErr return nil, invalidErr
} }
} }
if s.refreshTokenPolicy.CompletelyExpired(refresh.CreatedAt) { if s.refreshTokenPolicy.CompletelyExpired(refresh.CreatedAt) {
s.logger.Errorf("refresh token with id %s expired", refresh.ID) s.logger.Error("refresh token expired", "token_id", refresh.ID)
return nil, expiredErr return nil, expiredErr
} }
if s.refreshTokenPolicy.ExpiredBecauseUnused(refresh.LastUsed) { if s.refreshTokenPolicy.ExpiredBecauseUnused(refresh.LastUsed) {
s.logger.Errorf("refresh token with id %s expired due to inactivity", refresh.ID) s.logger.Error("refresh token expired due to inactivity", "token_id", refresh.ID)
return nil, expiredErr return nil, expiredErr
} }
@ -128,7 +128,7 @@ func (s *Server) getRefreshTokenFromStorage(clientID *string, token *internal.Re
// Get Connector // Get Connector
refreshCtx.connector, err = s.getConnector(refresh.ConnectorID) refreshCtx.connector, err = s.getConnector(refresh.ConnectorID)
if err != nil { if err != nil {
s.logger.Errorf("connector with ID %q not found: %v", refresh.ConnectorID, err) s.logger.Error("connector not found", "connector_id", refresh.ConnectorID, "err", err)
return nil, newInternalServerError() return nil, newInternalServerError()
} }
@ -137,7 +137,7 @@ func (s *Server) getRefreshTokenFromStorage(clientID *string, token *internal.Re
switch { switch {
case err != nil: case err != nil:
if err != storage.ErrNotFound { if err != storage.ErrNotFound {
s.logger.Errorf("failed to get offline session: %v", err) s.logger.Error("failed to get offline session", "err", err)
return nil, newInternalServerError() return nil, newInternalServerError()
} }
case len(refresh.ConnectorData) > 0: case len(refresh.ConnectorData) > 0:
@ -191,11 +191,11 @@ func (s *Server) refreshWithConnector(ctx context.Context, rCtx *refreshContext,
if refreshConn, ok := rCtx.connector.Connector.(connector.RefreshConnector); ok { if refreshConn, ok := rCtx.connector.Connector.(connector.RefreshConnector); ok {
// Set connector data to the one received from an offline session // Set connector data to the one received from an offline session
ident.ConnectorData = rCtx.connectorData ident.ConnectorData = rCtx.connectorData
s.logger.Debugf("connector data before refresh: %s", ident.ConnectorData) s.logger.Debug("connector data before refresh", "connector_data", ident.ConnectorData)
newIdent, err := refreshConn.Refresh(ctx, parseScopes(rCtx.scopes), ident) newIdent, err := refreshConn.Refresh(ctx, parseScopes(rCtx.scopes), ident)
if err != nil { if err != nil {
s.logger.Errorf("failed to refresh identity: %v", err) s.logger.Error("failed to refresh identity", "err", err)
return ident, newInternalServerError() return ident, newInternalServerError()
} }
@ -216,7 +216,7 @@ func (s *Server) updateOfflineSession(refresh *storage.RefreshToken, ident conne
old.ConnectorData = ident.ConnectorData old.ConnectorData = ident.ConnectorData
} }
s.logger.Debugf("saved connector data: %s %s", ident.UserID, ident.ConnectorData) s.logger.Debug("saved connector data", "user_id", ident.UserID, "connector_data", ident.ConnectorData)
return old, nil return old, nil
} }
@ -225,7 +225,7 @@ func (s *Server) updateOfflineSession(refresh *storage.RefreshToken, ident conne
// in offline session for the user. // in offline session for the user.
err := s.storage.UpdateOfflineSessions(refresh.Claims.UserID, refresh.ConnectorID, offlineSessionUpdater) err := s.storage.UpdateOfflineSessions(refresh.Claims.UserID, refresh.ConnectorID, offlineSessionUpdater)
if err != nil { if err != nil {
s.logger.Errorf("failed to update offline session: %v", err) s.logger.Error("failed to update offline session", "err", err)
return newInternalServerError() return newInternalServerError()
} }
@ -316,7 +316,7 @@ func (s *Server) updateRefreshToken(ctx context.Context, rCtx *refreshContext) (
// Update refresh token in the storage. // Update refresh token in the storage.
err := s.storage.UpdateRefreshToken(rCtx.storageToken.ID, refreshTokenUpdater) err := s.storage.UpdateRefreshToken(rCtx.storageToken.ID, refreshTokenUpdater)
if err != nil { if err != nil {
s.logger.Errorf("failed to update refresh token: %v", err) s.logger.Error("failed to update refresh token", "err", err)
return nil, ident, newInternalServerError() return nil, ident, newInternalServerError()
} }
@ -366,21 +366,21 @@ func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, clie
accessToken, _, err := s.newAccessToken(client.ID, claims, rCtx.scopes, rCtx.storageToken.Nonce, rCtx.storageToken.ConnectorID) accessToken, _, err := s.newAccessToken(client.ID, claims, rCtx.scopes, rCtx.storageToken.Nonce, rCtx.storageToken.ConnectorID)
if err != nil { if err != nil {
s.logger.Errorf("failed to create new access token: %v", err) s.logger.Error("failed to create new access token", "err", err)
s.refreshTokenErrHelper(w, newInternalServerError()) s.refreshTokenErrHelper(w, newInternalServerError())
return return
} }
idToken, expiry, err := s.newIDToken(client.ID, claims, rCtx.scopes, rCtx.storageToken.Nonce, accessToken, "", rCtx.storageToken.ConnectorID) idToken, expiry, err := s.newIDToken(client.ID, claims, rCtx.scopes, rCtx.storageToken.Nonce, accessToken, "", rCtx.storageToken.ConnectorID)
if err != nil { if err != nil {
s.logger.Errorf("failed to create ID token: %v", err) s.logger.Error("failed to create ID token", "err", err)
s.refreshTokenErrHelper(w, newInternalServerError()) s.refreshTokenErrHelper(w, newInternalServerError())
return return
} }
rawNewToken, err := internal.Marshal(newToken) rawNewToken, err := internal.Marshal(newToken)
if err != nil { if err != nil {
s.logger.Errorf("failed to marshal refresh token: %v", err) s.logger.Error("failed to marshal refresh token", "err", err)
s.refreshTokenErrHelper(w, newInternalServerError()) s.refreshTokenErrHelper(w, newInternalServerError())
return return
} }

26
server/rotation.go

@ -8,11 +8,11 @@ import (
"errors" "errors"
"fmt" "fmt"
"io" "io"
"log/slog"
"time" "time"
"github.com/go-jose/go-jose/v4" "github.com/go-jose/go-jose/v4"
"github.com/dexidp/dex/pkg/log"
"github.com/dexidp/dex/storage" "github.com/dexidp/dex/storage"
) )
@ -61,7 +61,7 @@ type keyRotator struct {
strategy rotationStrategy strategy rotationStrategy
now func() time.Time now func() time.Time
logger log.Logger logger *slog.Logger
} }
// startKeyRotation begins key rotation in a new goroutine, closing once the context is canceled. // startKeyRotation begins key rotation in a new goroutine, closing once the context is canceled.
@ -74,9 +74,9 @@ func (s *Server) startKeyRotation(ctx context.Context, strategy rotationStrategy
// Try to rotate immediately so properly configured storages will have keys. // Try to rotate immediately so properly configured storages will have keys.
if err := rotator.rotate(); err != nil { if err := rotator.rotate(); err != nil {
if err == errAlreadyRotated { if err == errAlreadyRotated {
s.logger.Infof("Key rotation not needed: %v", err) s.logger.Info("key rotation not needed", "err", err)
} else { } else {
s.logger.Errorf("failed to rotate keys: %v", err) s.logger.Error("failed to rotate keys", "err", err)
} }
} }
@ -87,7 +87,7 @@ func (s *Server) startKeyRotation(ctx context.Context, strategy rotationStrategy
return return
case <-time.After(time.Second * 30): case <-time.After(time.Second * 30):
if err := rotator.rotate(); err != nil { if err := rotator.rotate(); err != nil {
s.logger.Errorf("failed to rotate keys: %v", err) s.logger.Error("failed to rotate keys", "err", err)
} }
} }
} }
@ -102,7 +102,7 @@ func (k keyRotator) rotate() error {
if k.now().Before(keys.NextRotation) { if k.now().Before(keys.NextRotation) {
return nil return nil
} }
k.logger.Infof("keys expired, rotating") k.logger.Info("keys expired, rotating")
// Generate the key outside of a storage transaction. // Generate the key outside of a storage transaction.
key, err := k.strategy.key() key, err := k.strategy.key()
@ -174,7 +174,7 @@ func (k keyRotator) rotate() error {
if err != nil { if err != nil {
return err return err
} }
k.logger.Infof("keys rotated, next rotation: %s", nextRotation) k.logger.Info("keys rotated", "next_rotation", nextRotation)
return nil return nil
} }
@ -187,10 +187,10 @@ type RefreshTokenPolicy struct {
now func() time.Time now func() time.Time
logger log.Logger logger *slog.Logger
} }
func NewRefreshTokenPolicy(logger log.Logger, rotation bool, validIfNotUsedFor, absoluteLifetime, reuseInterval string) (*RefreshTokenPolicy, error) { func NewRefreshTokenPolicy(logger *slog.Logger, rotation bool, validIfNotUsedFor, absoluteLifetime, reuseInterval string) (*RefreshTokenPolicy, error) {
r := RefreshTokenPolicy{now: time.Now, logger: logger} r := RefreshTokenPolicy{now: time.Now, logger: logger}
var err error var err error
@ -199,7 +199,7 @@ func NewRefreshTokenPolicy(logger log.Logger, rotation bool, validIfNotUsedFor,
if err != nil { if err != nil {
return nil, fmt.Errorf("invalid config value %q for refresh token valid if not used for: %v", validIfNotUsedFor, err) return nil, fmt.Errorf("invalid config value %q for refresh token valid if not used for: %v", validIfNotUsedFor, err)
} }
logger.Infof("config refresh tokens valid if not used for: %v", validIfNotUsedFor) logger.Info("config refresh tokens", "valid_if_not_used_for", validIfNotUsedFor)
} }
if absoluteLifetime != "" { if absoluteLifetime != "" {
@ -207,7 +207,7 @@ func NewRefreshTokenPolicy(logger log.Logger, rotation bool, validIfNotUsedFor,
if err != nil { if err != nil {
return nil, fmt.Errorf("invalid config value %q for refresh tokens absolute lifetime: %v", absoluteLifetime, err) return nil, fmt.Errorf("invalid config value %q for refresh tokens absolute lifetime: %v", absoluteLifetime, err)
} }
logger.Infof("config refresh tokens absolute lifetime: %v", absoluteLifetime) logger.Info("config refresh tokens", "absolute_lifetime", absoluteLifetime)
} }
if reuseInterval != "" { if reuseInterval != "" {
@ -215,11 +215,11 @@ func NewRefreshTokenPolicy(logger log.Logger, rotation bool, validIfNotUsedFor,
if err != nil { if err != nil {
return nil, fmt.Errorf("invalid config value %q for refresh tokens reuse interval: %v", reuseInterval, err) return nil, fmt.Errorf("invalid config value %q for refresh tokens reuse interval: %v", reuseInterval, err)
} }
logger.Infof("config refresh tokens reuse interval: %v", reuseInterval) logger.Info("config refresh tokens", "reuse_interval", reuseInterval)
} }
r.rotateRefreshTokens = !rotation r.rotateRefreshTokens = !rotation
logger.Infof("config refresh tokens rotation enabled: %v", r.rotateRefreshTokens) logger.Info("config refresh tokens rotation", "enabled", r.rotateRefreshTokens)
return &r, nil return &r, nil
} }

16
server/rotation_test.go

@ -1,12 +1,12 @@
package server package server
import ( import (
"os" "io"
"log/slog"
"sort" "sort"
"testing" "testing"
"time" "time"
"github.com/sirupsen/logrus"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/dexidp/dex/storage" "github.com/dexidp/dex/storage"
@ -68,11 +68,7 @@ func TestKeyRotator(t *testing.T) {
// Only the last 5 verification keys are expected to be kept around. // Only the last 5 verification keys are expected to be kept around.
maxVerificationKeys := 5 maxVerificationKeys := 5
l := &logrus.Logger{ l := slog.New(slog.NewTextHandler(io.Discard, &slog.HandlerOptions{}))
Out: os.Stderr,
Formatter: &logrus.TextFormatter{DisableColors: true},
Level: logrus.DebugLevel,
}
r := &keyRotator{ r := &keyRotator{
Storage: memory.New(l), Storage: memory.New(l),
@ -104,11 +100,7 @@ func TestKeyRotator(t *testing.T) {
func TestRefreshTokenPolicy(t *testing.T) { func TestRefreshTokenPolicy(t *testing.T) {
lastTime := time.Now() lastTime := time.Now()
l := &logrus.Logger{ l := slog.New(slog.NewTextHandler(io.Discard, &slog.HandlerOptions{}))
Out: os.Stderr,
Formatter: &logrus.TextFormatter{DisableColors: true},
Level: logrus.DebugLevel,
}
r, err := NewRefreshTokenPolicy(l, true, "1m", "1m", "1m") r, err := NewRefreshTokenPolicy(l, true, "1m", "1m", "1m")
require.NoError(t, err) require.NoError(t, err)

17
server/server.go

@ -7,6 +7,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"io/fs" "io/fs"
"log/slog"
"net/http" "net/http"
"net/url" "net/url"
"os" "os"
@ -42,7 +43,6 @@ import (
"github.com/dexidp/dex/connector/oidc" "github.com/dexidp/dex/connector/oidc"
"github.com/dexidp/dex/connector/openshift" "github.com/dexidp/dex/connector/openshift"
"github.com/dexidp/dex/connector/saml" "github.com/dexidp/dex/connector/saml"
"github.com/dexidp/dex/pkg/log"
"github.com/dexidp/dex/storage" "github.com/dexidp/dex/storage"
"github.com/dexidp/dex/web" "github.com/dexidp/dex/web"
) )
@ -108,7 +108,7 @@ type Config struct {
Web WebConfig Web WebConfig
Logger log.Logger Logger *slog.Logger
PrometheusRegistry *prometheus.Registry PrometheusRegistry *prometheus.Registry
@ -189,7 +189,7 @@ type Server struct {
refreshTokenPolicy *RefreshTokenPolicy refreshTokenPolicy *RefreshTokenPolicy
logger log.Logger logger *slog.Logger
} }
// NewServer constructs a server from the provided config. // NewServer constructs a server from the provided config.
@ -556,10 +556,11 @@ func (s *Server) startGarbageCollection(ctx context.Context, frequency time.Dura
return return
case <-time.After(frequency): case <-time.After(frequency):
if r, err := s.storage.GarbageCollect(now()); err != nil { if r, err := s.storage.GarbageCollect(now()); err != nil {
s.logger.Errorf("garbage collection failed: %v", err) s.logger.ErrorContext(ctx, "garbage collection failed", "err", err)
} else if !r.IsEmpty() { } else if !r.IsEmpty() {
s.logger.Infof("garbage collection run, delete auth requests=%d, auth codes=%d, device requests=%d, device tokens=%d", s.logger.InfoContext(ctx, "garbage collection run, delete auth",
r.AuthRequests, r.AuthCodes, r.DeviceRequests, r.DeviceTokens) "requests", r.AuthRequests, "auth_codes", r.AuthCodes,
"device_requests", r.DeviceRequests, "device_tokens", r.DeviceTokens)
} }
} }
} }
@ -568,7 +569,7 @@ func (s *Server) startGarbageCollection(ctx context.Context, frequency time.Dura
// ConnectorConfig is a configuration that can open a connector. // ConnectorConfig is a configuration that can open a connector.
type ConnectorConfig interface { type ConnectorConfig interface {
Open(id string, logger log.Logger) (connector.Connector, error) Open(id string, logger *slog.Logger) (connector.Connector, error)
} }
// ConnectorsConfig variable provides an easy way to return a config struct // ConnectorsConfig variable provides an easy way to return a config struct
@ -596,7 +597,7 @@ var ConnectorsConfig = map[string]func() ConnectorConfig{
} }
// openConnector will parse the connector config and open the connector. // openConnector will parse the connector config and open the connector.
func openConnector(logger log.Logger, conn storage.Connector) (connector.Connector, error) { func openConnector(logger *slog.Logger, conn storage.Connector) (connector.Connector, error) {
var c connector.Connector var c connector.Connector
f, ok := ConnectorsConfig[conn.Type] f, ok := ConnectorsConfig[conn.Type]

9
server/server_test.go

@ -9,11 +9,11 @@ import (
"errors" "errors"
"fmt" "fmt"
"io" "io"
"log/slog"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"net/http/httputil" "net/http/httputil"
"net/url" "net/url"
"os"
"path" "path"
"reflect" "reflect"
"sort" "sort"
@ -26,7 +26,6 @@ import (
"github.com/go-jose/go-jose/v4" "github.com/go-jose/go-jose/v4"
"github.com/kylelemons/godebug/pretty" "github.com/kylelemons/godebug/pretty"
"github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus"
"github.com/sirupsen/logrus"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"golang.org/x/crypto/bcrypt" "golang.org/x/crypto/bcrypt"
"golang.org/x/oauth2" "golang.org/x/oauth2"
@ -77,11 +76,7 @@ FDWV28nTP9sqbtsmU8Tem2jzMvZ7C/Q0AuDoKELFUpux8shm8wfIhyaPnXUGZoAZ
Np4vUwMSYV5mopESLWOg3loBxKyLGFtgGKVCjGiQvy6zISQ4fQo= Np4vUwMSYV5mopESLWOg3loBxKyLGFtgGKVCjGiQvy6zISQ4fQo=
-----END RSA PRIVATE KEY-----`) -----END RSA PRIVATE KEY-----`)
var logger = &logrus.Logger{ var logger = slog.New(slog.NewTextHandler(io.Discard, &slog.HandlerOptions{}))
Out: os.Stderr,
Formatter: &logrus.TextFormatter{DisableColors: true},
Level: logrus.DebugLevel,
}
func newTestServer(ctx context.Context, t *testing.T, updateConfig func(c *Config)) (*httptest.Server, *Server) { func newTestServer(ctx context.Context, t *testing.T, updateConfig func(c *Config)) (*httptest.Server, *Server) {
var server *Server var server *Server

4
storage/ent/mysql.go

@ -7,6 +7,7 @@ import (
"crypto/x509" "crypto/x509"
"database/sql" "database/sql"
"fmt" "fmt"
"log/slog"
"net" "net"
"os" "os"
"strconv" "strconv"
@ -15,7 +16,6 @@ import (
entSQL "entgo.io/ent/dialect/sql" entSQL "entgo.io/ent/dialect/sql"
"github.com/go-sql-driver/mysql" // Register mysql driver. "github.com/go-sql-driver/mysql" // Register mysql driver.
"github.com/dexidp/dex/pkg/log"
"github.com/dexidp/dex/storage" "github.com/dexidp/dex/storage"
"github.com/dexidp/dex/storage/ent/client" "github.com/dexidp/dex/storage/ent/client"
"github.com/dexidp/dex/storage/ent/db" "github.com/dexidp/dex/storage/ent/db"
@ -39,7 +39,7 @@ type MySQL struct {
} }
// Open always returns a new in sqlite3 storage. // Open always returns a new in sqlite3 storage.
func (m *MySQL) Open(logger log.Logger) (storage.Storage, error) { func (m *MySQL) Open(logger *slog.Logger) (storage.Storage, error) {
logger.Debug("experimental ent-based storage driver is enabled") logger.Debug("experimental ent-based storage driver is enabled")
drv, err := m.driver() drv, err := m.driver()
if err != nil { if err != nil {

9
storage/ent/mysql_test.go

@ -1,11 +1,12 @@
package ent package ent
import ( import (
"io"
"log/slog"
"os" "os"
"strconv" "strconv"
"testing" "testing"
"github.com/sirupsen/logrus"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/dexidp/dex/storage" "github.com/dexidp/dex/storage"
@ -39,11 +40,7 @@ func mysqlTestConfig(host string, port uint64) *MySQL {
} }
func newMySQLStorage(host string, port uint64) storage.Storage { func newMySQLStorage(host string, port uint64) storage.Storage {
logger := &logrus.Logger{ logger := slog.New(slog.NewTextHandler(io.Discard, &slog.HandlerOptions{}))
Out: os.Stderr,
Formatter: &logrus.TextFormatter{DisableColors: true},
Level: logrus.DebugLevel,
}
cfg := mysqlTestConfig(host, port) cfg := mysqlTestConfig(host, port)
s, err := cfg.Open(logger) s, err := cfg.Open(logger)

5
storage/ent/postgres.go

@ -5,6 +5,7 @@ import (
"crypto/sha256" "crypto/sha256"
"database/sql" "database/sql"
"fmt" "fmt"
"log/slog"
"net" "net"
"regexp" "regexp"
"strconv" "strconv"
@ -14,13 +15,11 @@ import (
entSQL "entgo.io/ent/dialect/sql" entSQL "entgo.io/ent/dialect/sql"
_ "github.com/lib/pq" // Register postgres driver. _ "github.com/lib/pq" // Register postgres driver.
"github.com/dexidp/dex/pkg/log"
"github.com/dexidp/dex/storage" "github.com/dexidp/dex/storage"
"github.com/dexidp/dex/storage/ent/client" "github.com/dexidp/dex/storage/ent/client"
"github.com/dexidp/dex/storage/ent/db" "github.com/dexidp/dex/storage/ent/db"
) )
//nolint
const ( const (
// postgres SSL modes // postgres SSL modes
pgSSLDisable = "disable" pgSSLDisable = "disable"
@ -37,7 +36,7 @@ type Postgres struct {
} }
// Open always returns a new in sqlite3 storage. // Open always returns a new in sqlite3 storage.
func (p *Postgres) Open(logger log.Logger) (storage.Storage, error) { func (p *Postgres) Open(logger *slog.Logger) (storage.Storage, error) {
logger.Debug("experimental ent-based storage driver is enabled") logger.Debug("experimental ent-based storage driver is enabled")
drv, err := p.driver() drv, err := p.driver()
if err != nil { if err != nil {

9
storage/ent/postgres_test.go

@ -1,11 +1,12 @@
package ent package ent
import ( import (
"io"
"log/slog"
"os" "os"
"strconv" "strconv"
"testing" "testing"
"github.com/sirupsen/logrus"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/dexidp/dex/storage" "github.com/dexidp/dex/storage"
@ -36,11 +37,7 @@ func postgresTestConfig(host string, port uint64) *Postgres {
} }
func newPostgresStorage(host string, port uint64) storage.Storage { func newPostgresStorage(host string, port uint64) storage.Storage {
logger := &logrus.Logger{ logger := slog.New(slog.NewTextHandler(io.Discard, &slog.HandlerOptions{}))
Out: os.Stderr,
Formatter: &logrus.TextFormatter{DisableColors: true},
Level: logrus.DebugLevel,
}
cfg := postgresTestConfig(host, port) cfg := postgresTestConfig(host, port)
s, err := cfg.Open(logger) s, err := cfg.Open(logger)

4
storage/ent/sqlite.go

@ -3,12 +3,12 @@ package ent
import ( import (
"context" "context"
"crypto/sha256" "crypto/sha256"
"log/slog"
"strings" "strings"
"entgo.io/ent/dialect/sql" "entgo.io/ent/dialect/sql"
_ "github.com/mattn/go-sqlite3" // Register sqlite driver. _ "github.com/mattn/go-sqlite3" // Register sqlite driver.
"github.com/dexidp/dex/pkg/log"
"github.com/dexidp/dex/storage" "github.com/dexidp/dex/storage"
"github.com/dexidp/dex/storage/ent/client" "github.com/dexidp/dex/storage/ent/client"
"github.com/dexidp/dex/storage/ent/db" "github.com/dexidp/dex/storage/ent/db"
@ -20,7 +20,7 @@ type SQLite3 struct {
} }
// Open always returns a new in sqlite3 storage. // Open always returns a new in sqlite3 storage.
func (s *SQLite3) Open(logger log.Logger) (storage.Storage, error) { func (s *SQLite3) Open(logger *slog.Logger) (storage.Storage, error) {
logger.Debug("experimental ent-based storage driver is enabled") logger.Debug("experimental ent-based storage driver is enabled")
// Implicitly set foreign_keys pragma to "on" because it is required by ent // Implicitly set foreign_keys pragma to "on" because it is required by ent

11
storage/ent/sqlite_test.go

@ -1,21 +1,16 @@
package ent package ent
import ( import (
"os" "io"
"log/slog"
"testing" "testing"
"github.com/sirupsen/logrus"
"github.com/dexidp/dex/storage" "github.com/dexidp/dex/storage"
"github.com/dexidp/dex/storage/conformance" "github.com/dexidp/dex/storage/conformance"
) )
func newSQLiteStorage() storage.Storage { func newSQLiteStorage() storage.Storage {
logger := &logrus.Logger{ logger := slog.New(slog.NewTextHandler(io.Discard, &slog.HandlerOptions{}))
Out: os.Stderr,
Formatter: &logrus.TextFormatter{DisableColors: true},
Level: logrus.DebugLevel,
}
cfg := SQLite3{File: ":memory:"} cfg := SQLite3{File: ":memory:"}
s, err := cfg.Open(logger) s, err := cfg.Open(logger)

6
storage/etcd/config.go

@ -1,13 +1,13 @@
package etcd package etcd
import ( import (
"log/slog"
"time" "time"
"go.etcd.io/etcd/client/pkg/v3/transport" "go.etcd.io/etcd/client/pkg/v3/transport"
clientv3 "go.etcd.io/etcd/client/v3" clientv3 "go.etcd.io/etcd/client/v3"
"go.etcd.io/etcd/client/v3/namespace" "go.etcd.io/etcd/client/v3/namespace"
"github.com/dexidp/dex/pkg/log"
"github.com/dexidp/dex/storage" "github.com/dexidp/dex/storage"
) )
@ -34,11 +34,11 @@ type Etcd struct {
} }
// Open creates a new storage implementation backed by Etcd // Open creates a new storage implementation backed by Etcd
func (p *Etcd) Open(logger log.Logger) (storage.Storage, error) { func (p *Etcd) Open(logger *slog.Logger) (storage.Storage, error) {
return p.open(logger) return p.open(logger)
} }
func (p *Etcd) open(logger log.Logger) (*conn, error) { func (p *Etcd) open(logger *slog.Logger) (*conn, error) {
cfg := clientv3.Config{ cfg := clientv3.Config{
Endpoints: p.Endpoints, Endpoints: p.Endpoints,
DialTimeout: defaultDialTimeout, DialTimeout: defaultDialTimeout,

12
storage/etcd/etcd.go

@ -4,12 +4,12 @@ import (
"context" "context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"log/slog"
"strings" "strings"
"time" "time"
clientv3 "go.etcd.io/etcd/client/v3" clientv3 "go.etcd.io/etcd/client/v3"
"github.com/dexidp/dex/pkg/log"
"github.com/dexidp/dex/storage" "github.com/dexidp/dex/storage"
) )
@ -33,7 +33,7 @@ var _ storage.Storage = (*conn)(nil)
type conn struct { type conn struct {
db *clientv3.Client db *clientv3.Client
logger log.Logger logger *slog.Logger
} }
func (c *conn) Close() error { func (c *conn) Close() error {
@ -52,7 +52,7 @@ func (c *conn) GarbageCollect(now time.Time) (result storage.GCResult, err error
for _, authRequest := range authRequests { for _, authRequest := range authRequests {
if now.After(authRequest.Expiry) { if now.After(authRequest.Expiry) {
if err := c.deleteKey(ctx, keyID(authRequestPrefix, authRequest.ID)); err != nil { if err := c.deleteKey(ctx, keyID(authRequestPrefix, authRequest.ID)); err != nil {
c.logger.Errorf("failed to delete auth request: %v", err) c.logger.Error("failed to delete auth request", "err", err)
delErr = fmt.Errorf("failed to delete auth request: %v", err) delErr = fmt.Errorf("failed to delete auth request: %v", err)
} }
result.AuthRequests++ result.AuthRequests++
@ -70,7 +70,7 @@ func (c *conn) GarbageCollect(now time.Time) (result storage.GCResult, err error
for _, authCode := range authCodes { for _, authCode := range authCodes {
if now.After(authCode.Expiry) { if now.After(authCode.Expiry) {
if err := c.deleteKey(ctx, keyID(authCodePrefix, authCode.ID)); err != nil { if err := c.deleteKey(ctx, keyID(authCodePrefix, authCode.ID)); err != nil {
c.logger.Errorf("failed to delete auth code %v", err) c.logger.Error("failed to delete auth code", "err", err)
delErr = fmt.Errorf("failed to delete auth code: %v", err) delErr = fmt.Errorf("failed to delete auth code: %v", err)
} }
result.AuthCodes++ result.AuthCodes++
@ -85,7 +85,7 @@ func (c *conn) GarbageCollect(now time.Time) (result storage.GCResult, err error
for _, deviceRequest := range deviceRequests { for _, deviceRequest := range deviceRequests {
if now.After(deviceRequest.Expiry) { if now.After(deviceRequest.Expiry) {
if err := c.deleteKey(ctx, keyID(deviceRequestPrefix, deviceRequest.UserCode)); err != nil { if err := c.deleteKey(ctx, keyID(deviceRequestPrefix, deviceRequest.UserCode)); err != nil {
c.logger.Errorf("failed to delete device request %v", err) c.logger.Error("failed to delete device request", "err", err)
delErr = fmt.Errorf("failed to delete device request: %v", err) delErr = fmt.Errorf("failed to delete device request: %v", err)
} }
result.DeviceRequests++ result.DeviceRequests++
@ -100,7 +100,7 @@ func (c *conn) GarbageCollect(now time.Time) (result storage.GCResult, err error
for _, deviceToken := range deviceTokens { for _, deviceToken := range deviceTokens {
if now.After(deviceToken.Expiry) { if now.After(deviceToken.Expiry) {
if err := c.deleteKey(ctx, keyID(deviceTokenPrefix, deviceToken.DeviceCode)); err != nil { if err := c.deleteKey(ctx, keyID(deviceTokenPrefix, deviceToken.DeviceCode)); err != nil {
c.logger.Errorf("failed to delete device token %v", err) c.logger.Error("failed to delete device token", "err", err)
delErr = fmt.Errorf("failed to delete device token: %v", err) delErr = fmt.Errorf("failed to delete device token: %v", err)
} }
result.DeviceTokens++ result.DeviceTokens++

9
storage/etcd/etcd_test.go

@ -3,13 +3,14 @@ package etcd
import ( import (
"context" "context"
"fmt" "fmt"
"io"
"log/slog"
"os" "os"
"runtime" "runtime"
"strings" "strings"
"testing" "testing"
"time" "time"
"github.com/sirupsen/logrus"
clientv3 "go.etcd.io/etcd/client/v3" clientv3 "go.etcd.io/etcd/client/v3"
"github.com/dexidp/dex/storage" "github.com/dexidp/dex/storage"
@ -55,11 +56,7 @@ func cleanDB(c *conn) error {
return nil return nil
} }
var logger = &logrus.Logger{ var logger = slog.New(slog.NewTextHandler(io.Discard, &slog.HandlerOptions{}))
Out: os.Stderr,
Formatter: &logrus.TextFormatter{DisableColors: true},
Level: logrus.DebugLevel,
}
func TestEtcd(t *testing.T) { func TestEtcd(t *testing.T) {
testEtcdEnv := "DEX_ETCD_ENDPOINTS" testEtcdEnv := "DEX_ETCD_ENDPOINTS"

10
storage/kubernetes/client.go

@ -13,6 +13,7 @@ import (
"hash" "hash"
"hash/fnv" "hash/fnv"
"io" "io"
"log/slog"
"net" "net"
"net/http" "net/http"
"net/url" "net/url"
@ -27,7 +28,6 @@ import (
"github.com/ghodss/yaml" "github.com/ghodss/yaml"
"golang.org/x/net/http2" "golang.org/x/net/http2"
"github.com/dexidp/dex/pkg/log"
"github.com/dexidp/dex/storage" "github.com/dexidp/dex/storage"
"github.com/dexidp/dex/storage/kubernetes/k8sapi" "github.com/dexidp/dex/storage/kubernetes/k8sapi"
) )
@ -36,7 +36,7 @@ type client struct {
client *http.Client client *http.Client
baseURL string baseURL string
namespace string namespace string
logger log.Logger logger *slog.Logger
// Hash function to map IDs (which could span a large range) to Kubernetes names. // Hash function to map IDs (which could span a large range) to Kubernetes names.
// While this is not currently upgradable, it could be in the future. // While this is not currently upgradable, it could be in the future.
@ -268,7 +268,7 @@ func (cli *client) detectKubernetesVersion() error {
clusterVersion, err := semver.NewVersion(version.GitVersion) clusterVersion, err := semver.NewVersion(version.GitVersion)
if err != nil { if err != nil {
cli.logger.Warnf("cannot detect Kubernetes version (%s): %v", clusterVersion, err) cli.logger.Warn("cannot detect Kubernetes version", "version", clusterVersion, "err", err)
return nil return nil
} }
@ -358,7 +358,7 @@ func defaultTLSConfig() *tls.Config {
} }
} }
func newClient(cluster k8sapi.Cluster, user k8sapi.AuthInfo, namespace string, logger log.Logger, inCluster bool) (*client, error) { func newClient(cluster k8sapi.Cluster, user k8sapi.AuthInfo, namespace string, logger *slog.Logger, inCluster bool) (*client, error) {
tlsConfig := defaultTLSConfig() tlsConfig := defaultTLSConfig()
data := func(b string, file string) ([]byte, error) { data := func(b string, file string) ([]byte, error) {
if b != "" { if b != "" {
@ -418,7 +418,7 @@ func newClient(cluster k8sapi.Cluster, user k8sapi.AuthInfo, namespace string, l
apiVersion := "dex.coreos.com/v1" apiVersion := "dex.coreos.com/v1"
logger.Infof("kubernetes client apiVersion = %s", apiVersion) logger.Info("kubernetes client", "api_version", apiVersion)
return &client{ return &client{
client: &http.Client{ client: &http.Client{
Transport: t, Transport: t,

9
storage/kubernetes/client_test.go

@ -3,6 +3,8 @@ package kubernetes
import ( import (
"hash" "hash"
"hash/fnv" "hash/fnv"
"io"
"log/slog"
"net/http" "net/http"
"os" "os"
"path/filepath" "path/filepath"
@ -10,7 +12,6 @@ import (
"testing" "testing"
"time" "time"
"github.com/sirupsen/logrus"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/dexidp/dex/storage/kubernetes/k8sapi" "github.com/dexidp/dex/storage/kubernetes/k8sapi"
@ -52,11 +53,7 @@ func TestOfflineTokenName(t *testing.T) {
} }
func TestInClusterTransport(t *testing.T) { func TestInClusterTransport(t *testing.T) {
logger := &logrus.Logger{ logger := slog.New(slog.NewTextHandler(io.Discard, &slog.HandlerOptions{}))
Out: os.Stderr,
Formatter: &logrus.TextFormatter{DisableColors: true},
Level: logrus.DebugLevel,
}
user := k8sapi.AuthInfo{Token: "abc"} user := k8sapi.AuthInfo{Token: "abc"}
cli, err := newClient( cli, err := newClient(

6
storage/kubernetes/lock.go

@ -53,14 +53,14 @@ func (l *refreshTokenLock) Unlock(id string) {
r, err := l.cli.getRefreshToken(id) r, err := l.cli.getRefreshToken(id)
if err != nil { if err != nil {
l.cli.logger.Debugf("failed to get resource to release lock for refresh token %s: %v", id, err) l.cli.logger.Debug("failed to get resource to release lock for refresh token", "token_id", id, "err", err)
return return
} }
r.Annotations = nil r.Annotations = nil
err = l.cli.put(resourceRefreshToken, r.ObjectMeta.Name, r) err = l.cli.put(resourceRefreshToken, r.ObjectMeta.Name, r)
if err != nil { if err != nil {
l.cli.logger.Debugf("failed to release lock for refresh token %s: %v", id, err) l.cli.logger.Debug("failed to release lock for refresh token", "token_id", id, "err", err)
} }
} }
@ -114,7 +114,7 @@ func (l *refreshTokenLock) setLockAnnotation(id string) (bool, error) {
return false, nil return false, nil
} }
l.cli.logger.Debugf("break lock annotation error: %v", err) l.cli.logger.Debug("break lock annotation", "error", err)
if isKubernetesAPIConflictError(err) { if isKubernetesAPIConflictError(err) {
l.waitingState = true l.waitingState = true
// after breaking error waiting for the lock to be released // after breaking error waiting for the lock to be released

34
storage/kubernetes/storage.go

@ -4,12 +4,12 @@ import (
"context" "context"
"errors" "errors"
"fmt" "fmt"
"log/slog"
"math/rand" "math/rand"
"net/http" "net/http"
"strings" "strings"
"time" "time"
"github.com/dexidp/dex/pkg/log"
"github.com/dexidp/dex/storage" "github.com/dexidp/dex/storage"
"github.com/dexidp/dex/storage/kubernetes/k8sapi" "github.com/dexidp/dex/storage/kubernetes/k8sapi"
) )
@ -53,7 +53,7 @@ type Config struct {
} }
// Open returns a storage using Kubernetes third party resource. // Open returns a storage using Kubernetes third party resource.
func (c *Config) Open(logger log.Logger) (storage.Storage, error) { func (c *Config) Open(logger *slog.Logger) (storage.Storage, error) {
cli, err := c.open(logger, false) cli, err := c.open(logger, false)
if err != nil { if err != nil {
return nil, err return nil, err
@ -66,7 +66,7 @@ func (c *Config) Open(logger log.Logger) (storage.Storage, error) {
// //
// waitForResources controls if errors creating the resources cause this method to return // waitForResources controls if errors creating the resources cause this method to return
// immediately (used during testing), or if the client will asynchronously retry. // immediately (used during testing), or if the client will asynchronously retry.
func (c *Config) open(logger log.Logger, waitForResources bool) (*client, error) { func (c *Config) open(logger *slog.Logger, waitForResources bool) (*client, error) {
if c.InCluster && (c.KubeConfigFile != "") { if c.InCluster && (c.KubeConfigFile != "") {
return nil, errors.New("cannot specify both 'inCluster' and 'kubeConfigFile'") return nil, errors.New("cannot specify both 'inCluster' and 'kubeConfigFile'")
} }
@ -155,12 +155,12 @@ func (cli *client) registerCustomResources() (ok bool) {
r := definitions[i] r := definitions[i]
var i interface{} var i interface{}
cli.logger.Infof("checking if custom resource %s has already been created...", r.ObjectMeta.Name) cli.logger.Info("checking if custom resource has already been created...", "object", r.ObjectMeta.Name)
if err := cli.list(r.Spec.Names.Plural, &i); err == nil { if err := cli.list(r.Spec.Names.Plural, &i); err == nil {
cli.logger.Infof("The custom resource %s already available, skipping create", r.ObjectMeta.Name) cli.logger.Info("the custom resource already available, skipping create", "object", r.ObjectMeta.Name)
continue continue
} else { } else {
cli.logger.Infof("failed to list custom resource %s, attempting to create: %v", r.ObjectMeta.Name, err) cli.logger.Info("failed to list custom resource, attempting to create", "object", r.ObjectMeta.Name, "err", err)
} }
err = cli.postResource(cli.crdAPIVersion, "", "customresourcedefinitions", r) err = cli.postResource(cli.crdAPIVersion, "", "customresourcedefinitions", r)
@ -169,17 +169,17 @@ func (cli *client) registerCustomResources() (ok bool) {
if err != nil { if err != nil {
switch err { switch err {
case storage.ErrAlreadyExists: case storage.ErrAlreadyExists:
cli.logger.Infof("custom resource already created %s", resourceName) cli.logger.Info("custom resource already created", "object", resourceName)
case storage.ErrNotFound: case storage.ErrNotFound:
cli.logger.Errorf("custom resources not found, please enable the respective API group") cli.logger.Error("custom resources not found, please enable the respective API group")
ok = false ok = false
default: default:
cli.logger.Errorf("creating custom resource %s: %v", resourceName, err) cli.logger.Error("creating custom resource", "object", resourceName, "err", err)
ok = false ok = false
} }
continue continue
} }
cli.logger.Errorf("create custom resource %s", resourceName) cli.logger.Error("create custom resource", "object", resourceName)
} }
return ok return ok
} }
@ -197,7 +197,7 @@ func (cli *client) waitForCRDs(ctx context.Context) error {
break break
} }
cli.logger.Errorf("checking CRD: %v", err) cli.logger.ErrorContext(ctx, "checking CRD", "err", err)
select { select {
case <-ctx.Done(): case <-ctx.Done():
@ -556,7 +556,7 @@ func (cli *client) UpdateKeys(updater func(old storage.Keys) (storage.Keys, erro
err = cli.post(resourceKeys, newKeys) err = cli.post(resourceKeys, newKeys)
if err != nil && errors.Is(err, storage.ErrAlreadyExists) { if err != nil && errors.Is(err, storage.ErrAlreadyExists) {
// We need to tolerate conflicts here in case of HA mode. // We need to tolerate conflicts here in case of HA mode.
cli.logger.Debugf("Keys creation failed: %v. It is possible that keys have already been created by another dex instance.", err) cli.logger.Debug("Keys creation failed. It is possible that keys have already been created by another dex instance.", "err", err)
return errors.New("keys already created by another server instance") return errors.New("keys already created by another server instance")
} }
@ -569,7 +569,7 @@ func (cli *client) UpdateKeys(updater func(old storage.Keys) (storage.Keys, erro
if isKubernetesAPIConflictError(err) { if isKubernetesAPIConflictError(err) {
// We need to tolerate conflicts here in case of HA mode. // We need to tolerate conflicts here in case of HA mode.
// Dex instances run keys rotation at the same time because they use SigningKey.nextRotation CR field as a trigger. // Dex instances run keys rotation at the same time because they use SigningKey.nextRotation CR field as a trigger.
cli.logger.Debugf("Keys rotation failed: %v. It is possible that keys have already been rotated by another dex instance.", err) cli.logger.Debug("Keys rotation failed. It is possible that keys have already been rotated by another dex instance.", "err", err)
return errors.New("keys already rotated by another server instance") return errors.New("keys already rotated by another server instance")
} }
@ -622,7 +622,7 @@ func (cli *client) GarbageCollect(now time.Time) (result storage.GCResult, err e
for _, authRequest := range authRequests.AuthRequests { for _, authRequest := range authRequests.AuthRequests {
if now.After(authRequest.Expiry) { if now.After(authRequest.Expiry) {
if err := cli.delete(resourceAuthRequest, authRequest.ObjectMeta.Name); err != nil { if err := cli.delete(resourceAuthRequest, authRequest.ObjectMeta.Name); err != nil {
cli.logger.Errorf("failed to delete auth request: %v", err) cli.logger.Error("failed to delete auth request", "err", err)
delErr = fmt.Errorf("failed to delete auth request: %v", err) delErr = fmt.Errorf("failed to delete auth request: %v", err)
} }
result.AuthRequests++ result.AuthRequests++
@ -640,7 +640,7 @@ func (cli *client) GarbageCollect(now time.Time) (result storage.GCResult, err e
for _, authCode := range authCodes.AuthCodes { for _, authCode := range authCodes.AuthCodes {
if now.After(authCode.Expiry) { if now.After(authCode.Expiry) {
if err := cli.delete(resourceAuthCode, authCode.ObjectMeta.Name); err != nil { if err := cli.delete(resourceAuthCode, authCode.ObjectMeta.Name); err != nil {
cli.logger.Errorf("failed to delete auth code %v", err) cli.logger.Error("failed to delete auth code", "err", err)
delErr = fmt.Errorf("failed to delete auth code: %v", err) delErr = fmt.Errorf("failed to delete auth code: %v", err)
} }
result.AuthCodes++ result.AuthCodes++
@ -655,7 +655,7 @@ func (cli *client) GarbageCollect(now time.Time) (result storage.GCResult, err e
for _, deviceRequest := range deviceRequests.DeviceRequests { for _, deviceRequest := range deviceRequests.DeviceRequests {
if now.After(deviceRequest.Expiry) { if now.After(deviceRequest.Expiry) {
if err := cli.delete(resourceDeviceRequest, deviceRequest.ObjectMeta.Name); err != nil { if err := cli.delete(resourceDeviceRequest, deviceRequest.ObjectMeta.Name); err != nil {
cli.logger.Errorf("failed to delete device request: %v", err) cli.logger.Error("failed to delete device request", "err", err)
delErr = fmt.Errorf("failed to delete device request: %v", err) delErr = fmt.Errorf("failed to delete device request: %v", err)
} }
result.DeviceRequests++ result.DeviceRequests++
@ -670,7 +670,7 @@ func (cli *client) GarbageCollect(now time.Time) (result storage.GCResult, err e
for _, deviceToken := range deviceTokens.DeviceTokens { for _, deviceToken := range deviceTokens.DeviceTokens {
if now.After(deviceToken.Expiry) { if now.After(deviceToken.Expiry) {
if err := cli.delete(resourceDeviceToken, deviceToken.ObjectMeta.Name); err != nil { if err := cli.delete(resourceDeviceToken, deviceToken.ObjectMeta.Name); err != nil {
cli.logger.Errorf("failed to delete device token: %v", err) cli.logger.Error("failed to delete device token", "err", err)
delErr = fmt.Errorf("failed to delete device token: %v", err) delErr = fmt.Errorf("failed to delete device token: %v", err)
} }
result.DeviceTokens++ result.DeviceTokens++

21
storage/kubernetes/storage_test.go

@ -5,6 +5,8 @@ import (
"crypto/tls" "crypto/tls"
"errors" "errors"
"fmt" "fmt"
"io"
"log/slog"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"os" "os"
@ -13,7 +15,6 @@ import (
"testing" "testing"
"time" "time"
"github.com/sirupsen/logrus"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
@ -57,11 +58,7 @@ func (s *StorageTestSuite) SetupTest() {
KubeConfigFile: kubeconfigPath, KubeConfigFile: kubeconfigPath,
} }
logger := &logrus.Logger{ logger := slog.New(slog.NewTextHandler(io.Discard, &slog.HandlerOptions{}))
Out: os.Stderr,
Formatter: &logrus.TextFormatter{DisableColors: true},
Level: logrus.DebugLevel,
}
kubeClient, err := config.open(logger, true) kubeClient, err := config.open(logger, true)
s.Require().NoError(err) s.Require().NoError(err)
@ -253,11 +250,7 @@ func newStatusCodesResponseTestClient(getResponseCode, actionResponseCode int) *
return &client{ return &client{
client: &http.Client{Transport: tr}, client: &http.Client{Transport: tr},
baseURL: s.URL, baseURL: s.URL,
logger: &logrus.Logger{ logger: slog.New(slog.NewTextHandler(io.Discard, &slog.HandlerOptions{})),
Out: os.Stderr,
Formatter: &logrus.TextFormatter{DisableColors: true},
Level: logrus.DebugLevel,
},
} }
} }
@ -314,11 +307,7 @@ func TestRefreshTokenLock(t *testing.T) {
KubeConfigFile: kubeconfigPath, KubeConfigFile: kubeconfigPath,
} }
logger := &logrus.Logger{ logger := slog.New(slog.NewTextHandler(io.Discard, &slog.HandlerOptions{}))
Out: os.Stderr,
Formatter: &logrus.TextFormatter{DisableColors: true},
Level: logrus.DebugLevel,
}
kubeClient, err := config.open(logger, true) kubeClient, err := config.open(logger, true)
require.NoError(t, err) require.NoError(t, err)

8
storage/memory/memory.go

@ -3,18 +3,18 @@ package memory
import ( import (
"context" "context"
"log/slog"
"strings" "strings"
"sync" "sync"
"time" "time"
"github.com/dexidp/dex/pkg/log"
"github.com/dexidp/dex/storage" "github.com/dexidp/dex/storage"
) )
var _ storage.Storage = (*memStorage)(nil) var _ storage.Storage = (*memStorage)(nil)
// New returns an in memory storage. // New returns an in memory storage.
func New(logger log.Logger) storage.Storage { func New(logger *slog.Logger) storage.Storage {
return &memStorage{ return &memStorage{
clients: make(map[string]storage.Client), clients: make(map[string]storage.Client),
authCodes: make(map[string]storage.AuthCode), authCodes: make(map[string]storage.AuthCode),
@ -36,7 +36,7 @@ type Config struct { // The in memory implementation has no config.
} }
// Open always returns a new in memory storage. // Open always returns a new in memory storage.
func (c *Config) Open(logger log.Logger) (storage.Storage, error) { func (c *Config) Open(logger *slog.Logger) (storage.Storage, error) {
return New(logger), nil return New(logger), nil
} }
@ -55,7 +55,7 @@ type memStorage struct {
keys storage.Keys keys storage.Keys
logger log.Logger logger *slog.Logger
} }
type offlineSessionID struct { type offlineSessionID struct {

11
storage/memory/memory_test.go

@ -1,21 +1,16 @@
package memory package memory
import ( import (
"os" "io"
"log/slog"
"testing" "testing"
"github.com/sirupsen/logrus"
"github.com/dexidp/dex/storage" "github.com/dexidp/dex/storage"
"github.com/dexidp/dex/storage/conformance" "github.com/dexidp/dex/storage/conformance"
) )
func TestStorage(t *testing.T) { func TestStorage(t *testing.T) {
logger := &logrus.Logger{ logger := slog.New(slog.NewTextHandler(io.Discard, &slog.HandlerOptions{}))
Out: os.Stderr,
Formatter: &logrus.TextFormatter{DisableColors: true},
Level: logrus.DebugLevel,
}
newStorage := func() storage.Storage { newStorage := func() storage.Storage {
return New(logger) return New(logger)

23
storage/memory/static_test.go

@ -3,22 +3,17 @@ package memory
import ( import (
"context" "context"
"fmt" "fmt"
"os" "io"
"log/slog"
"strings" "strings"
"testing" "testing"
"github.com/sirupsen/logrus"
"github.com/dexidp/dex/storage" "github.com/dexidp/dex/storage"
) )
func TestStaticClients(t *testing.T) { func TestStaticClients(t *testing.T) {
ctx := context.Background() ctx := context.Background()
logger := &logrus.Logger{ logger := slog.New(slog.NewTextHandler(io.Discard, &slog.HandlerOptions{}))
Out: os.Stderr,
Formatter: &logrus.TextFormatter{DisableColors: true},
Level: logrus.DebugLevel,
}
backing := New(logger) backing := New(logger)
c1 := storage.Client{ID: "foo", Secret: "foo_secret"} c1 := storage.Client{ID: "foo", Secret: "foo_secret"}
@ -102,11 +97,7 @@ func TestStaticClients(t *testing.T) {
func TestStaticPasswords(t *testing.T) { func TestStaticPasswords(t *testing.T) {
ctx := context.Background() ctx := context.Background()
logger := &logrus.Logger{ logger := slog.New(slog.NewTextHandler(io.Discard, &slog.HandlerOptions{}))
Out: os.Stderr,
Formatter: &logrus.TextFormatter{DisableColors: true},
Level: logrus.DebugLevel,
}
backing := New(logger) backing := New(logger)
p1 := storage.Password{Email: "foo@example.com", Username: "foo_secret"} p1 := storage.Password{Email: "foo@example.com", Username: "foo_secret"}
@ -215,11 +206,7 @@ func TestStaticPasswords(t *testing.T) {
func TestStaticConnectors(t *testing.T) { func TestStaticConnectors(t *testing.T) {
ctx := context.Background() ctx := context.Background()
logger := &logrus.Logger{ logger := slog.New(slog.NewTextHandler(io.Discard, &slog.HandlerOptions{}))
Out: os.Stderr,
Formatter: &logrus.TextFormatter{DisableColors: true},
Level: logrus.DebugLevel,
}
backing := New(logger) backing := New(logger)
config1 := []byte(`{"issuer": "https://accounts.google.com"}`) config1 := []byte(`{"issuer": "https://accounts.google.com"}`)

12
storage/sql/config.go

@ -5,6 +5,7 @@ import (
"crypto/x509" "crypto/x509"
"database/sql" "database/sql"
"fmt" "fmt"
"log/slog"
"net" "net"
"os" "os"
"regexp" "regexp"
@ -15,7 +16,6 @@ import (
"github.com/go-sql-driver/mysql" "github.com/go-sql-driver/mysql"
"github.com/lib/pq" "github.com/lib/pq"
"github.com/dexidp/dex/pkg/log"
"github.com/dexidp/dex/storage" "github.com/dexidp/dex/storage"
) )
@ -31,7 +31,6 @@ const (
mysqlErrUnknownSysVar = 1193 mysqlErrUnknownSysVar = 1193
) )
//nolint
const ( const (
// postgres SSL modes // postgres SSL modes
pgSSLDisable = "disable" pgSSLDisable = "disable"
@ -40,7 +39,6 @@ const (
pgSSLVerifyFull = "verify-full" pgSSLVerifyFull = "verify-full"
) )
//nolint
const ( const (
// MySQL SSL modes // MySQL SSL modes
mysqlSSLTrue = "true" mysqlSSLTrue = "true"
@ -84,7 +82,7 @@ type Postgres struct {
} }
// Open creates a new storage implementation backed by Postgres. // Open creates a new storage implementation backed by Postgres.
func (p *Postgres) Open(logger log.Logger) (storage.Storage, error) { func (p *Postgres) Open(logger *slog.Logger) (storage.Storage, error) {
conn, err := p.open(logger) conn, err := p.open(logger)
if err != nil { if err != nil {
return nil, err return nil, err
@ -164,7 +162,7 @@ func (p *Postgres) createDataSourceName() string {
return strings.Join(parameters, " ") return strings.Join(parameters, " ")
} }
func (p *Postgres) open(logger log.Logger) (*conn, error) { func (p *Postgres) open(logger *slog.Logger) (*conn, error) {
dataSourceName := p.createDataSourceName() dataSourceName := p.createDataSourceName()
db, err := sql.Open("postgres", dataSourceName) db, err := sql.Open("postgres", dataSourceName)
@ -216,7 +214,7 @@ type MySQL struct {
} }
// Open creates a new storage implementation backed by MySQL. // Open creates a new storage implementation backed by MySQL.
func (s *MySQL) Open(logger log.Logger) (storage.Storage, error) { func (s *MySQL) Open(logger *slog.Logger) (storage.Storage, error) {
conn, err := s.open(logger) conn, err := s.open(logger)
if err != nil { if err != nil {
return nil, err return nil, err
@ -224,7 +222,7 @@ func (s *MySQL) Open(logger log.Logger) (storage.Storage, error) {
return conn, nil return conn, nil
} }
func (s *MySQL) open(logger log.Logger) (*conn, error) { func (s *MySQL) open(logger *slog.Logger) (*conn, error) {
cfg := mysql.Config{ cfg := mysql.Config{
User: s.User, User: s.User,
Passwd: s.Password, Passwd: s.Password,

13
storage/sql/config_test.go

@ -2,15 +2,14 @@ package sql
import ( import (
"fmt" "fmt"
"io"
"log/slog"
"os" "os"
"runtime" "runtime"
"strconv" "strconv"
"testing" "testing"
"time" "time"
"github.com/sirupsen/logrus"
"github.com/dexidp/dex/pkg/log"
"github.com/dexidp/dex/storage" "github.com/dexidp/dex/storage"
"github.com/dexidp/dex/storage/conformance" "github.com/dexidp/dex/storage/conformance"
) )
@ -48,14 +47,10 @@ func cleanDB(c *conn) error {
return nil return nil
} }
var logger = &logrus.Logger{ var logger = slog.New(slog.NewTextHandler(io.Discard, &slog.HandlerOptions{}))
Out: os.Stderr,
Formatter: &logrus.TextFormatter{DisableColors: true},
Level: logrus.DebugLevel,
}
type opener interface { type opener interface {
open(logger log.Logger) (*conn, error) open(logger *slog.Logger) (*conn, error)
} }
func testDB(t *testing.T, o opener, withTransactions bool) { func testDB(t *testing.T, o opener, withTransactions bool) {

10
storage/sql/migrate_test.go

@ -5,11 +5,11 @@ package sql
import ( import (
"database/sql" "database/sql"
"os" "io"
"log/slog"
"testing" "testing"
sqlite3 "github.com/mattn/go-sqlite3" sqlite3 "github.com/mattn/go-sqlite3"
"github.com/sirupsen/logrus"
) )
func TestMigrate(t *testing.T) { func TestMigrate(t *testing.T) {
@ -19,11 +19,7 @@ func TestMigrate(t *testing.T) {
} }
defer db.Close() defer db.Close()
logger := &logrus.Logger{ logger := slog.New(slog.NewTextHandler(io.Discard, &slog.HandlerOptions{}))
Out: os.Stderr,
Formatter: &logrus.TextFormatter{DisableColors: true},
Level: logrus.DebugLevel,
}
errCheck := func(err error) bool { errCheck := func(err error) bool {
sqlErr, ok := err.(sqlite3.Error) sqlErr, ok := err.(sqlite3.Error)

5
storage/sql/sql.go

@ -3,14 +3,13 @@ package sql
import ( import (
"database/sql" "database/sql"
"log/slog"
"regexp" "regexp"
"time" "time"
// import third party drivers // import third party drivers
_ "github.com/lib/pq" _ "github.com/lib/pq"
_ "github.com/mattn/go-sqlite3" _ "github.com/mattn/go-sqlite3"
"github.com/dexidp/dex/pkg/log"
) )
// flavor represents a specific SQL implementation, and is used to translate query strings // flavor represents a specific SQL implementation, and is used to translate query strings
@ -131,7 +130,7 @@ func (c *conn) translateArgs(args []interface{}) []interface{} {
type conn struct { type conn struct {
db *sql.DB db *sql.DB
flavor *flavor flavor *flavor
logger log.Logger logger *slog.Logger
alreadyExistsCheck func(err error) bool alreadyExistsCheck func(err error) bool
} }

6
storage/sql/sqlite.go

@ -6,10 +6,10 @@ package sql
import ( import (
"database/sql" "database/sql"
"fmt" "fmt"
"log/slog"
sqlite3 "github.com/mattn/go-sqlite3" sqlite3 "github.com/mattn/go-sqlite3"
"github.com/dexidp/dex/pkg/log"
"github.com/dexidp/dex/storage" "github.com/dexidp/dex/storage"
) )
@ -20,7 +20,7 @@ type SQLite3 struct {
} }
// Open creates a new storage implementation backed by SQLite3 // Open creates a new storage implementation backed by SQLite3
func (s *SQLite3) Open(logger log.Logger) (storage.Storage, error) { func (s *SQLite3) Open(logger *slog.Logger) (storage.Storage, error) {
conn, err := s.open(logger) conn, err := s.open(logger)
if err != nil { if err != nil {
return nil, err return nil, err
@ -28,7 +28,7 @@ func (s *SQLite3) Open(logger log.Logger) (storage.Storage, error) {
return conn, nil return conn, nil
} }
func (s *SQLite3) open(logger log.Logger) (*conn, error) { func (s *SQLite3) open(logger *slog.Logger) (*conn, error) {
db, err := sql.Open("sqlite3", s.File) db, err := sql.Open("sqlite3", s.File)
if err != nil { if err != nil {
return nil, err return nil, err

9
storage/static.go

@ -3,9 +3,8 @@ package storage
import ( import (
"context" "context"
"errors" "errors"
"log/slog"
"strings" "strings"
"github.com/dexidp/dex/pkg/log"
) )
// Tests for this code are in the "memory" package, since this package doesn't // Tests for this code are in the "memory" package, since this package doesn't
@ -90,17 +89,17 @@ type staticPasswordsStorage struct {
// A map of passwords that is indexed by lower-case email ids // A map of passwords that is indexed by lower-case email ids
passwordsByEmail map[string]Password passwordsByEmail map[string]Password
logger log.Logger logger *slog.Logger
} }
// WithStaticPasswords returns a storage with a read-only set of passwords. // WithStaticPasswords returns a storage with a read-only set of passwords.
func WithStaticPasswords(s Storage, staticPasswords []Password, logger log.Logger) Storage { func WithStaticPasswords(s Storage, staticPasswords []Password, logger *slog.Logger) Storage {
passwordsByEmail := make(map[string]Password, len(staticPasswords)) passwordsByEmail := make(map[string]Password, len(staticPasswords))
for _, p := range staticPasswords { for _, p := range staticPasswords {
// Enable case insensitive email comparison. // Enable case insensitive email comparison.
lowerEmail := strings.ToLower(p.Email) lowerEmail := strings.ToLower(p.Email)
if _, ok := passwordsByEmail[lowerEmail]; ok { if _, ok := passwordsByEmail[lowerEmail]; ok {
logger.Errorf("Attempting to create StaticPasswords with the same email id: %s", p.Email) logger.Error("attempting to create StaticPasswords with the same email id", "email", p.Email)
} }
passwordsByEmail[lowerEmail] = p passwordsByEmail[lowerEmail] = p
} }

Loading…
Cancel
Save