Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions api.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,10 @@ func AsActor(ctx context.Context, actorID string, metadata recorder.Metadata) co
return aibcontext.AsActor(ctx, actorID, metadata)
}

func WithOriginalHost(ctx context.Context, host string) context.Context {
return aibcontext.WithOriginalHost(ctx, host)
}

func NewAnthropicProvider(cfg config.Anthropic, bedrockCfg *config.AWSBedrock) provider.Provider {
return provider.NewAnthropic(cfg, bedrockCfg)
}
Expand Down
4 changes: 2 additions & 2 deletions bridge_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ func TestPassthroughRoutesForProviders(t *testing.T) {
name: "copilot_no_base_path",
requestPath: "/copilot/models",
provider: func(baseURL string) provider.Provider {
return NewCopilotProvider(config.Copilot{BaseURL: baseURL})
return NewCopilotProvider(config.Copilot{DefaultUpstreamURL: baseURL})
},
expectPath: "/models",
},
Expand All @@ -71,7 +71,7 @@ func TestPassthroughRoutesForProviders(t *testing.T) {
baseURLPath: "/v1",
requestPath: "/copilot/models",
provider: func(baseURL string) provider.Provider {
return NewCopilotProvider(config.Copilot{BaseURL: baseURL})
return NewCopilotProvider(config.Copilot{DefaultUpstreamURL: baseURL})
},
expectPath: "/v1/models",
},
Expand Down
8 changes: 5 additions & 3 deletions config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,9 @@ func DefaultCircuitBreaker() CircuitBreaker {
}

type Copilot struct {
BaseURL string
APIDumpDir string
CircuitBreaker *CircuitBreaker
// DefaultUpstreamURL is the fallback upstream URL when no upstream
// header is provided in the request.
DefaultUpstreamURL string
APIDumpDir string
CircuitBreaker *CircuitBreaker
}
15 changes: 14 additions & 1 deletion context/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@ import (
)

type (
actorContextKey struct{}
actorContextKey struct{}
originalHostContextKey struct{}
)

