Skip to content

Commit 2ec883d

Browse files
committed
Implement Anthropic Messages API support
Complete Anthropic Messages API implementation Signed-off-by: Eric Curtin <[email protected]>
1 parent 6bef10a commit 2ec883d

File tree

13 files changed

+526
-139
lines changed

13 files changed

+526
-139
lines changed

main.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import (
1111
"syscall"
1212
"time"
1313

14+
"github.com/docker/model-runner/pkg/anthropic"
1415
"github.com/docker/model-runner/pkg/gpuinfo"
1516
"github.com/docker/model-runner/pkg/inference"
1617
"github.com/docker/model-runner/pkg/inference/backends/llamacpp"
@@ -184,6 +185,10 @@ func main() {
184185
ollamaHandler := ollama.NewHTTPHandler(log, scheduler, schedulerHTTP, nil, modelManager)
185186
router.Handle(ollama.APIPrefix+"/", ollamaHandler)
186187

188+
// Add Anthropic Messages API compatibility layer
189+
anthropicHandler := anthropic.NewHandler(log, schedulerHTTP, nil, modelManager)
190+
router.Handle(anthropic.APIPrefix+"/", anthropicHandler)
191+
187192
// Register root handler LAST - it will only catch exact "/" requests that don't match other patterns
188193
router.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
189194
// Only respond to exact root path

pkg/anthropic/handler.go

Lines changed: 174 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,174 @@
1+
package anthropic
2+
3+
import (
4+
"bytes"
5+
"encoding/json"
6+
"errors"
7+
"io"
8+
"net/http"
9+
10+
"github.com/docker/model-runner/pkg/inference"
11+
"github.com/docker/model-runner/pkg/inference/models"
12+
"github.com/docker/model-runner/pkg/inference/scheduling"
13+
"github.com/docker/model-runner/pkg/internal/utils"
14+
"github.com/docker/model-runner/pkg/logging"
15+
"github.com/docker/model-runner/pkg/middleware"
16+
)
17+
18+
const (
19+
// APIPrefix is the prefix for Anthropic API routes.
20+
// llama.cpp implements Anthropic API at /v1/messages, matching the official Anthropic API structure.
21+
APIPrefix = "/anthropic"
22+
23+
// maxRequestBodySize is the maximum allowed size for request bodies (10MB).
24+
maxRequestBodySize = 10 * 1024 * 1024
25+
)
26+
27+
// Handler implements the Anthropic Messages API compatibility layer.
28+
// It forwards requests to the scheduler which proxies to llama.cpp,
29+
// which natively supports the Anthropic Messages API format.
30+
type Handler struct {
31+
log logging.Logger
32+
router *http.ServeMux
33+
httpHandler http.Handler
34+
modelManager *models.Manager
35+
schedulerHTTP *scheduling.HTTPHandler
36+
}
37+
38+
// NewHandler creates a new Anthropic API handler.
39+
func NewHandler(log logging.Logger, schedulerHTTP *scheduling.HTTPHandler, allowedOrigins []string, modelManager *models.Manager) *Handler {
40+
h := &Handler{
41+
log: log,
42+
router: http.NewServeMux(),
43+
schedulerHTTP: schedulerHTTP,
44+
modelManager: modelManager,
45+
}
46+
47+
// Register routes
48+
for route, handler := range h.routeHandlers() {
49+
h.router.HandleFunc(route, handler)
50+
}
51+
52+
// Apply CORS middleware
53+
h.httpHandler = middleware.CorsMiddleware(allowedOrigins, h.router)
54+
55+
return h
56+
}
57+
58+
// ServeHTTP implements the http.Handler interface.
59+
func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
60+
safeMethod := utils.SanitizeForLog(r.Method, -1)
61+
safePath := utils.SanitizeForLog(r.URL.Path, -1)
62+
h.log.Infof("Anthropic API request: %s %s", safeMethod, safePath)
63+
h.httpHandler.ServeHTTP(w, r)
64+
}
65+
66+
// routeHandlers returns the mapping of routes to their handlers.
67+
func (h *Handler) routeHandlers() map[string]http.HandlerFunc {
68+
return map[string]http.HandlerFunc{
69+
// Messages API endpoint - main chat completion endpoint
70+
"POST " + APIPrefix + "/v1/messages": h.handleMessages,
71+
// Token counting endpoint
72+
"POST " + APIPrefix + "/v1/messages/count_tokens": h.handleCountTokens,
73+
}
74+
}
75+
76+
// MessagesRequest represents an Anthropic Messages API request.
77+
// This is used to extract the model field for routing purposes.
78+
type MessagesRequest struct {
79+
Model string `json:"model"`
80+
}
81+
82+
// handleMessages handles POST /anthropic/v1/messages requests.
83+
// It forwards the request to the scheduler which proxies to the llama.cpp backend.
84+
// The llama.cpp backend natively handles the Anthropic Messages API format conversion.
85+
func (h *Handler) handleMessages(w http.ResponseWriter, r *http.Request) {
86+
h.proxyToBackend(w, r, "/v1/messages")
87+
}
88+
89+
// handleCountTokens handles POST /anthropic/v1/messages/count_tokens requests.
90+
// It forwards the request to the scheduler which proxies to the llama.cpp backend.
91+
func (h *Handler) handleCountTokens(w http.ResponseWriter, r *http.Request) {
92+
h.proxyToBackend(w, r, "/v1/messages/count_tokens")
93+
}
94+
95+
// proxyToBackend proxies the request to the llama.cpp backend via the scheduler.
96+
func (h *Handler) proxyToBackend(w http.ResponseWriter, r *http.Request, targetPath string) {
97+
ctx := r.Context()
98+
99+
// Read the request body
100+
body, err := io.ReadAll(http.MaxBytesReader(w, r.Body, maxRequestBodySize))
101+
if err != nil {
102+
var maxBytesError *http.MaxBytesError
103+
if errors.As(err, &maxBytesError) {
104+
h.writeAnthropicError(w, http.StatusRequestEntityTooLarge, "request_too_large", "Request body too large")
105+
} else {
106+
h.writeAnthropicError(w, http.StatusInternalServerError, "internal_error", "Failed to read request body")
107+
}
108+
return
109+
}
110+
111+
// Parse the model field from the request to route to the correct backend
112+
var req MessagesRequest
113+
if err := json.Unmarshal(body, &req); err != nil {
114+
h.writeAnthropicError(w, http.StatusBadRequest, "invalid_request_error", "Invalid JSON in request body")
115+
return
116+
}
117+
118+
if req.Model == "" {
119+
h.writeAnthropicError(w, http.StatusBadRequest, "invalid_request_error", "Missing required field: model")
120+
return
121+
}
122+
123+
// Normalize model name
124+
modelName := models.NormalizeModelName(req.Model)
125+
126+
// Verify the model exists locally
127+
_, err = h.modelManager.GetLocal(modelName)
128+
if err != nil {
129+
h.writeAnthropicError(w, http.StatusNotFound, "not_found_error", "Model not found: "+modelName)
130+
return
131+
}
132+
133+
// Create the proxied request to the inference endpoint
134+
// The scheduler will route to the appropriate backend
135+
newReq := r.Clone(ctx)
136+
newReq.URL.Path = inference.InferencePrefix + targetPath
137+
newReq.Body = io.NopCloser(bytes.NewReader(body))
138+
newReq.ContentLength = int64(len(body))
139+
newReq.Header.Set("Content-Type", "application/json")
140+
newReq.Header.Set(inference.RequestOriginHeader, inference.OriginAnthropicMessages)
141+
142+
// Forward to the scheduler HTTP handler
143+
h.schedulerHTTP.ServeHTTP(w, newReq)
144+
}
145+
146+
// AnthropicError represents an error response in the Anthropic API format.
147+
type AnthropicError struct {
148+
Type string `json:"type"`
149+
Error AnthropicErrorObj `json:"error"`
150+
}
151+
152+
// AnthropicErrorObj represents the error object in an Anthropic error response.
153+
type AnthropicErrorObj struct {
154+
Type string `json:"type"`
155+
Message string `json:"message"`
156+
}
157+
158+
// writeAnthropicError writes an error response in the Anthropic API format.
159+
func (h *Handler) writeAnthropicError(w http.ResponseWriter, statusCode int, errorType, message string) {
160+
w.Header().Set("Content-Type", "application/json")
161+
w.WriteHeader(statusCode)
162+
163+
errResp := AnthropicError{
164+
Type: "error",
165+
Error: AnthropicErrorObj{
166+
Type: errorType,
167+
Message: message,
168+
},
169+
}
170+
171+
if err := json.NewEncoder(w).Encode(errResp); err != nil {
172+
h.log.Errorf("Failed to encode error response: %v", err)
173+
}
174+
}

pkg/anthropic/handler_test.go

Lines changed: 204 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,204 @@
1+
package anthropic
2+
3+
import (
4+
"io"
5+
"net/http"
6+
"net/http/httptest"
7+
"strings"
8+
"testing"
9+
10+
"github.com/sirupsen/logrus"
11+
)
12+
13+
func TestWriteAnthropicError(t *testing.T) {
14+
t.Parallel()
15+
16+
tests := []struct {
17+
name string
18+
statusCode int
19+
errorType string
20+
message string
21+
wantBody string
22+
}{
23+
{
24+
name: "invalid request error",
25+
statusCode: http.StatusBadRequest,
26+
errorType: "invalid_request_error",
27+
message: "Missing required field: model",
28+
wantBody: `{"type":"error","error":{"type":"invalid_request_error","message":"Missing required field: model"}}`,
29+
},
30+
{
31+
name: "not found error",
32+
statusCode: http.StatusNotFound,
33+
errorType: "not_found_error",
34+
message: "Model not found: test-model",
35+
wantBody: `{"type":"error","error":{"type":"not_found_error","message":"Model not found: test-model"}}`,
36+
},
37+
{
38+
name: "internal error",
39+
statusCode: http.StatusInternalServerError,
40+
errorType: "internal_error",
41+
message: "An internal error occurred",
42+
wantBody: `{"type":"error","error":{"type":"internal_error","message":"An internal error occurred"}}`,
43+
},
44+
}
45+
46+
for _, tt := range tests {
47+
t.Run(tt.name, func(t *testing.T) {
48+
t.Parallel()
49+
50+
rec := httptest.NewRecorder()
51+
discard := logrus.New()
52+
discard.SetOutput(io.Discard)
53+
h := &Handler{log: logrus.NewEntry(discard)}
54+
h.writeAnthropicError(rec, tt.statusCode, tt.errorType, tt.message)
55+
56+
if rec.Code != tt.statusCode {
57+
t.Errorf("expected status %d, got %d", tt.statusCode, rec.Code)
58+
}
59+
60+
if contentType := rec.Header().Get("Content-Type"); contentType != "application/json" {
61+
t.Errorf("expected Content-Type application/json, got %s", contentType)
62+
}
63+
64+
body := strings.TrimSpace(rec.Body.String())
65+
if body != tt.wantBody {
66+
t.Errorf("expected body %s, got %s", tt.wantBody, body)
67+
}
68+
})
69+
}
70+
}
71+
72+
func TestRouteHandlers(t *testing.T) {
73+
t.Parallel()
74+
75+
h := &Handler{
76+
router: http.NewServeMux(),
77+
}
78+
79+
routes := h.routeHandlers()
80+
81+
expectedRoutes := []string{
82+
"POST " + APIPrefix + "/v1/messages",
83+
"POST " + APIPrefix + "/v1/messages/count_tokens",
84+
}
85+
86+
for _, route := range expectedRoutes {
87+
if _, exists := routes[route]; !exists {
88+
t.Errorf("expected route %s to be registered", route)
89+
}
90+
}
91+
92+
if len(routes) != len(expectedRoutes) {
93+
t.Errorf("expected %d routes, got %d", len(expectedRoutes), len(routes))
94+
}
95+
}
96+
97+
func TestAPIPrefix(t *testing.T) {
98+
t.Parallel()
99+
100+
if APIPrefix != "/anthropic" {
101+
t.Errorf("expected APIPrefix to be /anthropic, got %s", APIPrefix)
102+
}
103+
}
104+
105+
func TestProxyToBackend_InvalidJSON(t *testing.T) {
106+
t.Parallel()
107+
108+
discard := logrus.New()
109+
discard.SetOutput(io.Discard)
110+
h := &Handler{log: logrus.NewEntry(discard)}
111+
112+
rec := httptest.NewRecorder()
113+
req := httptest.NewRequest(http.MethodPost, "/anthropic/v1/messages", strings.NewReader(`{invalid json`))
114+
115+
h.proxyToBackend(rec, req, "/v1/messages")
116+
117+
if rec.Code != http.StatusBadRequest {
118+
t.Errorf("expected status %d, got %d", http.StatusBadRequest, rec.Code)
119+
}
120+
121+
body := rec.Body.String()
122+
if !strings.Contains(body, "invalid_request_error") {
123+
t.Errorf("expected body to contain 'invalid_request_error', got %s", body)
124+
}
125+
if !strings.Contains(body, "Invalid JSON") {
126+
t.Errorf("expected body to contain 'Invalid JSON', got %s", body)
127+
}
128+
}
129+
130+
func TestProxyToBackend_MissingModel(t *testing.T) {
131+
t.Parallel()
132+
133+
discard := logrus.New()
134+
discard.SetOutput(io.Discard)
135+
h := &Handler{log: logrus.NewEntry(discard)}
136+
137+
rec := httptest.NewRecorder()
138+
req := httptest.NewRequest(http.MethodPost, "/anthropic/v1/messages", strings.NewReader(`{"messages": []}`))
139+
140+
h.proxyToBackend(rec, req, "/v1/messages")
141+
142+
if rec.Code != http.StatusBadRequest {
143+
t.Errorf("expected status %d, got %d", http.StatusBadRequest, rec.Code)
144+
}
145+
146+
body := rec.Body.String()
147+
if !strings.Contains(body, "invalid_request_error") {
148+
t.Errorf("expected body to contain 'invalid_request_error', got %s", body)
149+
}
150+
if !strings.Contains(body, "Missing required field: model") {
151+
t.Errorf("expected body to contain 'Missing required field: model', got %s", body)
152+
}
153+
}
154+
155+
func TestProxyToBackend_EmptyModel(t *testing.T) {
156+
t.Parallel()
157+
158+
discard := logrus.New()
159+
discard.SetOutput(io.Discard)
160+
h := &Handler{log: logrus.NewEntry(discard)}
161+
162+
rec := httptest.NewRecorder()
163+
req := httptest.NewRequest(http.MethodPost, "/anthropic/v1/messages", strings.NewReader(`{"model": ""}`))
164+
165+
h.proxyToBackend(rec, req, "/v1/messages")
166+
167+
if rec.Code != http.StatusBadRequest {
168+
t.Errorf("expected status %d, got %d", http.StatusBadRequest, rec.Code)
169+
}
170+
171+
body := rec.Body.String()
172+
if !strings.Contains(body, "invalid_request_error") {
173+
t.Errorf("expected body to contain 'invalid_request_error', got %s", body)
174+
}
175+
if !strings.Contains(body, "Missing required field: model") {
176+
t.Errorf("expected body to contain 'Missing required field: model', got %s", body)
177+
}
178+
}
179+
180+
func TestProxyToBackend_RequestTooLarge(t *testing.T) {
181+
t.Parallel()
182+
183+
discard := logrus.New()
184+
discard.SetOutput(io.Discard)
185+
h := &Handler{log: logrus.NewEntry(discard)}
186+
187+
// Create a request body that exceeds the maxRequestBodySize (10MB)
188+
// We'll use a reader that simulates a large body without actually allocating it
189+
largeBody := strings.NewReader(`{"model": "test-model", "data": "` + strings.Repeat("x", maxRequestBodySize+1) + `"}`)
190+
191+
rec := httptest.NewRecorder()
192+
req := httptest.NewRequest(http.MethodPost, "/anthropic/v1/messages", largeBody)
193+
194+
h.proxyToBackend(rec, req, "/v1/messages")
195+
196+
if rec.Code != http.StatusRequestEntityTooLarge {
197+
t.Errorf("expected status %d, got %d", http.StatusRequestEntityTooLarge, rec.Code)
198+
}
199+
200+
body := rec.Body.String()
201+
if !strings.Contains(body, "request_too_large") {
202+
t.Errorf("expected body to contain 'request_too_large', got %s", body)
203+
}
204+
}

0 commit comments

Comments
 (0)