From 895615ad6ed1eb846fdeef4b20148cf3338bf9cd Mon Sep 17 00:00:00 2001 From: Julien Dessaux Date: Wed, 30 Apr 2025 22:31:25 +0200 Subject: [PATCH] chore(webui): rewrite the web session code again while preparing for csrf tokens #60 --- pkg/database/accounts.go | 42 ++++++++++++++----- pkg/database/sessions.go | 63 ++++++++++++++++------------ pkg/database/sql/000_init.sql | 8 ++-- pkg/model/account.go | 19 ++++----- pkg/model/session.go | 32 +++++++++++--- pkg/model/settings.go | 2 - pkg/webui/accounts.go | 10 ++--- pkg/webui/accountsId.go | 3 +- pkg/webui/accountsIdResetPassword.go | 23 +++++----- pkg/webui/admin.go | 13 ++---- pkg/webui/html/accounts.html | 2 +- pkg/webui/html/accountsId.html | 8 ++-- pkg/webui/html/base.html | 2 +- pkg/webui/html/settings.html | 2 +- pkg/webui/index.go | 18 ++------ pkg/webui/login.go | 28 ++++--------- pkg/webui/logout.go | 6 ++- pkg/webui/routes.go | 4 +- pkg/webui/sessions.go | 12 +----- pkg/webui/settings.go | 14 +++---- 20 files changed, 162 insertions(+), 149 deletions(-) diff --git a/pkg/database/accounts.go b/pkg/database/accounts.go index ca10c15..8eb8d16 100644 --- a/pkg/database/accounts.go +++ b/pkg/database/accounts.go @@ -29,7 +29,7 @@ func (db *DB) CreateAccount(username string, isAdmin bool) (*model.Account, erro return nil, fmt.Errorf("failed to generate password reset uuid: %w", err) } _, err := db.Exec(`INSERT INTO accounts(id, username, is_Admin, settings, password_reset) - VALUES (?, ?, ?, ?, ?);`, + VALUES (?, ?, ?, jsonb(?), ?);`, accountId, username, isAdmin, @@ -72,7 +72,7 @@ func (db *DB) InitAdminAccount() error { hash := helpers.HashPassword(password.String(), salt) if _, err := tx.ExecContext(db.ctx, `INSERT INTO accounts(id, username, salt, password_hash, is_admin, settings) - VALUES (:id, "admin", :salt, :hash, TRUE, :settings) + VALUES (:id, "admin", :salt, :hash, TRUE, jsonb(:settings)) ON CONFLICT DO UPDATE SET password_hash = :hash WHERE username = "admin";`, sql.Named("id", accountId), @@ -91,7 +91,8 @@ func (db *DB) InitAdminAccount() error { func (db *DB) LoadAccounts() ([]model.Account, error) { rows, err := db.Query( - `SELECT id, username, salt, password_hash, is_admin, created, last_login, settings, password_reset FROM accounts;`) + `SELECT id, username, salt, password_hash, is_admin, created, last_login, + json_extract(settings, '$'), password_reset FROM accounts;`) if err != nil { return nil, fmt.Errorf("failed to load accounts from database: %w", err) } @@ -102,6 +103,7 @@ func (db *DB) LoadAccounts() ([]model.Account, error) { account model.Account created int64 lastLogin int64 + settings []byte ) err = rows.Scan( &account.Id, @@ -111,11 +113,14 @@ func (db *DB) LoadAccounts() ([]model.Account, error) { &account.IsAdmin, &created, &lastLogin, - &account.Settings, + &settings, &account.PasswordReset) if err != nil { return nil, fmt.Errorf("failed to load account from row: %w", err) } + if err := json.Unmarshal(settings, &account.Settings); err != nil { + return nil, fmt.Errorf("failed to unmarshal account settings: %w", err) + } account.Created = time.Unix(created, 0) account.LastLogin = time.Unix(lastLogin, 0) accounts = append(accounts, account) @@ -161,9 +166,11 @@ func (db *DB) LoadAccountById(id *uuid.UUID) (*model.Account, error) { var ( created int64 lastLogin int64 + settings []byte ) err := db.QueryRow( - `SELECT username, salt, password_hash, is_admin, created, last_login, settings, password_reset + `SELECT username, salt, password_hash, is_admin, created, last_login, + json_extract(settings, '$'), password_reset FROM accounts WHERE id = ?;`, id, @@ -173,7 +180,7 @@ func (db *DB) LoadAccountById(id *uuid.UUID) (*model.Account, error) { &account.IsAdmin, &created, &lastLogin, - &account.Settings, + &settings, &account.PasswordReset) if err != nil { if errors.Is(err, sql.ErrNoRows) { @@ -181,6 +188,9 @@ func (db *DB) LoadAccountById(id *uuid.UUID) (*model.Account, error) { } return nil, fmt.Errorf("failed to load account by id %s: %w", id, err) } + if err := json.Unmarshal(settings, &account.Settings); err != nil { + return nil, fmt.Errorf("failed to unmarshal account settings: %w", err) + } account.Created = time.Unix(created, 0) account.LastLogin = time.Unix(lastLogin, 0) return &account, nil @@ -193,9 +203,11 @@ func (db *DB) LoadAccountByUsername(username string) (*model.Account, error) { var ( created int64 lastLogin int64 + settings []byte ) err := db.QueryRow( - `SELECT id, salt, password_hash, is_admin, created, last_login, settings, password_reset + `SELECT id, salt, password_hash, is_admin, created, last_login, + json_extract(settings, '$'), password_reset FROM accounts WHERE username = ?;`, username, @@ -205,7 +217,7 @@ func (db *DB) LoadAccountByUsername(username string) (*model.Account, error) { &account.IsAdmin, &created, &lastLogin, - &account.Settings, + &settings, &account.PasswordReset) if err != nil { if errors.Is(err, sql.ErrNoRows) { @@ -213,6 +225,9 @@ func (db *DB) LoadAccountByUsername(username string) (*model.Account, error) { } return nil, fmt.Errorf("failed to load account by username %s: %w", username, err) } + if err := json.Unmarshal(settings, &account.Settings); err != nil { + return nil, fmt.Errorf("failed to unmarshal account settings: %w", err) + } account.Created = time.Unix(created, 0) account.LastLogin = time.Unix(lastLogin, 0) return &account, nil @@ -242,13 +257,20 @@ func (db *DB) SaveAccount(account *model.Account) error { func (db *DB) SaveAccountSettings(account *model.Account, settings *model.Settings) error { data, err := json.Marshal(settings) if err != nil { - return fmt.Errorf("failed to marshal settings for user account %s: %w", account.Username, err) + return fmt.Errorf("failed to marshal settings for user accont %s: %w", account.Username, err) } _, err = db.Exec(`UPDATE accounts SET settings = ? WHERE id = ?`, data, account.Id) if err != nil { return fmt.Errorf("failed to update account settings for user account %s: %w", account.Username, err) } - _, err = db.Exec(`UPDATE sessions SET settings = ? WHERE account_id = ?`, data, account.Id) + _, err = db.Exec( + `UPDATE sessions + SET data = jsonb_replace(data, + '$.settings', jsonb(:data), + '$.account.settings', jsonb(:data)) + WHERE data->'account'->>'id' = :id`, + sql.Named("data", data), + sql.Named("id", account.Id)) if err != nil { return fmt.Errorf("failed to update account settings for user account %s: %w", account.Username, err) } diff --git a/pkg/database/sessions.go b/pkg/database/sessions.go index f579b4e..58a69d1 100644 --- a/pkg/database/sessions.go +++ b/pkg/database/sessions.go @@ -3,6 +3,7 @@ package database import ( "database/sql" "encoding/base64" + "encoding/json" "errors" "fmt" "time" @@ -10,31 +11,35 @@ import ( "git.adyxax.org/adyxax/tfstated/pkg/helpers" "git.adyxax.org/adyxax/tfstated/pkg/model" "git.adyxax.org/adyxax/tfstated/pkg/scrypto" - "go.n16f.net/uuid" ) -func (db *DB) CreateSession(account *model.Account, settingsData []byte) (string, error) { +func (db *DB) CreateSession(sessionData *model.SessionData) (string, *model.Session, error) { sessionBytes := scrypto.RandomBytes(32) sessionId := base64.RawURLEncoding.EncodeToString(sessionBytes[:]) sessionHash := helpers.HashSessionId(sessionBytes, db.sessionsSalt.Bytes()) - var accountId *uuid.UUID = nil - var settings = []byte("{}") - if account != nil { - accountId = &account.Id - settings = account.Settings - } else if settingsData != nil { - settings = settingsData + if sessionData == nil { + var err error + sessionData, err = model.NewSessionData(nil, nil) + if err != nil { + return "", nil, fmt.Errorf("failed to generate new session data: %w", err) + } + } + data, err := json.Marshal(sessionData) + if err != nil { + return "", nil, fmt.Errorf("failed to marshal session data: %w", err) } if _, err := db.Exec( - `INSERT INTO sessions(id, account_id, settings) - VALUES (?, ?, ?);`, + `INSERT INTO sessions(id, data) + VALUES (?, jsonb(?));`, sessionHash, - accountId, - settings, + data, ); err != nil { - return "", fmt.Errorf("failed insert new session in database: %w", err) + return "", nil, fmt.Errorf("failed insert new session in database: %w", err) } - return sessionId, nil + return sessionId, &model.Session{ + Id: sessionHash, + Data: sessionData, + }, nil } func (db *DB) DeleteExpiredSessions() error { @@ -66,38 +71,44 @@ func (db *DB) LoadSessionById(id string) (*model.Session, error) { var ( created int64 updated int64 + data []byte ) err = db.QueryRow( - `SELECT account_id, - created, + `SELECT created, updated, - settings + json_extract(data, '$') FROM sessions WHERE id = ?;`, sessionHash, - ).Scan(&session.AccountId, + ).Scan( &created, &updated, - &session.Settings, - ) + &data) if err != nil { if errors.Is(err, sql.ErrNoRows) { return nil, nil } return nil, fmt.Errorf("failed to load session by id %s: %w", id, err) } + if err := json.Unmarshal(data, &session.Data); err != nil { + return nil, fmt.Errorf("failed to unmarshal session data: %w", err) + } session.Created = time.Unix(created, 0) session.Updated = time.Unix(updated, 0) return &session, nil } -func (db *DB) MigrateSession(session *model.Session, account *model.Account) (string, error) { +func (db *DB) MigrateSession(session *model.Session, account *model.Account) (string, *model.Session, error) { if err := db.DeleteSession(session); err != nil { - return "", fmt.Errorf("failed to delete session: %w", err) + return "", nil, fmt.Errorf("failed to delete session: %w", err) } - sessionId, err := db.CreateSession(account, session.Settings) + sessionData, err := model.NewSessionData(account, session.Data.Settings) if err != nil { - return "", fmt.Errorf("failed to create session: %w", err) + return "", nil, fmt.Errorf("failed to generate new session data: %w", err) } - return sessionId, nil + sessionId, session, err := db.CreateSession(sessionData) + if err != nil { + return "", nil, fmt.Errorf("failed to create session: %w", err) + } + return sessionId, session, nil } diff --git a/pkg/database/sql/000_init.sql b/pkg/database/sql/000_init.sql index 2e8d530..0441c89 100644 --- a/pkg/database/sql/000_init.sql +++ b/pkg/database/sql/000_init.sql @@ -13,15 +13,15 @@ CREATE TABLE accounts ( settings BLOB NOT NULL, password_reset TEXT ) STRICT; -CREATE UNIQUE INDEX accounts_username on accounts(username); +CREATE UNIQUE INDEX accounts_username ON accounts(username); CREATE TABLE sessions ( id BLOB PRIMARY KEY, - account_id TEXT, created INTEGER NOT NULL DEFAULT (unixepoch()), updated INTEGER NOT NULL DEFAULT (unixepoch()), - settings BLOB NOT NULL + data BLOB NOT NULL ) STRICT; +CREATE INDEX sessions_data_account_id ON sessions(data->'account'->>'id'); CREATE TABLE states ( id TEXT PRIMARY KEY, @@ -30,7 +30,7 @@ CREATE TABLE states ( created INTEGER DEFAULT (unixepoch()), updated INTEGER DEFAULT (unixepoch()) ) STRICT; -CREATE UNIQUE INDEX states_path on states(path); +CREATE UNIQUE INDEX states_path ON states(path); CREATE TABLE versions ( id TEXT PRIMARY KEY, diff --git a/pkg/model/account.go b/pkg/model/account.go index df7ea8a..02d1276 100644 --- a/pkg/model/account.go +++ b/pkg/model/account.go @@ -2,7 +2,6 @@ package model import ( "crypto/subtle" - "encoding/json" "time" "git.adyxax.org/adyxax/tfstated/pkg/helpers" @@ -12,15 +11,15 @@ import ( type AccountContextKey struct{} type Account struct { - Id uuid.UUID - Username string - Salt []byte - PasswordHash []byte - IsAdmin bool - Created time.Time - LastLogin time.Time - Settings json.RawMessage - PasswordReset *uuid.UUID + Id uuid.UUID `json:"id"` + Username string `json:"username"` + Salt []byte `json:"salt"` + PasswordHash []byte `json:"password_hash"` + IsAdmin bool `json:"is_admin"` + Created time.Time `json:"created"` + LastLogin time.Time `json:"last_login"` + Settings *Settings `json:"settings"` + PasswordReset *uuid.UUID `json:"password_reset"` } func (account *Account) CheckPassword(password string) bool { diff --git a/pkg/model/session.go b/pkg/model/session.go index 6bf6e6b..0199b03 100644 --- a/pkg/model/session.go +++ b/pkg/model/session.go @@ -1,20 +1,40 @@ package model import ( - "encoding/json" + "fmt" "time" "go.n16f.net/uuid" ) +type SessionData struct { + Account *Account `json:"account"` + CsrfToken uuid.UUID `json:"csrf_token"` + Settings *Settings `json:"settings"` +} + +func NewSessionData(account *Account, previousSessionSettings *Settings) (*SessionData, error) { + data := SessionData{Account: account} + if err := data.CsrfToken.Generate(uuid.V4); err != nil { + return nil, fmt.Errorf("failed to generate csrf token uuid: %w", err) + } + if account != nil { + data.Settings = account.Settings + } else if previousSessionSettings != nil { + data.Settings = previousSessionSettings + } else { + data.Settings = &Settings{} + } + return &data, nil +} + type SessionContextKey struct{} type Session struct { - Id []byte - AccountId *uuid.UUID - Created time.Time - Updated time.Time - Settings json.RawMessage + Id []byte + Created time.Time + Updated time.Time + Data *SessionData } func (session *Session) IsExpired() bool { diff --git a/pkg/model/settings.go b/pkg/model/settings.go index 9da655b..bebbced 100644 --- a/pkg/model/settings.go +++ b/pkg/model/settings.go @@ -1,7 +1,5 @@ package model -type SettingsContextKey struct{} - type Settings struct { LightMode bool `json:"light_mode"` } diff --git a/pkg/webui/accounts.go b/pkg/webui/accounts.go index 01ed609..dfc4439 100644 --- a/pkg/webui/accounts.go +++ b/pkg/webui/accounts.go @@ -11,7 +11,6 @@ import ( type AccountsPage struct { Accounts []model.Account - ActiveTab int IsAdmin string Page *Page Username string @@ -45,11 +44,10 @@ func handleAccountsPOST(db *database.DB) http.Handler { accountUsername := r.FormValue("username") isAdmin := r.FormValue("isAdmin") page := AccountsPage{ - ActiveTab: 1, - Page: makePage(r, &Page{Title: "New Account", Section: "accounts"}), - Accounts: accounts, - IsAdmin: isAdmin, - Username: accountUsername, + Page: makePage(r, &Page{Title: "New Account", Section: "accounts"}), + Accounts: accounts, + IsAdmin: isAdmin, + Username: accountUsername, } if ok := validUsername.MatchString(accountUsername); !ok { page.UsernameInvalid = true diff --git a/pkg/webui/accountsId.go b/pkg/webui/accountsId.go index a3e6064..90c2774 100644 --- a/pkg/webui/accountsId.go +++ b/pkg/webui/accountsId.go @@ -1,6 +1,7 @@ package webui import ( + "fmt" "html/template" "net/http" @@ -35,7 +36,7 @@ func handleAccountsIdGET(db *database.DB) http.Handler { return } if account == nil { - errorResponse(w, r, http.StatusNotFound, err) + errorResponse(w, r, http.StatusNotFound, fmt.Errorf("The account Id could not be found.")) return } statePaths, err := db.LoadStatePaths() diff --git a/pkg/webui/accountsIdResetPassword.go b/pkg/webui/accountsIdResetPassword.go index 4cae081..a39d287 100644 --- a/pkg/webui/accountsIdResetPassword.go +++ b/pkg/webui/accountsIdResetPassword.go @@ -19,37 +19,36 @@ type AccountsIdResetPasswordPage struct { var accountsIdResetPasswordTemplates = template.Must(template.ParseFS(htmlFS, "html/base.html", "html/accountsIdResetPassword.html")) -func processAccountsIdResetPasswordPathValues(db *database.DB, w http.ResponseWriter, r *http.Request) (*model.Account, bool) { +func processAccountsIdResetPasswordPathValues(db *database.DB, w http.ResponseWriter, r *http.Request) *model.Account { var accountId uuid.UUID if err := accountId.Parse(r.PathValue("id")); err != nil { - errorResponse(w, r, http.StatusBadRequest, err) - return nil, false + return nil } var token uuid.UUID if err := token.Parse(r.PathValue("token")); err != nil { errorResponse(w, r, http.StatusBadRequest, err) - return nil, false + return nil } account, err := db.LoadAccountById(&accountId) if err != nil { errorResponse(w, r, http.StatusInternalServerError, err) - return nil, false + return nil } if account == nil || account.PasswordReset == nil { errorResponse(w, r, http.StatusBadRequest, err) - return nil, false + return nil } if !account.PasswordReset.Equal(token) { errorResponse(w, r, http.StatusBadRequest, err) - return nil, false + return nil } - return account, true + return account } func handleAccountsIdResetPasswordGET(db *database.DB) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - account, valid := processAccountsIdResetPasswordPathValues(db, w, r) - if !valid { + account := processAccountsIdResetPasswordPathValues(db, w, r) + if account == nil { return } render(w, accountsIdResetPasswordTemplates, http.StatusOK, @@ -63,8 +62,8 @@ func handleAccountsIdResetPasswordGET(db *database.DB) http.Handler { func handleAccountsIdResetPasswordPOST(db *database.DB) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - account, valid := processAccountsIdResetPasswordPathValues(db, w, r) - if !valid { + account := processAccountsIdResetPasswordPathValues(db, w, r) + if account == nil { return } password := r.FormValue("password") diff --git a/pkg/webui/admin.go b/pkg/webui/admin.go index 1906b3e..bcda0b5 100644 --- a/pkg/webui/admin.go +++ b/pkg/webui/admin.go @@ -4,21 +4,14 @@ import ( "fmt" "net/http" - "git.adyxax.org/adyxax/tfstated/pkg/database" "git.adyxax.org/adyxax/tfstated/pkg/model" ) -func adminMiddleware(db *database.DB, requireLogin func(http.Handler) http.Handler) func(http.Handler) http.Handler { +func adminMiddleware(requireLogin func(http.Handler) http.Handler) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return requireLogin(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - account := r.Context().Value(model.AccountContextKey{}) - if account == nil { - // this could happen if the account was deleted in the short - // time between retrieving the session and here - http.Redirect(w, r, "/login", http.StatusFound) - return - } - if !account.(*model.Account).IsAdmin { + session := r.Context().Value(model.SessionContextKey{}).(*model.Session) + if !session.Data.Account.IsAdmin { errorResponse(w, r, http.StatusForbidden, fmt.Errorf("Only administrators can perform this request.")) return } diff --git a/pkg/webui/html/accounts.html b/pkg/webui/html/accounts.html index da13bca..ff3670a 100644 --- a/pkg/webui/html/accounts.html +++ b/pkg/webui/html/accounts.html @@ -7,7 +7,7 @@ Use this page to inspect user accounts.

- {{ if .Page.IsAdmin }} + {{ if .Page.Session.Data.Account.IsAdmin }}
New User Account diff --git a/pkg/webui/html/accountsId.html b/pkg/webui/html/accountsId.html index 52c4cef..1ae7902 100644 --- a/pkg/webui/html/accountsId.html +++ b/pkg/webui/html/accountsId.html @@ -24,7 +24,7 @@ {{ if .Account.IsAdmin }}

This accounts has admin privileges on TfStated.

{{ end }} -{{ if .Page.IsAdmin }} +{{ if .Page.Session.Data.Account.IsAdmin }}

Operations

@@ -39,7 +39,7 @@ value="{{ .Username }}">
Danger Zone - -