type Actor struct {
Expand All @@ -28,6 +29,18 @@ func ActorFromContext(ctx context.Context) *Actor {
return a
}

// WithOriginalHost stores the original destination host in the context.
func WithOriginalHost(ctx context.Context, host string) context.Context {
return context.WithValue(ctx, originalHostContextKey{}, host)
}

// OriginalHostFromContext retrieves the original destination host from the context.
// Returns an empty string if not set.
func OriginalHostFromContext(ctx context.Context) string {
h, _ := ctx.Value(originalHostContextKey{}).(string)
return h
}

// ActorIDFromContext safely extracts the actor ID from the context.
// Returns an empty string if no actor is found.
func ActorIDFromContext(ctx context.Context) string {
Expand Down
2 changes: 1 addition & 1 deletion internal/integrationtest/apidump_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ func TestAPIDumpPassthrough(t *testing.T) {
{
name: "copilot",
providerFunc: func(addr string, dumpDir string) aibridge.Provider {
return provider.NewCopilot(config.Copilot{BaseURL: addr, APIDumpDir: dumpDir})
return provider.NewCopilot(config.Copilot{DefaultUpstreamURL: addr, APIDumpDir: dumpDir})
},
requestPath: "/copilot/models",
expectDumpName: "-models-",
Expand Down
41 changes: 35 additions & 6 deletions provider/copilot.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"strings"

"github.com/coder/aibridge/config"
aibcontext "github.com/coder/aibridge/context"
"github.com/coder/aibridge/intercept"
"github.com/coder/aibridge/intercept/chatcompletions"
"github.com/coder/aibridge/intercept/responses"
Expand All @@ -20,13 +21,22 @@ import (
)

const (
copilotBaseURL = "https://api.individual.githubcopilot.com"
copilotIndividualUpstreamURL = "https://api.individual.githubcopilot.com"
copilotBusinessUpstreamURL = "https://api.business.githubcopilot.com"
copilotEnterpriseUpstreamURL = "https://api.enterprise.githubcopilot.com"

// Copilot exposes an OpenAI-compatible API, including for Anthropic models.
routeCopilotChatCompletions = "/chat/completions"
routeCopilotResponses = "/responses"
)

// copilotUpstreams maps upstream URLs to their names.
var copilotUpstreams = map[string]string{
copilotIndividualUpstreamURL: "individual",
copilotBusinessUpstreamURL: "business",
copilotEnterpriseUpstreamURL: "enterprise",
}

var copilotOpenErrorResponse = func() []byte {
return []byte(`{"error":{"message":"circuit breaker is open","type":"server_error","code":"service_unavailable"}}`)
}
Expand All @@ -52,8 +62,8 @@ type Copilot struct {
var _ Provider = &Copilot{}

func NewCopilot(cfg config.Copilot) *Copilot {
if cfg.BaseURL == "" {
cfg.BaseURL = copilotBaseURL
if cfg.DefaultUpstreamURL == "" {
cfg.DefaultUpstreamURL = copilotIndividualUpstreamURL
}
if cfg.APIDumpDir == "" {
cfg.APIDumpDir = os.Getenv("BRIDGE_DUMP_DIR")
Expand All @@ -72,11 +82,25 @@ func (p *Copilot) Name() string {
}

func (p *Copilot) BaseURL() string {
return p.cfg.BaseURL
return p.cfg.DefaultUpstreamURL
}

func (p *Copilot) ResolveUpstream(_ *http.Request) intercept.ResolvedUpstream {
return intercept.ResolvedUpstream{Name: p.Name(), URL: p.cfg.BaseURL}
// ResolveUpstream determines the Copilot upstream based on the original
// destination host stored in the request context by coder. The host is
// mapped to a known upstream URL and name.
// If the host is absent or unknown, it falls back to the configured
// default upstream URL.
func (p *Copilot) ResolveUpstream(r *http.Request) intercept.ResolvedUpstream {
if host := aibcontext.OriginalHostFromContext(r.Context()); host != "" {
upstreamURL := "https://" + host
if name, ok := copilotUpstreams[upstreamURL]; ok {
return intercept.ResolvedUpstream{
Name: config.ProviderCopilot + "-" + name,
URL: upstreamURL,
}
}
}
return intercept.ResolvedUpstream{Name: p.Name(), URL: p.cfg.DefaultUpstreamURL}
}

func (p *Copilot) RoutePrefix() string {
Expand Down Expand Up @@ -119,6 +143,11 @@ func (p *Copilot) APIDumpDir() string {
}

func (p *Copilot) CreateInterceptor(_ http.ResponseWriter, r *http.Request, tracer trace.Tracer) (_ intercept.Interceptor, outErr error) {
fmt.Println("################### aibridge copilot CreateInterceptor headers received:")
for k, v := range r.Header {
fmt.Printf(" %s: %s\n", k, v)
}

_, span := tracer.Start(r.Context(), "Intercept.CreateInterceptor")
defer tracing.EndSpanErr(span, &outErr)

Expand Down
65 changes: 63 additions & 2 deletions provider/copilot_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"go.opentelemetry.io/otel"

"github.com/coder/aibridge/config"
aibcontext "github.com/coder/aibridge/context"
"github.com/coder/aibridge/internal/testutil"
)

Expand Down Expand Up @@ -145,7 +146,7 @@ func TestCopilot_CreateInterceptor(t *testing.T) {

// Create provider with mock upstream URL
provider := NewCopilot(config.Copilot{
BaseURL: mockUpstream.URL,
DefaultUpstreamURL: mockUpstream.URL,
})

body := `{"model": "gpt-4", "messages": [{"role": "user", "content": "hello"}], "stream": false}`
Expand Down Expand Up @@ -236,7 +237,7 @@ func TestCopilot_CreateInterceptor(t *testing.T) {

// Create provider with mock upstream URL
provider := NewCopilot(config.Copilot{
BaseURL: mockUpstream.URL,
DefaultUpstreamURL: mockUpstream.URL,
})

body := `{"model": "gpt-5-mini", "input": "hello", "stream": false}`
Expand Down Expand Up @@ -325,3 +326,63 @@ func TestExtractCopilotHeaders(t *testing.T) {
})
}
}

func TestCopilot_ResolveUpstream(t *testing.T) {
t.Parallel()

provider := NewCopilot(config.Copilot{})

tests := []struct {
name string
host string
expectName string
expectURL string
}{
{
name: "no_header_returns_default",
host: "",
expectName: config.ProviderCopilot,
expectURL: copilotIndividualUpstreamURL,
},
{
name: "individual",
host: "api.individual.githubcopilot.com",
expectName: "copilot-individual",
expectURL: copilotIndividualUpstreamURL,
},
{
name: "business",
host: "api.business.githubcopilot.com",
expectName: "copilot-business",
expectURL: copilotBusinessUpstreamURL,
},
{
name: "enterprise",
host: "api.enterprise.githubcopilot.com",
expectName: "copilot-enterprise",
expectURL: copilotEnterpriseUpstreamURL,
},
{
name: "unknown_host_returns_default",
host: "unknown.example.com",
expectName: config.ProviderCopilot,
expectURL: copilotIndividualUpstreamURL,
},
}

for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()

req := httptest.NewRequest(http.MethodPost, "/", nil)
if tc.host != "" {
ctx := aibcontext.WithOriginalHost(req.Context(), tc.host)
req = req.WithContext(ctx)
}

upstream := provider.ResolveUpstream(req)
assert.Equal(t, tc.expectName, upstream.Name)
assert.Equal(t, tc.expectURL, upstream.URL)
})
}
}
Loading