Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions hash.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
package httpcache

import (
"crypto/sha256"
"encoding/hex"
"fmt"
"net/http"
"slices"
"strings"
)

type RequestHashFn func(req *http.Request) string

func simpleRequestHash(req *http.Request) string {
return fmt.Sprintf("%s:%s:%s", req.Method, req.URL.String(), hash(req.Header))
}

const delimiter = "|"

func hash(headers http.Header) string {
keys := make([]string, 0, len(headers))

for key := range headers {
keys = append(keys, key)
}

slices.Sort(keys)

var sb strings.Builder
for _, key := range keys {
sb.WriteString(fmt.Sprintf("%s:%s%s", key, headers.Get(key), delimiter))
}

hash := sha256.Sum256([]byte(sb.String()))
return hex.EncodeToString(hash[:])
}
2 changes: 1 addition & 1 deletion db.go → internal/sqlite/db.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion models.go → internal/sqlite/models.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

File renamed without changes.
2 changes: 1 addition & 1 deletion queries.sql.go → internal/sqlite/queries.sql.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

File renamed without changes.
8 changes: 8 additions & 0 deletions internal/sqlite/source.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
package sqlite

import (
_ "embed"
)

//go:embed schema.sql
var Schema string
17 changes: 8 additions & 9 deletions readcloser.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package httpcache
import (
"bytes"
"context"
"database/sql"
"io"
"net/http"
)
Expand All @@ -12,8 +11,8 @@ type cachedReadCloser struct {
ctx context.Context
original io.ReadCloser
buffer *bytes.Buffer
cache *Queries
data func() CacheResponseParams
cache Querier
data func() Params
tee io.Reader
}

Expand All @@ -27,19 +26,19 @@ func (b *cachedReadCloser) Close() error {
return b.original.Close()
}

