summaryrefslogtreecommitdiff
path: root/pkg/middlewares
diff options
context:
space:
mode:
authorJulien Dessaux2025-01-03 00:54:15 +0100
committerJulien Dessaux2025-01-03 00:54:15 +0100
commitc18d03d4049e7fe2e032ba448e88a44671dfdbeb (patch)
tree09cfe305f3bb468b9a3a750c92c9b54e574a21ab /pkg/middlewares
parentfeat(tfstated): bootstrap webui listening on a second port (diff)
downloadtfstated-c18d03d4049e7fe2e032ba448e88a44671dfdbeb.tar.gz
tfstated-c18d03d4049e7fe2e032ba448e88a44671dfdbeb.tar.bz2
tfstated-c18d03d4049e7fe2e032ba448e88a44671dfdbeb.zip
chore(tfstated): refactor middlewares
Diffstat (limited to 'pkg/middlewares')
-rw-r--r--pkg/middlewares/basic_auth/middleware.go39
-rw-r--r--pkg/middlewares/logger/body_writer.go65
-rw-r--r--pkg/middlewares/logger/middleware.go72
3 files changed, 176 insertions, 0 deletions
diff --git a/pkg/middlewares/basic_auth/middleware.go b/pkg/middlewares/basic_auth/middleware.go
new file mode 100644
index 0000000..cb2dcf0
--- /dev/null
+++ b/pkg/middlewares/basic_auth/middleware.go
@@ -0,0 +1,39 @@
+package basic_auth
+
+import (
+ "context"
+ "fmt"
+ "net/http"
+
+ "git.adyxax.org/adyxax/tfstated/pkg/database"
+ "git.adyxax.org/adyxax/tfstated/pkg/helpers"
+ "git.adyxax.org/adyxax/tfstated/pkg/model"
+)
+
+func Middleware(db *database.DB) func(http.Handler) http.Handler {
+ return func(next http.Handler) http.Handler {
+ return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ username, password, ok := r.BasicAuth()
+ if !ok {
+ w.Header().Set("WWW-Authenticate", `Basic realm="tfstated", charset="UTF-8"`)
+ helpers.ErrorResponse(w, http.StatusUnauthorized, fmt.Errorf("Unauthorized"))
+ return
+ }
+ account, err := db.LoadAccountByUsername(username)
+ if err != nil {
+ helpers.ErrorResponse(w, http.StatusInternalServerError, err)
+ return
+ }
+ if account == nil || !account.CheckPassword(password) {
+ helpers.ErrorResponse(w, http.StatusForbidden, fmt.Errorf("Forbidden"))
+ return
+ }
+ if err := db.TouchAccount(account); err != nil {
+ helpers.ErrorResponse(w, http.StatusInternalServerError, err)
+ return
+ }
+ ctx := context.WithValue(r.Context(), model.AccountContextKey{}, account)
+ next.ServeHTTP(w, r.WithContext(ctx))
+ })
+ }
+}
diff --git a/pkg/middlewares/logger/body_writer.go b/pkg/middlewares/logger/body_writer.go
new file mode 100644
index 0000000..f57a03e
--- /dev/null
+++ b/pkg/middlewares/logger/body_writer.go
@@ -0,0 +1,65 @@
+package logger
+
+import (
+ "bufio"
+ "bytes"
+ "errors"
+ "net"
+ "net/http"
+)
+
+type bodyWriter struct {
+ http.ResponseWriter
+ body *bytes.Buffer
+ maxSize int
+ bytes int
+ status int
+}
+
+func newBodyWriter(writer http.ResponseWriter, recordBody bool) *bodyWriter {
+ var body *bytes.Buffer
+ if recordBody {
+ body = bytes.NewBufferString("")
+ }
+ return &bodyWriter{
+ body: body,
+ maxSize: 64000,
+ ResponseWriter: writer,
+ status: http.StatusNotImplemented,
+ }
+}
+
+// implements http.ResponseWriter
+func (w *bodyWriter) Write(b []byte) (int, error) {
+ if w.body != nil {
+ if w.body.Len()+len(b) > w.maxSize {
+ w.body.Write(b[:w.maxSize-w.body.Len()])
+ } else {
+ w.body.Write(b)
+ }
+ }
+
+ w.bytes += len(b)
+ return w.ResponseWriter.Write(b)
+}
+
+// implements http.ResponseWriter
+func (w *bodyWriter) WriteHeader(code int) {
+ w.status = code
+ w.ResponseWriter.WriteHeader(code)
+}
+
+// implements http.Flusher
+func (w *bodyWriter) Flush() {
+ if f, ok := w.ResponseWriter.(http.Flusher); ok {
+ f.Flush()
+ }
+}
+
+// implements http.Hijacker
+func (w *bodyWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
+ if hi, ok := w.ResponseWriter.(http.Hijacker); ok {
+ return hi.Hijack()
+ }
+ return nil, nil, errors.New("Hijack not supported")
+}
diff --git a/pkg/middlewares/logger/middleware.go b/pkg/middlewares/logger/middleware.go
new file mode 100644
index 0000000..54cca96
--- /dev/null
+++ b/pkg/middlewares/logger/middleware.go
@@ -0,0 +1,72 @@
+package logger
+
+import (
+ "log/slog"
+ "net/http"
+ "runtime/debug"
+ "strconv"
+ "time"
+)
+
+func Middleware(next http.Handler, recordBody bool) http.Handler {
+ return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ if r.Method == "GET" && r.URL.Path == "/healthz" {
+ next.ServeHTTP(w, r)
+ return
+ }
+
+ defer func() {
+ if err := recover(); err != nil {
+ w.WriteHeader(http.StatusInternalServerError)
+ slog.Error(
+ "panic",
+ "err", err,
+ "trace", string(debug.Stack()),
+ )
+ }
+ }()
+ start := time.Now()
+ path := r.URL.Path
+ query := r.URL.RawQuery
+
+ bw := newBodyWriter(w, recordBody)
+
+ next.ServeHTTP(bw, r)
+
+ end := time.Now()
+ requestAttributes := []slog.Attr{
+ slog.Time("time", start.UTC()),
+ slog.String("method", r.Method),
+ slog.String("host", r.Host),
+ slog.String("path", path),
+ slog.String("query", query),
+ slog.String("ip", r.RemoteAddr),
+ }
+ responseAttributes := []slog.Attr{
+ slog.Time("time", end.UTC()),
+ slog.Duration("latency", end.Sub(start)),
+ slog.Int("length", bw.bytes),
+ slog.Int("status", bw.status),
+ }
+ if recordBody {
+ responseAttributes = append(responseAttributes, slog.String("body", bw.body.String()))
+ }
+ attributes := []slog.Attr{
+ {
+ Key: "request",
+ Value: slog.GroupValue(requestAttributes...),
+ },
+ {
+ Key: "response",
+ Value: slog.GroupValue(responseAttributes...),
+ },
+ }
+ level := slog.LevelInfo
+ if bw.status >= http.StatusInternalServerError {
+ level = slog.LevelError
+ } else if bw.status >= http.StatusBadRequest && bw.status < http.StatusInternalServerError {
+ level = slog.LevelWarn
+ }
+ slog.LogAttrs(r.Context(), level, strconv.Itoa(bw.status)+": "+http.StatusText(bw.status), attributes...)
+ })
+}