diff --git a/hash.go b/hash.go new file mode 100644 index 0000000..951aa20 --- /dev/null +++ b/hash.go @@ -0,0 +1,36 @@ +package httpcache + +import ( + "crypto/sha256" + "encoding/hex" + "fmt" + "net/http" + "slices" + "strings" +) + +type RequestHashFn func(req *http.Request) string + +func simpleRequestHash(req *http.Request) string { + return fmt.Sprintf("%s:%s:%s", req.Method, req.URL.String(), hash(req.Header)) +} + +const delimiter = "|" + +func hash(headers http.Header) string { + keys := make([]string, 0, len(headers)) + + for key := range headers { + keys = append(keys, key) + } + + slices.Sort(keys) + + var sb strings.Builder + for _, key := range keys { + sb.WriteString(fmt.Sprintf("%s:%s%s", key, headers.Get(key), delimiter)) + } + + hash := sha256.Sum256([]byte(sb.String())) + return hex.EncodeToString(hash[:]) +} diff --git a/db.go b/internal/sqlite/db.go similarity index 97% rename from db.go rename to internal/sqlite/db.go index 89af509..3c39218 100644 --- a/db.go +++ b/internal/sqlite/db.go @@ -2,7 +2,7 @@ // versions: // sqlc v1.30.0 -package httpcache +package sqlite import ( "context" diff --git a/models.go b/internal/sqlite/models.go similarity index 93% rename from models.go rename to internal/sqlite/models.go index d4c08f6..a3f162b 100644 --- a/models.go +++ b/internal/sqlite/models.go @@ -2,7 +2,7 @@ // versions: // sqlc v1.30.0 -package httpcache +package sqlite import ( "database/sql" diff --git a/queries.sql b/internal/sqlite/queries.sql similarity index 100% rename from queries.sql rename to internal/sqlite/queries.sql diff --git a/queries.sql.go b/internal/sqlite/queries.sql.go similarity index 98% rename from queries.sql.go rename to internal/sqlite/queries.sql.go index 9b48b97..76dd271 100644 --- a/queries.sql.go +++ b/internal/sqlite/queries.sql.go @@ -3,7 +3,7 @@ // sqlc v1.30.0 // source: queries.sql -package httpcache +package sqlite import ( "context" diff --git a/schema.sql b/internal/sqlite/schema.sql similarity index 100% rename from schema.sql rename to internal/sqlite/schema.sql diff --git a/internal/sqlite/source.go b/internal/sqlite/source.go new file mode 100644 index 0000000..f4918df --- /dev/null +++ b/internal/sqlite/source.go @@ -0,0 +1,8 @@ +package sqlite + +import ( + _ "embed" +) + +//go:embed schema.sql +var Schema string diff --git a/readcloser.go b/readcloser.go index 8c08977..70d2059 100644 --- a/readcloser.go +++ b/readcloser.go @@ -3,7 +3,6 @@ package httpcache import ( "bytes" "context" - "database/sql" "io" "net/http" ) @@ -12,8 +11,8 @@ type cachedReadCloser struct { ctx context.Context original io.ReadCloser buffer *bytes.Buffer - cache *Queries - data func() CacheResponseParams + cache Querier + data func() Params tee io.Reader } @@ -27,19 +26,19 @@ func (b *cachedReadCloser) Close() error { return b.original.Close() } -func newCachedReadCloser(hash string, cache *Queries, resp *http.Response) (*cachedReadCloser) { +func newCachedReadCloser(hash string, cache Querier, resp *http.Response) *cachedReadCloser { buffer := &bytes.Buffer{} return &cachedReadCloser{ ctx: resp.Request.Context(), original: resp.Body, buffer: buffer, cache: cache, - data: func() CacheResponseParams { - return CacheResponseParams{ + data: func() Params { + return Params{ ReqHash: hash, - Body: sql.NullString{String: buffer.String(), Valid: true}, - Headers: sql.NullString{String: "", Valid: true}, - StatusCode: sql.NullInt64{Int64: int64(resp.StatusCode), Valid: true}, + Body: buffer.String(), + Headers: "", + StatusCode: resp.StatusCode, } }, tee: io.TeeReader(resp.Body, buffer), diff --git a/sqlc.yaml b/sqlc.yaml index 475812a..d588f69 100644 --- a/sqlc.yaml +++ b/sqlc.yaml @@ -1,9 +1,9 @@ version: '2' sql: - engine: sqlite - schema: schema.sql - queries: queries.sql + schema: ./internal/sqlite/schema.sql + queries: ./internal/sqlite/queries.sql gen: go: - package: httpcache - out: . + package: sqlite + out: ./internal/sqlite diff --git a/sqlite.go b/sqlite.go index db3cb50..121fdd1 100644 --- a/sqlite.go +++ b/sqlite.go @@ -3,30 +3,66 @@ package httpcache import ( "context" "database/sql" - _ "embed" "fmt" + "strings" + "github.com/cyberbeast/httpcache/internal/sqlite" _ "modernc.org/sqlite" ) -//go:embed schema.sql -var ddl string +type sqliteStore struct{ queries *sqlite.Queries } + +func (s *sqliteStore) CacheResponse(ctx context.Context, arg Params) (Response, error) { + return wrapSQLiteResponse(s.queries.CacheResponse(ctx, sqlite.CacheResponseParams{ + ReqHash: arg.ReqHash, + Body: sql.NullString{String: arg.Body, Valid: true}, + Headers: sql.NullString{String: arg.Headers, Valid: true}, + StatusCode: sql.NullInt64{Int64: int64(arg.StatusCode), Valid: true}, + })) +} + +func (s *sqliteStore) DeleteAllResponses(ctx context.Context) error { + return s.queries.DeleteAllResponses(ctx) +} + +func (s *sqliteStore) GetResponse(ctx context.Context, reqHash string) (Response, error) { + return wrapSQLiteResponse(s.queries.GetResponse(ctx, reqHash)) +} + +func wrapSQLiteResponse(res sqlite.Response, err error) (Response, error) { + return Response{ + ReqHash: res.ReqHash, + Body: res.Body.String, + Headers: res.Headers.String, + StatusCode: int(res.StatusCode.Int64), + UpdatedAt: res.UpdatedAt.String, + }, err +} + +const filePrefix = "file://" type SQLiteSource string func (s SQLiteSource) name() string { return "sqlite" } -func (s SQLiteSource) filepath() string { return "file://" + string(s) } +func (s SQLiteSource) filepath() string { + file := string(s) + if !strings.HasPrefix(file, filePrefix) { + file = filePrefix + file + } + + return file +} -func initSQLiteDB(ctx context.Context, src SQLiteSource) (*sql.DB, error) { - db, err := sql.Open(src.name(), src.filepath()) +func (s SQLiteSource) Init(ctx context.Context) (Querier, error) { + db, err := sql.Open(s.name(), s.filepath()) if err != nil { return nil, fmt.Errorf("opening db: %w", err) } - if _, err := db.ExecContext(ctx, ddl); err != nil { - return db, fmt.Errorf("creating db schema: %w", err) + if _, err := db.ExecContext(ctx, sqlite.Schema); err != nil { + return nil, fmt.Errorf("creating db schema: %w", err) } - return db, nil + return &sqliteStore{queries: sqlite.New(db)}, nil } diff --git a/transport.go b/transport.go index f974e27..6d4a993 100644 --- a/transport.go +++ b/transport.go @@ -1,20 +1,53 @@ package httpcache import ( + "cmp" "context" - "crypto/sha256" - "encoding/hex" - "fmt" "io" "net/http" - "slices" "strings" ) -type RequestHashFn func(req *http.Request) string +type Cache interface { + Init(ctx context.Context) (Querier, error) +} + +type Querier interface { + GetResponse(ctx context.Context, reqHash string) (Response, error) + CacheResponse(ctx context.Context, arg Params) (Response, error) + DeleteAllResponses(ctx context.Context) error +} + +func NewTransport(ctx context.Context, cache Cache, rt http.RoundTripper) (*cachedTransport, error) { + store, err := cache.Init(ctx) + if err != nil { + return nil, err + } + + return &cachedTransport{ + rt: cmp.Or(rt, http.DefaultTransport), + queries: store, + reqHashFn: simpleRequestHash, + }, nil +} + +type Params struct { + ReqHash string + Body string + Headers string + StatusCode int +} + +type Response struct { + ReqHash string + Body string + Headers string + StatusCode int + UpdatedAt string +} type cachedTransport struct { - queries *Queries + queries Querier rt http.RoundTripper reqHashFn RequestHashFn } @@ -30,9 +63,9 @@ func (ct cachedTransport) cachedRoundTrip(req *http.Request) *http.Response { } return &http.Response{ - Body: io.NopCloser(strings.NewReader(res.Body.String)), - StatusCode: int(res.StatusCode.Int64), - Status: http.StatusText(int(res.StatusCode.Int64)), + Body: io.NopCloser(strings.NewReader(res.Body)), + StatusCode: res.StatusCode, + Status: http.StatusText(res.StatusCode), // Header: res.Headers.String, } } @@ -49,44 +82,4 @@ func (ct cachedTransport) RoundTrip(req *http.Request) (*http.Response, error) { res.Body = newCachedReadCloser(ct.reqHashFn(req), ct.queries, res) return res, nil - -} - -func NewTransport(ctx context.Context, src SQLiteSource, rt http.RoundTripper) (*cachedTransport, error) { - db, err := initSQLiteDB(ctx, src) - if err != nil { - return nil, err - } - - if rt == nil { - rt = http.DefaultTransport - } - - return &cachedTransport{ - rt: rt, - queries: New(db), - reqHashFn: func(req *http.Request) string { - return fmt.Sprintf("%s:%s:%s", req.Method, req.URL.String(), hash(req.Header)) - }, - }, nil -} - -const delimiter = "|" - -func hash(headers http.Header) string { - keys := make([]string, 0, len(headers)) - - for key := range headers { - keys = append(keys, key) - } - - slices.Sort(keys) - - var sb strings.Builder - for _, key := range keys { - sb.WriteString(fmt.Sprintf("%s:%s%s", key, headers.Get(key), delimiter)) - } - - hash := sha256.Sum256([]byte(sb.String())) - return hex.EncodeToString(hash[:]) }