chore(tfstated): refactor helpers to their own package

This commit is contained in:
Julien Dessaux 2024-11-17 00:05:22 +01:00
parent 5b6da56089
commit 25ed1188ed
Signed by: adyxax
GPG key ID: F92E51B86E07177E
11 changed files with 78 additions and 62 deletions

View file

@ -5,22 +5,23 @@ import (
"net/http" "net/http"
"git.adyxax.org/adyxax/tfstated/pkg/database" "git.adyxax.org/adyxax/tfstated/pkg/database"
"git.adyxax.org/adyxax/tfstated/pkg/helpers"
) )
func handleDelete(db *database.DB) http.Handler { func handleDelete(db *database.DB) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/" { if r.URL.Path == "/" {
_ = errorResponse(w, http.StatusBadRequest, helpers.ErrorResponse(w, http.StatusBadRequest,
fmt.Errorf("no state path provided, cannot DELETE /")) fmt.Errorf("no state path provided, cannot DELETE /"))
return return
} }
if success, err := db.DeleteState(r.URL.Path); err != nil { if success, err := db.DeleteState(r.URL.Path); err != nil {
_ = errorResponse(w, http.StatusInternalServerError, err) helpers.ErrorResponse(w, http.StatusInternalServerError, err)
} else if success { } else if success {
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
} else { } else {
_ = errorResponse(w, http.StatusNotFound, helpers.ErrorResponse(w, http.StatusNotFound,
fmt.Errorf("state path not found: %s", r.URL.Path)) fmt.Errorf("state path not found: %s", r.URL.Path))
} }
}) })

View file

@ -5,6 +5,7 @@ import (
"net/http" "net/http"
"git.adyxax.org/adyxax/tfstated/pkg/database" "git.adyxax.org/adyxax/tfstated/pkg/database"
"git.adyxax.org/adyxax/tfstated/pkg/helpers"
) )
func handleGet(db *database.DB) http.Handler { func handleGet(db *database.DB) http.Handler {
@ -12,13 +13,13 @@ func handleGet(db *database.DB) http.Handler {
w.Header().Set("Cache-Control", "no-store, no-cache") w.Header().Set("Cache-Control", "no-store, no-cache")
if r.URL.Path == "/" { if r.URL.Path == "/" {
_ = errorResponse(w, http.StatusBadRequest, helpers.ErrorResponse(w, http.StatusBadRequest,
fmt.Errorf("no state path provided, cannot GET /")) fmt.Errorf("no state path provided, cannot GET /"))
return return
} }
if data, err := db.GetState(r.URL.Path); err != nil { if data, err := db.GetState(r.URL.Path); err != nil {
_ = errorResponse(w, http.StatusInternalServerError, err) helpers.ErrorResponse(w, http.StatusInternalServerError, err)
} else { } else {
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
_, _ = w.Write(data) _, _ = w.Write(data)

View file

@ -7,6 +7,7 @@ import (
"time" "time"
"git.adyxax.org/adyxax/tfstated/pkg/database" "git.adyxax.org/adyxax/tfstated/pkg/database"
"git.adyxax.org/adyxax/tfstated/pkg/helpers"
) )
type lockRequest struct { type lockRequest struct {
@ -34,27 +35,27 @@ func (l *lockRequest) valid() []error {
func handleLock(db *database.DB) http.Handler { func handleLock(db *database.DB) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/" { if r.URL.Path == "/" {
_ = encode(w, http.StatusBadRequest, _ = helpers.Encode(w, http.StatusBadRequest,
fmt.Errorf("no state path provided, cannot LOCK /")) fmt.Errorf("no state path provided, cannot LOCK /"))
return return
} }
var lock lockRequest var lock lockRequest
if err := decode(r, &lock); err != nil { if err := helpers.Decode(r, &lock); err != nil {
_ = encode(w, http.StatusBadRequest, err) _ = helpers.Encode(w, http.StatusBadRequest, err)
return return
} }
if errs := lock.valid(); len(errs) > 0 { if errs := lock.valid(); len(errs) > 0 {
_ = encode(w, http.StatusBadRequest, _ = helpers.Encode(w, http.StatusBadRequest,
fmt.Errorf("invalid lock: %+v", errs)) fmt.Errorf("invalid lock: %+v", errs))
return return
} }
if success, err := db.SetLockOrGetExistingLock(r.URL.Path, &lock); err != nil { if success, err := db.SetLockOrGetExistingLock(r.URL.Path, &lock); err != nil {
_ = errorResponse(w, http.StatusInternalServerError, err) helpers.ErrorResponse(w, http.StatusInternalServerError, err)
} else if success { } else if success {
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
} else { } else {
_ = encode(w, http.StatusConflict, lock) _ = helpers.Encode(w, http.StatusConflict, lock)
} }
}) })
} }

View file

