From 449f66477c3d2f426162cdc99f15b3bbde474ee8 Mon Sep 17 00:00:00 2001 From: Maksim Nabokikh Date: Fri, 20 Mar 2026 20:06:43 +0100 Subject: [PATCH] feat: Add AuthSession GC (#4667) Signed-off-by: maksim.nabokikh --- server/server.go | 3 +- server/session.go | 21 ++-- server/session_test.go | 116 ++++++++++-------- storage/conformance/conformance.go | 106 +++++++++++++++- storage/ent/client/authsession.go | 4 + storage/ent/client/main.go | 14 +++ storage/ent/client/types.go | 16 +-- storage/ent/db/authsession.go | 26 +++- storage/ent/db/authsession/authsession.go | 16 +++ storage/ent/db/authsession/where.go | 90 ++++++++++++++ storage/ent/db/authsession_create.go | 26 ++++ storage/ent/db/authsession_update.go | 68 +++++++++++ storage/ent/db/migrate/schema.go | 2 + storage/ent/db/mutation.go | 140 +++++++++++++++++++--- storage/ent/schema/authsession.go | 4 + storage/etcd/etcd.go | 43 +++++-- storage/etcd/types.go | 54 +++++---- storage/kubernetes/storage.go | 16 +++ storage/kubernetes/types.go | 54 +++++---- storage/memory/memory.go | 6 + storage/sql/crud.go | 23 +++- storage/sql/migrate.go | 10 ++ storage/storage.go | 9 +- 23 files changed, 718 insertions(+), 149 deletions(-) diff --git a/server/server.go b/server/server.go index 669c16a0..53ef2d81 100644 --- a/server/server.go +++ b/server/server.go @@ -728,7 +728,8 @@ func (s *Server) startGarbageCollection(ctx context.Context, frequency time.Dura } else if !r.IsEmpty() { s.logger.InfoContext(ctx, "garbage collection run, delete auth", "requests", r.AuthRequests, "auth_codes", r.AuthCodes, - "device_requests", r.DeviceRequests, "device_tokens", r.DeviceTokens) + "device_requests", r.DeviceRequests, "device_tokens", r.DeviceTokens, + "auth_sessions", r.AuthSessions) } } } diff --git a/server/session.go b/server/session.go index 750b970d..5f65ea00 100644 --- a/server/session.go +++ b/server/session.go @@ -134,8 +134,8 @@ func (s *Server) getValidAuthSession(ctx context.Context, w http.ResponseWriter, now := s.now() - // Check absolute lifetime. - if now.After(session.CreatedAt.Add(s.sessionConfig.AbsoluteLifetime)) { + // Check absolute lifetime using the stored expiry (set once at creation). + if !session.AbsoluteExpiry.IsZero() && now.After(session.AbsoluteExpiry) { s.logger.InfoContext(ctx, "auth session expired (absolute lifetime)", "user_id", session.UserID, "connector_id", session.ConnectorID) if err := s.storage.DeleteAuthSession(ctx, session.UserID, session.ConnectorID); err != nil { @@ -145,8 +145,8 @@ func (s *Server) getValidAuthSession(ctx context.Context, w http.ResponseWriter, return nil } - // Check idle timeout. - if now.After(session.LastActivity.Add(s.sessionConfig.ValidIfNotUsedFor)) { + // Check idle timeout using the stored expiry (updated on every activity). + if !session.IdleExpiry.IsZero() && now.After(session.IdleExpiry) { s.logger.InfoContext(ctx, "auth session expired (idle timeout)", "user_id", session.UserID, "connector_id", session.ConnectorID) if err := s.storage.DeleteAuthSession(ctx, session.UserID, session.ConnectorID); err != nil { @@ -191,6 +191,7 @@ func (s *Server) createOrUpdateAuthSession(ctx context.Context, r *http.Request, if err := s.storage.UpdateAuthSession(ctx, userID, connectorID, func(old storage.AuthSession) (storage.AuthSession, error) { old.LastActivity = now + old.IdleExpiry = now.Add(s.sessionConfig.ValidIfNotUsedFor) if old.ClientStates == nil { old.ClientStates = make(map[string]*storage.ClientAuthState) } @@ -217,10 +218,12 @@ func (s *Server) createOrUpdateAuthSession(ctx context.Context, r *http.Request, ClientStates: map[string]*storage.ClientAuthState{ authReq.ClientID: clientState, }, - CreatedAt: now, - LastActivity: now, - IPAddress: remoteIP(r), - UserAgent: r.UserAgent(), + CreatedAt: now, + LastActivity: now, + IPAddress: remoteIP(r), + UserAgent: r.UserAgent(), + AbsoluteExpiry: now.Add(s.sessionConfig.AbsoluteLifetime), + IdleExpiry: now.Add(s.sessionConfig.ValidIfNotUsedFor), } if err := s.storage.CreateAuthSession(ctx, newSession); err != nil { @@ -300,6 +303,7 @@ func (s *Server) trySessionLoginWithSession(ctx context.Context, r *http.Request // Update session activity. _ = s.storage.UpdateAuthSession(ctx, session.UserID, session.ConnectorID, func(old storage.AuthSession) (storage.AuthSession, error) { old.LastActivity = now + old.IdleExpiry = now.Add(s.sessionConfig.ValidIfNotUsedFor) if cs, ok := old.ClientStates[authReq.ClientID]; ok { cs.LastActivity = now } @@ -346,6 +350,7 @@ func (s *Server) updateSessionTokenIssuedAt(r *http.Request, clientID string) { now := s.now() _ = s.storage.UpdateAuthSession(r.Context(), userID, connectorID, func(old storage.AuthSession) (storage.AuthSession, error) { old.LastActivity = now + old.IdleExpiry = now.Add(s.sessionConfig.ValidIfNotUsedFor) if cs, ok := old.ClientStates[clientID]; ok { cs.LastTokenIssuedAt = now cs.LastActivity = now diff --git a/server/session_test.go b/server/session_test.go index 6d32d123..f00eed59 100644 --- a/server/session_test.go +++ b/server/session_test.go @@ -155,14 +155,16 @@ func TestGetValidAuthSession(t *testing.T) { nonce := "test-nonce" session := storage.AuthSession{ - UserID: "user1", - ConnectorID: "conn1", - Nonce: nonce, - ClientStates: map[string]*storage.ClientAuthState{}, - CreatedAt: now.Add(-30 * time.Minute), - LastActivity: now.Add(-5 * time.Minute), - IPAddress: "127.0.0.1", - UserAgent: "test", + UserID: "user1", + ConnectorID: "conn1", + Nonce: nonce, + ClientStates: map[string]*storage.ClientAuthState{}, + CreatedAt: now.Add(-30 * time.Minute), + LastActivity: now.Add(-5 * time.Minute), + IPAddress: "127.0.0.1", + UserAgent: "test", + AbsoluteExpiry: now.Add(24 * time.Hour), + IdleExpiry: now.Add(1 * time.Hour), } require.NoError(t, s.storage.CreateAuthSession(ctx, session)) @@ -181,14 +183,16 @@ func TestGetValidAuthSession(t *testing.T) { nonce := "test-nonce-conn" session := storage.AuthSession{ - UserID: "user1", - ConnectorID: "ldap", - Nonce: nonce, - ClientStates: map[string]*storage.ClientAuthState{}, - CreatedAt: now.Add(-30 * time.Minute), - LastActivity: now.Add(-5 * time.Minute), - IPAddress: "127.0.0.1", - UserAgent: "test", + UserID: "user1", + ConnectorID: "ldap", + Nonce: nonce, + ClientStates: map[string]*storage.ClientAuthState{}, + CreatedAt: now.Add(-30 * time.Minute), + LastActivity: now.Add(-5 * time.Minute), + IPAddress: "127.0.0.1", + UserAgent: "test", + AbsoluteExpiry: now.Add(24 * time.Hour), + IdleExpiry: now.Add(1 * time.Hour), } require.NoError(t, s.storage.CreateAuthSession(ctx, session)) @@ -204,14 +208,16 @@ func TestGetValidAuthSession(t *testing.T) { now := s.now() session := storage.AuthSession{ - UserID: "user2", - ConnectorID: "conn2", - Nonce: "correct-nonce", - ClientStates: map[string]*storage.ClientAuthState{}, - CreatedAt: now.Add(-30 * time.Minute), - LastActivity: now.Add(-5 * time.Minute), - IPAddress: "127.0.0.1", - UserAgent: "test", + UserID: "user2", + ConnectorID: "conn2", + Nonce: "correct-nonce", + ClientStates: map[string]*storage.ClientAuthState{}, + CreatedAt: now.Add(-30 * time.Minute), + LastActivity: now.Add(-5 * time.Minute), + IPAddress: "127.0.0.1", + UserAgent: "test", + AbsoluteExpiry: now.Add(24 * time.Hour), + IdleExpiry: now.Add(1 * time.Hour), } require.NoError(t, s.storage.CreateAuthSession(ctx, session)) @@ -230,14 +236,16 @@ func TestGetValidAuthSession(t *testing.T) { nonce := "expired-nonce" session := storage.AuthSession{ - UserID: "user3", - ConnectorID: "conn3", - Nonce: nonce, - ClientStates: map[string]*storage.ClientAuthState{}, - CreatedAt: now.Add(-25 * time.Hour), - LastActivity: now.Add(-1 * time.Minute), - IPAddress: "127.0.0.1", - UserAgent: "test", + UserID: "user3", + ConnectorID: "conn3", + Nonce: nonce, + ClientStates: map[string]*storage.ClientAuthState{}, + CreatedAt: now.Add(-25 * time.Hour), + LastActivity: now.Add(-1 * time.Minute), + IPAddress: "127.0.0.1", + UserAgent: "test", + AbsoluteExpiry: now.Add(-1 * time.Hour), + IdleExpiry: now.Add(1 * time.Hour), } require.NoError(t, s.storage.CreateAuthSession(ctx, session)) @@ -260,14 +268,16 @@ func TestGetValidAuthSession(t *testing.T) { nonce := "idle-nonce" session := storage.AuthSession{ - UserID: "user4", - ConnectorID: "conn4", - Nonce: nonce, - ClientStates: map[string]*storage.ClientAuthState{}, - CreatedAt: now.Add(-2 * time.Hour), - LastActivity: now.Add(-2 * time.Hour), - IPAddress: "127.0.0.1", - UserAgent: "test", + UserID: "user4", + ConnectorID: "conn4", + Nonce: nonce, + ClientStates: map[string]*storage.ClientAuthState{}, + CreatedAt: now.Add(-2 * time.Hour), + LastActivity: now.Add(-2 * time.Hour), + IPAddress: "127.0.0.1", + UserAgent: "test", + AbsoluteExpiry: now.Add(22 * time.Hour), + IdleExpiry: now.Add(-1 * time.Hour), } require.NoError(t, s.storage.CreateAuthSession(ctx, session)) @@ -338,10 +348,12 @@ func TestCreateOrUpdateAuthSession(t *testing.T) { LastActivity: now.Add(-10 * time.Minute), }, }, - CreatedAt: now.Add(-30 * time.Minute), - LastActivity: now.Add(-10 * time.Minute), - IPAddress: "127.0.0.1", - UserAgent: "test", + CreatedAt: now.Add(-30 * time.Minute), + LastActivity: now.Add(-10 * time.Minute), + IPAddress: "127.0.0.1", + UserAgent: "test", + AbsoluteExpiry: now.Add(24 * time.Hour), + IdleExpiry: now.Add(50 * time.Minute), } require.NoError(t, s.storage.CreateAuthSession(ctx, existingSession)) @@ -402,10 +414,12 @@ func setupSessionLoginFixture(t *testing.T, s *Server) storage.AuthRequest { LastActivity: now.Add(-1 * time.Minute), }, }, - CreatedAt: now.Add(-30 * time.Minute), - LastActivity: now.Add(-1 * time.Minute), - IPAddress: "127.0.0.1", - UserAgent: "test", + CreatedAt: now.Add(-30 * time.Minute), + LastActivity: now.Add(-1 * time.Minute), + IPAddress: "127.0.0.1", + UserAgent: "test", + AbsoluteExpiry: now.Add(24 * time.Hour), + IdleExpiry: now.Add(59 * time.Minute), })) require.NoError(t, s.storage.CreateUserIdentity(ctx, storage.UserIdentity{ @@ -536,8 +550,10 @@ func TestTrySessionLogin(t *testing.T) { ExpiresAt: now.Add(-1 * time.Hour), }, }, - CreatedAt: now.Add(-2 * time.Hour), - LastActivity: now.Add(-1 * time.Minute), + CreatedAt: now.Add(-2 * time.Hour), + LastActivity: now.Add(-1 * time.Minute), + AbsoluteExpiry: now.Add(22 * time.Hour), + IdleExpiry: now.Add(59 * time.Minute), })) require.NoError(t, s.storage.CreateUserIdentity(t.Context(), storage.UserIdentity{ diff --git a/storage/conformance/conformance.go b/storage/conformance/conformance.go index e2570b47..c81c9484 100644 --- a/storage/conformance/conformance.go +++ b/storage/conformance/conformance.go @@ -966,6 +966,100 @@ func testGC(t *testing.T, s storage.Storage) { } else if err != storage.ErrNotFound { t.Errorf("expected storage.ErrNotFound, got %v", err) } + + // Test auth session GC. + authSession := storage.AuthSession{ + UserID: "gc-user", + ConnectorID: "gc-conn", + Nonce: storage.NewID(), + ClientStates: map[string]*storage.ClientAuthState{ + "client1": {Active: true, ExpiresAt: expiry.Add(time.Hour), LastActivity: expiry}, + }, + CreatedAt: expiry.Add(-time.Hour), + LastActivity: expiry.Add(-time.Hour), + AbsoluteExpiry: expiry, + IdleExpiry: expiry, + } + + if err := s.CreateAuthSession(ctx, authSession); err != nil { + t.Fatalf("failed creating auth session: %v", err) + } + + // GC before expiry should not delete. + for _, tz := range []*time.Location{time.UTC, est, pst} { + result, err := s.GarbageCollect(ctx, expiry.Add(-time.Hour).In(tz)) + if err != nil { + t.Errorf("garbage collection failed: %v", err) + } else if result.AuthSessions != 0 { + t.Errorf("expected no auth session garbage collection results, got %#v", result) + } + if _, err := s.GetAuthSession(ctx, authSession.UserID, authSession.ConnectorID); err != nil { + t.Errorf("expected to be able to get auth session after GC: %v", err) + } + } + + // GC after expiry should delete. + if r, err := s.GarbageCollect(ctx, expiry.Add(time.Hour)); err != nil { + t.Errorf("garbage collection failed: %v", err) + } else if r.AuthSessions != 1 { + t.Errorf("expected to garbage collect 1 auth session, got %d", r.AuthSessions) + } + + if _, err := s.GetAuthSession(ctx, authSession.UserID, authSession.ConnectorID); err == nil { + t.Errorf("expected auth session to be GC'd") + } else if err != storage.ErrNotFound { + t.Errorf("expected storage.ErrNotFound, got %v", err) + } + + // Test auth session GC: absolute expired, idle still valid. + absExpiredSession := storage.AuthSession{ + UserID: "gc-abs-expired", + ConnectorID: "gc-conn", + Nonce: storage.NewID(), + ClientStates: map[string]*storage.ClientAuthState{ + "client1": {Active: true, ExpiresAt: expiry.Add(time.Hour), LastActivity: expiry}, + }, + CreatedAt: expiry.Add(-25 * time.Hour), + LastActivity: expiry.Add(-time.Minute), + AbsoluteExpiry: expiry.Add(-time.Hour), // expired + IdleExpiry: expiry.Add(time.Hour), // still valid + } + if err := s.CreateAuthSession(ctx, absExpiredSession); err != nil { + t.Fatalf("failed creating abs-expired auth session: %v", err) + } + if r, err := s.GarbageCollect(ctx, expiry); err != nil { + t.Errorf("garbage collection failed: %v", err) + } else if r.AuthSessions != 1 { + t.Errorf("expected to garbage collect 1 auth session (absolute expired), got %d", r.AuthSessions) + } + if _, err := s.GetAuthSession(ctx, absExpiredSession.UserID, absExpiredSession.ConnectorID); err == nil { + t.Errorf("expected abs-expired auth session to be GC'd") + } + + // Test auth session GC: absolute still valid, idle expired. + idleExpiredSession := storage.AuthSession{ + UserID: "gc-idle-expired", + ConnectorID: "gc-conn", + Nonce: storage.NewID(), + ClientStates: map[string]*storage.ClientAuthState{ + "client1": {Active: true, ExpiresAt: expiry.Add(time.Hour), LastActivity: expiry}, + }, + CreatedAt: expiry.Add(-time.Hour), + LastActivity: expiry.Add(-2 * time.Hour), + AbsoluteExpiry: expiry.Add(23 * time.Hour), // still valid + IdleExpiry: expiry.Add(-time.Hour), // expired + } + if err := s.CreateAuthSession(ctx, idleExpiredSession); err != nil { + t.Fatalf("failed creating idle-expired auth session: %v", err) + } + if r, err := s.GarbageCollect(ctx, expiry); err != nil { + t.Errorf("garbage collection failed: %v", err) + } else if r.AuthSessions != 1 { + t.Errorf("expected to garbage collect 1 auth session (idle expired), got %d", r.AuthSessions) + } + if _, err := s.GetAuthSession(ctx, idleExpiredSession.UserID, idleExpiredSession.ConnectorID); err == nil { + t.Errorf("expected idle-expired auth session to be GC'd") + } } // testTimezones tests that backends either fully support timezones or @@ -1197,10 +1291,12 @@ func testAuthSessionCRUD(t *testing.T, s storage.Storage) { LastTokenIssuedAt: now, }, }, - CreatedAt: now, - LastActivity: now, - IPAddress: "192.168.1.1", - UserAgent: "TestBrowser/1.0", + CreatedAt: now, + LastActivity: now, + IPAddress: "192.168.1.1", + UserAgent: "TestBrowser/1.0", + AbsoluteExpiry: now.Add(24 * time.Hour), + IdleExpiry: now.Add(1 * time.Hour), } // Create. @@ -1220,6 +1316,8 @@ func testAuthSessionCRUD(t *testing.T, s storage.Storage) { got.CreatedAt = got.CreatedAt.UTC().Round(time.Millisecond) got.LastActivity = got.LastActivity.UTC().Round(time.Millisecond) + got.AbsoluteExpiry = got.AbsoluteExpiry.UTC().Round(time.Millisecond) + got.IdleExpiry = got.IdleExpiry.UTC().Round(time.Millisecond) for _, cs := range got.ClientStates { cs.ExpiresAt = cs.ExpiresAt.UTC().Round(time.Millisecond) cs.LastActivity = cs.LastActivity.UTC().Round(time.Millisecond) diff --git a/storage/ent/client/authsession.go b/storage/ent/client/authsession.go index 14120f1b..b4cdfe81 100644 --- a/storage/ent/client/authsession.go +++ b/storage/ent/client/authsession.go @@ -29,6 +29,8 @@ func (d *Database) CreateAuthSession(ctx context.Context, session storage.AuthSe SetLastActivity(session.LastActivity). SetIPAddress(session.IPAddress). SetUserAgent(session.UserAgent). + SetAbsoluteExpiry(session.AbsoluteExpiry.UTC()). + SetIdleExpiry(session.IdleExpiry.UTC()). Save(ctx) if err != nil { return convertDBError("create auth session: %w", err) @@ -102,6 +104,8 @@ func (d *Database) UpdateAuthSession(ctx context.Context, userID, connectorID st SetLastActivity(newSession.LastActivity). SetIPAddress(newSession.IPAddress). SetUserAgent(newSession.UserAgent). + SetAbsoluteExpiry(newSession.AbsoluteExpiry.UTC()). + SetIdleExpiry(newSession.IdleExpiry.UTC()). Save(ctx) if err != nil { return rollback(tx, "update auth session updating: %w", err) diff --git a/storage/ent/client/main.go b/storage/ent/client/main.go index a78830fc..ea9ca297 100644 --- a/storage/ent/client/main.go +++ b/storage/ent/client/main.go @@ -10,6 +10,7 @@ import ( "github.com/dexidp/dex/storage/ent/db" "github.com/dexidp/dex/storage/ent/db/authcode" "github.com/dexidp/dex/storage/ent/db/authrequest" + "github.com/dexidp/dex/storage/ent/db/authsession" "github.com/dexidp/dex/storage/ent/db/devicerequest" "github.com/dexidp/dex/storage/ent/db/devicetoken" "github.com/dexidp/dex/storage/ent/db/migrate" @@ -106,5 +107,18 @@ func (d *Database) GarbageCollect(ctx context.Context, now time.Time) (storage.G } result.DeviceTokens = int64(q) + q, err = d.client.AuthSession.Delete(). + Where( + authsession.Or( + authsession.AbsoluteExpiryLT(utcNow), + authsession.IdleExpiryLT(utcNow), + ), + ). + Exec(ctx) + if err != nil { + return result, convertDBError("gc auth session: %w", err) + } + result.AuthSessions = int64(q) + return result, err } diff --git a/storage/ent/client/types.go b/storage/ent/client/types.go index d58d9f41..91689619 100644 --- a/storage/ent/client/types.go +++ b/storage/ent/client/types.go @@ -216,13 +216,15 @@ func toStorageUserIdentity(u *db.UserIdentity) storage.UserIdentity { func toStorageAuthSession(s *db.AuthSession) storage.AuthSession { result := storage.AuthSession{ - UserID: s.UserID, - ConnectorID: s.ConnectorID, - Nonce: s.Nonce, - CreatedAt: s.CreatedAt, - LastActivity: s.LastActivity, - IPAddress: s.IPAddress, - UserAgent: s.UserAgent, + UserID: s.UserID, + ConnectorID: s.ConnectorID, + Nonce: s.Nonce, + CreatedAt: s.CreatedAt, + LastActivity: s.LastActivity, + IPAddress: s.IPAddress, + UserAgent: s.UserAgent, + AbsoluteExpiry: s.AbsoluteExpiry, + IdleExpiry: s.IdleExpiry, } if s.ClientStates != nil { diff --git a/storage/ent/db/authsession.go b/storage/ent/db/authsession.go index 26882da8..6ced0680 100644 --- a/storage/ent/db/authsession.go +++ b/storage/ent/db/authsession.go @@ -32,7 +32,11 @@ type AuthSession struct { // IPAddress holds the value of the "ip_address" field. IPAddress string `json:"ip_address,omitempty"` // UserAgent holds the value of the "user_agent" field. - UserAgent string `json:"user_agent,omitempty"` + UserAgent string `json:"user_agent,omitempty"` + // AbsoluteExpiry holds the value of the "absolute_expiry" field. + AbsoluteExpiry time.Time `json:"absolute_expiry,omitempty"` + // IdleExpiry holds the value of the "idle_expiry" field. + IdleExpiry time.Time `json:"idle_expiry,omitempty"` selectValues sql.SelectValues } @@ -45,7 +49,7 @@ func (*AuthSession) scanValues(columns []string) ([]any, error) { values[i] = new([]byte) case authsession.FieldID, authsession.FieldUserID, authsession.FieldConnectorID, authsession.FieldNonce, authsession.FieldIPAddress, authsession.FieldUserAgent: values[i] = new(sql.NullString) - case authsession.FieldCreatedAt, authsession.FieldLastActivity: + case authsession.FieldCreatedAt, authsession.FieldLastActivity, authsession.FieldAbsoluteExpiry, authsession.FieldIdleExpiry: values[i] = new(sql.NullTime) default: values[i] = new(sql.UnknownType) @@ -116,6 +120,18 @@ func (_m *AuthSession) assignValues(columns []string, values []any) error { } else if value.Valid { _m.UserAgent = value.String } + case authsession.FieldAbsoluteExpiry: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field absolute_expiry", values[i]) + } else if value.Valid { + _m.AbsoluteExpiry = value.Time + } + case authsession.FieldIdleExpiry: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field idle_expiry", values[i]) + } else if value.Valid { + _m.IdleExpiry = value.Time + } default: _m.selectValues.Set(columns[i], values[i]) } @@ -175,6 +191,12 @@ func (_m *AuthSession) String() string { builder.WriteString(", ") builder.WriteString("user_agent=") builder.WriteString(_m.UserAgent) + builder.WriteString(", ") + builder.WriteString("absolute_expiry=") + builder.WriteString(_m.AbsoluteExpiry.Format(time.ANSIC)) + builder.WriteString(", ") + builder.WriteString("idle_expiry=") + builder.WriteString(_m.IdleExpiry.Format(time.ANSIC)) builder.WriteByte(')') return builder.String() } diff --git a/storage/ent/db/authsession/authsession.go b/storage/ent/db/authsession/authsession.go index 8e5bdfc2..fc1cd5ed 100644 --- a/storage/ent/db/authsession/authsession.go +++ b/storage/ent/db/authsession/authsession.go @@ -27,6 +27,10 @@ const ( FieldIPAddress = "ip_address" // FieldUserAgent holds the string denoting the user_agent field in the database. FieldUserAgent = "user_agent" + // FieldAbsoluteExpiry holds the string denoting the absolute_expiry field in the database. + FieldAbsoluteExpiry = "absolute_expiry" + // FieldIdleExpiry holds the string denoting the idle_expiry field in the database. + FieldIdleExpiry = "idle_expiry" // Table holds the table name of the authsession in the database. Table = "auth_sessions" ) @@ -42,6 +46,8 @@ var Columns = []string{ FieldLastActivity, FieldIPAddress, FieldUserAgent, + FieldAbsoluteExpiry, + FieldIdleExpiry, } // ValidColumn reports if the column name is valid (part of the table columns). @@ -111,3 +117,13 @@ func ByIPAddress(opts ...sql.OrderTermOption) OrderOption { func ByUserAgent(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldUserAgent, opts...).ToFunc() } + +// ByAbsoluteExpiry orders the results by the absolute_expiry field. +func ByAbsoluteExpiry(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldAbsoluteExpiry, opts...).ToFunc() +} + +// ByIdleExpiry orders the results by the idle_expiry field. +func ByIdleExpiry(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldIdleExpiry, opts...).ToFunc() +} diff --git a/storage/ent/db/authsession/where.go b/storage/ent/db/authsession/where.go index cdda7fb8..193f1133 100644 --- a/storage/ent/db/authsession/where.go +++ b/storage/ent/db/authsession/where.go @@ -104,6 +104,16 @@ func UserAgent(v string) predicate.AuthSession { return predicate.AuthSession(sql.FieldEQ(FieldUserAgent, v)) } +// AbsoluteExpiry applies equality check predicate on the "absolute_expiry" field. It's identical to AbsoluteExpiryEQ. +func AbsoluteExpiry(v time.Time) predicate.AuthSession { + return predicate.AuthSession(sql.FieldEQ(FieldAbsoluteExpiry, v)) +} + +// IdleExpiry applies equality check predicate on the "idle_expiry" field. It's identical to IdleExpiryEQ. +func IdleExpiry(v time.Time) predicate.AuthSession { + return predicate.AuthSession(sql.FieldEQ(FieldIdleExpiry, v)) +} + // UserIDEQ applies the EQ predicate on the "user_id" field. func UserIDEQ(v string) predicate.AuthSession { return predicate.AuthSession(sql.FieldEQ(FieldUserID, v)) @@ -549,6 +559,86 @@ func UserAgentContainsFold(v string) predicate.AuthSession { return predicate.AuthSession(sql.FieldContainsFold(FieldUserAgent, v)) } +// AbsoluteExpiryEQ applies the EQ predicate on the "absolute_expiry" field. +func AbsoluteExpiryEQ(v time.Time) predicate.AuthSession { + return predicate.AuthSession(sql.FieldEQ(FieldAbsoluteExpiry, v)) +} + +// AbsoluteExpiryNEQ applies the NEQ predicate on the "absolute_expiry" field. +func AbsoluteExpiryNEQ(v time.Time) predicate.AuthSession { + return predicate.AuthSession(sql.FieldNEQ(FieldAbsoluteExpiry, v)) +} + +// AbsoluteExpiryIn applies the In predicate on the "absolute_expiry" field. +func AbsoluteExpiryIn(vs ...time.Time) predicate.AuthSession { + return predicate.AuthSession(sql.FieldIn(FieldAbsoluteExpiry, vs...)) +} + +// AbsoluteExpiryNotIn applies the NotIn predicate on the "absolute_expiry" field. +func AbsoluteExpiryNotIn(vs ...time.Time) predicate.AuthSession { + return predicate.AuthSession(sql.FieldNotIn(FieldAbsoluteExpiry, vs...)) +} + +// AbsoluteExpiryGT applies the GT predicate on the "absolute_expiry" field. +func AbsoluteExpiryGT(v time.Time) predicate.AuthSession { + return predicate.AuthSession(sql.FieldGT(FieldAbsoluteExpiry, v)) +} + +// AbsoluteExpiryGTE applies the GTE predicate on the "absolute_expiry" field. +func AbsoluteExpiryGTE(v time.Time) predicate.AuthSession { + return predicate.AuthSession(sql.FieldGTE(FieldAbsoluteExpiry, v)) +} + +// AbsoluteExpiryLT applies the LT predicate on the "absolute_expiry" field. +func AbsoluteExpiryLT(v time.Time) predicate.AuthSession { + return predicate.AuthSession(sql.FieldLT(FieldAbsoluteExpiry, v)) +} + +// AbsoluteExpiryLTE applies the LTE predicate on the "absolute_expiry" field. +func AbsoluteExpiryLTE(v time.Time) predicate.AuthSession { + return predicate.AuthSession(sql.FieldLTE(FieldAbsoluteExpiry, v)) +} + +// IdleExpiryEQ applies the EQ predicate on the "idle_expiry" field. +func IdleExpiryEQ(v time.Time) predicate.AuthSession { + return predicate.AuthSession(sql.FieldEQ(FieldIdleExpiry, v)) +} + +// IdleExpiryNEQ applies the NEQ predicate on the "idle_expiry" field. +func IdleExpiryNEQ(v time.Time) predicate.AuthSession { + return predicate.AuthSession(sql.FieldNEQ(FieldIdleExpiry, v)) +} + +// IdleExpiryIn applies the In predicate on the "idle_expiry" field. +func IdleExpiryIn(vs ...time.Time) predicate.AuthSession { + return predicate.AuthSession(sql.FieldIn(FieldIdleExpiry, vs...)) +} + +// IdleExpiryNotIn applies the NotIn predicate on the "idle_expiry" field. +func IdleExpiryNotIn(vs ...time.Time) predicate.AuthSession { + return predicate.AuthSession(sql.FieldNotIn(FieldIdleExpiry, vs...)) +} + +// IdleExpiryGT applies the GT predicate on the "idle_expiry" field. +func IdleExpiryGT(v time.Time) predicate.AuthSession { + return predicate.AuthSession(sql.FieldGT(FieldIdleExpiry, v)) +} + +// IdleExpiryGTE applies the GTE predicate on the "idle_expiry" field. +func IdleExpiryGTE(v time.Time) predicate.AuthSession { + return predicate.AuthSession(sql.FieldGTE(FieldIdleExpiry, v)) +} + +// IdleExpiryLT applies the LT predicate on the "idle_expiry" field. +func IdleExpiryLT(v time.Time) predicate.AuthSession { + return predicate.AuthSession(sql.FieldLT(FieldIdleExpiry, v)) +} + +// IdleExpiryLTE applies the LTE predicate on the "idle_expiry" field. +func IdleExpiryLTE(v time.Time) predicate.AuthSession { + return predicate.AuthSession(sql.FieldLTE(FieldIdleExpiry, v)) +} + // And groups predicates with the AND operator between them. func And(predicates ...predicate.AuthSession) predicate.AuthSession { return predicate.AuthSession(sql.AndPredicates(predicates...)) diff --git a/storage/ent/db/authsession_create.go b/storage/ent/db/authsession_create.go index 080b094c..0dc99e76 100644 --- a/storage/ent/db/authsession_create.go +++ b/storage/ent/db/authsession_create.go @@ -84,6 +84,18 @@ func (_c *AuthSessionCreate) SetNillableUserAgent(v *string) *AuthSessionCreate return _c } +// SetAbsoluteExpiry sets the "absolute_expiry" field. +func (_c *AuthSessionCreate) SetAbsoluteExpiry(v time.Time) *AuthSessionCreate { + _c.mutation.SetAbsoluteExpiry(v) + return _c +} + +// SetIdleExpiry sets the "idle_expiry" field. +func (_c *AuthSessionCreate) SetIdleExpiry(v time.Time) *AuthSessionCreate { + _c.mutation.SetIdleExpiry(v) + return _c +} + // SetID sets the "id" field. func (_c *AuthSessionCreate) SetID(v string) *AuthSessionCreate { _c.mutation.SetID(v) @@ -176,6 +188,12 @@ func (_c *AuthSessionCreate) check() error { if _, ok := _c.mutation.UserAgent(); !ok { return &ValidationError{Name: "user_agent", err: errors.New(`db: missing required field "AuthSession.user_agent"`)} } + if _, ok := _c.mutation.AbsoluteExpiry(); !ok { + return &ValidationError{Name: "absolute_expiry", err: errors.New(`db: missing required field "AuthSession.absolute_expiry"`)} + } + if _, ok := _c.mutation.IdleExpiry(); !ok { + return &ValidationError{Name: "idle_expiry", err: errors.New(`db: missing required field "AuthSession.idle_expiry"`)} + } if v, ok := _c.mutation.ID(); ok { if err := authsession.IDValidator(v); err != nil { return &ValidationError{Name: "id", err: fmt.Errorf(`db: validator failed for field "AuthSession.id": %w`, err)} @@ -248,6 +266,14 @@ func (_c *AuthSessionCreate) createSpec() (*AuthSession, *sqlgraph.CreateSpec) { _spec.SetField(authsession.FieldUserAgent, field.TypeString, value) _node.UserAgent = value } + if value, ok := _c.mutation.AbsoluteExpiry(); ok { + _spec.SetField(authsession.FieldAbsoluteExpiry, field.TypeTime, value) + _node.AbsoluteExpiry = value + } + if value, ok := _c.mutation.IdleExpiry(); ok { + _spec.SetField(authsession.FieldIdleExpiry, field.TypeTime, value) + _node.IdleExpiry = value + } return _node, _spec } diff --git a/storage/ent/db/authsession_update.go b/storage/ent/db/authsession_update.go index 5457b04c..d80e682b 100644 --- a/storage/ent/db/authsession_update.go +++ b/storage/ent/db/authsession_update.go @@ -132,6 +132,34 @@ func (_u *AuthSessionUpdate) SetNillableUserAgent(v *string) *AuthSessionUpdate return _u } +// SetAbsoluteExpiry sets the "absolute_expiry" field. +func (_u *AuthSessionUpdate) SetAbsoluteExpiry(v time.Time) *AuthSessionUpdate { + _u.mutation.SetAbsoluteExpiry(v) + return _u +} + +// SetNillableAbsoluteExpiry sets the "absolute_expiry" field if the given value is not nil. +func (_u *AuthSessionUpdate) SetNillableAbsoluteExpiry(v *time.Time) *AuthSessionUpdate { + if v != nil { + _u.SetAbsoluteExpiry(*v) + } + return _u +} + +// SetIdleExpiry sets the "idle_expiry" field. +func (_u *AuthSessionUpdate) SetIdleExpiry(v time.Time) *AuthSessionUpdate { + _u.mutation.SetIdleExpiry(v) + return _u +} + +// SetNillableIdleExpiry sets the "idle_expiry" field if the given value is not nil. +func (_u *AuthSessionUpdate) SetNillableIdleExpiry(v *time.Time) *AuthSessionUpdate { + if v != nil { + _u.SetIdleExpiry(*v) + } + return _u +} + // Mutation returns the AuthSessionMutation object of the builder. func (_u *AuthSessionUpdate) Mutation() *AuthSessionMutation { return _u.mutation @@ -220,6 +248,12 @@ func (_u *AuthSessionUpdate) sqlSave(ctx context.Context) (_node int, err error) if value, ok := _u.mutation.UserAgent(); ok { _spec.SetField(authsession.FieldUserAgent, field.TypeString, value) } + if value, ok := _u.mutation.AbsoluteExpiry(); ok { + _spec.SetField(authsession.FieldAbsoluteExpiry, field.TypeTime, value) + } + if value, ok := _u.mutation.IdleExpiry(); ok { + _spec.SetField(authsession.FieldIdleExpiry, field.TypeTime, value) + } if _node, err = sqlgraph.UpdateNodes(ctx, _u.driver, _spec); err != nil { if _, ok := err.(*sqlgraph.NotFoundError); ok { err = &NotFoundError{authsession.Label} @@ -344,6 +378,34 @@ func (_u *AuthSessionUpdateOne) SetNillableUserAgent(v *string) *AuthSessionUpda return _u } +// SetAbsoluteExpiry sets the "absolute_expiry" field. +func (_u *AuthSessionUpdateOne) SetAbsoluteExpiry(v time.Time) *AuthSessionUpdateOne { + _u.mutation.SetAbsoluteExpiry(v) + return _u +} + +// SetNillableAbsoluteExpiry sets the "absolute_expiry" field if the given value is not nil. +func (_u *AuthSessionUpdateOne) SetNillableAbsoluteExpiry(v *time.Time) *AuthSessionUpdateOne { + if v != nil { + _u.SetAbsoluteExpiry(*v) + } + return _u +} + +// SetIdleExpiry sets the "idle_expiry" field. +func (_u *AuthSessionUpdateOne) SetIdleExpiry(v time.Time) *AuthSessionUpdateOne { + _u.mutation.SetIdleExpiry(v) + return _u +} + +// SetNillableIdleExpiry sets the "idle_expiry" field if the given value is not nil. +func (_u *AuthSessionUpdateOne) SetNillableIdleExpiry(v *time.Time) *AuthSessionUpdateOne { + if v != nil { + _u.SetIdleExpiry(*v) + } + return _u +} + // Mutation returns the AuthSessionMutation object of the builder. func (_u *AuthSessionUpdateOne) Mutation() *AuthSessionMutation { return _u.mutation @@ -462,6 +524,12 @@ func (_u *AuthSessionUpdateOne) sqlSave(ctx context.Context) (_node *AuthSession if value, ok := _u.mutation.UserAgent(); ok { _spec.SetField(authsession.FieldUserAgent, field.TypeString, value) } + if value, ok := _u.mutation.AbsoluteExpiry(); ok { + _spec.SetField(authsession.FieldAbsoluteExpiry, field.TypeTime, value) + } + if value, ok := _u.mutation.IdleExpiry(); ok { + _spec.SetField(authsession.FieldIdleExpiry, field.TypeTime, value) + } _node = &AuthSession{config: _u.config} _spec.Assign = _node.assignValues _spec.ScanValues = _node.scanValues diff --git a/storage/ent/db/migrate/schema.go b/storage/ent/db/migrate/schema.go index 3fc0f834..e8958ec9 100644 --- a/storage/ent/db/migrate/schema.go +++ b/storage/ent/db/migrate/schema.go @@ -79,6 +79,8 @@ var ( {Name: "last_activity", Type: field.TypeTime, SchemaType: map[string]string{"mysql": "datetime(3)", "postgres": "timestamptz", "sqlite3": "timestamp"}}, {Name: "ip_address", Type: field.TypeString, Size: 2147483647, Default: "", SchemaType: map[string]string{"mysql": "varchar(384)", "postgres": "text", "sqlite3": "text"}}, {Name: "user_agent", Type: field.TypeString, Size: 2147483647, Default: "", SchemaType: map[string]string{"mysql": "varchar(384)", "postgres": "text", "sqlite3": "text"}}, + {Name: "absolute_expiry", Type: field.TypeTime, SchemaType: map[string]string{"mysql": "datetime(3)", "postgres": "timestamptz", "sqlite3": "timestamp"}}, + {Name: "idle_expiry", Type: field.TypeTime, SchemaType: map[string]string{"mysql": "datetime(3)", "postgres": "timestamptz", "sqlite3": "timestamp"}}, } // AuthSessionsTable holds the schema information for the "auth_sessions" table. AuthSessionsTable = &schema.Table{ diff --git a/storage/ent/db/mutation.go b/storage/ent/db/mutation.go index 46f204eb..4c7858f0 100644 --- a/storage/ent/db/mutation.go +++ b/storage/ent/db/mutation.go @@ -3068,21 +3068,23 @@ func (m *AuthRequestMutation) ResetEdge(name string) error { // AuthSessionMutation represents an operation that mutates the AuthSession nodes in the graph. type AuthSessionMutation struct { config - op Op - typ string - id *string - user_id *string - connector_id *string - nonce *string - client_states *[]byte - created_at *time.Time - last_activity *time.Time - ip_address *string - user_agent *string - clearedFields map[string]struct{} - done bool - oldValue func(context.Context) (*AuthSession, error) - predicates []predicate.AuthSession + op Op + typ string + id *string + user_id *string + connector_id *string + nonce *string + client_states *[]byte + created_at *time.Time + last_activity *time.Time + ip_address *string + user_agent *string + absolute_expiry *time.Time + idle_expiry *time.Time + clearedFields map[string]struct{} + done bool + oldValue func(context.Context) (*AuthSession, error) + predicates []predicate.AuthSession } var _ ent.Mutation = (*AuthSessionMutation)(nil) @@ -3477,6 +3479,78 @@ func (m *AuthSessionMutation) ResetUserAgent() { m.user_agent = nil } +// SetAbsoluteExpiry sets the "absolute_expiry" field. +func (m *AuthSessionMutation) SetAbsoluteExpiry(t time.Time) { + m.absolute_expiry = &t +} + +// AbsoluteExpiry returns the value of the "absolute_expiry" field in the mutation. +func (m *AuthSessionMutation) AbsoluteExpiry() (r time.Time, exists bool) { + v := m.absolute_expiry + if v == nil { + return + } + return *v, true +} + +// OldAbsoluteExpiry returns the old "absolute_expiry" field's value of the AuthSession entity. +// If the AuthSession object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *AuthSessionMutation) OldAbsoluteExpiry(ctx context.Context) (v time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldAbsoluteExpiry is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldAbsoluteExpiry requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldAbsoluteExpiry: %w", err) + } + return oldValue.AbsoluteExpiry, nil +} + +// ResetAbsoluteExpiry resets all changes to the "absolute_expiry" field. +func (m *AuthSessionMutation) ResetAbsoluteExpiry() { + m.absolute_expiry = nil +} + +// SetIdleExpiry sets the "idle_expiry" field. +func (m *AuthSessionMutation) SetIdleExpiry(t time.Time) { + m.idle_expiry = &t +} + +// IdleExpiry returns the value of the "idle_expiry" field in the mutation. +func (m *AuthSessionMutation) IdleExpiry() (r time.Time, exists bool) { + v := m.idle_expiry + if v == nil { + return + } + return *v, true +} + +// OldIdleExpiry returns the old "idle_expiry" field's value of the AuthSession entity. +// If the AuthSession object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *AuthSessionMutation) OldIdleExpiry(ctx context.Context) (v time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldIdleExpiry is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldIdleExpiry requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldIdleExpiry: %w", err) + } + return oldValue.IdleExpiry, nil +} + +// ResetIdleExpiry resets all changes to the "idle_expiry" field. +func (m *AuthSessionMutation) ResetIdleExpiry() { + m.idle_expiry = nil +} + // Where appends a list predicates to the AuthSessionMutation builder. func (m *AuthSessionMutation) Where(ps ...predicate.AuthSession) { m.predicates = append(m.predicates, ps...) @@ -3511,7 +3585,7 @@ func (m *AuthSessionMutation) Type() string { // order to get all numeric fields that were incremented/decremented, call // AddedFields(). func (m *AuthSessionMutation) Fields() []string { - fields := make([]string, 0, 8) + fields := make([]string, 0, 10) if m.user_id != nil { fields = append(fields, authsession.FieldUserID) } @@ -3536,6 +3610,12 @@ func (m *AuthSessionMutation) Fields() []string { if m.user_agent != nil { fields = append(fields, authsession.FieldUserAgent) } + if m.absolute_expiry != nil { + fields = append(fields, authsession.FieldAbsoluteExpiry) + } + if m.idle_expiry != nil { + fields = append(fields, authsession.FieldIdleExpiry) + } return fields } @@ -3560,6 +3640,10 @@ func (m *AuthSessionMutation) Field(name string) (ent.Value, bool) { return m.IPAddress() case authsession.FieldUserAgent: return m.UserAgent() + case authsession.FieldAbsoluteExpiry: + return m.AbsoluteExpiry() + case authsession.FieldIdleExpiry: + return m.IdleExpiry() } return nil, false } @@ -3585,6 +3669,10 @@ func (m *AuthSessionMutation) OldField(ctx context.Context, name string) (ent.Va return m.OldIPAddress(ctx) case authsession.FieldUserAgent: return m.OldUserAgent(ctx) + case authsession.FieldAbsoluteExpiry: + return m.OldAbsoluteExpiry(ctx) + case authsession.FieldIdleExpiry: + return m.OldIdleExpiry(ctx) } return nil, fmt.Errorf("unknown AuthSession field %s", name) } @@ -3650,6 +3738,20 @@ func (m *AuthSessionMutation) SetField(name string, value ent.Value) error { } m.SetUserAgent(v) return nil + case authsession.FieldAbsoluteExpiry: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetAbsoluteExpiry(v) + return nil + case authsession.FieldIdleExpiry: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetIdleExpiry(v) + return nil } return fmt.Errorf("unknown AuthSession field %s", name) } @@ -3723,6 +3825,12 @@ func (m *AuthSessionMutation) ResetField(name string) error { case authsession.FieldUserAgent: m.ResetUserAgent() return nil + case authsession.FieldAbsoluteExpiry: + m.ResetAbsoluteExpiry() + return nil + case authsession.FieldIdleExpiry: + m.ResetIdleExpiry() + return nil } return fmt.Errorf("unknown AuthSession field %s", name) } diff --git a/storage/ent/schema/authsession.go b/storage/ent/schema/authsession.go index ff76ab61..0b641b7f 100644 --- a/storage/ent/schema/authsession.go +++ b/storage/ent/schema/authsession.go @@ -37,6 +37,10 @@ func (AuthSession) Fields() []ent.Field { field.Text("user_agent"). SchemaType(textSchema). Default(""), + field.Time("absolute_expiry"). + SchemaType(timeSchema), + field.Time("idle_expiry"). + SchemaType(timeSchema), } } diff --git a/storage/etcd/etcd.go b/storage/etcd/etcd.go index e21acc9b..ad6244c3 100644 --- a/storage/etcd/etcd.go +++ b/storage/etcd/etcd.go @@ -108,6 +108,23 @@ func (c *conn) GarbageCollect(ctx context.Context, now time.Time) (result storag result.DeviceTokens++ } } + + authSessions, err := c.listAuthSessionsInternal(ctx) + if err != nil { + return result, err + } + + for _, authSession := range authSessions { + if now.After(authSession.AbsoluteExpiry) || now.After(authSession.IdleExpiry) { + if err := c.deleteKey(ctx, keyAuthSession(authSession.UserID, authSession.ConnectorID)); err != nil { + c.logger.Error("failed to delete auth session", "err", err) + delErr = fmt.Errorf("failed to delete auth session: %v", err) + } else { + result.AuthSessions++ + } + } + } + return result, delErr } @@ -458,16 +475,13 @@ func (c *conn) UpdateAuthSession(ctx context.Context, userID, connectorID string func (c *conn) ListAuthSessions(ctx context.Context) (sessions []storage.AuthSession, err error) { ctx, cancel := context.WithTimeout(ctx, defaultStorageTimeout) defer cancel() - res, err := c.db.Get(ctx, authSessionPrefix, clientv3.WithPrefix()) + + res, err := c.listAuthSessionsInternal(ctx) if err != nil { return sessions, err } - for _, v := range res.Kvs { - var s AuthSession - if err = json.Unmarshal(v.Value, &s); err != nil { - return sessions, err - } - sessions = append(sessions, toStorageAuthSession(s)) + for _, v := range res { + sessions = append(sessions, toStorageAuthSession(v)) } return sessions, nil } @@ -613,6 +627,21 @@ func (c *conn) listAuthCodes(ctx context.Context) (codes []AuthCode, err error) return codes, nil } +func (c *conn) listAuthSessionsInternal(ctx context.Context) (sessions []AuthSession, err error) { + res, err := c.db.Get(ctx, authSessionPrefix, clientv3.WithPrefix()) + if err != nil { + return sessions, err + } + for _, v := range res.Kvs { + var s AuthSession + if err = json.Unmarshal(v.Value, &s); err != nil { + return sessions, err + } + sessions = append(sessions, s) + } + return sessions, nil +} + func (c *conn) txnCreate(ctx context.Context, key string, value interface{}) error { b, err := json.Marshal(value) if err != nil { diff --git a/storage/etcd/types.go b/storage/etcd/types.go index 0a73e9ea..7be8bcf8 100644 --- a/storage/etcd/types.go +++ b/storage/etcd/types.go @@ -319,39 +319,45 @@ func toStorageUserIdentity(u UserIdentity) storage.UserIdentity { // AuthSession is a mirrored struct from storage with JSON struct tags. type AuthSession struct { - UserID string `json:"user_id,omitempty"` - ConnectorID string `json:"connector_id,omitempty"` - Nonce string `json:"nonce,omitempty"` - ClientStates map[string]*storage.ClientAuthState `json:"client_states,omitempty"` - CreatedAt time.Time `json:"created_at"` - LastActivity time.Time `json:"last_activity"` - IPAddress string `json:"ip_address,omitempty"` - UserAgent string `json:"user_agent,omitempty"` + UserID string `json:"user_id,omitempty"` + ConnectorID string `json:"connector_id,omitempty"` + Nonce string `json:"nonce,omitempty"` + ClientStates map[string]*storage.ClientAuthState `json:"client_states,omitempty"` + CreatedAt time.Time `json:"created_at"` + LastActivity time.Time `json:"last_activity"` + IPAddress string `json:"ip_address,omitempty"` + UserAgent string `json:"user_agent,omitempty"` + AbsoluteExpiry time.Time `json:"absolute_expiry"` + IdleExpiry time.Time `json:"idle_expiry"` } func fromStorageAuthSession(s storage.AuthSession) AuthSession { return AuthSession{ - UserID: s.UserID, - ConnectorID: s.ConnectorID, - Nonce: s.Nonce, - ClientStates: s.ClientStates, - CreatedAt: s.CreatedAt, - LastActivity: s.LastActivity, - IPAddress: s.IPAddress, - UserAgent: s.UserAgent, + UserID: s.UserID, + ConnectorID: s.ConnectorID, + Nonce: s.Nonce, + ClientStates: s.ClientStates, + CreatedAt: s.CreatedAt, + LastActivity: s.LastActivity, + IPAddress: s.IPAddress, + UserAgent: s.UserAgent, + AbsoluteExpiry: s.AbsoluteExpiry, + IdleExpiry: s.IdleExpiry, } } func toStorageAuthSession(s AuthSession) storage.AuthSession { result := storage.AuthSession{ - UserID: s.UserID, - ConnectorID: s.ConnectorID, - Nonce: s.Nonce, - ClientStates: s.ClientStates, - CreatedAt: s.CreatedAt, - LastActivity: s.LastActivity, - IPAddress: s.IPAddress, - UserAgent: s.UserAgent, + UserID: s.UserID, + ConnectorID: s.ConnectorID, + Nonce: s.Nonce, + ClientStates: s.ClientStates, + CreatedAt: s.CreatedAt, + LastActivity: s.LastActivity, + IPAddress: s.IPAddress, + UserAgent: s.UserAgent, + AbsoluteExpiry: s.AbsoluteExpiry, + IdleExpiry: s.IdleExpiry, } if result.ClientStates == nil { result.ClientStates = make(map[string]*storage.ClientAuthState) diff --git a/storage/kubernetes/storage.go b/storage/kubernetes/storage.go index b6d2990b..867cbed2 100644 --- a/storage/kubernetes/storage.go +++ b/storage/kubernetes/storage.go @@ -694,6 +694,22 @@ func (cli *client) GarbageCollect(ctx context.Context, now time.Time) (result st } } + var authSessions AuthSessionList + if err := cli.listN(resourceAuthSession, &authSessions, gcResultLimit); err != nil { + return result, fmt.Errorf("failed to list auth sessions: %v", err) + } + + for _, authSession := range authSessions.AuthSessions { + if now.After(authSession.AbsoluteExpiry) || now.After(authSession.IdleExpiry) { + if err := cli.delete(resourceAuthSession, authSession.ObjectMeta.Name); err != nil { + cli.logger.Error("failed to delete auth session", "err", err) + delErr = fmt.Errorf("failed to delete auth session: %v", err) + } else { + result.AuthSessions++ + } + } + } + if delErr != nil { return result, delErr } diff --git a/storage/kubernetes/types.go b/storage/kubernetes/types.go index 344ae75e..c945874a 100644 --- a/storage/kubernetes/types.go +++ b/storage/kubernetes/types.go @@ -996,14 +996,16 @@ type AuthSession struct { k8sapi.TypeMeta `json:",inline"` k8sapi.ObjectMeta `json:"metadata,omitempty"` - UserID string `json:"userID,omitempty"` - ConnectorID string `json:"connectorID,omitempty"` - Nonce string `json:"nonce,omitempty"` - ClientStates map[string]*storage.ClientAuthState `json:"clientStates,omitempty"` - CreatedAt time.Time `json:"createdAt,omitempty"` - LastActivity time.Time `json:"lastActivity,omitempty"` - IPAddress string `json:"ipAddress,omitempty"` - UserAgent string `json:"userAgent,omitempty"` + UserID string `json:"userID,omitempty"` + ConnectorID string `json:"connectorID,omitempty"` + Nonce string `json:"nonce,omitempty"` + ClientStates map[string]*storage.ClientAuthState `json:"clientStates,omitempty"` + CreatedAt time.Time `json:"createdAt,omitempty"` + LastActivity time.Time `json:"lastActivity,omitempty"` + IPAddress string `json:"ipAddress,omitempty"` + UserAgent string `json:"userAgent,omitempty"` + AbsoluteExpiry time.Time `json:"absoluteExpiry,omitempty"` + IdleExpiry time.Time `json:"idleExpiry,omitempty"` } // AuthSessionList is a list of AuthSessions. @@ -1023,27 +1025,31 @@ func (cli *client) fromStorageAuthSession(s storage.AuthSession) AuthSession { Name: offlineTokenName(s.UserID, s.ConnectorID, cli.hash), Namespace: cli.namespace, }, - UserID: s.UserID, - ConnectorID: s.ConnectorID, - Nonce: s.Nonce, - ClientStates: s.ClientStates, - CreatedAt: s.CreatedAt, - LastActivity: s.LastActivity, - IPAddress: s.IPAddress, - UserAgent: s.UserAgent, + UserID: s.UserID, + ConnectorID: s.ConnectorID, + Nonce: s.Nonce, + ClientStates: s.ClientStates, + CreatedAt: s.CreatedAt, + LastActivity: s.LastActivity, + IPAddress: s.IPAddress, + UserAgent: s.UserAgent, + AbsoluteExpiry: s.AbsoluteExpiry, + IdleExpiry: s.IdleExpiry, } } func toStorageAuthSession(s AuthSession) storage.AuthSession { result := storage.AuthSession{ - UserID: s.UserID, - ConnectorID: s.ConnectorID, - Nonce: s.Nonce, - ClientStates: s.ClientStates, - CreatedAt: s.CreatedAt, - LastActivity: s.LastActivity, - IPAddress: s.IPAddress, - UserAgent: s.UserAgent, + UserID: s.UserID, + ConnectorID: s.ConnectorID, + Nonce: s.Nonce, + ClientStates: s.ClientStates, + CreatedAt: s.CreatedAt, + LastActivity: s.LastActivity, + IPAddress: s.IPAddress, + UserAgent: s.UserAgent, + AbsoluteExpiry: s.AbsoluteExpiry, + IdleExpiry: s.IdleExpiry, } if result.ClientStates == nil { result.ClientStates = make(map[string]*storage.ClientAuthState) diff --git a/storage/memory/memory.go b/storage/memory/memory.go index e507340c..495ce6bb 100644 --- a/storage/memory/memory.go +++ b/storage/memory/memory.go @@ -101,6 +101,12 @@ func (s *memStorage) GarbageCollect(ctx context.Context, now time.Time) (result result.DeviceTokens++ } } + for id, a := range s.authSessions { + if now.After(a.AbsoluteExpiry) || now.After(a.IdleExpiry) { + delete(s.authSessions, id) + result.AuthSessions++ + } + } }) return result, nil } diff --git a/storage/sql/crud.go b/storage/sql/crud.go index 82b9bc9b..83308672 100644 --- a/storage/sql/crud.go +++ b/storage/sql/crud.go @@ -121,7 +121,15 @@ func (c *conn) GarbageCollect(ctc context.Context, now time.Time) (storage.GCRes result.DeviceTokens = n } - return result, err + r, err = c.Exec(`delete from auth_session where absolute_expiry < $1 OR idle_expiry < $2`, now, now) + if err != nil { + return result, fmt.Errorf("gc auth_session: %v", err) + } + if n, err := r.RowsAffected(); err == nil { + result.AuthSessions = n + } + + return result, nil } func (c *conn) CreateAuthRequest(ctx context.Context, a storage.AuthRequest) error { @@ -962,14 +970,16 @@ func (c *conn) CreateAuthSession(ctx context.Context, s storage.AuthSession) err user_id, connector_id, nonce, client_states, created_at, last_activity, - ip_address, user_agent + ip_address, user_agent, + absolute_expiry, idle_expiry ) - values ($1, $2, $3, $4, $5, $6, $7, $8); + values ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10); `, s.UserID, s.ConnectorID, s.Nonce, encoder(s.ClientStates), s.CreatedAt, s.LastActivity, s.IPAddress, s.UserAgent, + s.AbsoluteExpiry, s.IdleExpiry, ) if err != nil { if c.alreadyExistsCheck(err) { @@ -1022,7 +1032,8 @@ func getAuthSession(ctx context.Context, q querier, userID, connectorID string) user_id, connector_id, nonce, client_states, created_at, last_activity, - ip_address, user_agent + ip_address, user_agent, + absolute_expiry, idle_expiry from auth_session where user_id = $1 AND connector_id = $2; `, userID, connectorID)) @@ -1034,6 +1045,7 @@ func scanAuthSession(s scanner) (session storage.AuthSession, err error) { decoder(&session.ClientStates), &session.CreatedAt, &session.LastActivity, &session.IPAddress, &session.UserAgent, + &session.AbsoluteExpiry, &session.IdleExpiry, ) if err != nil { if err == sql.ErrNoRows { @@ -1053,7 +1065,8 @@ func (c *conn) ListAuthSessions(ctx context.Context) ([]storage.AuthSession, err user_id, connector_id, nonce, client_states, created_at, last_activity, - ip_address, user_agent + ip_address, user_agent, + absolute_expiry, idle_expiry from auth_session; `) if err != nil { diff --git a/storage/sql/migrate.go b/storage/sql/migrate.go index 9131d284..80b81b4f 100644 --- a/storage/sql/migrate.go +++ b/storage/sql/migrate.go @@ -442,4 +442,14 @@ var migrations = []migration{ `alter table auth_code add column auth_time timestamptz not null default '1970-01-01 00:00:00';`, }, }, + { + stmts: []string{ + ` + alter table auth_session + add column absolute_expiry timestamptz not null default '1970-01-01 00:00:00';`, + ` + alter table auth_session + add column idle_expiry timestamptz not null default '1970-01-01 00:00:00';`, + }, + }, } diff --git a/storage/storage.go b/storage/storage.go index 8161eb83..0889e6b6 100644 --- a/storage/storage.go +++ b/storage/storage.go @@ -60,6 +60,7 @@ type GCResult struct { AuthCodes int64 DeviceRequests int64 DeviceTokens int64 + AuthSessions int64 } // IsEmpty returns whether the garbage collection result is empty or not. @@ -67,7 +68,8 @@ func (g *GCResult) IsEmpty() bool { return g.AuthRequests == 0 && g.AuthCodes == 0 && g.DeviceRequests == 0 && - g.DeviceTokens == 0 + g.DeviceTokens == 0 && + g.AuthSessions == 0 } // Storage is the storage interface used by the server. Implementations are @@ -402,6 +404,11 @@ type AuthSession struct { LastActivity time.Time IPAddress string UserAgent string + + // AbsoluteExpiry is CreatedAt + AbsoluteLifetime, set once at creation. + AbsoluteExpiry time.Time + // IdleExpiry is LastActivity + ValidIfNotUsedFor, updated on every activity. + IdleExpiry time.Time } // OfflineSessions objects are sessions pertaining to users with refresh tokens.