1
0
Fork 0

[golang] fixed golang api client design mistakes

This commit is contained in:
Julien Dessaux 2024-05-28 13:13:13 +02:00
parent d0f6c4343e
commit 0d00bf9fd2
Signed by: adyxax
GPG key ID: F92E51B86E07177E
9 changed files with 268 additions and 149 deletions

View file

@ -2,6 +2,7 @@ package main
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"log/slog" "log/slog"
"os" "os"
@ -12,67 +13,74 @@ import (
) )
func main() { func main() {
opts := &slog.HandlerOptions{ var opts *slog.HandlerOptions
// //AddSource: true, if os.Getenv("SPACETRADERS_DEBUG") != "" {
opts = &slog.HandlerOptions{
//AddSource: true,
Level: slog.LevelDebug, Level: slog.LevelDebug,
} }
//logger := slog.New(slog.NewJSONHandler(os.Stdout, opts)) }
logger := slog.New(slog.NewTextHandler(os.Stdout, opts)) logger := slog.New(slog.NewJSONHandler(os.Stdout, opts))
slog.SetDefault(logger) slog.SetDefault(logger)
ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt) ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt)
defer cancel() defer cancel()
db, err := database.DBInit(ctx, "./spacetraders.db") db, err := database.DBInit(ctx, "./spacetraders.db")
if err != nil { if err != nil {
fmt.Fprintf(os.Stderr, "DbInit error %+v\n", err) fmt.Fprintf(os.Stderr, "DbInit error %+v\n", err)
os.Exit(1) os.Exit(1)
} }
client := api.NewClient(ctx) client := api.NewClient(ctx)
defer client.Close() defer client.Close()
err = run( //ctx, if err := run(
db, db,
client, client,
//os.Args, ); err != nil {
//os.Getenv,
//os.Getwd,
//os.Stdin,
//os.Stdout,
//os.Stderr,
)
if err != nil {
fmt.Fprintf(os.Stderr, "%s\n", err) fmt.Fprintf(os.Stderr, "%s\n", err)
if err = db.Close(); err != nil { if err := db.Close(); err != nil {
fmt.Fprintf(os.Stderr, "%s\n", err) fmt.Fprintf(os.Stderr, "%s\n", err)
} }
os.Exit(2) os.Exit(2)
} }
} }
func run( // ctx context.Context, func run(
db *database.DB, db *database.DB,
client *api.Client, client *api.Client,
//args []string, ) error {
//getenv func(string) string,
//getwd func() (string, error),
//stdin io.Reader,
//stdout, stderr io.Writer,
) (err error) {
// ----- Get token or register --------------------------------------------- // ----- Get token or register ---------------------------------------------
r, err := client.Register("COSMIC", "ADYXAX-GO")
if err != nil {
apiError := &api.APIError{}
if errors.As(err, &apiError) {
switch apiError.Code {
case 4111: // Agent symbol has already been claimed
token, err := db.GetToken() token, err := db.GetToken()
if err != nil || token == "" { if err != nil || token == "" {
var r api.APIMessage[api.RegisterMessage, any] return fmt.Errorf("failed to register and failed to get a token from the database: someone stole are agent's callsign: %w", err)
if r, err = client.Register("COSMIC", "ADYXAX-GO"); err != nil {
// TODO handle server reset
fmt.Printf("%+v, %+v\n", r, err)
return err
}
if err = db.AddToken(r.Data.Token); err != nil {
return err
}
} }
client.SetToken(token) client.SetToken(token)
default:
return fmt.Errorf("failed to register: %w\n", err)
}
} else {
return fmt.Errorf("failed to register: %w\n", err)
}
} else {
token, err := db.GetToken()
if err != nil || token == "" {
if err := db.AddToken(r.Token); err != nil {
return fmt.Errorf("failed to save token: %w", err)
}
client.SetToken(r.Token)
} else {
return fmt.Errorf("TODO server reset not implemented yet")
}
}
// ----- Update agent ------------------------------------------------------ // ----- Update agent ------------------------------------------------------
agent, err := client.MyAgent() agent, err := client.MyAgent()
slog.Info("agent", "agent", agent, "err", err) slog.Info("agent", "agent", agent, "err", err)
return err return nil
} }

View file

