From 895615ad6ed1eb846fdeef4b20148cf3338bf9cd Mon Sep 17 00:00:00 2001
From: Julien Dessaux
Date: Wed, 30 Apr 2025 22:31:25 +0200
Subject: [PATCH] chore(webui): rewrite the web session code again while
preparing for csrf tokens
#60
---
pkg/database/accounts.go | 42 ++++++++++++++-----
pkg/database/sessions.go | 63 ++++++++++++++++------------
pkg/database/sql/000_init.sql | 8 ++--
pkg/model/account.go | 19 ++++-----
pkg/model/session.go | 32 +++++++++++---
pkg/model/settings.go | 2 -
pkg/webui/accounts.go | 10 ++---
pkg/webui/accountsId.go | 3 +-
pkg/webui/accountsIdResetPassword.go | 23 +++++-----
pkg/webui/admin.go | 13 ++----
pkg/webui/html/accounts.html | 2 +-
pkg/webui/html/accountsId.html | 8 ++--
pkg/webui/html/base.html | 2 +-
pkg/webui/html/settings.html | 2 +-
pkg/webui/index.go | 18 ++------
pkg/webui/login.go | 28 ++++---------
pkg/webui/logout.go | 6 ++-
pkg/webui/routes.go | 4 +-
pkg/webui/sessions.go | 12 +-----
pkg/webui/settings.go | 14 +++----
20 files changed, 162 insertions(+), 149 deletions(-)
diff --git a/pkg/database/accounts.go b/pkg/database/accounts.go
index ca10c15..8eb8d16 100644
--- a/pkg/database/accounts.go
+++ b/pkg/database/accounts.go
@@ -29,7 +29,7 @@ func (db *DB) CreateAccount(username string, isAdmin bool) (*model.Account, erro
return nil, fmt.Errorf("failed to generate password reset uuid: %w", err)
}
_, err := db.Exec(`INSERT INTO accounts(id, username, is_Admin, settings, password_reset)
- VALUES (?, ?, ?, ?, ?);`,
+ VALUES (?, ?, ?, jsonb(?), ?);`,
accountId,
username,
isAdmin,
@@ -72,7 +72,7 @@ func (db *DB) InitAdminAccount() error {
hash := helpers.HashPassword(password.String(), salt)
if _, err := tx.ExecContext(db.ctx,
`INSERT INTO accounts(id, username, salt, password_hash, is_admin, settings)
- VALUES (:id, "admin", :salt, :hash, TRUE, :settings)
+ VALUES (:id, "admin", :salt, :hash, TRUE, jsonb(:settings))
ON CONFLICT DO UPDATE SET password_hash = :hash
WHERE username = "admin";`,
sql.Named("id", accountId),
@@ -91,7 +91,8 @@ func (db *DB) InitAdminAccount() error {
func (db *DB) LoadAccounts() ([]model.Account, error) {
rows, err := db.Query(
- `SELECT id, username, salt, password_hash, is_admin, created, last_login, settings, password_reset FROM accounts;`)
+ `SELECT id, username, salt, password_hash, is_admin, created, last_login,
+ json_extract(settings, '$'), password_reset FROM accounts;`)
if err != nil {
return nil, fmt.Errorf("failed to load accounts from database: %w", err)
}
@@ -102,6 +103,7 @@ func (db *DB) LoadAccounts() ([]model.Account, error) {
account model.Account
created int64
lastLogin int64
+ settings []byte
)
err = rows.Scan(
&account.Id,
@@ -111,11 +113,14 @@ func (db *DB) LoadAccounts() ([]model.Account, error) {
&account.IsAdmin,
&created,
&lastLogin,
- &account.Settings,
+ &settings,
&account.PasswordReset)
if err != nil {
return nil, fmt.Errorf("failed to load account from row: %w", err)
}
+ if err := json.Unmarshal(settings, &account.Settings); err != nil {
+ return nil, fmt.Errorf("failed to unmarshal account settings: %w", err)
+ }
account.Created = time.Unix(created, 0)
account.LastLogin = time.Unix(lastLogin, 0)
accounts = append(accounts, account)
@@ -161,9 +166,11 @@ func (db *DB) LoadAccountById(id *uuid.UUID) (*model.Account, error) {
var (
created int64
lastLogin int64
+ settings []byte
)
err := db.QueryRow(
- `SELECT username, salt, password_hash, is_admin, created, last_login, settings, password_reset
+ `SELECT username, salt, password_hash, is_admin, created, last_login,
+ json_extract(settings, '$'), password_reset
FROM accounts
WHERE id = ?;`,
id,
@@ -173,7 +180,7 @@ func (db *DB) LoadAccountById(id *uuid.UUID) (*model.Account, error) {
&account.IsAdmin,
&created,
&lastLogin,
- &account.Settings,
+ &settings,
&account.PasswordReset)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
@@ -181,6 +188,9 @@ func (db *DB) LoadAccountById(id *uuid.UUID) (*model.Account, error) {
}
return nil, fmt.Errorf("failed to load account by id %s: %w", id, err)
}
+ if err := json.Unmarshal(settings, &account.Settings); err != nil {
+ return nil, fmt.Errorf("failed to unmarshal account settings: %w", err)
+ }
account.Created = time.Unix(created, 0)
account.LastLogin = time.Unix(lastLogin, 0)
return &account, nil
@@ -193,9 +203,11 @@ func (db *DB) LoadAccountByUsername(username string) (*model.Account, error) {
var (
created int64
lastLogin int64
+ settings []byte
)
err := db.QueryRow(
- `SELECT id, salt, password_hash, is_admin, created, last_login, settings, password_reset
+ `SELECT id, salt, password_hash, is_admin, created, last_login,
+ json_extract(settings, '$'), password_reset
FROM accounts
WHERE username = ?;`,
username,
@@ -205,7 +217,7 @@ func (db *DB) LoadAccountByUsername(username string) (*model.Account, error) {
&account.IsAdmin,
&created,
&lastLogin,
- &account.Settings,
+ &settings,
&account.PasswordReset)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
@@ -213,6 +225,9 @@ func (db *DB) LoadAccountByUsername(username string) (*model.Account, error) {
}
return nil, fmt.Errorf("failed to load account by username %s: %w", username, err)
}
+ if err := json.Unmarshal(settings, &account.Settings); err != nil {
+ return nil, fmt.Errorf("failed to unmarshal account settings: %w", err)
+ }
account.Created = time.Unix(created, 0)
account.LastLogin = time.Unix(lastLogin, 0)
return &account, nil
@@ -242,13 +257,20 @@ func (db *DB) SaveAccount(account *model.Account) error {
func (db *DB) SaveAccountSettings(account *model.Account, settings *model.Settings) error {
data, err := json.Marshal(settings)
if err != nil {
- return fmt.Errorf("failed to marshal settings for user account %s: %w", account.Username, err)
+ return fmt.Errorf("failed to marshal settings for user accont %s: %w", account.Username, err)
}
_, err = db.Exec(`UPDATE accounts SET settings = ? WHERE id = ?`, data, account.Id)
if err != nil {
return fmt.Errorf("failed to update account settings for user account %s: %w", account.Username, err)
}
- _, err = db.Exec(`UPDATE sessions SET settings = ? WHERE account_id = ?`, data, account.Id)
+ _, err = db.Exec(
+ `UPDATE sessions
+ SET data = jsonb_replace(data,
+ '$.settings', jsonb(:data),
+ '$.account.settings', jsonb(:data))
+ WHERE data->'account'->>'id' = :id`,
+ sql.Named("data", data),
+ sql.Named("id", account.Id))
if err != nil {
return fmt.Errorf("failed to update account settings for user account %s: %w", account.Username, err)
}
diff --git a/pkg/database/sessions.go b/pkg/database/sessions.go
index f579b4e..58a69d1 100644
--- a/pkg/database/sessions.go
+++ b/pkg/database/sessions.go
@@ -3,6 +3,7 @@ package database
import (
"database/sql"
"encoding/base64"
+ "encoding/json"
"errors"
"fmt"
"time"
@@ -10,31 +11,35 @@ import (
"git.adyxax.org/adyxax/tfstated/pkg/helpers"
"git.adyxax.org/adyxax/tfstated/pkg/model"
"git.adyxax.org/adyxax/tfstated/pkg/scrypto"
- "go.n16f.net/uuid"
)
-func (db *DB) CreateSession(account *model.Account, settingsData []byte) (string, error) {
+func (db *DB) CreateSession(sessionData *model.SessionData) (string, *model.Session, error) {
sessionBytes := scrypto.RandomBytes(32)
sessionId := base64.RawURLEncoding.EncodeToString(sessionBytes[:])
sessionHash := helpers.HashSessionId(sessionBytes, db.sessionsSalt.Bytes())
- var accountId *uuid.UUID = nil
- var settings = []byte("{}")
- if account != nil {
- accountId = &account.Id
- settings = account.Settings
- } else if settingsData != nil {
- settings = settingsData
+ if sessionData == nil {
+ var err error
+ sessionData, err = model.NewSessionData(nil, nil)
+ if err != nil {
+ return "", nil, fmt.Errorf("failed to generate new session data: %w", err)
+ }
+ }
+ data, err := json.Marshal(sessionData)
+ if err != nil {
+ return "", nil, fmt.Errorf("failed to marshal session data: %w", err)
}
if _, err := db.Exec(
- `INSERT INTO sessions(id, account_id, settings)
- VALUES (?, ?, ?);`,
+ `INSERT INTO sessions(id, data)
+ VALUES (?, jsonb(?));`,
sessionHash,
- accountId,
- settings,
+ data,
); err != nil {
- return "", fmt.Errorf("failed insert new session in database: %w", err)
+ return "", nil, fmt.Errorf("failed insert new session in database: %w", err)
}
- return sessionId, nil
+ return sessionId, &model.Session{
+ Id: sessionHash,
+ Data: sessionData,
+ }, nil
}
func (db *DB) DeleteExpiredSessions() error {
@@ -66,38 +71,44 @@ func (db *DB) LoadSessionById(id string) (*model.Session, error) {
var (
created int64
updated int64
+ data []byte
)
err = db.QueryRow(
- `SELECT account_id,
- created,
+ `SELECT created,
updated,
- settings
+ json_extract(data, '$')
FROM sessions
WHERE id = ?;`,
sessionHash,
- ).Scan(&session.AccountId,
+ ).Scan(
&created,
&updated,
- &session.Settings,
- )
+ &data)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, nil
}
return nil, fmt.Errorf("failed to load session by id %s: %w", id, err)
}
+ if err := json.Unmarshal(data, &session.Data); err != nil {
+ return nil, fmt.Errorf("failed to unmarshal session data: %w", err)
+ }
session.Created = time.Unix(created, 0)
session.Updated = time.Unix(updated, 0)
return &session, nil
}
-func (db *DB) MigrateSession(session *model.Session, account *model.Account) (string, error) {
+func (db *DB) MigrateSession(session *model.Session, account *model.Account) (string, *model.Session, error) {
if err := db.DeleteSession(session); err != nil {
- return "", fmt.Errorf("failed to delete session: %w", err)
+ return "", nil, fmt.Errorf("failed to delete session: %w", err)
}
- sessionId, err := db.CreateSession(account, session.Settings)
+ sessionData, err := model.NewSessionData(account, session.Data.Settings)
if err != nil {
- return "", fmt.Errorf("failed to create session: %w", err)
+ return "", nil, fmt.Errorf("failed to generate new session data: %w", err)
}
- return sessionId, nil
+ sessionId, session, err := db.CreateSession(sessionData)
+ if err != nil {
+ return "", nil, fmt.Errorf("failed to create session: %w", err)
+ }
+ return sessionId, session, nil
}
diff --git a/pkg/database/sql/000_init.sql b/pkg/database/sql/000_init.sql
index 2e8d530..0441c89 100644
--- a/pkg/database/sql/000_init.sql
+++ b/pkg/database/sql/000_init.sql
@@ -13,15 +13,15 @@ CREATE TABLE accounts (
settings BLOB NOT NULL,
password_reset TEXT
) STRICT;
-CREATE UNIQUE INDEX accounts_username on accounts(username);
+CREATE UNIQUE INDEX accounts_username ON accounts(username);
CREATE TABLE sessions (
id BLOB PRIMARY KEY,
- account_id TEXT,
created INTEGER NOT NULL DEFAULT (unixepoch()),
updated INTEGER NOT NULL DEFAULT (unixepoch()),
- settings BLOB NOT NULL
+ data BLOB NOT NULL
) STRICT;
+CREATE INDEX sessions_data_account_id ON sessions(data->'account'->>'id');
CREATE TABLE states (
id TEXT PRIMARY KEY,
@@ -30,7 +30,7 @@ CREATE TABLE states (
created INTEGER DEFAULT (unixepoch()),
updated INTEGER DEFAULT (unixepoch())
) STRICT;
-CREATE UNIQUE INDEX states_path on states(path);
+CREATE UNIQUE INDEX states_path ON states(path);
CREATE TABLE versions (
id TEXT PRIMARY KEY,
diff --git a/pkg/model/account.go b/pkg/model/account.go
index df7ea8a..02d1276 100644
--- a/pkg/model/account.go
+++ b/pkg/model/account.go
@@ -2,7 +2,6 @@ package model
import (
"crypto/subtle"
- "encoding/json"
"time"
"git.adyxax.org/adyxax/tfstated/pkg/helpers"
@@ -12,15 +11,15 @@ import (
type AccountContextKey struct{}
type Account struct {
- Id uuid.UUID
- Username string
- Salt []byte
- PasswordHash []byte
- IsAdmin bool
- Created time.Time
- LastLogin time.Time
- Settings json.RawMessage
- PasswordReset *uuid.UUID
+ Id uuid.UUID `json:"id"`
+ Username string `json:"username"`
+ Salt []byte `json:"salt"`
+ PasswordHash []byte `json:"password_hash"`
+ IsAdmin bool `json:"is_admin"`
+ Created time.Time `json:"created"`
+ LastLogin time.Time `json:"last_login"`
+ Settings *Settings `json:"settings"`
+ PasswordReset *uuid.UUID `json:"password_reset"`
}
func (account *Account) CheckPassword(password string) bool {
diff --git a/pkg/model/session.go b/pkg/model/session.go
index 6bf6e6b..0199b03 100644
--- a/pkg/model/session.go
+++ b/pkg/model/session.go
@@ -1,20 +1,40 @@
package model
import (
- "encoding/json"
+ "fmt"
"time"
"go.n16f.net/uuid"
)
+type SessionData struct {
+ Account *Account `json:"account"`
+ CsrfToken uuid.UUID `json:"csrf_token"`
+ Settings *Settings `json:"settings"`
+}
+
+func NewSessionData(account *Account, previousSessionSettings *Settings) (*SessionData, error) {
+ data := SessionData{Account: account}
+ if err := data.CsrfToken.Generate(uuid.V4); err != nil {
+ return nil, fmt.Errorf("failed to generate csrf token uuid: %w", err)
+ }
+ if account != nil {
+ data.Settings = account.Settings
+ } else if previousSessionSettings != nil {
+ data.Settings = previousSessionSettings
+ } else {
+ data.Settings = &Settings{}
+ }
+ return &data, nil
+}
+
type SessionContextKey struct{}
type Session struct {
- Id []byte
- AccountId *uuid.UUID
- Created time.Time
- Updated time.Time
- Settings json.RawMessage
+ Id []byte
+ Created time.Time
+ Updated time.Time
+ Data *SessionData
}
func (session *Session) IsExpired() bool {
diff --git a/pkg/model/settings.go b/pkg/model/settings.go
index 9da655b..bebbced 100644
--- a/pkg/model/settings.go
+++ b/pkg/model/settings.go
@@ -1,7 +1,5 @@
package model
-type SettingsContextKey struct{}
-
type Settings struct {
LightMode bool `json:"light_mode"`
}
diff --git a/pkg/webui/accounts.go b/pkg/webui/accounts.go
index 01ed609..dfc4439 100644
--- a/pkg/webui/accounts.go
+++ b/pkg/webui/accounts.go
@@ -11,7 +11,6 @@ import (
type AccountsPage struct {
Accounts []model.Account
- ActiveTab int
IsAdmin string
Page *Page
Username string
@@ -45,11 +44,10 @@ func handleAccountsPOST(db *database.DB) http.Handler {
accountUsername := r.FormValue("username")
isAdmin := r.FormValue("isAdmin")
page := AccountsPage{
- ActiveTab: 1,
- Page: makePage(r, &Page{Title: "New Account", Section: "accounts"}),
- Accounts: accounts,
- IsAdmin: isAdmin,
- Username: accountUsername,
+ Page: makePage(r, &Page{Title: "New Account", Section: "accounts"}),
+ Accounts: accounts,
+ IsAdmin: isAdmin,
+ Username: accountUsername,
}
if ok := validUsername.MatchString(accountUsername); !ok {
page.UsernameInvalid = true
diff --git a/pkg/webui/accountsId.go b/pkg/webui/accountsId.go
index a3e6064..90c2774 100644
--- a/pkg/webui/accountsId.go
+++ b/pkg/webui/accountsId.go
@@ -1,6 +1,7 @@
package webui
import (
+ "fmt"
"html/template"
"net/http"
@@ -35,7 +36,7 @@ func handleAccountsIdGET(db *database.DB) http.Handler {
return
}
if account == nil {
- errorResponse(w, r, http.StatusNotFound, err)
+ errorResponse(w, r, http.StatusNotFound, fmt.Errorf("The account Id could not be found."))
return
}
statePaths, err := db.LoadStatePaths()
diff --git a/pkg/webui/accountsIdResetPassword.go b/pkg/webui/accountsIdResetPassword.go
index 4cae081..a39d287 100644
--- a/pkg/webui/accountsIdResetPassword.go
+++ b/pkg/webui/accountsIdResetPassword.go
@@ -19,37 +19,36 @@ type AccountsIdResetPasswordPage struct {
var accountsIdResetPasswordTemplates = template.Must(template.ParseFS(htmlFS, "html/base.html", "html/accountsIdResetPassword.html"))
-func processAccountsIdResetPasswordPathValues(db *database.DB, w http.ResponseWriter, r *http.Request) (*model.Account, bool) {
+func processAccountsIdResetPasswordPathValues(db *database.DB, w http.ResponseWriter, r *http.Request) *model.Account {
var accountId uuid.UUID
if err := accountId.Parse(r.PathValue("id")); err != nil {
- errorResponse(w, r, http.StatusBadRequest, err)
- return nil, false
+ return nil
}
var token uuid.UUID
if err := token.Parse(r.PathValue("token")); err != nil {
errorResponse(w, r, http.StatusBadRequest, err)
- return nil, false
+ return nil
}
account, err := db.LoadAccountById(&accountId)
if err != nil {
errorResponse(w, r, http.StatusInternalServerError, err)
- return nil, false
+ return nil
}
if account == nil || account.PasswordReset == nil {
errorResponse(w, r, http.StatusBadRequest, err)
- return nil, false
+ return nil
}
if !account.PasswordReset.Equal(token) {
errorResponse(w, r, http.StatusBadRequest, err)
- return nil, false
+ return nil
}
- return account, true
+ return account
}
func handleAccountsIdResetPasswordGET(db *database.DB) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- account, valid := processAccountsIdResetPasswordPathValues(db, w, r)
- if !valid {
+ account := processAccountsIdResetPasswordPathValues(db, w, r)
+ if account == nil {
return
}
render(w, accountsIdResetPasswordTemplates, http.StatusOK,
@@ -63,8 +62,8 @@ func handleAccountsIdResetPasswordGET(db *database.DB) http.Handler {
func handleAccountsIdResetPasswordPOST(db *database.DB) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- account, valid := processAccountsIdResetPasswordPathValues(db, w, r)
- if !valid {
+ account := processAccountsIdResetPasswordPathValues(db, w, r)
+ if account == nil {
return
}
password := r.FormValue("password")
diff --git a/pkg/webui/admin.go b/pkg/webui/admin.go
index 1906b3e..bcda0b5 100644
--- a/pkg/webui/admin.go
+++ b/pkg/webui/admin.go
@@ -4,21 +4,14 @@ import (
"fmt"
"net/http"
- "git.adyxax.org/adyxax/tfstated/pkg/database"
"git.adyxax.org/adyxax/tfstated/pkg/model"
)
-func adminMiddleware(db *database.DB, requireLogin func(http.Handler) http.Handler) func(http.Handler) http.Handler {
+func adminMiddleware(requireLogin func(http.Handler) http.Handler) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return requireLogin(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- account := r.Context().Value(model.AccountContextKey{})
- if account == nil {
- // this could happen if the account was deleted in the short
- // time between retrieving the session and here
- http.Redirect(w, r, "/login", http.StatusFound)
- return
- }
- if !account.(*model.Account).IsAdmin {
+ session := r.Context().Value(model.SessionContextKey{}).(*model.Session)
+ if !session.Data.Account.IsAdmin {
errorResponse(w, r, http.StatusForbidden, fmt.Errorf("Only administrators can perform this request."))
return
}
diff --git a/pkg/webui/html/accounts.html b/pkg/webui/html/accounts.html
index da13bca..ff3670a 100644
--- a/pkg/webui/html/accounts.html
+++ b/pkg/webui/html/accounts.html
@@ -7,7 +7,7 @@
Use this page to inspect user accounts.
- {{ if .Page.IsAdmin }}
+ {{ if .Page.Session.Data.Account.IsAdmin }}