func newCachedReadCloser(hash string, cache *Queries, resp *http.Response) (*cachedReadCloser) {
func newCachedReadCloser(hash string, cache Querier, resp *http.Response) *cachedReadCloser {
buffer := &bytes.Buffer{}
return &cachedReadCloser{
ctx: resp.Request.Context(),
original: resp.Body,
buffer: buffer,
cache: cache,
data: func() CacheResponseParams {
return CacheResponseParams{
data: func() Params {
return Params{
ReqHash: hash,
Body: sql.NullString{String: buffer.String(), Valid: true},
Headers: sql.NullString{String: "", Valid: true},
StatusCode: sql.NullInt64{Int64: int64(resp.StatusCode), Valid: true},
Body: buffer.String(),
Headers: "",
StatusCode: resp.StatusCode,
}
},
tee: io.TeeReader(resp.Body, buffer),
Expand Down
8 changes: 4 additions & 4 deletions sqlc.yaml
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
version: '2'
sql:
- engine: sqlite
schema: schema.sql
queries: queries.sql
schema: ./internal/sqlite/schema.sql
queries: ./internal/sqlite/queries.sql
gen:
go:
package: httpcache
out: .
package: sqlite
out: ./internal/sqlite
54 changes: 45 additions & 9 deletions sqlite.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,30 +3,66 @@ package httpcache
import (
"context"
"database/sql"
_ "embed"
"fmt"
"strings"

"github.com/cyberbeast/httpcache/internal/sqlite"
_ "modernc.org/sqlite"
)

//go:embed schema.sql
var ddl string
type sqliteStore struct{ queries *sqlite.Queries }

func (s *sqliteStore) CacheResponse(ctx context.Context, arg Params) (Response, error) {
return wrapSQLiteResponse(s.queries.CacheResponse(ctx, sqlite.CacheResponseParams{
ReqHash: arg.ReqHash,
Body: sql.NullString{String: arg.Body, Valid: true},
Headers: sql.NullString{String: arg.Headers, Valid: true},
StatusCode: sql.NullInt64{Int64: int64(arg.StatusCode), Valid: true},
}))
}

func (s *sqliteStore) DeleteAllResponses(ctx context.Context) error {
return s.queries.DeleteAllResponses(ctx)
}

func (s *sqliteStore) GetResponse(ctx context.Context, reqHash string) (Response, error) {
return wrapSQLiteResponse(s.queries.GetResponse(ctx, reqHash))
}

func wrapSQLiteResponse(res sqlite.Response, err error) (Response, error) {
return Response{
ReqHash: res.ReqHash,
Body: res.Body.String,
Headers: res.Headers.String,
StatusCode: int(res.StatusCode.Int64),
UpdatedAt: res.UpdatedAt.String,
}, err
}

const filePrefix = "file://"

type SQLiteSource string

func (s SQLiteSource) name() string { return "sqlite" }

func (s SQLiteSource) filepath() string { return "file://" + string(s) }
func (s SQLiteSource) filepath() string {
file := string(s)
if !strings.HasPrefix(file, filePrefix) {
file = filePrefix + file
}

return file
}

func initSQLiteDB(ctx context.Context, src SQLiteSource) (*sql.DB, error) {
db, err := sql.Open(src.name(), src.filepath())
func (s SQLiteSource) Init(ctx context.Context) (Querier, error) {
db, err := sql.Open(s.name(), s.filepath())
if err != nil {
return nil, fmt.Errorf("opening db: %w", err)
}

if _, err := db.ExecContext(ctx, ddl); err != nil {
return db, fmt.Errorf("creating db schema: %w", err)
if _, err := db.ExecContext(ctx, sqlite.Schema); err != nil {
return nil, fmt.Errorf("creating db schema: %w", err)
}

return db, nil
return &sqliteStore{queries: sqlite.New(db)}, nil
}
91 changes: 42 additions & 49 deletions transport.go
Original file line number Diff line number Diff line change
@@ -1,20 +1,53 @@
package httpcache

import (
"cmp"
"context"
"crypto/sha256"
"encoding/hex"
"fmt"
"io"
"net/http"
"slices"
"strings"
)

type RequestHashFn func(req *http.Request) string
type Cache interface {
Init(ctx context.Context) (Querier, error)
}

type Querier interface {
GetResponse(ctx context.Context, reqHash string) (Response, error)
CacheResponse(ctx context.Context, arg Params) (Response, error)
DeleteAllResponses(ctx context.Context) error
}

func NewTransport(ctx context.Context, cache Cache, rt http.RoundTripper) (*cachedTransport, error) {
store, err := cache.Init(ctx)
if err != nil {
return nil, err
}

return &cachedTransport{
rt: cmp.Or(rt, http.DefaultTransport),
queries: store,
reqHashFn: simpleRequestHash,
}, nil
}

type Params struct {
ReqHash string
Body string
Headers string
StatusCode int
}

type Response struct {
ReqHash string
Body string
Headers string
StatusCode int
UpdatedAt string
}

type cachedTransport struct {
queries *Queries
queries Querier
rt http.RoundTripper
reqHashFn RequestHashFn
}
Expand All @@ -30,9 +63,9 @@ func (ct cachedTransport) cachedRoundTrip(req *http.Request) *http.Response {
}

return &http.Response{
Body: io.NopCloser(strings.NewReader(res.Body.String)),
StatusCode: int(res.StatusCode.Int64),
Status: http.StatusText(int(res.StatusCode.Int64)),
Body: io.NopCloser(strings.NewReader(res.Body)),
StatusCode: res.StatusCode,
Status: http.StatusText(res.StatusCode),
// Header: res.Headers.String,
}
}
Expand All @@ -49,44 +82,4 @@ func (ct cachedTransport) RoundTrip(req *http.Request) (*http.Response, error) {
res.Body = newCachedReadCloser(ct.reqHashFn(req), ct.queries, res)

return res, nil

}

func NewTransport(ctx context.Context, src SQLiteSource, rt http.RoundTripper) (*cachedTransport, error) {
db, err := initSQLiteDB(ctx, src)
if err != nil {
return nil, err
}

if rt == nil {
rt = http.DefaultTransport
}

return &cachedTransport{
rt: rt,
queries: New(db),
reqHashFn: func(req *http.Request) string {
return fmt.Sprintf("%s:%s:%s", req.Method, req.URL.String(), hash(req.Header))
},
}, nil
}

const delimiter = "|"

func hash(headers http.Header) string {
keys := make([]string, 0, len(headers))

for key := range headers {
keys = append(keys, key)
}

slices.Sort(keys)

var sb strings.Builder
for _, key := range keys {
sb.WriteString(fmt.Sprintf("%s:%s%s", key, headers.Get(key), delimiter))
}

hash := sha256.Sum256([]byte(sb.String()))
return hex.EncodeToString(hash[:])
}