From 3bb5e735c6644709807594850e67365b60762718 Mon Sep 17 00:00:00 2001 From: Julien Dessaux Date: Tue, 29 Apr 2025 01:25:11 +0200 Subject: [PATCH] chore(webui): rewrite all the web session code #60 --- .gitignore | 2 + cmd/tfstated/main_test.go | 6 ++- pkg/database/accounts.go | 15 +++++-- pkg/database/db.go | 18 +++++--- pkg/database/sessions.go | 53 ++++++++++++++++-------- pkg/database/sql/000_init.sql | 7 ++-- pkg/helpers/crypto.go | 11 +++-- pkg/model/session.go | 7 ++-- pkg/webui/accountsId.go | 2 +- pkg/webui/accountsIdResetPassword.go | 6 +-- pkg/webui/error.go | 2 +- pkg/webui/html/base.html | 1 - pkg/webui/index.go | 11 +++-- pkg/webui/login.go | 62 ++++++++++++---------------- pkg/webui/logout.go | 14 ++++--- pkg/webui/routes.go | 14 +++---- pkg/webui/sessions.go | 58 +++++++++++++++----------- pkg/webui/settings.go | 5 ++- pkg/webui/versions.go | 2 +- 19 files changed, 173 insertions(+), 123 deletions(-) diff --git a/.gitignore b/.gitignore index 8e16b05..aa5f590 100644 --- a/.gitignore +++ b/.gitignore @@ -2,3 +2,5 @@ cover.out tfstated.db tfstated.db-shm tfstated.db-wal +tfstated_data_encryption_key +tfstated_sessions_salt diff --git a/cmd/tfstated/main_test.go b/cmd/tfstated/main_test.go index b28edbd..ecb82ee 100644 --- a/cmd/tfstated/main_test.go +++ b/cmd/tfstated/main_test.go @@ -26,13 +26,15 @@ var adminPasswordMutex sync.Mutex func TestMain(m *testing.M) { getenv := func(key string) string { switch key { - case "DATA_ENCRYPTION_KEY": + case "TFSTATED_DATA_ENCRYPTION_KEY": return "hP3ZSCnY3LMgfTQjwTaGrhKwdA0yXMXIfv67OJnntqM=" case "TFSTATED_HOST": return "127.0.0.1" case "TFSTATED_PORT": return "8082" - case "VERSIONS_HISTORY_LIMIT": + case "TFSTATED_SESSIONS_SALT": + return "a528D1m9q3IZxLinSmHmeKxrx3Pmm7GQ3nBzIDxjr0A=" + case "TFSTATED_VERSIONS_HISTORY_LIMIT": return "3" default: return "" diff --git a/pkg/database/accounts.go b/pkg/database/accounts.go index 5f589c8..ca10c15 100644 --- a/pkg/database/accounts.go +++ b/pkg/database/accounts.go @@ -151,9 +151,12 @@ func (db *DB) LoadAccountUsernames() (map[string]string, error) { return accounts, nil } -func (db *DB) LoadAccountById(id uuid.UUID) (*model.Account, error) { +func (db *DB) LoadAccountById(id *uuid.UUID) (*model.Account, error) { + if id == nil { + return nil, nil + } account := model.Account{ - Id: id, + Id: *id, } var ( created int64 @@ -239,11 +242,15 @@ func (db *DB) SaveAccount(account *model.Account) error { func (db *DB) SaveAccountSettings(account *model.Account, settings *model.Settings) error { data, err := json.Marshal(settings) if err != nil { - return fmt.Errorf("failed to marshal settings for user %s: %w", account.Username, err) + return fmt.Errorf("failed to marshal settings for user account %s: %w", account.Username, err) } _, err = db.Exec(`UPDATE accounts SET settings = ? WHERE id = ?`, data, account.Id) if err != nil { - return fmt.Errorf("failed to update settings for user %s: %w", account.Username, err) + return fmt.Errorf("failed to update account settings for user account %s: %w", account.Username, err) + } + _, err = db.Exec(`UPDATE sessions SET settings = ? WHERE account_id = ?`, data, account.Id) + if err != nil { + return fmt.Errorf("failed to update account settings for user account %s: %w", account.Username, err) } return nil } diff --git a/pkg/database/db.go b/pkg/database/db.go index c155ecc..09d9d19 100644 --- a/pkg/database/db.go +++ b/pkg/database/db.go @@ -31,6 +31,7 @@ type DB struct { ctx context.Context dataEncryptionKey scrypto.AES256Key readDB *sql.DB + sessionsSalt scrypto.AES256Key versionsHistoryLimit int writeDB *sql.DB } @@ -82,17 +83,24 @@ func NewDB(ctx context.Context, url string, getenv func(string) string) (*DB, er return nil, fmt.Errorf("failed to migrate: %w", err) } - dataEncryptionKey := getenv("DATA_ENCRYPTION_KEY") + dataEncryptionKey := getenv("TFSTATED_DATA_ENCRYPTION_KEY") if dataEncryptionKey == "" { - return nil, fmt.Errorf("the DATA_ENCRYPTION_KEY environment variable is not set") + return nil, fmt.Errorf("the TFSTATED_DATA_ENCRYPTION_KEY environment variable is not set") } if err := db.dataEncryptionKey.FromBase64(dataEncryptionKey); err != nil { - return nil, fmt.Errorf("failed to decode the DATA_ENCRYPTION_KEY environment variable, expected base64: %w", err) + return nil, fmt.Errorf("failed to decode the TFSTATED_DATA_ENCRYPTION_KEY environment variable, expected 32 bytes base64 encoded: %w", err) } - versionsHistoryLimit := getenv("VERSIONS_HISTORY_LIMIT") + sessionsSalt := getenv("TFSTATED_SESSIONS_SALT") + if sessionsSalt == "" { + return nil, fmt.Errorf("the TFSTATED_SESSIONS_SALT environment variable is not set") + } + if err := db.sessionsSalt.FromBase64(dataEncryptionKey); err != nil { + return nil, fmt.Errorf("failed to decode the TFSTATED_SESSIONS_SALT environment variable, expected 32 bytes base64 encoded: %w", err) + } + versionsHistoryLimit := getenv("TFSTATED_VERSIONS_HISTORY_LIMIT") if versionsHistoryLimit != "" { if db.versionsHistoryLimit, err = strconv.Atoi(versionsHistoryLimit); err != nil { - return nil, fmt.Errorf("failed to parse the VERSIONS_HISTORY_LIMIT environment variable, expected an integer: %w", err) + return nil, fmt.Errorf("failed to parse the TFSTATED_VERSIONS_HISTORY_LIMIT environment variable, expected an integer: %w", err) } } diff --git a/pkg/database/sessions.go b/pkg/database/sessions.go index e8570d2..f579b4e 100644 --- a/pkg/database/sessions.go +++ b/pkg/database/sessions.go @@ -7,19 +7,30 @@ import ( "fmt" "time" + "git.adyxax.org/adyxax/tfstated/pkg/helpers" "git.adyxax.org/adyxax/tfstated/pkg/model" "git.adyxax.org/adyxax/tfstated/pkg/scrypto" + "go.n16f.net/uuid" ) -func (db *DB) CreateSession(account *model.Account) (string, error) { +func (db *DB) CreateSession(account *model.Account, settingsData []byte) (string, error) { sessionBytes := scrypto.RandomBytes(32) sessionId := base64.RawURLEncoding.EncodeToString(sessionBytes[:]) + sessionHash := helpers.HashSessionId(sessionBytes, db.sessionsSalt.Bytes()) + var accountId *uuid.UUID = nil + var settings = []byte("{}") + if account != nil { + accountId = &account.Id + settings = account.Settings + } else if settingsData != nil { + settings = settingsData + } if _, err := db.Exec( - `INSERT INTO sessions(id, account_id, data) + `INSERT INTO sessions(id, account_id, settings) VALUES (?, ?, ?);`, - sessionId, - account.Id, - "", + sessionHash, + accountId, + settings, ); err != nil { return "", fmt.Errorf("failed insert new session in database: %w", err) } @@ -27,7 +38,8 @@ func (db *DB) CreateSession(account *model.Account) (string, error) { } func (db *DB) DeleteExpiredSessions() error { - _, err := db.Exec(`DELETE FROM sessions WHERE created < ?`, time.Now().Unix()) + expires := time.Now().Add(-12 * time.Hour) + _, err := db.Exec(`DELETE FROM sessions WHERE created < ?`, expires.Unix()) if err != nil { return fmt.Errorf("failed to delete expired session: %w", err) } @@ -43,25 +55,30 @@ func (db *DB) DeleteSession(session *model.Session) error { } func (db *DB) LoadSessionById(id string) (*model.Session, error) { + sessionBytes, err := base64.RawURLEncoding.DecodeString(id) + if err != nil { + return nil, fmt.Errorf("failed to decode base64 session id: %w", err) + } + sessionHash := helpers.HashSessionId(sessionBytes, db.sessionsSalt.Bytes()) session := model.Session{ - Id: id, + Id: sessionHash, } var ( created int64 updated int64 ) - err := db.QueryRow( + err = db.QueryRow( `SELECT account_id, created, updated, - data + settings FROM sessions WHERE id = ?;`, - id, + sessionHash, ).Scan(&session.AccountId, &created, &updated, - &session.Data, + &session.Settings, ) if err != nil { if errors.Is(err, sql.ErrNoRows) { @@ -74,11 +91,13 @@ func (db *DB) LoadSessionById(id string) (*model.Session, error) { return &session, nil } -func (db *DB) TouchSession(sessionId string) error { - now := time.Now().UTC() - _, err := db.Exec(`UPDATE sessions SET updated = ? WHERE id = ?`, now.Unix(), sessionId) - if err != nil { - return fmt.Errorf("failed to touch updated for session %s: %w", sessionId, err) +func (db *DB) MigrateSession(session *model.Session, account *model.Account) (string, error) { + if err := db.DeleteSession(session); err != nil { + return "", fmt.Errorf("failed to delete session: %w", err) } - return nil + sessionId, err := db.CreateSession(account, session.Settings) + if err != nil { + return "", fmt.Errorf("failed to create session: %w", err) + } + return sessionId, nil } diff --git a/pkg/database/sql/000_init.sql b/pkg/database/sql/000_init.sql index bc9f7ce..2e8d530 100644 --- a/pkg/database/sql/000_init.sql +++ b/pkg/database/sql/000_init.sql @@ -16,12 +16,11 @@ CREATE TABLE accounts ( CREATE UNIQUE INDEX accounts_username on accounts(username); CREATE TABLE sessions ( - id TEXT PRIMARY KEY, - account_id TEXT NOT NULL, + id BLOB PRIMARY KEY, + account_id TEXT, created INTEGER NOT NULL DEFAULT (unixepoch()), updated INTEGER NOT NULL DEFAULT (unixepoch()), - data TEXT NOT NULL, - FOREIGN KEY(account_id) REFERENCES accounts(id) ON DELETE CASCADE + settings BLOB NOT NULL ) STRICT; CREATE TABLE states ( diff --git a/pkg/helpers/crypto.go b/pkg/helpers/crypto.go index ce73cd3..326e657 100644 --- a/pkg/helpers/crypto.go +++ b/pkg/helpers/crypto.go @@ -8,8 +8,9 @@ import ( ) const ( - PBKDF2Iterations = 600000 - SaltSize = 32 + PBKDF2PasswordIterations = 600000 + PBKDF2SessionIterations = 12 + SaltSize = 32 ) func GenerateSalt() []byte { @@ -17,5 +18,9 @@ func GenerateSalt() []byte { } func HashPassword(password string, salt []byte) []byte { - return pbkdf2.Key([]byte(password), salt, PBKDF2Iterations, 32, sha256.New) + return pbkdf2.Key([]byte(password), salt, PBKDF2PasswordIterations, 32, sha256.New) +} + +func HashSessionId(id []byte, salt []byte) []byte { + return pbkdf2.Key(id, salt, PBKDF2SessionIterations, 32, sha256.New) } diff --git a/pkg/model/session.go b/pkg/model/session.go index ac6da89..6bf6e6b 100644 --- a/pkg/model/session.go +++ b/pkg/model/session.go @@ -1,6 +1,7 @@ package model import ( + "encoding/json" "time" "go.n16f.net/uuid" @@ -9,11 +10,11 @@ import ( type SessionContextKey struct{} type Session struct { - Id string - AccountId uuid.UUID + Id []byte + AccountId *uuid.UUID Created time.Time Updated time.Time - Data any + Settings json.RawMessage } func (session *Session) IsExpired() bool { diff --git a/pkg/webui/accountsId.go b/pkg/webui/accountsId.go index 73a93c1..a3e6064 100644 --- a/pkg/webui/accountsId.go +++ b/pkg/webui/accountsId.go @@ -29,7 +29,7 @@ func handleAccountsIdGET(db *database.DB) http.Handler { errorResponse(w, r, http.StatusBadRequest, err) return } - account, err := db.LoadAccountById(accountId) + account, err := db.LoadAccountById(&accountId) if err != nil { errorResponse(w, r, http.StatusInternalServerError, err) return diff --git a/pkg/webui/accountsIdResetPassword.go b/pkg/webui/accountsIdResetPassword.go index f8e341e..4cae081 100644 --- a/pkg/webui/accountsIdResetPassword.go +++ b/pkg/webui/accountsIdResetPassword.go @@ -30,7 +30,7 @@ func processAccountsIdResetPasswordPathValues(db *database.DB, w http.ResponseWr errorResponse(w, r, http.StatusBadRequest, err) return nil, false } - account, err := db.LoadAccountById(accountId) + account, err := db.LoadAccountById(&accountId) if err != nil { errorResponse(w, r, http.StatusInternalServerError, err) return nil, false @@ -55,7 +55,7 @@ func handleAccountsIdResetPasswordGET(db *database.DB) http.Handler { render(w, accountsIdResetPasswordTemplates, http.StatusOK, AccountsIdResetPasswordPage{ Account: account, - Page: &Page{Title: "Password Reset", Section: "reset"}, + Page: makePage(r, &Page{Title: "Password Reset", Section: "reset"}), Token: r.PathValue("token"), }) }) @@ -80,7 +80,7 @@ func handleAccountsIdResetPasswordPOST(db *database.DB) http.Handler { render(w, accountsIdResetPasswordTemplates, http.StatusOK, AccountsIdResetPasswordPage{ Account: account, - Page: &Page{Title: "Password Reset", Section: "reset"}, + Page: makePage(r, &Page{Title: "Password Reset", Section: "reset"}), PasswordChanged: true, }) }) diff --git a/pkg/webui/error.go b/pkg/webui/error.go index b1fcf4a..6e9815c 100644 --- a/pkg/webui/error.go +++ b/pkg/webui/error.go @@ -15,7 +15,7 @@ func errorResponse(w http.ResponseWriter, r *http.Request, status int, err error StatusText string } render(w, errorTemplates, status, &ErrorData{ - Page: &Page{Title: "Error", Section: "error"}, + Page: makePage(r, &Page{Title: "Error", Section: "error"}), Err: err, Status: status, StatusText: http.StatusText(status), diff --git a/pkg/webui/html/base.html b/pkg/webui/html/base.html index 2a5b5a9..ff0b356 100644 --- a/pkg/webui/html/base.html +++ b/pkg/webui/html/base.html @@ -55,6 +55,5 @@ - diff --git a/pkg/webui/index.go b/pkg/webui/index.go index 580d5a8..bb22ae0 100644 --- a/pkg/webui/index.go +++ b/pkg/webui/index.go @@ -9,7 +9,7 @@ import ( ) type Page struct { - AccountId uuid.UUID + AccountId *uuid.UUID IsAdmin bool LightMode bool Section string @@ -17,9 +17,12 @@ type Page struct { } func makePage(r *http.Request, page *Page) *Page { - account := r.Context().Value(model.AccountContextKey{}).(*model.Account) - page.AccountId = account.Id - page.IsAdmin = account.IsAdmin + accountCtx := r.Context().Value(model.AccountContextKey{}) + if accountCtx != nil { + account := accountCtx.(*model.Account) + page.AccountId = &account.Id + page.IsAdmin = account.IsAdmin + } settings := r.Context().Value(model.SettingsContextKey{}).(*model.Settings) page.LightMode = settings.LightMode return page diff --git a/pkg/webui/login.go b/pkg/webui/login.go index 7563aad..ba02279 100644 --- a/pkg/webui/login.go +++ b/pkg/webui/login.go @@ -2,7 +2,7 @@ package webui import ( "context" - "encoding/json" + "fmt" "html/template" "log/slog" "net/http" @@ -26,29 +26,30 @@ func handleLoginGET() http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Cache-Control", "no-store, no-cache") - session := r.Context().Value(model.SessionContextKey{}) - if session != nil { + account := r.Context().Value(model.AccountContextKey{}) + if account != nil { http.Redirect(w, r, "/states", http.StatusFound) return } render(w, loginTemplate, http.StatusOK, loginPage{ - Page: &Page{Title: "Login", Section: "login"}, + Page: makePage(r, &Page{Title: "Login", Section: "login"}), }) }) } func handleLoginPOST(db *database.DB) http.Handler { - renderForbidden := func(w http.ResponseWriter, username string) { + renderForbidden := func(w http.ResponseWriter, r *http.Request, username string) { render(w, loginTemplate, http.StatusForbidden, loginPage{ - Page: &Page{Title: "Login", Section: "login"}, + Page: makePage(r, &Page{Title: "Login", Section: "login"}), Forbidden: true, Username: username, }) } return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if err := r.ParseForm(); err != nil { - errorResponse(w, r, http.StatusBadRequest, err) + errorResponse(w, r, http.StatusBadRequest, + fmt.Errorf("failed to parse form: %w", err)) return } username := r.FormValue("username") @@ -59,37 +60,32 @@ func handleLoginPOST(db *database.DB) http.Handler { return } if ok := validUsername.MatchString(username); !ok { - renderForbidden(w, username) + renderForbidden(w, r, username) return } account, err := db.LoadAccountByUsername(username) if err != nil { - errorResponse(w, r, http.StatusInternalServerError, err) + errorResponse(w, r, http.StatusInternalServerError, + fmt.Errorf("failed to load account by username %s: %w", username, err)) return } if account == nil || !account.CheckPassword(password) { - renderForbidden(w, username) + renderForbidden(w, r, username) return } if err := db.TouchAccount(account); err != nil { - errorResponse(w, r, http.StatusInternalServerError, err) + errorResponse(w, r, http.StatusInternalServerError, + fmt.Errorf("failed to touch account %s: %w", username, err)) return } - sessionId, err := db.CreateSession(account) + session := r.Context().Value(model.SessionContextKey{}).(*model.Session) + sessionId, err := db.MigrateSession(session, account) if err != nil { - errorResponse(w, r, http.StatusInternalServerError, err) + errorResponse(w, r, http.StatusInternalServerError, + fmt.Errorf("failed to migrate session: %w", err)) return } - http.SetCookie(w, &http.Cookie{ - Name: cookieName, - Value: sessionId, - Quoted: false, - Path: "/", - MaxAge: 12 * 3600, // 12 hours sessions - HttpOnly: true, - SameSite: http.SameSiteStrictMode, - Secure: true, - }) + setSessionCookie(w, sessionId) if err := db.DeleteExpiredSessions(); err != nil { slog.Error("failed to delete expired sessions after user login", "err", err, "accountId", account.Id) } @@ -97,32 +93,26 @@ func handleLoginPOST(db *database.DB) http.Handler { }) } -func loginMiddleware(db *database.DB, processSession func(http.Handler) http.Handler) func(http.Handler) http.Handler { +func loginMiddleware(db *database.DB, requireSession func(http.Handler) http.Handler) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { - return processSession(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + return requireSession(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Cache-Control", "no-store, no-cache") - session := r.Context().Value(model.SessionContextKey{}) - if session == nil { + session := r.Context().Value(model.SessionContextKey{}).(*model.Session) + if session.AccountId == nil { http.Redirect(w, r, "/login", http.StatusFound) return } - account, err := db.LoadAccountById(session.(*model.Session).AccountId) + account, err := db.LoadAccountById(session.AccountId) if err != nil { - errorResponse(w, r, http.StatusInternalServerError, err) + errorResponse(w, r, http.StatusInternalServerError, + fmt.Errorf("failed to load account by Id: %w", err)) return } if account == nil { - // this could happen if the account was deleted in the short - // time between retrieving the session and here http.Redirect(w, r, "/login", http.StatusFound) return } ctx := context.WithValue(r.Context(), model.AccountContextKey{}, account) - var settings model.Settings - if err := json.Unmarshal(account.Settings, &settings); err != nil { - slog.Error("failed to unmarshal account settings", "err", err, "accountId", account.Id) - } - ctx = context.WithValue(ctx, model.SettingsContextKey{}, &settings) next.ServeHTTP(w, r.WithContext(ctx)) })) } diff --git a/pkg/webui/logout.go b/pkg/webui/logout.go index b7c2bde..b02f749 100644 --- a/pkg/webui/logout.go +++ b/pkg/webui/logout.go @@ -1,6 +1,7 @@ package webui import ( + "fmt" "html/template" "net/http" @@ -12,18 +13,19 @@ var logoutTemplate = template.Must(template.ParseFS(htmlFS, "html/base.html", "h func handleLogoutGET(db *database.DB) http.Handler { type logoutPage struct { - Page + Page *Page } return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - session := r.Context().Value(model.SessionContextKey{}) - err := db.DeleteSession(session.(*model.Session)) + session := r.Context().Value(model.SessionContextKey{}).(*model.Session) + sessionId, err := db.MigrateSession(session, nil) if err != nil { - errorResponse(w, r, http.StatusInternalServerError, err) + errorResponse(w, r, http.StatusInternalServerError, + fmt.Errorf("failed to migrate session: %w", err)) return } - unsetSesssionCookie(w) + setSessionCookie(w, sessionId) render(w, logoutTemplate, http.StatusOK, logoutPage{ - Page: Page{Title: "Logout", Section: "login"}, + Page: makePage(r, &Page{Title: "Logout", Section: "login"}), }) }) } diff --git a/pkg/webui/routes.go b/pkg/webui/routes.go index fb00aca..a6a1d49 100644 --- a/pkg/webui/routes.go +++ b/pkg/webui/routes.go @@ -10,17 +10,17 @@ func addRoutes( mux *http.ServeMux, db *database.DB, ) { - processSession := sessionsMiddleware(db) - requireLogin := loginMiddleware(db, processSession) + requireSession := sessionsMiddleware(db) + requireLogin := loginMiddleware(db, requireSession) requireAdmin := adminMiddleware(db, requireLogin) mux.Handle("GET /accounts", requireLogin(handleAccountsGET(db))) mux.Handle("GET /accounts/{id}", requireLogin(handleAccountsIdGET(db))) - mux.Handle("GET /accounts/{id}/reset/{token}", handleAccountsIdResetPasswordGET(db)) - mux.Handle("POST /accounts/{id}/reset/{token}", handleAccountsIdResetPasswordPOST(db)) + mux.Handle("GET /accounts/{id}/reset/{token}", requireSession(handleAccountsIdResetPasswordGET(db))) + mux.Handle("POST /accounts/{id}/reset/{token}", requireSession(handleAccountsIdResetPasswordPOST(db))) mux.Handle("POST /accounts", requireAdmin(handleAccountsPOST(db))) mux.Handle("GET /healthz", handleHealthz()) - mux.Handle("GET /login", processSession(handleLoginGET())) - mux.Handle("POST /login", processSession(handleLoginPOST(db))) + mux.Handle("GET /login", requireSession(handleLoginGET())) + mux.Handle("POST /login", requireSession(handleLoginPOST(db))) mux.Handle("GET /logout", requireLogin(handleLogoutGET(db))) mux.Handle("GET /settings", requireLogin(handleSettingsGET(db))) mux.Handle("POST /settings", requireLogin(handleSettingsPOST(db))) @@ -30,5 +30,5 @@ func addRoutes( mux.Handle("POST /states/{id}", requireLogin(handleStatesIdPOST(db))) mux.Handle("GET /static/", cache(http.FileServer(http.FS(staticFS)))) mux.Handle("GET /versions/{id}", requireLogin(handleVersionsGET(db))) - mux.Handle("GET /", requireLogin(handleIndexGET())) + mux.Handle("GET /", requireSession(handleIndexGET())) } diff --git a/pkg/webui/sessions.go b/pkg/webui/sessions.go index f3659e1..4d44710 100644 --- a/pkg/webui/sessions.go +++ b/pkg/webui/sessions.go @@ -2,8 +2,10 @@ package webui import ( "context" + "encoding/json" "errors" "fmt" + "log/slog" "net/http" "git.adyxax.org/adyxax/tfstated/pkg/database" @@ -17,49 +19,59 @@ func sessionsMiddleware(db *database.DB) func(http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { cookie, err := r.Cookie(cookieName) if err != nil && !errors.Is(err, http.ErrNoCookie) { - errorResponse(w, r, http.StatusInternalServerError, fmt.Errorf("failed to get request cookie \"%s\": %w", cookieName, err)) + errorResponse(w, r, http.StatusInternalServerError, + fmt.Errorf("failed to get session cookie from request: %w", err)) return } if err == nil { - if len(cookie.Value) != 43 { - unsetSesssionCookie(w) - } else { + if len(cookie.Value) == 43 { session, err := db.LoadSessionById(cookie.Value) if err != nil { - errorResponse(w, r, http.StatusInternalServerError, err) + errorResponse(w, r, http.StatusInternalServerError, + fmt.Errorf("failed to load session by ID: %w", err)) return } - if session == nil { - unsetSesssionCookie(w) - } else if session.IsExpired() { - unsetSesssionCookie(w) - if err := db.DeleteSession(session); err != nil { - errorResponse(w, r, http.StatusInternalServerError, err) + if session != nil { + if session.IsExpired() { + if err := db.DeleteSession(session); err != nil { + errorResponse(w, r, http.StatusInternalServerError, + fmt.Errorf("failed to delete session: %w", err)) + return + } + } else { + ctx := context.WithValue(r.Context(), model.SessionContextKey{}, session) + var settings model.Settings + if err := json.Unmarshal(session.Settings, &settings); err != nil { + slog.Error("failed to unmarshal session settings", "err", err) + } + ctx = context.WithValue(ctx, model.SettingsContextKey{}, &settings) + next.ServeHTTP(w, r.WithContext(ctx)) return } - } else { - if err := db.TouchSession(cookie.Value); err != nil { - errorResponse(w, r, http.StatusInternalServerError, err) - return - } - ctx := context.WithValue(r.Context(), model.SessionContextKey{}, session) - next.ServeHTTP(w, r.WithContext(ctx)) - return } } } - next.ServeHTTP(w, r) + sessionId, err := db.CreateSession(nil, nil) + if err != nil { + errorResponse(w, r, http.StatusInternalServerError, + fmt.Errorf("failed to create session: %w", err)) + return + } + setSessionCookie(w, sessionId) + var settings model.Settings + ctx := context.WithValue(r.Context(), model.SettingsContextKey{}, &settings) + next.ServeHTTP(w, r.WithContext(ctx)) }) } } -func unsetSesssionCookie(w http.ResponseWriter) { +func setSessionCookie(w http.ResponseWriter, sessionId string) { http.SetCookie(w, &http.Cookie{ Name: cookieName, - Value: "", + Value: sessionId, Quoted: false, Path: "/", - MaxAge: 0, // remove invalid cookie + MaxAge: 12 * 3600, // 12 hours sessions HttpOnly: true, SameSite: http.SameSiteStrictMode, Secure: true, diff --git a/pkg/webui/settings.go b/pkg/webui/settings.go index 8b8eccc..74fdb30 100644 --- a/pkg/webui/settings.go +++ b/pkg/webui/settings.go @@ -1,6 +1,7 @@ package webui import ( + "context" "html/template" "net/http" @@ -41,8 +42,8 @@ func handleSettingsPOST(db *database.DB) http.Handler { errorResponse(w, r, http.StatusInternalServerError, err) return } - page := makePage(r, &Page{Title: "Settings", Section: "settings"}) - page.LightMode = settings.LightMode + ctx := context.WithValue(r.Context(), model.SettingsContextKey{}, &settings) + page := makePage(r.WithContext(ctx), &Page{Title: "Settings", Section: "settings"}) render(w, settingsTemplates, http.StatusOK, SettingsPage{ Page: page, Settings: &settings, diff --git a/pkg/webui/versions.go b/pkg/webui/versions.go index 2ad4a28..014fa87 100644 --- a/pkg/webui/versions.go +++ b/pkg/webui/versions.go @@ -39,7 +39,7 @@ func handleVersionsGET(db *database.DB) http.Handler { errorResponse(w, r, http.StatusInternalServerError, err) return } - account, err := db.LoadAccountById(version.AccountId) + account, err := db.LoadAccountById(&version.AccountId) if err != nil { errorResponse(w, r, http.StatusInternalServerError, err) return