@ -6,13 +6,14 @@ import (
"net/http" "net/http"
"git.adyxax.org/adyxax/tfstated/pkg/database" "git.adyxax.org/adyxax/tfstated/pkg/database"
"git.adyxax.org/adyxax/tfstated/pkg/helpers"
"git.adyxax.org/adyxax/tfstated/pkg/model" "git.adyxax.org/adyxax/tfstated/pkg/model"
) )
func handlePost(db *database.DB) http.Handler { func handlePost(db *database.DB) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/" { if r.URL.Path == "/" {
_ = errorResponse(w, http.StatusBadRequest, helpers.ErrorResponse(w, http.StatusBadRequest,
fmt.Errorf("no state path provided, cannot POST /"), fmt.Errorf("no state path provided, cannot POST /"),
) )
return return
@ -22,15 +23,15 @@ func handlePost(db *database.DB) http.Handler {
data, err := io.ReadAll(r.Body) data, err := io.ReadAll(r.Body)
if err != nil || len(data) == 0 { if err != nil || len(data) == 0 {
_ = errorResponse(w, http.StatusBadRequest, err) helpers.ErrorResponse(w, http.StatusBadRequest, err)
return return
} }
account := r.Context().Value(model.AccountContextKey{}).(*model.Account) account := r.Context().Value(model.AccountContextKey{}).(*model.Account)
if idMismatch, err := db.SetState(r.URL.Path, account.Id, data, id); err != nil { if idMismatch, err := db.SetState(r.URL.Path, account.Id, data, id); err != nil {
if idMismatch { if idMismatch {
_ = errorResponse(w, http.StatusConflict, err) helpers.ErrorResponse(w, http.StatusConflict, err)
} else { } else {
_ = errorResponse(w, http.StatusInternalServerError, err) helpers.ErrorResponse(w, http.StatusInternalServerError, err)
} }
} else { } else {
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)

View file

@ -5,27 +5,28 @@ import (
"net/http" "net/http"
"git.adyxax.org/adyxax/tfstated/pkg/database" "git.adyxax.org/adyxax/tfstated/pkg/database"
"git.adyxax.org/adyxax/tfstated/pkg/helpers"
) )
func handleUnlock(db *database.DB) http.Handler { func handleUnlock(db *database.DB) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/" { if r.URL.Path == "/" {
_ = encode(w, http.StatusBadRequest, _ = helpers.Encode(w, http.StatusBadRequest,
fmt.Errorf("no state path provided, cannot LOCK /")) fmt.Errorf("no state path provided, cannot LOCK /"))
return return
} }
var lock lockRequest var lock lockRequest
if err := decode(r, &lock); err != nil { if err := helpers.Decode(r, &lock); err != nil {
_ = encode(w, http.StatusBadRequest, err) _ = helpers.Encode(w, http.StatusBadRequest, err)
return return
} }
if success, err := db.Unlock(r.URL.Path, &lock); err != nil { if success, err := db.Unlock(r.URL.Path, &lock); err != nil {
_ = errorResponse(w, http.StatusInternalServerError, err) helpers.ErrorResponse(w, http.StatusInternalServerError, err)
} else if success { } else if success {
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
} else { } else {
_ = encode(w, http.StatusConflict, lock) _ = helpers.Encode(w, http.StatusConflict, lock)
} }
}) })
} }

View file

