chore(webui): rewrite all the web session code
All checks were successful
main / main (push) Successful in 3m13s
main / deploy (push) Has been skipped
main / publish (push) Has been skipped

#60
This commit is contained in:
Julien Dessaux 2025-04-29 01:25:11 +02:00
parent 929657fd34
commit 3bb5e735c6
Signed by: adyxax
GPG key ID: F92E51B86E07177E
19 changed files with 173 additions and 123 deletions

2
.gitignore vendored
View file

@ -2,3 +2,5 @@ cover.out
tfstated.db tfstated.db
tfstated.db-shm tfstated.db-shm
tfstated.db-wal tfstated.db-wal
tfstated_data_encryption_key
tfstated_sessions_salt

View file

@ -26,13 +26,15 @@ var adminPasswordMutex sync.Mutex
func TestMain(m *testing.M) { func TestMain(m *testing.M) {
getenv := func(key string) string { getenv := func(key string) string {
switch key { switch key {
case "DATA_ENCRYPTION_KEY": case "TFSTATED_DATA_ENCRYPTION_KEY":
return "hP3ZSCnY3LMgfTQjwTaGrhKwdA0yXMXIfv67OJnntqM=" return "hP3ZSCnY3LMgfTQjwTaGrhKwdA0yXMXIfv67OJnntqM="
case "TFSTATED_HOST": case "TFSTATED_HOST":
return "127.0.0.1" return "127.0.0.1"
case "TFSTATED_PORT": case "TFSTATED_PORT":
return "8082" return "8082"
case "VERSIONS_HISTORY_LIMIT": case "TFSTATED_SESSIONS_SALT":
return "a528D1m9q3IZxLinSmHmeKxrx3Pmm7GQ3nBzIDxjr0A="
case "TFSTATED_VERSIONS_HISTORY_LIMIT":
return "3" return "3"
default: default:
return "" return ""

View file

@ -151,9 +151,12 @@ func (db *DB) LoadAccountUsernames() (map[string]string, error) {
return accounts, nil return accounts, nil
} }
func (db *DB) LoadAccountById(id uuid.UUID) (*model.Account, error) { func (db *DB) LoadAccountById(id *uuid.UUID) (*model.Account, error) {
if id == nil {
return nil, nil
}
account := model.Account{ account := model.Account{
Id: id, Id: *id,
} }
var ( var (
created int64 created int64
@ -239,11 +242,15 @@ func (db *DB) SaveAccount(account *model.Account) error {
func (db *DB) SaveAccountSettings(account *model.Account, settings *model.Settings) error { func (db *DB) SaveAccountSettings(account *model.Account, settings *model.Settings) error {
data, err := json.Marshal(settings) data, err := json.Marshal(settings)
if err != nil { if err != nil {
return fmt.Errorf("failed to marshal settings for user %s: %w", account.Username, err) return fmt.Errorf("failed to marshal settings for user account %s: %w", account.Username, err)
} }
_, err = db.Exec(`UPDATE accounts SET settings = ? WHERE id = ?`, data, account.Id) _, err = db.Exec(`UPDATE accounts SET settings = ? WHERE id = ?`, data, account.Id)
if err != nil { if err != nil {
return fmt.Errorf("failed to update settings for user %s: %w", account.Username, err) return fmt.Errorf("failed to update account settings for user account %s: %w", account.Username, err)
}
_, err = db.Exec(`UPDATE sessions SET settings = ? WHERE account_id = ?`, data, account.Id)
if err != nil {
return fmt.Errorf("failed to update account settings for user account %s: %w", account.Username, err)
} }
return nil return nil
} }

View file

@ -31,6 +31,7 @@ type DB struct {
ctx context.Context ctx context.Context
dataEncryptionKey scrypto.AES256Key dataEncryptionKey scrypto.AES256Key
readDB *sql.DB readDB *sql.DB
sessionsSalt scrypto.AES256Key
versionsHistoryLimit int versionsHistoryLimit int
writeDB *sql.DB writeDB *sql.DB
} }
@ -82,17 +83,24 @@ func NewDB(ctx context.Context, url string, getenv func(string) string) (*DB, er
return nil, fmt.Errorf("failed to migrate: %w", err) return nil, fmt.Errorf("failed to migrate: %w", err)
} }
dataEncryptionKey := getenv("DATA_ENCRYPTION_KEY") dataEncryptionKey := getenv("TFSTATED_DATA_ENCRYPTION_KEY")
if dataEncryptionKey == "" { if dataEncryptionKey == "" {
return nil, fmt.Errorf("the DATA_ENCRYPTION_KEY environment variable is not set") return nil, fmt.Errorf("the TFSTATED_DATA_ENCRYPTION_KEY environment variable is not set")
} }
if err := db.dataEncryptionKey.FromBase64(dataEncryptionKey); err != nil { if err := db.dataEncryptionKey.FromBase64(dataEncryptionKey); err != nil {
return nil, fmt.Errorf("failed to decode the DATA_ENCRYPTION_KEY environment variable, expected base64: %w", err) return nil, fmt.Errorf("failed to decode the TFSTATED_DATA_ENCRYPTION_KEY environment variable, expected 32 bytes base64 encoded: %w", err)
} }
versionsHistoryLimit := getenv("VERSIONS_HISTORY_LIMIT") sessionsSalt := getenv("TFSTATED_SESSIONS_SALT")
if sessionsSalt == "" {
return nil, fmt.Errorf("the TFSTATED_SESSIONS_SALT environment variable is not set")
}
if err := db.sessionsSalt.FromBase64(dataEncryptionKey); err != nil {
return nil, fmt.Errorf("failed to decode the TFSTATED_SESSIONS_SALT environment variable, expected 32 bytes base64 encoded: %w", err)
}
versionsHistoryLimit := getenv("TFSTATED_VERSIONS_HISTORY_LIMIT")
if versionsHistoryLimit != "" { if versionsHistoryLimit != "" {
if db.versionsHistoryLimit, err = strconv.Atoi(versionsHistoryLimit); err != nil { if db.versionsHistoryLimit, err = strconv.Atoi(versionsHistoryLimit); err != nil {
return nil, fmt.Errorf("failed to parse the VERSIONS_HISTORY_LIMIT environment variable, expected an integer: %w", err) return nil, fmt.Errorf("failed to parse the TFSTATED_VERSIONS_HISTORY_LIMIT environment variable, expected an integer: %w", err)
} }
} }

View file

@ -7,19 +7,30 @@ import (
"fmt" "fmt"
"time" "time"
"git.adyxax.org/adyxax/tfstated/pkg/helpers"
"git.adyxax.org/adyxax/tfstated/pkg/model" "git.adyxax.org/adyxax/tfstated/pkg/model"
"git.adyxax.org/adyxax/tfstated/pkg/scrypto" "git.adyxax.org/adyxax/tfstated/pkg/scrypto"
"go.n16f.net/uuid"
) )
func (db *DB) CreateSession(account *model.Account) (string, error) { func (db *DB) CreateSession(account *model.Account, settingsData []byte) (string, error) {
sessionBytes := scrypto.RandomBytes(32) sessionBytes := scrypto.RandomBytes(32)
sessionId := base64.RawURLEncoding.EncodeToString(sessionBytes[:]) sessionId := base64.RawURLEncoding.EncodeToString(sessionBytes[:])
sessionHash := helpers.HashSessionId(sessionBytes, db.sessionsSalt.Bytes())
var accountId *uuid.UUID = nil
var settings = []byte("{}")
if account != nil {
accountId = &account.Id
settings = account.Settings
} else if settingsData != nil {
settings = settingsData
}
if _, err := db.Exec( if _, err := db.Exec(
`INSERT INTO sessions(id, account_id, data) `INSERT INTO sessions(id, account_id, settings)
VALUES (?, ?, ?);`, VALUES (?, ?, ?);`,
sessionId, sessionHash,
account.Id, accountId,
"", settings,
); err != nil { ); err != nil {
return "", fmt.Errorf("failed insert new session in database: %w", err) return "", fmt.Errorf("failed insert new session in database: %w", err)
} }
@ -27,7 +38,8 @@ func (db *DB) CreateSession(account *model.Account) (string, error) {
} }
func (db *DB) DeleteExpiredSessions() error { func (db *DB) DeleteExpiredSessions() error {
_, err := db.Exec(`DELETE FROM sessions WHERE created < ?`, time.Now().Unix()) expires := time.Now().Add(-12 * time.Hour)
_, err := db.Exec(`DELETE FROM sessions WHERE created < ?`, expires.Unix())
if err != nil { if err != nil {
return fmt.Errorf("failed to delete expired session: %w", err) return fmt.Errorf("failed to delete expired session: %w", err)
} }
@ -43,25 +55,30 @@ func (db *DB) DeleteSession(session *model.Session) error {
} }
func (db *DB) LoadSessionById(id string) (*model.Session, error) { func (db *DB) LoadSessionById(id string) (*model.Session, error) {
sessionBytes, err := base64.RawURLEncoding.DecodeString(id)
if err != nil {
return nil, fmt.Errorf("failed to decode base64 session id: %w", err)
}
sessionHash := helpers.HashSessionId(sessionBytes, db.sessionsSalt.Bytes())
session := model.Session{ session := model.Session{
Id: id, Id: sessionHash,
} }
var ( var (
created int64 created int64
updated int64 updated int64
) )
err := db.QueryRow( err = db.QueryRow(
`SELECT account_id, `SELECT account_id,
created, created,
updated, updated,
data settings
FROM sessions FROM sessions
WHERE id = ?;`, WHERE id = ?;`,
id, sessionHash,
).Scan(&session.AccountId, ).Scan(&session.AccountId,
&created, &created,
&updated, &updated,
&session.Data, &session.Settings,
) )
if err != nil { if err != nil {
if errors.Is(err, sql.ErrNoRows) { if errors.Is(err, sql.ErrNoRows) {
@ -74,11 +91,13 @@ func (db *DB) LoadSessionById(id string) (*model.Session, error) {
return &session, nil return &session, nil
} }
func (db *DB) TouchSession(sessionId string) error { func (db *DB) MigrateSession(session *model.Session, account *model.Account) (string, error) {
now := time.Now().UTC() if err := db.DeleteSession(session); err != nil {
_, err := db.Exec(`UPDATE sessions SET updated = ? WHERE id = ?`, now.Unix(), sessionId) return "", fmt.Errorf("failed to delete session: %w", err)
if err != nil {
return fmt.Errorf("failed to touch updated for session %s: %w", sessionId, err)
} }
return nil sessionId, err := db.CreateSession(account, session.Settings)
if err != nil {
return "", fmt.Errorf("failed to create session: %w", err)
}
return sessionId, nil
} }

View file

@ -16,12 +16,11 @@ CREATE TABLE accounts (
CREATE UNIQUE INDEX accounts_username on accounts(username); CREATE UNIQUE INDEX accounts_username on accounts(username);
CREATE TABLE sessions ( CREATE TABLE sessions (
id TEXT PRIMARY KEY, id BLOB PRIMARY KEY,
account_id TEXT NOT NULL, account_id TEXT,
created INTEGER NOT NULL DEFAULT (unixepoch()), created INTEGER NOT NULL DEFAULT (unixepoch()),
updated INTEGER NOT NULL DEFAULT (unixepoch()), updated INTEGER NOT NULL DEFAULT (unixepoch()),
data TEXT NOT NULL, settings BLOB NOT NULL
FOREIGN KEY(account_id) REFERENCES accounts(id) ON DELETE CASCADE
) STRICT; ) STRICT;
CREATE TABLE states ( CREATE TABLE states (

View file

@ -8,8 +8,9 @@ import (
) )
const ( const (
PBKDF2Iterations = 600000 PBKDF2PasswordIterations = 600000
SaltSize = 32 PBKDF2SessionIterations = 12
SaltSize = 32
) )
func GenerateSalt() []byte { func GenerateSalt() []byte {
@ -17,5 +18,9 @@ func GenerateSalt() []byte {
} }
func HashPassword(password string, salt []byte) []byte { func HashPassword(password string, salt []byte) []byte {
return pbkdf2.Key([]byte(password), salt, PBKDF2Iterations, 32, sha256.New) return pbkdf2.Key([]byte(password), salt, PBKDF2PasswordIterations, 32, sha256.New)
}
func HashSessionId(id []byte, salt []byte) []byte {
return pbkdf2.Key(id, salt, PBKDF2SessionIterations, 32, sha256.New)
} }

View file

@ -1,6 +1,7 @@
package model package model
import ( import (
"encoding/json"
"time" "time"
"go.n16f.net/uuid" "go.n16f.net/uuid"
@ -9,11 +10,11 @@ import (
type SessionContextKey struct{} type SessionContextKey struct{}
type Session struct { type Session struct {
Id string Id []byte
AccountId uuid.UUID AccountId *uuid.UUID
Created time.Time Created time.Time
Updated time.Time Updated time.Time
Data any Settings json.RawMessage
} }
func (session *Session) IsExpired() bool { func (session *Session) IsExpired() bool {

View file

@ -29,7 +29,7 @@ func handleAccountsIdGET(db *database.DB) http.Handler {
errorResponse(w, r, http.StatusBadRequest, err) errorResponse(w, r, http.StatusBadRequest, err)
return return
} }
account, err := db.LoadAccountById(accountId) account, err := db.LoadAccountById(&accountId)
if err != nil { if err != nil {
errorResponse(w, r, http.StatusInternalServerError, err) errorResponse(w, r, http.StatusInternalServerError, err)
return return

View file

@ -30,7 +30,7 @@ func processAccountsIdResetPasswordPathValues(db *database.DB, w http.ResponseWr
errorResponse(w, r, http.StatusBadRequest, err) errorResponse(w, r, http.StatusBadRequest, err)
return nil, false return nil, false
} }
account, err := db.LoadAccountById(accountId) account, err := db.LoadAccountById(&accountId)
if err != nil { if err != nil {
errorResponse(w, r, http.StatusInternalServerError, err) errorResponse(w, r, http.StatusInternalServerError, err)
return nil, false return nil, false
@ -55,7 +55,7 @@ func handleAccountsIdResetPasswordGET(db *database.DB) http.Handler {
render(w, accountsIdResetPasswordTemplates, http.StatusOK, render(w, accountsIdResetPasswordTemplates, http.StatusOK,
AccountsIdResetPasswordPage{ AccountsIdResetPasswordPage{
Account: account, Account: account,
Page: &Page{Title: "Password Reset", Section: "reset"}, Page: makePage(r, &Page{Title: "Password Reset", Section: "reset"}),
Token: r.PathValue("token"), Token: r.PathValue("token"),
}) })
}) })
@ -80,7 +80,7 @@ func handleAccountsIdResetPasswordPOST(db *database.DB) http.Handler {
render(w, accountsIdResetPasswordTemplates, http.StatusOK, render(w, accountsIdResetPasswordTemplates, http.StatusOK,
AccountsIdResetPasswordPage{ AccountsIdResetPasswordPage{
Account: account, Account: account,
Page: &Page{Title: "Password Reset", Section: "reset"}, Page: makePage(r, &Page{Title: "Password Reset", Section: "reset"}),
PasswordChanged: true, PasswordChanged: true,
}) })
}) })

View file

@ -15,7 +15,7 @@ func errorResponse(w http.ResponseWriter, r *http.Request, status int, err error
StatusText string StatusText string
} }
render(w, errorTemplates, status, &ErrorData{ render(w, errorTemplates, status, &ErrorData{
Page: &Page{Title: "Error", Section: "error"}, Page: makePage(r, &Page{Title: "Error", Section: "error"}),
Err: err, Err: err,
Status: status, Status: status,
StatusText: http.StatusText(status), StatusText: http.StatusText(status),

View file

@ -55,6 +55,5 @@
</div> </div>
<footer> <footer>
</footer> </footer>
<script type="module" src="/static/main.js"></script>
</body> </body>
</html> </html>

View file

@ -9,7 +9,7 @@ import (
) )
type Page struct { type Page struct {
AccountId uuid.UUID AccountId *uuid.UUID
IsAdmin bool IsAdmin bool
LightMode bool LightMode bool
Section string Section string
@ -17,9 +17,12 @@ type Page struct {
} }
func makePage(r *http.Request, page *Page) *Page { func makePage(r *http.Request, page *Page) *Page {
account := r.Context().Value(model.AccountContextKey{}).(*model.Account) accountCtx := r.Context().Value(model.AccountContextKey{})
page.AccountId = account.Id if accountCtx != nil {
page.IsAdmin = account.IsAdmin account := accountCtx.(*model.Account)
page.AccountId = &account.Id
page.IsAdmin = account.IsAdmin
}
settings := r.Context().Value(model.SettingsContextKey{}).(*model.Settings) settings := r.Context().Value(model.SettingsContextKey{}).(*model.Settings)
page.LightMode = settings.LightMode page.LightMode = settings.LightMode
return page return page

View file

@ -2,7 +2,7 @@ package webui
import ( import (
"context" "context"
"encoding/json" "fmt"
"html/template" "html/template"
"log/slog" "log/slog"
"net/http" "net/http"
@ -26,29 +26,30 @@ func handleLoginGET() http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Cache-Control", "no-store, no-cache") w.Header().Set("Cache-Control", "no-store, no-cache")
session := r.Context().Value(model.SessionContextKey{}) account := r.Context().Value(model.AccountContextKey{})
if session != nil { if account != nil {
http.Redirect(w, r, "/states", http.StatusFound) http.Redirect(w, r, "/states", http.StatusFound)
return return
} }
render(w, loginTemplate, http.StatusOK, loginPage{ render(w, loginTemplate, http.StatusOK, loginPage{
Page: &Page{Title: "Login", Section: "login"}, Page: makePage(r, &Page{Title: "Login", Section: "login"}),
}) })
}) })
} }
func handleLoginPOST(db *database.DB) http.Handler { func handleLoginPOST(db *database.DB) http.Handler {
renderForbidden := func(w http.ResponseWriter, username string) { renderForbidden := func(w http.ResponseWriter, r *http.Request, username string) {
render(w, loginTemplate, http.StatusForbidden, loginPage{ render(w, loginTemplate, http.StatusForbidden, loginPage{
Page: &Page{Title: "Login", Section: "login"}, Page: makePage(r, &Page{Title: "Login", Section: "login"}),
Forbidden: true, Forbidden: true,
Username: username, Username: username,
}) })
} }
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if err := r.ParseForm(); err != nil { if err := r.ParseForm(); err != nil {
errorResponse(w, r, http.StatusBadRequest, err) errorResponse(w, r, http.StatusBadRequest,
fmt.Errorf("failed to parse form: %w", err))
return return
} }
username := r.FormValue("username") username := r.FormValue("username")
@ -59,37 +60,32 @@ func handleLoginPOST(db *database.DB) http.Handler {
return return
} }
if ok := validUsername.MatchString(username); !ok { if ok := validUsername.MatchString(username); !ok {
renderForbidden(w, username) renderForbidden(w, r, username)
return return
} }
account, err := db.LoadAccountByUsername(username) account, err := db.LoadAccountByUsername(username)
if err != nil { if err != nil {
errorResponse(w, r, http.StatusInternalServerError, err) errorResponse(w, r, http.StatusInternalServerError,
fmt.Errorf("failed to load account by username %s: %w", username, err))
return return
} }
if account == nil || !account.CheckPassword(password) { if account == nil || !account.CheckPassword(password) {
renderForbidden(w, username) renderForbidden(w, r, username)
return return
} }
if err := db.TouchAccount(account); err != nil { if err := db.TouchAccount(account); err != nil {
errorResponse(w, r, http.StatusInternalServerError, err) errorResponse(w, r, http.StatusInternalServerError,
fmt.Errorf("failed to touch account %s: %w", username, err))
return return
} }
sessionId, err := db.CreateSession(account) session := r.Context().Value(model.SessionContextKey{}).(*model.Session)
sessionId, err := db.MigrateSession(session, account)
if err != nil { if err != nil {
errorResponse(w, r, http.StatusInternalServerError, err) errorResponse(w, r, http.StatusInternalServerError,
fmt.Errorf("failed to migrate session: %w", err))
return return
} }
http.SetCookie(w, &http.Cookie{ setSessionCookie(w, sessionId)
Name: cookieName,
Value: sessionId,
Quoted: false,
Path: "/",
MaxAge: 12 * 3600, // 12 hours sessions
HttpOnly: true,
SameSite: http.SameSiteStrictMode,
Secure: true,
})
if err := db.DeleteExpiredSessions(); err != nil { if err := db.DeleteExpiredSessions(); err != nil {
slog.Error("failed to delete expired sessions after user login", "err", err, "accountId", account.Id) slog.Error("failed to delete expired sessions after user login", "err", err, "accountId", account.Id)
} }
@ -97,32 +93,26 @@ func handleLoginPOST(db *database.DB) http.Handler {
}) })
} }
func loginMiddleware(db *database.DB, processSession func(http.Handler) http.Handler) func(http.Handler) http.Handler { func loginMiddleware(db *database.DB, requireSession func(http.Handler) http.Handler) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler { return func(next http.Handler) http.Handler {
return processSession(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return requireSession(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Cache-Control", "no-store, no-cache") w.Header().Set("Cache-Control", "no-store, no-cache")
session := r.Context().Value(model.SessionContextKey{}) session := r.Context().Value(model.SessionContextKey{}).(*model.Session)
if session == nil { if session.AccountId == nil {
http.Redirect(w, r, "/login", http.StatusFound) http.Redirect(w, r, "/login", http.StatusFound)
return return
} }
account, err := db.LoadAccountById(session.(*model.Session).AccountId) account, err := db.LoadAccountById(session.AccountId)
if err != nil { if err != nil {
errorResponse(w, r, http.StatusInternalServerError, err) errorResponse(w, r, http.StatusInternalServerError,
fmt.Errorf("failed to load account by Id: %w", err))
return return
} }
if account == nil { if account == nil {
// this could happen if the account was deleted in the short
// time between retrieving the session and here
http.Redirect(w, r, "/login", http.StatusFound) http.Redirect(w, r, "/login", http.StatusFound)
return return
} }
ctx := context.WithValue(r.Context(), model.AccountContextKey{}, account) ctx := context.WithValue(r.Context(), model.AccountContextKey{}, account)
var settings model.Settings
if err := json.Unmarshal(account.Settings, &settings); err != nil {
slog.Error("failed to unmarshal account settings", "err", err, "accountId", account.Id)
}
ctx = context.WithValue(ctx, model.SettingsContextKey{}, &settings)
next.ServeHTTP(w, r.WithContext(ctx)) next.ServeHTTP(w, r.WithContext(ctx))
})) }))
} }

View file

@ -1,6 +1,7 @@
package webui package webui
import ( import (
"fmt"
"html/template" "html/template"
"net/http" "net/http"
@ -12,18 +13,19 @@ var logoutTemplate = template.Must(template.ParseFS(htmlFS, "html/base.html", "h
func handleLogoutGET(db *database.DB) http.Handler { func handleLogoutGET(db *database.DB) http.Handler {
type logoutPage struct { type logoutPage struct {
Page Page *Page
} }
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
session := r.Context().Value(model.SessionContextKey{}) session := r.Context().Value(model.SessionContextKey{}).(*model.Session)
err := db.DeleteSession(session.(*model.Session)) sessionId, err := db.MigrateSession(session, nil)
if err != nil { if err != nil {
errorResponse(w, r, http.StatusInternalServerError, err) errorResponse(w, r, http.StatusInternalServerError,
fmt.Errorf("failed to migrate session: %w", err))
return return
} }
unsetSesssionCookie(w) setSessionCookie(w, sessionId)
render(w, logoutTemplate, http.StatusOK, logoutPage{ render(w, logoutTemplate, http.StatusOK, logoutPage{
Page: Page{Title: "Logout", Section: "login"}, Page: makePage(r, &Page{Title: "Logout", Section: "login"}),
}) })
}) })
} }

View file

@ -10,17 +10,17 @@ func addRoutes(
mux *http.ServeMux, mux *http.ServeMux,
db *database.DB, db *database.DB,
) { ) {
processSession := sessionsMiddleware(db) requireSession := sessionsMiddleware(db)
requireLogin := loginMiddleware(db, processSession) requireLogin := loginMiddleware(db, requireSession)
requireAdmin := adminMiddleware(db, requireLogin) requireAdmin := adminMiddleware(db, requireLogin)
mux.Handle("GET /accounts", requireLogin(handleAccountsGET(db))) mux.Handle("GET /accounts", requireLogin(handleAccountsGET(db)))
mux.Handle("GET /accounts/{id}", requireLogin(handleAccountsIdGET(db))) mux.Handle("GET /accounts/{id}", requireLogin(handleAccountsIdGET(db)))
mux.Handle("GET /accounts/{id}/reset/{token}", handleAccountsIdResetPasswordGET(db)) mux.Handle("GET /accounts/{id}/reset/{token}", requireSession(handleAccountsIdResetPasswordGET(db)))
mux.Handle("POST /accounts/{id}/reset/{token}", handleAccountsIdResetPasswordPOST(db)) mux.Handle("POST /accounts/{id}/reset/{token}", requireSession(handleAccountsIdResetPasswordPOST(db)))
mux.Handle("POST /accounts", requireAdmin(handleAccountsPOST(db))) mux.Handle("POST /accounts", requireAdmin(handleAccountsPOST(db)))
mux.Handle("GET /healthz", handleHealthz()) mux.Handle("GET /healthz", handleHealthz())
mux.Handle("GET /login", processSession(handleLoginGET())) mux.Handle("GET /login", requireSession(handleLoginGET()))
mux.Handle("POST /login", processSession(handleLoginPOST(db))) mux.Handle("POST /login", requireSession(handleLoginPOST(db)))
mux.Handle("GET /logout", requireLogin(handleLogoutGET(db))) mux.Handle("GET /logout", requireLogin(handleLogoutGET(db)))
mux.Handle("GET /settings", requireLogin(handleSettingsGET(db))) mux.Handle("GET /settings", requireLogin(handleSettingsGET(db)))
mux.Handle("POST /settings", requireLogin(handleSettingsPOST(db))) mux.Handle("POST /settings", requireLogin(handleSettingsPOST(db)))
@ -30,5 +30,5 @@ func addRoutes(
mux.Handle("POST /states/{id}", requireLogin(handleStatesIdPOST(db))) mux.Handle("POST /states/{id}", requireLogin(handleStatesIdPOST(db)))
mux.Handle("GET /static/", cache(http.FileServer(http.FS(staticFS)))) mux.Handle("GET /static/", cache(http.FileServer(http.FS(staticFS))))
mux.Handle("GET /versions/{id}", requireLogin(handleVersionsGET(db))) mux.Handle("GET /versions/{id}", requireLogin(handleVersionsGET(db)))
mux.Handle("GET /", requireLogin(handleIndexGET())) mux.Handle("GET /", requireSession(handleIndexGET()))
} }

View file

@ -2,8 +2,10 @@ package webui
import ( import (
"context" "context"
"encoding/json"
"errors" "errors"
"fmt" "fmt"
"log/slog"
"net/http" "net/http"
"git.adyxax.org/adyxax/tfstated/pkg/database" "git.adyxax.org/adyxax/tfstated/pkg/database"
@ -17,49 +19,59 @@ func sessionsMiddleware(db *database.DB) func(http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
cookie, err := r.Cookie(cookieName) cookie, err := r.Cookie(cookieName)
if err != nil && !errors.Is(err, http.ErrNoCookie) { if err != nil && !errors.Is(err, http.ErrNoCookie) {
errorResponse(w, r, http.StatusInternalServerError, fmt.Errorf("failed to get request cookie \"%s\": %w", cookieName, err)) errorResponse(w, r, http.StatusInternalServerError,
fmt.Errorf("failed to get session cookie from request: %w", err))
return return
} }
if err == nil { if err == nil {
if len(cookie.Value) != 43 { if len(cookie.Value) == 43 {
unsetSesssionCookie(w)
} else {
session, err := db.LoadSessionById(cookie.Value) session, err := db.LoadSessionById(cookie.Value)
if err != nil { if err != nil {
errorResponse(w, r, http.StatusInternalServerError, err) errorResponse(w, r, http.StatusInternalServerError,
fmt.Errorf("failed to load session by ID: %w", err))
return return
} }
if session == nil { if session != nil {
unsetSesssionCookie(w) if session.IsExpired() {
} else if session.IsExpired() { if err := db.DeleteSession(session); err != nil {
unsetSesssionCookie(w) errorResponse(w, r, http.StatusInternalServerError,
if err := db.DeleteSession(session); err != nil { fmt.Errorf("failed to delete session: %w", err))
errorResponse(w, r, http.StatusInternalServerError, err) return
}
} else {
ctx := context.WithValue(r.Context(), model.SessionContextKey{}, session)
var settings model.Settings
if err := json.Unmarshal(session.Settings, &settings); err != nil {
slog.Error("failed to unmarshal session settings", "err", err)
}
ctx = context.WithValue(ctx, model.SettingsContextKey{}, &settings)
next.ServeHTTP(w, r.WithContext(ctx))
return return
} }
} else {
if err := db.TouchSession(cookie.Value); err != nil {
errorResponse(w, r, http.StatusInternalServerError, err)
return
}
ctx := context.WithValue(r.Context(), model.SessionContextKey{}, session)
next.ServeHTTP(w, r.WithContext(ctx))
return
} }
} }
} }
next.ServeHTTP(w, r) sessionId, err := db.CreateSession(nil, nil)
if err != nil {
errorResponse(w, r, http.StatusInternalServerError,
fmt.Errorf("failed to create session: %w", err))
return
}
setSessionCookie(w, sessionId)
var settings model.Settings
ctx := context.WithValue(r.Context(), model.SettingsContextKey{}, &settings)
next.ServeHTTP(w, r.WithContext(ctx))
}) })
} }
} }
func unsetSesssionCookie(w http.ResponseWriter) { func setSessionCookie(w http.ResponseWriter, sessionId string) {
http.SetCookie(w, &http.Cookie{ http.SetCookie(w, &http.Cookie{
Name: cookieName, Name: cookieName,
Value: "", Value: sessionId,
Quoted: false, Quoted: false,
Path: "/", Path: "/",
MaxAge: 0, // remove invalid cookie MaxAge: 12 * 3600, // 12 hours sessions
HttpOnly: true, HttpOnly: true,
SameSite: http.SameSiteStrictMode, SameSite: http.SameSiteStrictMode,
Secure: true, Secure: true,

View file

@ -1,6 +1,7 @@
package webui package webui
import ( import (
"context"
"html/template" "html/template"
"net/http" "net/http"
@ -41,8 +42,8 @@ func handleSettingsPOST(db *database.DB) http.Handler {
errorResponse(w, r, http.StatusInternalServerError, err) errorResponse(w, r, http.StatusInternalServerError, err)
return return
} }
page := makePage(r, &Page{Title: "Settings", Section: "settings"}) ctx := context.WithValue(r.Context(), model.SettingsContextKey{}, &settings)
page.LightMode = settings.LightMode page := makePage(r.WithContext(ctx), &Page{Title: "Settings", Section: "settings"})
render(w, settingsTemplates, http.StatusOK, SettingsPage{ render(w, settingsTemplates, http.StatusOK, SettingsPage{
Page: page, Page: page,
Settings: &settings, Settings: &settings,

View file

@ -39,7 +39,7 @@ func handleVersionsGET(db *database.DB) http.Handler {
errorResponse(w, r, http.StatusInternalServerError, err) errorResponse(w, r, http.StatusInternalServerError, err)
return return
} }
account, err := db.LoadAccountById(version.AccountId) account, err := db.LoadAccountById(&version.AccountId)
if err != nil { if err != nil {
errorResponse(w, r, http.StatusInternalServerError, err) errorResponse(w, r, http.StatusInternalServerError, err)
return return