@ -1,5 +1,11 @@
package api package api
import (
"encoding/json"
"fmt"
"net/url"
)
type AgentMessage struct { type AgentMessage struct {
AccountID string `json:"accountId"` AccountID string `json:"accountId"`
Credits int `json:"credits"` Credits int `json:"credits"`
@ -9,6 +15,18 @@ type AgentMessage struct {
Symbol string `json:"symbol"` Symbol string `json:"symbol"`
} }
func (c *Client) MyAgent() (APIMessage[AgentMessage, any], error) { func (c *Client) MyAgent() (*AgentMessage, error) {
return Send[AgentMessage](c, "GET", "/my/agent", nil) uriRef := url.URL{Path: "my/agent"}
msg, err := c.Send("GET", &uriRef, nil)
if err != nil {
return nil, err
}
if msg.Error != nil {
return nil, msg.Error
}
var response AgentMessage
if err := json.Unmarshal(msg.Data, &response); err != nil {
return nil, fmt.Errorf("failed to unmarshal agent data: %w", err)
}
return &response, nil
} }

View file

@ -2,86 +2,143 @@ package api
import ( import (
"bytes" "bytes"
"container/heap"
"encoding/json" "encoding/json"
"fmt"
"io" "io"
"log/slog" "log/slog"
"net/http" "net/http"
"net/url"
"time" "time"
) )
type Error[T any] struct { type APIError struct {
Code int `json:"code"` Code int `json:"code"`
Data T `json:"data"` Data json.RawMessage `json:"data"`
Message string `json:"message"` Message string `json:"message"`
} }
type APIMessage[T any, E any] struct { func (e *APIError) Error() string {
Data T `json:"data"` return fmt.Sprintf("unhandled APIError code %d, message \"%s\", data: %s", e.Code, e.Message, string(e.Data))
Error Error[E] `json:"error"` }
type APIMessage struct {
Data json.RawMessage `json:"data"`
Error *APIError `json:"error"`
//meta //meta
} }
type Request struct {
index int
priority int
method string
uri *url.URL
payload any
responseChannel chan *Response
}
type Response struct { type Response struct {
Response []byte Message *APIMessage
Err error Err error
} }
func Send[T any](c *Client, method, path string, payload any) (message APIMessage[T, any], err error) { func (c *Client) Send(method string, uriRef *url.URL, payload any) (*APIMessage, error) {
resp := make(chan *Response) responseChannel := make(chan *Response)
c.channel <- &Request{ c.requestsChannel <- &Request{
method: method, method: method,
path: path,
payload: payload, payload: payload,
priority: 10, priority: 10,
resp: resp, responseChannel: responseChannel,
uri: c.baseURI.ResolveReference(uriRef),
} }
res := <-resp res := <-responseChannel
if res.Err != nil { return res.Message, res.Err
return message, res.Err
}
err = json.Unmarshal(res.Response, &message)
return message, err
} }
func (c *Client) sendOne(method, path string, payload any) (body []byte, err error) { func queueProcessor(client *Client) {
slog.Debug("Request", "method", method, "path", path, "payload", payload) var (
var req *http.Request req *Request
ok bool
)
for {
// The queue is empty so we do this blocking call
select {
case <-client.ctx.Done():
return
case req, ok = <-client.requestsChannel:
if !ok {
return
}
heap.Push(client.pq, req)
}
// we enqueue all values read from the channel and process the queue's
// contents until empty. We keep reading the channel as long as this
// emptying goes on
for {
select {
case <-client.ctx.Done():
return
case req, ok = <-client.requestsChannel:
if !ok {
return
}
heap.Push(client.pq, req)
default:
if client.pq.Len() == 0 {
break
}
// we process one
if req, ok = heap.Pop(client.pq).(*Request); !ok {
panic("queueProcessor got something other than a Request on its channel")
}
msg, err := client.sendOne(req.method, req.uri, req.payload)
req.responseChannel <- &Response{
Message: msg,
Err: err,
}
}
}
}
}
func (c *Client) sendOne(method string, uri *url.URL, payload any) (*APIMessage, error) {
slog.Debug("request", "method", method, "path", uri.Path, "payload", payload)
var payloadReader io.Reader
if payload != nil { if payload != nil {
body, err = json.Marshal(payload) if body, err := json.Marshal(payload); err != nil {
if err == nil { return nil, fmt.Errorf("failed to marshal payload: %w", err)
req, err = http.NewRequest(method, c.baseURL+path, bytes.NewBuffer(body))
} else { } else {
return nil, err payloadReader = bytes.NewReader(body)
} }
} else {
req, err = http.NewRequest(method, c.baseURL+path, nil)
} }
req, err := http.NewRequestWithContext(c.ctx, method, uri.String(), payloadReader)
if err != nil { if err != nil {
return nil, err return nil, fmt.Errorf("failed to create request: %w", err)
} }
req.Header = *c.headers req.Header = *c.headers
req = req.WithContext(c.ctx)
resp, err := c.httpClient.Do(req) resp, err := c.httpClient.Do(req)
if err != nil { if err != nil {
slog.Error("sendOne Do", "method", method, "path", path, "error", err) return nil, fmt.Errorf("failed to do request: %w", err)
return nil, err
} }
defer func() { defer resp.Body.Close()
if e := resp.Body.Close(); err == nil { body, err := io.ReadAll(resp.Body)
err = e if err != nil {
return nil, fmt.Errorf("failed to read response body: %w", err)
} }
}()
if body, err = io.ReadAll(resp.Body); err != nil { var msg APIMessage
slog.Error("sendOne ReadAll", "method", method, "path", path, "error", err) if err = json.Unmarshal(body, &msg); err != nil {
return nil, err return nil, fmt.Errorf("failed to unmarshal response body: %w", err)
} }
slog.Debug("Response", "body", string(body)) slog.Debug("response", "code", resp.StatusCode, "message", msg)
switch resp.StatusCode { switch resp.StatusCode {
case 429: case 429:
e := decode429(body) e := decodeRateLimitError(msg.Error.Data)
time.Sleep(time.Duration(e.Error.Data.RetryAfter * float64(time.Second))) time.Sleep(e.RetryAfter.Duration() * time.Second)
return c.sendOne(method, path, payload) return c.sendOne(method, uri, payload)
} }
return body, nil return &msg, nil
} }

View file

@ -4,12 +4,13 @@ import (
"container/heap" "container/heap"
"context" "context"
"net/http" "net/http"
"net/url"
"time" "time"
) )
type Client struct { type Client struct {
baseURL string baseURI *url.URL
channel chan *Request requestsChannel chan *Request
ctx context.Context ctx context.Context
headers *http.Header headers *http.Header
httpClient *http.Client httpClient *http.Client
@ -17,11 +18,15 @@ type Client struct {
} }
func NewClient(ctx context.Context) *Client { func NewClient(ctx context.Context) *Client {
baseURI, err := url.Parse("https://api.spacetraders.io/v2/")
if err != nil {
panic("baseURI failed to parse")
}
pq := make(PriorityQueue, 0) pq := make(PriorityQueue, 0)
heap.Init(&pq) heap.Init(&pq)
client := &Client{ client := &Client{
baseURL: "https://api.spacetraders.io/v2", baseURI: baseURI,
channel: make(chan *Request), requestsChannel: make(chan *Request),
ctx: ctx, ctx: ctx,
headers: &http.Header{ headers: &http.Header{
"Content-Type": {"application/json"}, "Content-Type": {"application/json"},
@ -36,40 +41,9 @@ func NewClient(ctx context.Context) *Client {
} }
func (c *Client) Close() { func (c *Client) Close() {
close(c.channel) close(c.requestsChannel)
} }
func (c *Client) SetToken(token string) { func (c *Client) SetToken(token string) {
c.headers.Set("Authorization", "Bearer "+token) c.headers.Set("Authorization", "Bearer "+token)
} }
func queueProcessor(client *Client) {
var ok bool
for {
// The queue is empty so we do this blocking call
req := <-client.channel
heap.Push(client.pq, req)
// we enqueue all values read from the channel and process the queue's
// contents until empty. We keep reading the channel as long as this
// emptying goes on
for {
select {
case req = <-client.channel:
heap.Push(client.pq, req)
default:
if client.pq.Len() == 0 {
break
}
// we process one
if req, ok = heap.Pop(client.pq).(*Request); !ok {
panic("queueProcessor got something other than a Request on its channel")
}
response, err := client.sendOne(req.method, req.path, req.payload)
req.resp <- &Response{
Response: response,
Err: err,
}
}
}
}
}

View file

@ -0,0 +1,38 @@
package api
import (
"encoding/json"
"errors"
"time"
)
type Duration time.Duration
func (d *Duration) Duration() time.Duration {
return time.Duration(*d)
}
func (d *Duration) MarshalJSON() ([]byte, error) {
return json.Marshal(time.Duration(*d).String())
}
func (d *Duration) UnmarshalJSON(b []byte) error {
var v interface{}
if err := json.Unmarshal(b, &v); err != nil {
return err
}
switch value := v.(type) {
case float64:
*d = Duration(time.Duration(value))
return nil
case string:
tmp, err := time.ParseDuration(value)
if err != nil {
return err
}
*d = Duration(tmp)
return nil
default:
return errors.New("invalid duration")
}
}

View file

@ -6,18 +6,37 @@ import (
"time" "time"
) )
// ----- 429 --------------------------------------------------------------------
type RateLimitError struct { type RateLimitError struct {
LimitType string `json:"type"` LimitType string `json:"type"`
RetryAfter float64 `json:"retryAfter"` RetryAfter Duration `json:"retryAfter"`
LimitBurst int `json:"limitBurst"` LimitBurst int `json:"limitBurst"`
LimitPerSecond int `json:"limitPerSecond"` LimitPerSecond int `json:"limitPerSecond"`
Remaining int `json:"remaining"` Remaining int `json:"remaining"`
Reset time.Time `json:"reset"` Reset time.Time `json:"reset"`
} }
func decode429(msg []byte) (e APIMessage[any, RateLimitError]) { func decodeRateLimitError(msg json.RawMessage) RateLimitError {
var e RateLimitError
if err := json.Unmarshal(msg, &e); err != nil { if err := json.Unmarshal(msg, &e); err != nil {
panic(fmt.Sprintf("Failed to decode419: %+v", err)) panic(fmt.Errorf("Failed to decode iapi error code 429 RateLimitError: %v, %w", msg, err))
}
return e
}
// ----- 4214 -------------------------------------------------------------------
type ShipInTransitError struct {
Arrival time.Time `json:"arrival"`
DepartureSymbol string `json:"departureSymbol"`
DepartureTime time.Time `json:"departureTime"`
DestinationSymbol string `json:"destinationSymbol"`
SecondsToArrival Duration `json:"secondsToArrival"`
}
func decodeShipInTransitError(msg json.RawMessage) ShipInTransitError {
var e ShipInTransitError
if err := json.Unmarshal(msg, &e); err != nil {
panic(fmt.Errorf("Failed to decode api error code 4214 ShipInTransitError: %v, %w", msg, err))
} }
return e return e
} }

View file

@ -1,15 +1,5 @@
package api package api
type Request struct {
index int
priority int
method string
path string
payload any
resp chan *Response
}
type PriorityQueue []*Request type PriorityQueue []*Request
func (pq PriorityQueue) Len() int { func (pq PriorityQueue) Len() int {

View file

@ -1,5 +1,11 @@
package api package api
import (
"encoding/json"
"fmt"
"net/url"
)
type RegisterMessage struct { type RegisterMessage struct {
//agent //agent
//contract //contract
@ -8,13 +14,25 @@ type RegisterMessage struct {
Token string `json:"token"` Token string `json:"token"`
} }
func (c *Client) Register(faction, symbol string) (APIMessage[RegisterMessage, any], error) { func (c *Client) Register(faction, symbol string) (*RegisterMessage, error) {
type RegisterRequest struct { type RegisterRequest struct {
Faction string `json:"faction"` Faction string `json:"faction"`
Symbol string `json:"symbol"` Symbol string `json:"symbol"`
} }
return Send[RegisterMessage](c, "POST", "/register", RegisterRequest{ uriRef := url.URL{Path: "register"}
msg, err := c.Send("POST", &uriRef, RegisterRequest{
Faction: faction, Faction: faction,
Symbol: symbol, Symbol: symbol,
}) })
if err != nil {
return nil, err
}
if msg.Error != nil {
return nil, msg.Error
}
var response RegisterMessage
if err := json.Unmarshal(msg.Data, &response); err != nil {
return nil, fmt.Errorf("failed to unmarshal register data: %w", err)
}
return &response, nil
} }

View file

@ -1,14 +1,11 @@
package database package database
func (db DB) AddToken(token string) error { func (db *DB) AddToken(token string) error {
_, err := db.db.ExecContext(db.ctx, `INSERT INTO tokens(data) VALUES (?);`, token) _, err := db.db.ExecContext(db.ctx, `INSERT INTO tokens(data) VALUES (?);`, token)
return err return err
} }
func (db DB) GetToken() (string, error) { func (db *DB) GetToken() (token string, err error) {
var token string err = db.db.QueryRowContext(db.ctx, `SELECT data FROM tokens;`).Scan(&token)
if err := db.db.QueryRowContext(db.ctx, `SELECT data FROM tokens;`).Scan(&token); err != nil { return token, err
return "", err
}
return token, nil
} }