@ -2,10 +2,12 @@ package basic_auth
import ( import (
"context" "context"
"fmt"
"net/http" "net/http"
"time" "time"
"git.adyxax.org/adyxax/tfstated/pkg/database" "git.adyxax.org/adyxax/tfstated/pkg/database"
"git.adyxax.org/adyxax/tfstated/pkg/helpers"
"git.adyxax.org/adyxax/tfstated/pkg/model" "git.adyxax.org/adyxax/tfstated/pkg/model"
) )
@ -15,26 +17,22 @@ func Middleware(db *database.DB) func(http.Handler) http.Handler {
username, password, ok := r.BasicAuth() username, password, ok := r.BasicAuth()
if !ok { if !ok {
w.Header().Set("WWW-Authenticate", `Basic realm="tfstated", charset="UTF-8"`) w.Header().Set("WWW-Authenticate", `Basic realm="tfstated", charset="UTF-8"`)
http.Error(w, "Unauthorized", http.StatusUnauthorized) helpers.ErrorResponse(w, http.StatusUnauthorized, fmt.Errorf("Unauthorized"))
return return
} }
account, err := db.LoadAccountByUsername(username) account, err := db.LoadAccountByUsername(username)
if err != nil { if err != nil {
http.Error(w, "Internal Server Error", http.StatusInternalServerError) helpers.ErrorResponse(w, http.StatusInternalServerError, err)
return return
} }
if account == nil { if account == nil || !account.CheckPassword(password) {
http.Error(w, "Forbidden", http.StatusForbidden) helpers.ErrorResponse(w, http.StatusForbidden, fmt.Errorf("Forbidden"))
return
}
if !account.CheckPassword(password) {
http.Error(w, "Forbidden", http.StatusForbidden)
return return
} }
now := time.Now().UTC() now := time.Now().UTC()
_, err = db.Exec(`UPDATE accounts SET last_login = ? WHERE id = ?`, now.Unix(), account.Id) _, err = db.Exec(`UPDATE accounts SET last_login = ? WHERE id = ?`, now.Unix(), account.Id)
if err != nil { if err != nil {
http.Error(w, "Internal Server Error", http.StatusInternalServerError) helpers.ErrorResponse(w, http.StatusInternalServerError, err)
return return
} }
ctx := context.WithValue(r.Context(), model.AccountContextKey{}, account) ctx := context.WithValue(r.Context(), model.AccountContextKey{}, account)

View file

@ -7,6 +7,7 @@ import (
"log/slog" "log/slog"
"time" "time"
"git.adyxax.org/adyxax/tfstated/pkg/helpers"
"git.adyxax.org/adyxax/tfstated/pkg/model" "git.adyxax.org/adyxax/tfstated/pkg/model"
"go.n16f.net/uuid" "go.n16f.net/uuid"
) )
@ -69,8 +70,8 @@ func (db *DB) InitAdminAccount() error {
if err = password.Generate(uuid.V4); err != nil { if err = password.Generate(uuid.V4); err != nil {
return fmt.Errorf("failed to generate initial admin password: %w", err) return fmt.Errorf("failed to generate initial admin password: %w", err)
} }
salt := model.GenerateSalt() salt := helpers.GenerateSalt()
hash := model.HashPassword(password.String(), salt) hash := helpers.HashPassword(password.String(), salt)
if _, err = tx.ExecContext(db.ctx, if _, err = tx.ExecContext(db.ctx,
`INSERT INTO accounts(username, salt, password_hash, is_admin) `INSERT INTO accounts(username, salt, password_hash, is_admin)
VALUES ("admin", :salt, :hash, TRUE) VALUES ("admin", :salt, :hash, TRUE)

21
pkg/helpers/crypto.go Normal file
View file

@ -0,0 +1,21 @@
package helpers
import (
"crypto/sha256"
"git.adyxax.org/adyxax/tfstated/pkg/scrypto"
"golang.org/x/crypto/pbkdf2"
)
const (
PBKDF2Iterations = 600000
SaltSize = 32
)
func GenerateSalt() []byte {
return scrypto.RandomBytes(SaltSize)
}
func HashPassword(password string, salt []byte) []byte {
return pbkdf2.Key([]byte(password), salt, PBKDF2Iterations, 32, sha256.New)
}

17
pkg/helpers/error.go Normal file
View file

@ -0,0 +1,17 @@
package helpers
import (
"fmt"
"net/http"
)
func ErrorResponse(w http.ResponseWriter, status int, err error) {
type errorResponse struct {
Msg string `json:"msg"`
Status int `json:"status"`
}
_ = Encode(w, status, &errorResponse{
Msg: fmt.Sprintf("%+v", err),
Status: status,
})
}

View file

@ -1,4 +1,4 @@
package main package helpers
import ( import (
"encoding/json" "encoding/json"
@ -7,14 +7,14 @@ import (
"net/http" "net/http"
) )
func decode(r *http.Request, data any) error { func Decode(r *http.Request, data any) error {
if err := json.NewDecoder(r.Body).Decode(&data); err != nil { if err := json.NewDecoder(r.Body).Decode(&data); err != nil {
return fmt.Errorf("failed to decode json: %w", err) return fmt.Errorf("failed to decode json: %w", err)
} }
return nil return nil
} }
func encode(w http.ResponseWriter, status int, data any) error { func Encode(w http.ResponseWriter, status int, data any) error {
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
w.WriteHeader(status) w.WriteHeader(status)
if err := json.NewEncoder(w).Encode(data); err != nil { if err := json.NewEncoder(w).Encode(data); err != nil {
@ -23,14 +23,3 @@ func encode(w http.ResponseWriter, status int, data any) error {
} }
return nil return nil
} }
func errorResponse(w http.ResponseWriter, status int, err error) error {
type errorResponse struct {
Msg string `json:"msg"`
Status int `json:"status"`
}
return encode(w, status, &errorResponse{
Msg: fmt.Sprintf("%+v", err),
Status: status,
})
}

View file

@ -1,17 +1,10 @@
package model package model
import ( import (
"crypto/sha256"
"crypto/subtle" "crypto/subtle"
"time" "time"
"git.adyxax.org/adyxax/tfstated/pkg/scrypto" "git.adyxax.org/adyxax/tfstated/pkg/helpers"
"golang.org/x/crypto/pbkdf2"
)
const (
PBKDF2Iterations = 600000
SaltSize = 32
) )
type AccountContextKey struct{} type AccountContextKey struct{}
@ -28,14 +21,6 @@ type Account struct {
} }
func (account *Account) CheckPassword(password string) bool { func (account *Account) CheckPassword(password string) bool {
hash := HashPassword(password, account.Salt) hash := helpers.HashPassword(password, account.Salt)
return subtle.ConstantTimeCompare(hash, account.PasswordHash) == 1 return subtle.ConstantTimeCompare(hash, account.PasswordHash) == 1
} }
func GenerateSalt() []byte {
return scrypto.RandomBytes(SaltSize)
}
func HashPassword(password string, salt []byte) []byte {
return pbkdf2.Key([]byte(password), salt, PBKDF2Iterations, 32, sha256.New)
}