From e5feba1d3b37b2d99d10b374af124c9c623b106b Mon Sep 17 00:00:00 2001 From: "maksim.nabokikh" Date: Mon, 16 Mar 2026 13:18:33 +0100 Subject: [PATCH] Fixes according to code review comments Signed-off-by: maksim.nabokikh --- storage/conformance/conformance.go | 9 +++++++ storage/ent/client/authsession.go | 17 +++++++++++-- storage/etcd/etcd.go | 17 +++++++++++++ storage/kubernetes/storage.go | 14 +++++++++++ storage/memory/memory.go | 9 +++++++ storage/sql/crud.go | 38 +++++++++++++++++++++++++----- storage/storage.go | 1 + 7 files changed, 97 insertions(+), 8 deletions(-) diff --git a/storage/conformance/conformance.go b/storage/conformance/conformance.go index 4b067df1..94b23745 100644 --- a/storage/conformance/conformance.go +++ b/storage/conformance/conformance.go @@ -1248,6 +1248,15 @@ func testAuthSessionCRUD(t *testing.T, s storage.Storage) { t.Errorf("expected client2 user_id to be user2, got %s", got.ClientStates["client2"].UserID) } + // List and verify. + sessions, err := s.ListAuthSessions(ctx) + if err != nil { + t.Fatalf("list auth sessions: %v", err) + } + if len(sessions) != 1 { + t.Fatalf("expected 1 auth session, got %d", len(sessions)) + } + // Delete. if err := s.DeleteAuthSession(ctx, session.ID); err != nil { t.Fatalf("delete auth session: %v", err) diff --git a/storage/ent/client/authsession.go b/storage/ent/client/authsession.go index a20cfe76..439bdbe3 100644 --- a/storage/ent/client/authsession.go +++ b/storage/ent/client/authsession.go @@ -41,6 +41,20 @@ func (d *Database) GetAuthSession(ctx context.Context, sessionID string) (storag return toStorageAuthSession(authSession), nil } +// ListAuthSessions extracts all auth sessions from the database. +func (d *Database) ListAuthSessions(ctx context.Context) ([]storage.AuthSession, error) { + authSessions, err := d.client.AuthSession.Query().All(ctx) + if err != nil { + return nil, convertDBError("list auth sessions: %w", err) + } + + storageAuthSessions := make([]storage.AuthSession, 0, len(authSessions)) + for _, s := range authSessions { + storageAuthSessions = append(storageAuthSessions, toStorageAuthSession(s)) + } + return storageAuthSessions, nil +} + // DeleteAuthSession deletes an auth session from the database by session ID. func (d *Database) DeleteAuthSession(ctx context.Context, sessionID string) error { err := d.client.AuthSession.DeleteOneID(sessionID).Exec(ctx) @@ -78,13 +92,12 @@ func (d *Database) UpdateAuthSession(ctx context.Context, sessionID string, upda _, err = tx.AuthSession.UpdateOneID(sessionID). SetClientStates(encodedStates). - SetCreatedAt(newSession.CreatedAt). SetLastActivity(newSession.LastActivity). SetIPAddress(newSession.IPAddress). SetUserAgent(newSession.UserAgent). Save(ctx) if err != nil { - return rollback(tx, "update auth session uploading: %w", err) + return rollback(tx, "update auth session updating: %w", err) } if err = tx.Commit(); err != nil { diff --git a/storage/etcd/etcd.go b/storage/etcd/etcd.go index ff2e12b3..13aa20a1 100644 --- a/storage/etcd/etcd.go +++ b/storage/etcd/etcd.go @@ -455,6 +455,23 @@ func (c *conn) UpdateAuthSession(ctx context.Context, sessionID string, updater }) } +func (c *conn) ListAuthSessions(ctx context.Context) (sessions []storage.AuthSession, err error) { + ctx, cancel := context.WithTimeout(ctx, defaultStorageTimeout) + defer cancel() + res, err := c.db.Get(ctx, authSessionPrefix, clientv3.WithPrefix()) + if err != nil { + return sessions, err + } + for _, v := range res.Kvs { + var s AuthSession + if err = json.Unmarshal(v.Value, &s); err != nil { + return sessions, err + } + sessions = append(sessions, toStorageAuthSession(s)) + } + return sessions, nil +} + func (c *conn) DeleteAuthSession(ctx context.Context, sessionID string) error { ctx, cancel := context.WithTimeout(ctx, defaultStorageTimeout) defer cancel() diff --git a/storage/kubernetes/storage.go b/storage/kubernetes/storage.go index c9b47d0c..55ea7455 100644 --- a/storage/kubernetes/storage.go +++ b/storage/kubernetes/storage.go @@ -841,6 +841,20 @@ func (cli *client) UpdateAuthSession(ctx context.Context, sessionID string, upda }) } +func (cli *client) ListAuthSessions(ctx context.Context) ([]storage.AuthSession, error) { + var authSessionList AuthSessionList + if err := cli.list(resourceAuthSession, &authSessionList); err != nil { + return nil, fmt.Errorf("failed to list auth sessions: %v", err) + } + + sessions := make([]storage.AuthSession, len(authSessionList.AuthSessions)) + for i, s := range authSessionList.AuthSessions { + sessions[i] = toStorageAuthSession(s) + } + + return sessions, nil +} + func (cli *client) DeleteAuthSession(ctx context.Context, sessionID string) error { return cli.delete(resourceAuthSession, sessionID) } diff --git a/storage/memory/memory.go b/storage/memory/memory.go index 9757cf71..483ed246 100644 --- a/storage/memory/memory.go +++ b/storage/memory/memory.go @@ -248,6 +248,15 @@ func (s *memStorage) ListUserIdentities(ctx context.Context) (identities []stora return } +func (s *memStorage) ListAuthSessions(ctx context.Context) (sessions []storage.AuthSession, err error) { + s.tx(func() { + for _, session := range s.authSessions { + sessions = append(sessions, session) + } + }) + return +} + func (s *memStorage) CreateAuthSession(ctx context.Context, session storage.AuthSession) (err error) { s.tx(func() { if _, ok := s.authSessions[session.ID]; ok { diff --git a/storage/sql/crud.go b/storage/sql/crud.go index 50c51f66..ab11713a 100644 --- a/storage/sql/crud.go +++ b/storage/sql/crud.go @@ -963,14 +963,13 @@ func (c *conn) UpdateAuthSession(ctx context.Context, sessionID string, updater update auth_session set client_states = $1, - created_at = $2, - last_activity = $3, - ip_address = $4, - user_agent = $5 - where id = $6; + last_activity = $2, + ip_address = $3, + user_agent = $4 + where id = $5; `, encoder(newSession.ClientStates), - newSession.CreatedAt, newSession.LastActivity, + newSession.LastActivity, newSession.IPAddress, newSession.UserAgent, sessionID, ) @@ -1014,6 +1013,33 @@ func scanAuthSession(s scanner) (session storage.AuthSession, err error) { return session, nil } +func (c *conn) ListAuthSessions(ctx context.Context) ([]storage.AuthSession, error) { + rows, err := c.Query(` + select + id, client_states, + created_at, last_activity, + ip_address, user_agent + from auth_session; + `) + if err != nil { + return nil, fmt.Errorf("query: %v", err) + } + defer rows.Close() + + var sessions []storage.AuthSession + for rows.Next() { + s, err := scanAuthSession(rows) + if err != nil { + return nil, err + } + sessions = append(sessions, s) + } + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("scan: %v", err) + } + return sessions, nil +} + func (c *conn) DeleteAuthSession(ctx context.Context, sessionID string) error { result, err := c.Exec(`delete from auth_session where id = $1`, sessionID) if err != nil { diff --git a/storage/storage.go b/storage/storage.go index 4e89a6b3..963c7c67 100644 --- a/storage/storage.go +++ b/storage/storage.go @@ -109,6 +109,7 @@ type Storage interface { ListPasswords(ctx context.Context) ([]Password, error) ListConnectors(ctx context.Context) ([]Connector, error) ListUserIdentities(ctx context.Context) ([]UserIdentity, error) + ListAuthSessions(ctx context.Context) ([]AuthSession, error) // Delete methods MUST be atomic. DeleteAuthRequest(ctx context.Context, id string) error