|
1 | 1 | package e2e_test |
2 | 2 |
|
3 | 3 | import ( |
4 | | - "bytes" |
5 | 4 | "context" |
6 | | - "io" |
7 | | - "log/slog" |
8 | | - "maps" |
9 | 5 | "net/http" |
10 | 6 | "net/http/httptest" |
11 | 7 | "os" |
12 | 8 | "path/filepath" |
13 | | - "regexp" |
14 | | - "strings" |
15 | 9 | "testing" |
16 | 10 |
|
17 | | - "github.com/labstack/echo/v4" |
18 | 11 | "github.com/stretchr/testify/require" |
19 | | - "gopkg.in/dnaeon/go-vcr.v4/pkg/cassette" |
20 | 12 | "gopkg.in/dnaeon/go-vcr.v4/pkg/recorder" |
21 | 13 |
|
22 | 14 | "github.com/docker/cagent/pkg/config" |
23 | 15 | "github.com/docker/cagent/pkg/environment" |
| 16 | + "github.com/docker/cagent/pkg/fake" |
24 | 17 | ) |
25 | 18 |
|
26 | | -func removeHeadersHook(i *cassette.Interaction) error { |
27 | | - i.Request.Headers = map[string][]string{} |
28 | | - i.Response.Headers = map[string][]string{} |
29 | | - return nil |
30 | | -} |
31 | | - |
32 | | -func customMatcher(t *testing.T) recorder.MatcherFunc { |
| 19 | +func startRecordingAIProxy(t *testing.T) (*httptest.Server, *config.RuntimeConfig) { |
33 | 20 | t.Helper() |
34 | 21 |
|
35 | | - callIDRegex := regexp.MustCompile(`call_[a-z0-9\-]+`) |
36 | | - |
37 | | - return func(r *http.Request, i cassette.Request) bool { |
38 | | - if r.Body == nil || r.Body == http.NoBody { |
39 | | - return cassette.DefaultMatcher(r, i) |
40 | | - } |
41 | | - if r.Method != i.Method { |
42 | | - return false |
43 | | - } |
44 | | - if r.URL.String() != i.URL { |
45 | | - return false |
46 | | - } |
| 22 | + cassettePath := filepath.Join("testdata", "cassettes", t.Name()) |
47 | 23 |
|
48 | | - reqBody, err := io.ReadAll(r.Body) |
| 24 | + // Create a matcher that fails the test on error |
| 25 | + matcher := fake.CustomMatcher(func(err error) { |
49 | 26 | require.NoError(t, err) |
50 | | - r.Body.Close() |
51 | | - r.Body = io.NopCloser(bytes.NewBuffer(reqBody)) |
| 27 | + }) |
52 | 28 |
|
53 | | - // Ignore Gemini's functionResponse's names |
54 | | - return callIDRegex.ReplaceAllString(string(reqBody), "call_ID") == callIDRegex.ReplaceAllString(i.Body, "call_ID") |
| 29 | + // Header updater that adds real API keys for recording |
| 30 | + headerUpdater := func(host string, req *http.Request) { |
| 31 | + switch host { |
| 32 | + case "https://api.openai.com/v1": |
| 33 | + req.Header.Set("Authorization", "Bearer "+os.Getenv("OPENAI_API_KEY")) |
| 34 | + case "https://api.anthropic.com": |
| 35 | + req.Header.Del("Authorization") |
| 36 | + req.Header.Set("X-Api-Key", os.Getenv("ANTHROPIC_API_KEY")) |
| 37 | + case "https://generativelanguage.googleapis.com": |
| 38 | + req.Header.Del("Authorization") |
| 39 | + req.Header.Set("X-Goog-Api-Key", os.Getenv("GOOGLE_API_KEY")) |
| 40 | + case "https://api.mistral.ai/v1": |
| 41 | + req.Header.Set("Authorization", "Bearer "+os.Getenv("MISTRAL_API_KEY")) |
| 42 | + } |
55 | 43 | } |
56 | | -} |
57 | 44 |
|
58 | | -func startRecordingAIProxy(t *testing.T) (*httptest.Server, *config.RuntimeConfig) { |
59 | | - t.Helper() |
60 | | - |
61 | | - transport, err := recorder.New(filepath.Join("testdata", "cassettes", t.Name()), |
62 | | - recorder.WithMode(recorder.ModeRecordOnce), |
63 | | - recorder.WithMatcher(customMatcher(t)), |
64 | | - recorder.WithSkipRequestLatency(true), |
65 | | - recorder.WithHook(removeHeadersHook, recorder.AfterCaptureHook), |
| 45 | + proxyURL, cleanup, err := fake.StartProxyWithOptions( |
| 46 | + cassettePath, |
| 47 | + recorder.ModeRecordOnce, |
| 48 | + matcher, |
| 49 | + headerUpdater, |
66 | 50 | ) |
67 | 51 | require.NoError(t, err) |
68 | 52 |
|
69 | | - t.Cleanup(func() { require.NoError(t, transport.Stop()) }) |
70 | | - |
71 | | - e := echo.New() |
72 | | - e.Any("/*", handle(transport)) |
| 53 | + t.Cleanup(func() { |
| 54 | + require.NoError(t, cleanup()) |
| 55 | + }) |
73 | 56 |
|
74 | | - httpServer := httptest.NewServer(e) |
75 | | - t.Cleanup(httpServer.Close) |
76 | | - |
77 | | - return httpServer, &config.RuntimeConfig{ |
| 57 | + return &httptest.Server{URL: proxyURL}, &config.RuntimeConfig{ |
78 | 58 | Config: config.Config{ |
79 | | - ModelsGateway: httpServer.URL, |
| 59 | + ModelsGateway: proxyURL, |
80 | 60 | }, |
81 | 61 | EnvProviderForTests: &testEnvProvider{ |
82 | 62 | environment.DockerDesktopTokenEnv: "DUMMY", |
83 | 63 | }, |
84 | 64 | } |
85 | 65 | } |
86 | 66 |
|
87 | | -func handle(transport http.RoundTripper) echo.HandlerFunc { |
88 | | - return func(c echo.Context) error { |
89 | | - ctx := c.Request().Context() |
90 | | - |
91 | | - host := c.Request().Header.Get("X-Cagent-Forward") |
92 | | - host = strings.TrimSuffix(host, "/") |
93 | | - |
94 | | - var toTargetURL func(req *http.Request) string |
95 | | - var updateHeaders func(req *http.Request) |
96 | | - switch host { |
97 | | - case "https://api.openai.com/v1": |
98 | | - toTargetURL = func(req *http.Request) string { |
99 | | - return "https://api.openai.com" + req.URL.Redacted() |
100 | | - } |
101 | | - updateHeaders = func(req *http.Request) { |
102 | | - req.Header.Set("Authorization", "Bearer "+os.Getenv("OPENAI_API_KEY")) |
103 | | - } |
104 | | - case "https://api.anthropic.com": |
105 | | - toTargetURL = func(req *http.Request) string { |
106 | | - return "https://api.anthropic.com" + req.URL.Redacted() |
107 | | - } |
108 | | - updateHeaders = func(req *http.Request) { |
109 | | - req.Header.Del("Authorization") |
110 | | - req.Header.Set("X-Api-Key", os.Getenv("ANTHROPIC_API_KEY")) |
111 | | - } |
112 | | - case "https://generativelanguage.googleapis.com": |
113 | | - toTargetURL = func(req *http.Request) string { |
114 | | - return "https://generativelanguage.googleapis.com" + req.URL.Redacted() |
115 | | - } |
116 | | - updateHeaders = func(req *http.Request) { |
117 | | - req.Header.Del("Authorization") |
118 | | - req.Header.Set("X-Goog-Api-Key", os.Getenv("GOOGLE_API_KEY")) |
119 | | - } |
120 | | - case "https://api.mistral.ai/v1": |
121 | | - toTargetURL = func(req *http.Request) string { |
122 | | - return "https://api.mistral.ai" + req.URL.Redacted() |
123 | | - } |
124 | | - updateHeaders = func(req *http.Request) { |
125 | | - req.Header.Set("Authorization", "Bearer "+os.Getenv("MISTRAL_API_KEY")) |
126 | | - } |
127 | | - default: |
128 | | - return echo.NewHTTPError(http.StatusBadRequest, "unknown service host "+host) |
129 | | - } |
130 | | - |
131 | | - targetURL := toTargetURL(c.Request()) |
132 | | - |
133 | | - req, err := http.NewRequestWithContext(ctx, c.Request().Method, targetURL, c.Request().Body) |
134 | | - if err != nil { |
135 | | - return echo.NewHTTPError(http.StatusInternalServerError, "Failed to create new request") |
136 | | - } |
137 | | - |
138 | | - maps.Copy(req.Header, c.Request().Header) |
139 | | - updateHeaders(req) |
140 | | - |
141 | | - client := &http.Client{ |
142 | | - Timeout: 0, // no timeout, let ctx control it |
143 | | - Transport: transport, |
144 | | - } |
145 | | - |
146 | | - resp, err := client.Do(req) |
147 | | - if err != nil { |
148 | | - return echo.NewHTTPError(http.StatusInternalServerError, "Failed to run request"+err.Error()) |
149 | | - } |
150 | | - defer resp.Body.Close() |
151 | | - |
152 | | - maps.Copy(c.Response().Header(), resp.Header) |
153 | | - |
154 | | - c.Response().WriteHeader(resp.StatusCode) |
155 | | - |
156 | | - if isStreamResponse(resp) { |
157 | | - return streamCopy(c, resp) |
158 | | - } |
159 | | - |
160 | | - _, err = io.Copy(c.Response().Writer, resp.Body) |
161 | | - return err |
162 | | - } |
163 | | -} |
164 | | - |
165 | | -func streamCopy(c echo.Context, resp *http.Response) error { |
166 | | - ctx := c.Request().Context() |
167 | | - |
168 | | - writer := c.Response().Writer.(io.ReaderFrom) |
169 | | - |
170 | | - for { |
171 | | - select { |
172 | | - case <-ctx.Done(): |
173 | | - slog.WarnContext(ctx, "client disconnected, stop streaming") |
174 | | - return nil |
175 | | - default: |
176 | | - n, err := writer.ReadFrom(io.LimitReader(resp.Body, 256)) |
177 | | - if n > 0 { |
178 | | - c.Response().Flush() // keep flushing to client |
179 | | - } |
180 | | - if err != nil { |
181 | | - if err == io.EOF || ctx.Err() != nil { |
182 | | - return nil |
183 | | - } |
184 | | - slog.ErrorContext(ctx, "stream read error", "error", err) |
185 | | - return err |
186 | | - } |
187 | | - } |
188 | | - } |
189 | | -} |
190 | | - |
191 | | -func isStreamResponse(resp *http.Response) bool { |
192 | | - ct := strings.ToLower(resp.Header.Get("Content-Type")) |
193 | | - if strings.Contains(ct, "text/event-stream") { |
194 | | - return true |
195 | | - } |
196 | | - |
197 | | - te := strings.ToLower(resp.Header.Get("Transfer-Encoding")) |
198 | | - if strings.Contains(te, "chunked") && !strings.Contains(ct, "application/json") { |
199 | | - return true |
200 | | - } |
201 | | - |
202 | | - return strings.Contains(ct, "application/octet-stream") || |
203 | | - strings.Contains(ct, "application/x-ndjson") || |
204 | | - strings.Contains(ct, "application/stream+json") |
205 | | -} |
206 | | - |
207 | 67 | type testEnvProvider map[string]string |
208 | 68 |
|
209 | 69 | func (p *testEnvProvider) Get(_ context.Context, name string) string { |
|
0 commit comments