From c18d03d4049e7fe2e032ba448e88a44671dfdbeb Mon Sep 17 00:00:00 2001 From: Julien Dessaux Date: Fri, 3 Jan 2025 00:54:15 +0100 Subject: chore(tfstated): refactor middlewares --- pkg/backend/routes.go | 2 +- pkg/backend/run.go | 2 +- pkg/basic_auth/middleware.go | 39 ----------------- pkg/logger/body_writer.go | 65 ---------------------------- pkg/logger/middleware.go | 72 -------------------------------- pkg/middlewares/basic_auth/middleware.go | 39 +++++++++++++++++ pkg/middlewares/logger/body_writer.go | 65 ++++++++++++++++++++++++++++ pkg/middlewares/logger/middleware.go | 72 ++++++++++++++++++++++++++++++++ pkg/webui/run.go | 2 +- 9 files changed, 179 insertions(+), 179 deletions(-) delete mode 100644 pkg/basic_auth/middleware.go delete mode 100644 pkg/logger/body_writer.go delete mode 100644 pkg/logger/middleware.go create mode 100644 pkg/middlewares/basic_auth/middleware.go create mode 100644 pkg/middlewares/logger/body_writer.go create mode 100644 pkg/middlewares/logger/middleware.go diff --git a/pkg/backend/routes.go b/pkg/backend/routes.go index 058febd..960a2e8 100644 --- a/pkg/backend/routes.go +++ b/pkg/backend/routes.go @@ -3,8 +3,8 @@ package backend import ( "net/http" - "git.adyxax.org/adyxax/tfstated/pkg/basic_auth" "git.adyxax.org/adyxax/tfstated/pkg/database" + "git.adyxax.org/adyxax/tfstated/pkg/middlewares/basic_auth" ) func addRoutes( diff --git a/pkg/backend/run.go b/pkg/backend/run.go index 2af47ab..d76d189 100644 --- a/pkg/backend/run.go +++ b/pkg/backend/run.go @@ -7,7 +7,7 @@ import ( "net/http" "git.adyxax.org/adyxax/tfstated/pkg/database" - "git.adyxax.org/adyxax/tfstated/pkg/logger" + "git.adyxax.org/adyxax/tfstated/pkg/middlewares/logger" ) func Run( diff --git a/pkg/basic_auth/middleware.go b/pkg/basic_auth/middleware.go deleted file mode 100644 index cb2dcf0..0000000 --- a/pkg/basic_auth/middleware.go +++ /dev/null @@ -1,39 +0,0 @@ -package basic_auth - -import ( - "context" - "fmt" - "net/http" - - "git.adyxax.org/adyxax/tfstated/pkg/database" - "git.adyxax.org/adyxax/tfstated/pkg/helpers" - "git.adyxax.org/adyxax/tfstated/pkg/model" -) - -func Middleware(db *database.DB) func(http.Handler) http.Handler { - return func(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - username, password, ok := r.BasicAuth() - if !ok { - w.Header().Set("WWW-Authenticate", `Basic realm="tfstated", charset="UTF-8"`) - helpers.ErrorResponse(w, http.StatusUnauthorized, fmt.Errorf("Unauthorized")) - return - } - account, err := db.LoadAccountByUsername(username) - if err != nil { - helpers.ErrorResponse(w, http.StatusInternalServerError, err) - return - } - if account == nil || !account.CheckPassword(password) { - helpers.ErrorResponse(w, http.StatusForbidden, fmt.Errorf("Forbidden")) - return - } - if err := db.TouchAccount(account); err != nil { - helpers.ErrorResponse(w, http.StatusInternalServerError, err) - return - } - ctx := context.WithValue(r.Context(), model.AccountContextKey{}, account) - next.ServeHTTP(w, r.WithContext(ctx)) - }) - } -} diff --git a/pkg/logger/body_writer.go b/pkg/logger/body_writer.go deleted file mode 100644 index f57a03e..0000000 --- a/pkg/logger/body_writer.go +++ /dev/null @@ -1,65 +0,0 @@ -package logger - -import ( - "bufio" - "bytes" - "errors" - "net" - "net/http" -) - -type bodyWriter struct { - http.ResponseWriter - body *bytes.Buffer - maxSize int - bytes int - status int -} - -func newBodyWriter(writer http.ResponseWriter, recordBody bool) *bodyWriter { - var body *bytes.Buffer - if recordBody { - body = bytes.NewBufferString("") - } - return &bodyWriter{ - body: body, - maxSize: 64000, - ResponseWriter: writer, - status: http.StatusNotImplemented, - } -} - -// implements http.ResponseWriter -func (w *bodyWriter) Write(b []byte) (int, error) { - if w.body != nil { - if w.body.Len()+len(b) > w.maxSize { - w.body.Write(b[:w.maxSize-w.body.Len()]) - } else { - w.body.Write(b) - } - } - - w.bytes += len(b) - return w.ResponseWriter.Write(b) -} - -// implements http.ResponseWriter -func (w *bodyWriter) WriteHeader(code int) { - w.status = code - w.ResponseWriter.WriteHeader(code) -} - -// implements http.Flusher -func (w *bodyWriter) Flush() { - if f, ok := w.ResponseWriter.(http.Flusher); ok { - f.Flush() - } -} - -// implements http.Hijacker -func (w *bodyWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { - if hi, ok := w.ResponseWriter.(http.Hijacker); ok { - return hi.Hijack() - } - return nil, nil, errors.New("Hijack not supported") -} diff --git a/pkg/logger/middleware.go b/pkg/logger/middleware.go deleted file mode 100644 index 54cca96..0000000 --- a/pkg/logger/middleware.go +++ /dev/null @@ -1,72 +0,0 @@ -package logger - -import ( - "log/slog" - "net/http" - "runtime/debug" - "strconv" - "time" -) - -func Middleware(next http.Handler, recordBody bool) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.Method == "GET" && r.URL.Path == "/healthz" { - next.ServeHTTP(w, r) - return - } - - defer func() { - if err := recover(); err != nil { - w.WriteHeader(http.StatusInternalServerError) - slog.Error( - "panic", - "err", err, - "trace", string(debug.Stack()), - ) - } - }() - start := time.Now() - path := r.URL.Path - query := r.URL.RawQuery - - bw := newBodyWriter(w, recordBody) - - next.ServeHTTP(bw, r) - - end := time.Now() - requestAttributes := []slog.Attr{ - slog.Time("time", start.UTC()), - slog.String("method", r.Method), - slog.String("host", r.Host), - slog.String("path", path), - slog.String("query", query), - slog.String("ip", r.RemoteAddr), - } - responseAttributes := []slog.Attr{ - slog.Time("time", end.UTC()), - slog.Duration("latency", end.Sub(start)), - slog.Int("length", bw.bytes), - slog.Int("status", bw.status), - } - if recordBody { - responseAttributes = append(responseAttributes, slog.String("body", bw.body.String())) - } - attributes := []slog.Attr{ - { - Key: "request", - Value: slog.GroupValue(requestAttributes...), - }, - { - Key: "response", - Value: slog.GroupValue(responseAttributes...), - }, - } - level := slog.LevelInfo - if bw.status >= http.StatusInternalServerError { - level = slog.LevelError - } else if bw.status >= http.StatusBadRequest && bw.status < http.StatusInternalServerError { - level = slog.LevelWarn - } - slog.LogAttrs(r.Context(), level, strconv.Itoa(bw.status)+": "+http.StatusText(bw.status), attributes...) - }) -} diff --git a/pkg/middlewares/basic_auth/middleware.go b/pkg/middlewares/basic_auth/middleware.go new file mode 100644 index 0000000..cb2dcf0 --- /dev/null +++ b/pkg/middlewares/basic_auth/middleware.go @@ -0,0 +1,39 @@ +package basic_auth + +import ( + "context" + "fmt" + "net/http" + + "git.adyxax.org/adyxax/tfstated/pkg/database" + "git.adyxax.org/adyxax/tfstated/pkg/helpers" + "git.adyxax.org/adyxax/tfstated/pkg/model" +) + +func Middleware(db *database.DB) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + username, password, ok := r.BasicAuth() + if !ok { + w.Header().Set("WWW-Authenticate", `Basic realm="tfstated", charset="UTF-8"`) + helpers.ErrorResponse(w, http.StatusUnauthorized, fmt.Errorf("Unauthorized")) + return + } + account, err := db.LoadAccountByUsername(username) + if err != nil { + helpers.ErrorResponse(w, http.StatusInternalServerError, err) + return + } + if account == nil || !account.CheckPassword(password) { + helpers.ErrorResponse(w, http.StatusForbidden, fmt.Errorf("Forbidden")) + return + } + if err := db.TouchAccount(account); err != nil { + helpers.ErrorResponse(w, http.StatusInternalServerError, err) + return + } + ctx := context.WithValue(r.Context(), model.AccountContextKey{}, account) + next.ServeHTTP(w, r.WithContext(ctx)) + }) + } +} diff --git a/pkg/middlewares/logger/body_writer.go b/pkg/middlewares/logger/body_writer.go new file mode 100644 index 0000000..f57a03e --- /dev/null +++ b/pkg/middlewares/logger/body_writer.go @@ -0,0 +1,65 @@ +package logger + +import ( + "bufio" + "bytes" + "errors" + "net" + "net/http" +) + +type bodyWriter struct { + http.ResponseWriter + body *bytes.Buffer + maxSize int + bytes int + status int +} + +func newBodyWriter(writer http.ResponseWriter, recordBody bool) *bodyWriter { + var body *bytes.Buffer + if recordBody { + body = bytes.NewBufferString("") + } + return &bodyWriter{ + body: body, + maxSize: 64000, + ResponseWriter: writer, + status: http.StatusNotImplemented, + } +} + +// implements http.ResponseWriter +func (w *bodyWriter) Write(b []byte) (int, error) { + if w.body != nil { + if w.body.Len()+len(b) > w.maxSize { + w.body.Write(b[:w.maxSize-w.body.Len()]) + } else { + w.body.Write(b) + } + } + + w.bytes += len(b) + return w.ResponseWriter.Write(b) +} + +// implements http.ResponseWriter +func (w *bodyWriter) WriteHeader(code int) { + w.status = code + w.ResponseWriter.WriteHeader(code) +} + +// implements http.Flusher +func (w *bodyWriter) Flush() { + if f, ok := w.ResponseWriter.(http.Flusher); ok { + f.Flush() + } +} + +// implements http.Hijacker +func (w *bodyWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { + if hi, ok := w.ResponseWriter.(http.Hijacker); ok { + return hi.Hijack() + } + return nil, nil, errors.New("Hijack not supported") +} diff --git a/pkg/middlewares/logger/middleware.go b/pkg/middlewares/logger/middleware.go new file mode 100644 index 0000000..54cca96 --- /dev/null +++ b/pkg/middlewares/logger/middleware.go @@ -0,0 +1,72 @@ +package logger + +import ( + "log/slog" + "net/http" + "runtime/debug" + "strconv" + "time" +) + +func Middleware(next http.Handler, recordBody bool) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method == "GET" && r.URL.Path == "/healthz" { + next.ServeHTTP(w, r) + return + } + + defer func() { + if err := recover(); err != nil { + w.WriteHeader(http.StatusInternalServerError) + slog.Error( + "panic", + "err", err, + "trace", string(debug.Stack()), + ) + } + }() + start := time.Now() + path := r.URL.Path + query := r.URL.RawQuery + + bw := newBodyWriter(w, recordBody) + + next.ServeHTTP(bw, r) + + end := time.Now() + requestAttributes := []slog.Attr{ + slog.Time("time", start.UTC()), + slog.String("method", r.Method), + slog.String("host", r.Host), + slog.String("path", path), + slog.String("query", query), + slog.String("ip", r.RemoteAddr), + } + responseAttributes := []slog.Attr{ + slog.Time("time", end.UTC()), + slog.Duration("latency", end.Sub(start)), + slog.Int("length", bw.bytes), + slog.Int("status", bw.status), + } + if recordBody { + responseAttributes = append(responseAttributes, slog.String("body", bw.body.String())) + } + attributes := []slog.Attr{ + { + Key: "request", + Value: slog.GroupValue(requestAttributes...), + }, + { + Key: "response", + Value: slog.GroupValue(responseAttributes...), + }, + } + level := slog.LevelInfo + if bw.status >= http.StatusInternalServerError { + level = slog.LevelError + } else if bw.status >= http.StatusBadRequest && bw.status < http.StatusInternalServerError { + level = slog.LevelWarn + } + slog.LogAttrs(r.Context(), level, strconv.Itoa(bw.status)+": "+http.StatusText(bw.status), attributes...) + }) +} diff --git a/pkg/webui/run.go b/pkg/webui/run.go index 3aaec55..f1b20cb 100644 --- a/pkg/webui/run.go +++ b/pkg/webui/run.go @@ -7,7 +7,7 @@ import ( "net/http" "git.adyxax.org/adyxax/tfstated/pkg/database" - "git.adyxax.org/adyxax/tfstated/pkg/logger" + "git.adyxax.org/adyxax/tfstated/pkg/middlewares/logger" ) func Run( -- cgit v1.2.3