Skip to content

Commit 20b398a

Browse files
authored
Merge pull request #1061 from stanislavHamara/fake-flag-for-e2e-session-replay
Add --fake flag for e2e test session replay
2 parents a0242cc + d8862d5 commit 20b398a

File tree

3 files changed

+274
-170
lines changed

3 files changed

+274
-170
lines changed

cmd/root/api.go

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import (
1010

1111
"github.com/docker/cagent/pkg/cli"
1212
"github.com/docker/cagent/pkg/config"
13+
"github.com/docker/cagent/pkg/fake"
1314
"github.com/docker/cagent/pkg/server"
1415
"github.com/docker/cagent/pkg/session"
1516
"github.com/docker/cagent/pkg/telemetry"
@@ -19,6 +20,7 @@ type apiFlags struct {
1920
listenAddr string
2021
sessionDB string
2122
pullIntervalMins int
23+
fakeResponses string
2224
runConfig config.RuntimeConfig
2325
}
2426

@@ -37,6 +39,7 @@ func newAPICmd() *cobra.Command {
3739
cmd.PersistentFlags().StringVarP(&flags.listenAddr, "listen", "l", ":8080", "Address to listen on")
3840
cmd.PersistentFlags().StringVarP(&flags.sessionDB, "session-db", "s", "session.db", "Path to the session database")
3941
cmd.PersistentFlags().IntVar(&flags.pullIntervalMins, "pull-interval", 0, "Auto-pull OCI reference every N minutes (0 = disabled)")
42+
cmd.PersistentFlags().StringVar(&flags.fakeResponses, "fake", "", "Replay AI responses from cassette file (for testing)")
4043
addRuntimeConfigFlags(cmd, &flags.runConfig)
4144

4245
return cmd
@@ -52,6 +55,22 @@ func (f *apiFlags) runAPICommand(cmd *cobra.Command, args []string) error {
5255
// Make sure no question is ever asked to the user in api mode.
5356
os.Stdin = nil
5457

58+
// Start fake proxy if --fake is specified
59+
if f.fakeResponses != "" {
60+
proxyURL, cleanup, err := fake.StartProxy(f.fakeResponses)
61+
if err != nil {
62+
return fmt.Errorf("failed to start fake proxy: %w", err)
63+
}
64+
defer func() {
65+
if err := cleanup(); err != nil {
66+
slog.Error("Failed to cleanup fake proxy", "error", err)
67+
}
68+
}()
69+
70+
f.runConfig.ModelsGateway = proxyURL
71+
slog.Info("Fake mode enabled", "cassette", f.fakeResponses, "proxy", proxyURL)
72+
}
73+
5574
if f.pullIntervalMins > 0 && !config.IsOCIReference(agentsPath) {
5675
return fmt.Errorf("--pull-interval flag can only be used with OCI references, not local files")
5776
}

e2e/proxy_test.go

Lines changed: 30 additions & 170 deletions
Original file line numberDiff line numberDiff line change
@@ -1,209 +1,69 @@
11
package e2e_test
22

33
import (
4-
"bytes"
54
"context"
6-
"io"
7-
"log/slog"
8-
"maps"
95
"net/http"
106
"net/http/httptest"
117
"os"
128
"path/filepath"
13-
"regexp"
14-
"strings"
159
"testing"
1610

17-
"github.com/labstack/echo/v4"
1811
"github.com/stretchr/testify/require"
19-
"gopkg.in/dnaeon/go-vcr.v4/pkg/cassette"
2012
"gopkg.in/dnaeon/go-vcr.v4/pkg/recorder"
2113

2214
"github.com/docker/cagent/pkg/config"
2315
"github.com/docker/cagent/pkg/environment"
16+
"github.com/docker/cagent/pkg/fake"
2417
)
2518

26-
func removeHeadersHook(i *cassette.Interaction) error {
27-
i.Request.Headers = map[string][]string{}
28-
i.Response.Headers = map[string][]string{}
29-
return nil
30-
}
31-
32-
func customMatcher(t *testing.T) recorder.MatcherFunc {
19+
func startRecordingAIProxy(t *testing.T) (*httptest.Server, *config.RuntimeConfig) {
3320
t.Helper()
3421

35-
callIDRegex := regexp.MustCompile(`call_[a-z0-9\-]+`)
36-
37-
return func(r *http.Request, i cassette.Request) bool {
38-
if r.Body == nil || r.Body == http.NoBody {
39-
return cassette.DefaultMatcher(r, i)
40-
}
41-
if r.Method != i.Method {
42-
return false
43-
}
44-
if r.URL.String() != i.URL {
45-
return false
46-
}
22+
cassettePath := filepath.Join("testdata", "cassettes", t.Name())
4723

48-
reqBody, err := io.ReadAll(r.Body)
24+
// Create a matcher that fails the test on error
25+
matcher := fake.CustomMatcher(func(err error) {
4926
require.NoError(t, err)
50-
r.Body.Close()
51-
r.Body = io.NopCloser(bytes.NewBuffer(reqBody))
27+
})
5228

53-
// Ignore Gemini's functionResponse's names
54-
return callIDRegex.ReplaceAllString(string(reqBody), "call_ID") == callIDRegex.ReplaceAllString(i.Body, "call_ID")
29+
// Header updater that adds real API keys for recording
30+
headerUpdater := func(host string, req *http.Request) {
31+
switch host {
32+
case "https://api.openai.com/v1":
33+
req.Header.Set("Authorization", "Bearer "+os.Getenv("OPENAI_API_KEY"))
34+
case "https://api.anthropic.com":
35+
req.Header.Del("Authorization")
36+
req.Header.Set("X-Api-Key", os.Getenv("ANTHROPIC_API_KEY"))
37+
case "https://generativelanguage.googleapis.com":
38+
req.Header.Del("Authorization")
39+
req.Header.Set("X-Goog-Api-Key", os.Getenv("GOOGLE_API_KEY"))
40+
case "https://api.mistral.ai/v1":
41+
req.Header.Set("Authorization", "Bearer "+os.Getenv("MISTRAL_API_KEY"))
42+
}
5543
}
56-
}
5744

58-
func startRecordingAIProxy(t *testing.T) (*httptest.Server, *config.RuntimeConfig) {
59-
t.Helper()
60-
61-
transport, err := recorder.New(filepath.Join("testdata", "cassettes", t.Name()),
62-
recorder.WithMode(recorder.ModeRecordOnce),
63-
recorder.WithMatcher(customMatcher(t)),
64-
recorder.WithSkipRequestLatency(true),
65-
recorder.WithHook(removeHeadersHook, recorder.AfterCaptureHook),
45+
proxyURL, cleanup, err := fake.StartProxyWithOptions(
46+
cassettePath,
47+
recorder.ModeRecordOnce,
48+
matcher,
49+
headerUpdater,
6650
)
6751
require.NoError(t, err)
6852

69-
t.Cleanup(func() { require.NoError(t, transport.Stop()) })
70-
71-
e := echo.New()
72-
e.Any("/*", handle(transport))
53+
t.Cleanup(func() {
54+
require.NoError(t, cleanup())
55+
})
7356

74-
httpServer := httptest.NewServer(e)
75-
t.Cleanup(httpServer.Close)
76-
77-
return httpServer, &config.RuntimeConfig{
57+
return &httptest.Server{URL: proxyURL}, &config.RuntimeConfig{
7858
Config: config.Config{
79-
ModelsGateway: httpServer.URL,
59+
ModelsGateway: proxyURL,
8060
},
8161
EnvProviderForTests: &testEnvProvider{
8262
environment.DockerDesktopTokenEnv: "DUMMY",
8363
},
8464
}
8565
}
8666

87-
func handle(transport http.RoundTripper) echo.HandlerFunc {
88-
return func(c echo.Context) error {
89-
ctx := c.Request().Context()
90-
91-
host := c.Request().Header.Get("X-Cagent-Forward")
92-
host = strings.TrimSuffix(host, "/")
93-
94-
var toTargetURL func(req *http.Request) string
95-
var updateHeaders func(req *http.Request)
96-
switch host {
97-
case "https://api.openai.com/v1":
98-
toTargetURL = func(req *http.Request) string {
99-
return "https://api.openai.com" + req.URL.Redacted()
100-
}
101-
updateHeaders = func(req *http.Request) {
102-
req.Header.Set("Authorization", "Bearer "+os.Getenv("OPENAI_API_KEY"))
103-
}
104-
case "https://api.anthropic.com":
105-
toTargetURL = func(req *http.Request) string {
106-
return "https://api.anthropic.com" + req.URL.Redacted()
107-
}
108-
updateHeaders = func(req *http.Request) {
109-
req.Header.Del("Authorization")
110-
req.Header.Set("X-Api-Key", os.Getenv("ANTHROPIC_API_KEY"))
111-
}
112-
case "https://generativelanguage.googleapis.com":
113-
toTargetURL = func(req *http.Request) string {
114-
return "https://generativelanguage.googleapis.com" + req.URL.Redacted()
115-
}
116-
updateHeaders = func(req *http.Request) {
117-
req.Header.Del("Authorization")
118-
req.Header.Set("X-Goog-Api-Key", os.Getenv("GOOGLE_API_KEY"))
119-
}
120-
case "https://api.mistral.ai/v1":
121-
toTargetURL = func(req *http.Request) string {
122-
return "https://api.mistral.ai" + req.URL.Redacted()
123-
}
124-
updateHeaders = func(req *http.Request) {
125-
req.Header.Set("Authorization", "Bearer "+os.Getenv("MISTRAL_API_KEY"))
126-
}
127-
default:
128-
return echo.NewHTTPError(http.StatusBadRequest, "unknown service host "+host)
129-
}
130-
131-
targetURL := toTargetURL(c.Request())
132-
133-
req, err := http.NewRequestWithContext(ctx, c.Request().Method, targetURL, c.Request().Body)
134-
if err != nil {
135-
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to create new request")
136-
}
137-
138-
maps.Copy(req.Header, c.Request().Header)
139-
updateHeaders(req)
140-
141-
client := &http.Client{
142-
Timeout: 0, // no timeout, let ctx control it
143-
Transport: transport,
144-
}
145-
146-
resp, err := client.Do(req)
147-
if err != nil {
148-
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to run request"+err.Error())
149-
}
150-
defer resp.Body.Close()
151-
152-
maps.Copy(c.Response().Header(), resp.Header)
153-
154-
c.Response().WriteHeader(resp.StatusCode)
155-
156-
if isStreamResponse(resp) {
157-
return streamCopy(c, resp)
158-
}
159-
160-
_, err = io.Copy(c.Response().Writer, resp.Body)
161-
return err
162-
}
163-
}
164-
165-
func streamCopy(c echo.Context, resp *http.Response) error {
166-
ctx := c.Request().Context()
167-
168-
writer := c.Response().Writer.(io.ReaderFrom)
169-
170-
for {
171-
select {
172-
case <-ctx.Done():
173-
slog.WarnContext(ctx, "client disconnected, stop streaming")
174-
return nil
175-
default:
176-
n, err := writer.ReadFrom(io.LimitReader(resp.Body, 256))
177-
if n > 0 {
178-
c.Response().Flush() // keep flushing to client
179-
}
180-
if err != nil {
181-
if err == io.EOF || ctx.Err() != nil {
182-
return nil
183-
}
184-
slog.ErrorContext(ctx, "stream read error", "error", err)
185-
return err
186-
}
187-
}
188-
}
189-
}
190-
191-
func isStreamResponse(resp *http.Response) bool {
192-
ct := strings.ToLower(resp.Header.Get("Content-Type"))
193-
if strings.Contains(ct, "text/event-stream") {
194-
return true
195-
}
196-
197-
te := strings.ToLower(resp.Header.Get("Transfer-Encoding"))
198-
if strings.Contains(te, "chunked") && !strings.Contains(ct, "application/json") {
199-
return true
200-
}
201-
202-
return strings.Contains(ct, "application/octet-stream") ||
203-
strings.Contains(ct, "application/x-ndjson") ||
204-
strings.Contains(ct, "application/stream+json")
205-
}
206-
20767
type testEnvProvider map[string]string
20868

20969
func (p *testEnvProvider) Get(_ context.Context, name string) string {

0 commit comments

Comments
 (0)