From 5a7d2b8db73fffc59653a869ab600f5864cdc8ac Mon Sep 17 00:00:00 2001 From: "maksim.nabokikh" Date: Mon, 16 Mar 2026 21:39:44 +0100 Subject: [PATCH] Fixes and refactoring: Update session Signed-off-by: maksim.nabokikh --- cmd/dex/config.go | 10 +++------- cmd/dex/serve.go | 39 ++++++++++++++++++++++----------------- server/session.go | 4 +++- server/session_test.go | 3 ++- 4 files changed, 30 insertions(+), 26 deletions(-) diff --git a/cmd/dex/config.go b/cmd/dex/config.go index c3d381e1..32f8e832 100644 --- a/cmd/dex/config.go +++ b/cmd/dex/config.go @@ -68,7 +68,7 @@ type Config struct { // Sessions holds authentication session configuration. // Requires DEX_SESSIONS_ENABLED=true feature flag. - Sessions Sessions `json:"sessions"` + Sessions *Sessions `json:"sessions"` } // Validate the configuration @@ -108,7 +108,7 @@ func (c Config) Validate() error { return fmt.Errorf("invalid Config:\n\t-\t%s", strings.Join(checkErrors, "\n\t-\t")) } - if c.Sessions.isSet() && !featureflags.SessionsEnabled.Enabled() { + if c.Sessions != nil && !featureflags.SessionsEnabled.Enabled() { return fmt.Errorf("sessions config requires sessions to be enabled (DEX_SESSIONS_ENABLED=true)") } @@ -604,9 +604,5 @@ type Sessions struct { // ValidIfNotUsedFor is the idle timeout. Defaults to "1h". ValidIfNotUsedFor string `json:"validIfNotUsedFor"` // RememberMeCheckedByDefault controls the default state of the "remember me" checkbox. - RememberMeCheckedByDefault bool `json:"rememberMeCheckedByDefault"` -} - -func (s Sessions) isSet() bool { - return s.CookieName != "" || s.AbsoluteLifetime != "" || s.ValidIfNotUsedFor != "" || s.RememberMeCheckedByDefault + RememberMeCheckedByDefault *bool `json:"rememberMeCheckedByDefault"` } diff --git a/cmd/dex/serve.go b/cmd/dex/serve.go index e12c3551..8c616931 100644 --- a/cmd/dex/serve.go +++ b/cmd/dex/serve.go @@ -773,29 +773,34 @@ func recordBuildInfo() { buildInfo.WithLabelValues(version, runtime.Version(), fmt.Sprintf("%s/%s", runtime.GOOS, runtime.GOARCH)).Set(1) } -func parseSessionConfig(c Sessions) (*server.SessionConfig, error) { +func parseSessionConfig(s *Sessions) (*server.SessionConfig, error) { sc := &server.SessionConfig{ - CookieName: c.CookieName, + CookieName: "dex_session", AbsoluteLifetime: 24 * time.Hour, ValidIfNotUsedFor: 1 * time.Hour, - RememberMeCheckedByDefault: c.RememberMeCheckedByDefault, + RememberMeCheckedByDefault: true, } - if sc.CookieName == "" { - sc.CookieName = "dex_session" - } - if c.AbsoluteLifetime != "" { - d, err := time.ParseDuration(c.AbsoluteLifetime) - if err != nil { - return nil, fmt.Errorf("invalid absoluteLifetime %q: %v", c.AbsoluteLifetime, err) + if s != nil { + if s.CookieName != "" { + sc.CookieName = s.CookieName } - sc.AbsoluteLifetime = d - } - if c.ValidIfNotUsedFor != "" { - d, err := time.ParseDuration(c.ValidIfNotUsedFor) - if err != nil { - return nil, fmt.Errorf("invalid validIfNotUsedFor %q: %v", c.ValidIfNotUsedFor, err) + if s.AbsoluteLifetime != "" { + d, err := time.ParseDuration(s.AbsoluteLifetime) + if err != nil { + return nil, fmt.Errorf("invalid absoluteLifetime %q: %v", s.AbsoluteLifetime, err) + } + sc.AbsoluteLifetime = d + } + if s.ValidIfNotUsedFor != "" { + d, err := time.ParseDuration(s.ValidIfNotUsedFor) + if err != nil { + return nil, fmt.Errorf("invalid validIfNotUsedFor %q: %v", s.ValidIfNotUsedFor, err) + } + sc.ValidIfNotUsedFor = d + } + if s.RememberMeCheckedByDefault != nil { + sc.RememberMeCheckedByDefault = *s.RememberMeCheckedByDefault } - sc.ValidIfNotUsedFor = d } return sc, nil } diff --git a/server/session.go b/server/session.go index 73364926..b307c99d 100644 --- a/server/session.go +++ b/server/session.go @@ -5,6 +5,7 @@ import ( "crypto/hmac" "crypto/sha256" "encoding/base64" + "errors" "fmt" "net/http" "path" @@ -24,6 +25,7 @@ func (s *Server) rememberMeDefault() *bool { // sessionCookieValue encodes session identity into a cookie value. // Format: base64url(userID) + "." + base64url(connectorID) + "." + nonce +// TODO(nabokihms): consider cookie encoding func sessionCookieValue(userID, connectorID, nonce string) string { return base64.RawURLEncoding.EncodeToString([]byte(userID)) + "." + base64.RawURLEncoding.EncodeToString([]byte(connectorID)) + @@ -98,7 +100,7 @@ func (s *Server) getValidAuthSession(ctx context.Context, r *http.Request) *stor session, err := s.storage.GetAuthSession(ctx, userID, connectorID) if err != nil { - if err != storage.ErrNotFound { + if errors.Is(err, storage.ErrNotFound) { s.logger.ErrorContext(ctx, "failed to get auth session", "err", err) } return nil diff --git a/server/session_test.go b/server/session_test.go index 8c0a60ab..e568685c 100644 --- a/server/session_test.go +++ b/server/session_test.go @@ -103,9 +103,10 @@ func TestSessionCookieValueRoundtrip(t *testing.T) { } func TestParseSessionCookie_Invalid(t *testing.T) { + //nolint:dogsled // only for tests _, _, _, err := parseSessionCookie("invalid") assert.Error(t, err) - + //nolint:dogsled // only for tests _, _, _, err = parseSessionCookie("a.b") assert.Error(t, err) }