diff options
Diffstat (limited to '')
-rw-r--r-- | cmd/tfstated/main.go | 9 | ||||
-rw-r--r-- | cmd/tfstated/main_test.go | 6 | ||||
-rw-r--r-- | cmd/tfstated/post_test.go | 14 | ||||
-rw-r--r-- | pkg/database/db.go | 20 | ||||
-rw-r--r-- | pkg/database/states.go | 3 |
5 files changed, 43 insertions, 9 deletions
diff --git a/cmd/tfstated/main.go b/cmd/tfstated/main.go index 5ead1ed..722ec82 100644 --- a/cmd/tfstated/main.go +++ b/cmd/tfstated/main.go @@ -10,6 +10,7 @@ import ( "net/http" "os" "os/signal" + "strconv" "sync" "time" @@ -42,6 +43,14 @@ func run( if err := db.SetDataEncryptionKey(dataEncryptionKey); err != nil { return err } + versionsHistoryLimit := getenv("VERSIONS_HISTORY_LIMIT") + if versionsHistoryLimit != "" { + n, err := strconv.Atoi(versionsHistoryLimit) + if err != nil { + return fmt.Errorf("failed to parse the VERSIONS_HISTORY_LIMIT environment variable: %w", err) + } + db.SetVersionsHistoryLimit(n) + } mux := http.NewServeMux() addRoutes( diff --git a/cmd/tfstated/main_test.go b/cmd/tfstated/main_test.go index 4b819b3..c52b924 100644 --- a/cmd/tfstated/main_test.go +++ b/cmd/tfstated/main_test.go @@ -18,6 +18,7 @@ var baseURI = url.URL{ Path: "/", Scheme: "http", } +var db *database.DB func TestMain(m *testing.M) { ctx := context.Background() @@ -27,7 +28,8 @@ func TestMain(m *testing.M) { Port: "8081", } _ = os.Remove("./test.db") - db, err := database.NewDB(ctx, "./test.db") + var err error + db, err = database.NewDB(ctx, "./test.db") if err != nil { fmt.Fprintf(os.Stderr, "%+v\n", err) os.Exit(1) @@ -36,6 +38,8 @@ func TestMain(m *testing.M) { switch key { case "DATA_ENCRYPTION_KEY": return "hP3ZSCnY3LMgfTQjwTaGrhKwdA0yXMXIfv67OJnntqM=" + case "VERSIONS_HISTORY_LIMIT": + return "3" default: return "" } diff --git a/cmd/tfstated/post_test.go b/cmd/tfstated/post_test.go index 55308c0..6eb68f2 100644 --- a/cmd/tfstated/post_test.go +++ b/cmd/tfstated/post_test.go @@ -30,6 +30,9 @@ func TestPost(t *testing.T) { {"GET", url.URL{Path: "/test_post"}, nil, "the_test_post4", http.StatusOK, "/test_post"}, {"POST", url.URL{Path: "/test_post"}, strings.NewReader("the_test_post5"), "", http.StatusOK, "without lock ID in query string on a locked state"}, {"GET", url.URL{Path: "/test_post"}, nil, "the_test_post5", http.StatusOK, "/test_post"}, + {"POST", url.URL{Path: "/test_post"}, strings.NewReader("the_test_post6"), "", http.StatusOK, "another post just to make sure the history limit works"}, + {"POST", url.URL{Path: "/test_post"}, strings.NewReader("the_test_post7"), "", http.StatusOK, "another post just to make sure the history limit works"}, + {"POST", url.URL{Path: "/test_post"}, strings.NewReader("the_test_post8"), "", http.StatusOK, "another post just to make sure the history limit works"}, } for _, tt := range tests { runHTTPRequest(tt.method, &tt.uri, tt.body, func(r *http.Response, err error) { @@ -46,4 +49,15 @@ func TestPost(t *testing.T) { } }) } + var n int + err := db.QueryRow(`SELECT COUNT(versions.id) + FROM versions + JOIN states ON states.id = versions.state_id + WHERE states.name = "/test_post"`).Scan(&n) + if err != nil { + t.Fatalf("failed to count versions for the /test_post state: %s", err) + } + if n != 3 { + t.Fatalf("there should only be 3 versions of the /test_post state, got %d", n) + } } diff --git a/pkg/database/db.go b/pkg/database/db.go index 2744570..38265ba 100644 --- a/pkg/database/db.go +++ b/pkg/database/db.go @@ -26,10 +26,11 @@ func initDB(ctx context.Context, url string) (*sql.DB, error) { } type DB struct { - ctx context.Context - dataEncryptionKey scrypto.AES256Key - readDB *sql.DB - writeDB *sql.DB + ctx context.Context + dataEncryptionKey scrypto.AES256Key + readDB *sql.DB + versionsHistoryLimit int + writeDB *sql.DB } func NewDB(ctx context.Context, url string) (*DB, error) { @@ -56,9 +57,10 @@ func NewDB(ctx context.Context, url string) (*DB, error) { writeDB.SetMaxOpenConns(1) db := DB{ - ctx: ctx, - readDB: readDB, - writeDB: writeDB, + ctx: ctx, + readDB: readDB, + versionsHistoryLimit: 64, + writeDB: writeDB, } if _, err = db.Exec("PRAGMA foreign_keys = ON"); err != nil { return nil, err @@ -101,3 +103,7 @@ func (db *DB) QueryRow(query string, args ...any) *sql.Row { func (db *DB) SetDataEncryptionKey(s string) error { return db.dataEncryptionKey.FromBase64(s) } + +func (db *DB) SetVersionsHistoryLimit(n int) { + db.versionsHistoryLimit = n +} diff --git a/pkg/database/states.go b/pkg/database/states.go index 26a1021..5536533 100644 --- a/pkg/database/states.go +++ b/pkg/database/states.go @@ -102,7 +102,8 @@ func (db *DB) SetState(name string, data []byte, lockID string) (bool, error) { JOIN states ON states.id = versions.state_id WHERE states.name = :name ORDER BY versions.id DESC - LIMIT 64));`, + LIMIT :limit));`, + sql.Named("limit", db.versionsHistoryLimit), sql.Named("name", name), ) if err != nil { |