diff --git a/NOTICE b/NOTICE index 1e2aac0728..2c90b58f2d 100644 --- a/NOTICE +++ b/NOTICE @@ -79,10 +79,6 @@ Copyright (c) 2013 Dario Castañé. All rights reserved. Copyright (c) 2012 The Go Authors. All rights reserved. License - https://github.com/darccio/mergo/blob/master/LICENSE -gorilla/mux - https://github.com/gorilla/mux -Copyright (c) 2023 The Gorilla Authors. All rights reserved. -License - https://github.com/gorilla/mux/blob/main/LICENSE - palantir/pkg - https://github.com/palantir/pkg Copyright (c) 2016, Palantir Technologies, Inc. License - https://github.com/palantir/pkg/blob/master/LICENSE diff --git a/acceptance/bundle/invariant/test.toml b/acceptance/bundle/invariant/test.toml index beabef5ef1..257e33005a 100644 --- a/acceptance/bundle/invariant/test.toml +++ b/acceptance/bundle/invariant/test.toml @@ -81,5 +81,5 @@ Pattern = "POST /api/2.0/sql/statements/" Response.Body = '{"status": {"state": "SUCCEEDED"}, "manifest": {"schema": {"columns": []}}}' [[Server]] -Pattern = "DELETE /api/2.1/unity-catalog/tables/{name}" +Pattern = "DELETE /api/2.1/unity-catalog/tables/{full_name}" Response.Body = '{"status": "OK"}' diff --git a/acceptance/bundle/resources/synced_database_tables/basic/test.toml b/acceptance/bundle/resources/synced_database_tables/basic/test.toml index d41d9b917c..191670590b 100644 --- a/acceptance/bundle/resources/synced_database_tables/basic/test.toml +++ b/acceptance/bundle/resources/synced_database_tables/basic/test.toml @@ -20,5 +20,5 @@ Pattern = "POST /api/2.0/sql/statements/" Response.Body = '{"status": {"state": "SUCCEEDED"}, "manifest": {"schema": {"columns": []}}}' [[Server]] -Pattern = "DELETE /api/2.1/unity-catalog/tables/{name}" +Pattern = "DELETE /api/2.1/unity-catalog/tables/{full_name}" Response.Body = '{"status": "OK"}' diff --git a/acceptance/internal/prepare_server.go b/acceptance/internal/prepare_server.go index 702b4e145e..8f18d1c61b 100644 --- a/acceptance/internal/prepare_server.go +++ b/acceptance/internal/prepare_server.go @@ -188,8 +188,8 @@ func startLocalServer(t *testing.T, killCountersMu := &sync.Mutex{} for ind := range stubs { - // We want later stubs takes precedence, because then leaf configs take precedence over parent directory configs - // In gorilla/mux earlier handlers take precedence, so we need to reverse the order + // Later stubs take precedence over earlier ones (leaf configs override parent configs). + // The first handler registered for a given pattern wins, so we reverse the order. stub := stubs[len(stubs)-1-ind] require.NotEmpty(t, stub.Pattern) items := strings.Split(stub.Pattern, " ") @@ -226,7 +226,8 @@ func startLocalServer(t *testing.T, }) } - // The earliest handlers take precedence, add default handlers last + // The first handler registered for a given pattern wins, so default + // handlers registered last serve as fallbacks. testserver.AddDefaultHandlers(s) return s.URL } diff --git a/go.mod b/go.mod index 7b79b86e7f..2e841c0d31 100644 --- a/go.mod +++ b/go.mod @@ -15,7 +15,6 @@ require ( github.com/databricks/databricks-sdk-go v0.128.0 // Apache-2.0 github.com/google/jsonschema-go v0.4.3 // MIT github.com/google/uuid v1.6.0 // BSD-3-Clause - github.com/gorilla/mux v1.8.1 // BSD-3-Clause github.com/gorilla/websocket v1.5.3 // BSD-2-Clause github.com/hashicorp/go-version v1.9.0 // MPL-2.0 github.com/hashicorp/hc-install v0.9.3 // MPL-2.0 diff --git a/go.sum b/go.sum index 993ac401ff..0d6543b5ce 100644 --- a/go.sum +++ b/go.sum @@ -124,8 +124,6 @@ github.com/googleapis/enterprise-certificate-proxy v0.3.11 h1:vAe81Msw+8tKUxi2Dq github.com/googleapis/enterprise-certificate-proxy v0.3.11/go.mod h1:RFV7MUdlb7AgEq2v7FmMCfeSMCllAzWxFgRdusoGks8= github.com/googleapis/gax-go/v2 v2.17.0 h1:RksgfBpxqff0EZkDWYuz9q/uWsTVz+kf43LsZ1J6SMc= github.com/googleapis/gax-go/v2 v2.17.0/go.mod h1:mzaqghpQp4JDh3HvADwrat+6M3MOIDp5YKHhb9PAgDY= -github.com/gorilla/mux v1.8.1 h1:TuBL49tXwgrFYWhqrNgrUNEY92u81SPhu7sTdzQEiWY= -github.com/gorilla/mux v1.8.1/go.mod h1:AKf9I4AEqPTmMytcMc0KkNouC66V3BtZ4qD5fmWSiMQ= github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg= github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/hashicorp/go-cleanhttp v0.5.2 h1:035FKYIWjmULyFRBKPs8TBQoi0x6d9G4xc9neXJWAZQ= diff --git a/libs/testserver/handlers.go b/libs/testserver/handlers.go index 8bd5339184..d98011fc7b 100644 --- a/libs/testserver/handlers.go +++ b/libs/testserver/handlers.go @@ -109,7 +109,7 @@ func AddDefaultHandlers(server *Server) { return "" }) - server.Handle("POST", "/api/2.0/workspace-files/import-file/{path:.*}", func(req Request) any { + server.Handle("POST", "/api/2.0/workspace-files/import-file/{path...}", func(req Request) any { path := req.Vars["path"] overwrite := req.URL.Query().Get("overwrite") == "true" return req.Workspace.WorkspaceFilesImportFile(path, req.Body, overwrite) @@ -145,12 +145,12 @@ func AddDefaultHandlers(server *Server) { return req.Workspace.WorkspaceFilesImportFile(request.Path, decoded, request.Overwrite) }) - server.Handle("GET", "/api/2.0/workspace-files/{path:.*}", func(req Request) any { + server.Handle("GET", "/api/2.0/workspace-files/{path...}", func(req Request) any { path := req.Vars["path"] return req.Workspace.WorkspaceFilesExportFile(path) }) - server.Handle("HEAD", "/api/2.0/fs/directories/{path:.*}", func(req Request) any { + server.Handle("HEAD", "/api/2.0/fs/directories/{path...}", func(req Request) any { dirPath := req.Vars["path"] if !strings.HasPrefix(dirPath, "/") { dirPath = "/" + dirPath @@ -165,7 +165,7 @@ func AddDefaultHandlers(server *Server) { return Response{StatusCode: 404} }) - server.Handle("HEAD", "/api/2.0/fs/files/{path:.*}", func(req Request) any { + server.Handle("HEAD", "/api/2.0/fs/files/{path...}", func(req Request) any { path := req.Vars["path"] if req.Workspace.FileExists(path) { return Response{StatusCode: 200} @@ -173,7 +173,7 @@ func AddDefaultHandlers(server *Server) { return Response{StatusCode: 404} }) - server.Handle("PUT", "/api/2.0/fs/directories/{path:.*}", func(req Request) any { + server.Handle("PUT", "/api/2.0/fs/directories/{path...}", func(req Request) any { dirPath := req.Vars["path"] if !strings.HasPrefix(dirPath, "/") { dirPath = "/" + dirPath @@ -194,13 +194,13 @@ func AddDefaultHandlers(server *Server) { return Response{} }) - server.Handle("PUT", "/api/2.0/fs/files/{path:.*}", func(req Request) any { + server.Handle("PUT", "/api/2.0/fs/files/{path...}", func(req Request) any { path := req.Vars["path"] overwrite := req.URL.Query().Get("overwrite") == "true" return req.Workspace.WorkspaceFilesImportFile(path, req.Body, overwrite) }) - server.Handle("GET", "/api/2.0/fs/files/{path:.*}", func(req Request) any { + server.Handle("GET", "/api/2.0/fs/files/{path...}", func(req Request) any { path := req.Vars["path"] data := req.Workspace.WorkspaceFilesExportFile(path) if data == nil { diff --git a/libs/testserver/router.go b/libs/testserver/router.go new file mode 100644 index 0000000000..00381eae6a --- /dev/null +++ b/libs/testserver/router.go @@ -0,0 +1,124 @@ +package testserver + +import ( + "net/http" + "strings" +) + +// HandlerFunc is the test-server handler signature. +type HandlerFunc func(req Request) any + +// Router maps method+path to a HandlerFunc. Wildcards use Go 1.22 ServeMux +// placeholder syntax ({name} or {name...}). +// +// # Why a custom router +// +// Go 1.22 added method matching ("GET /path") and {name}/{name...} +// placeholders to http.ServeMux, covering most of what we previously used +// gorilla/mux for. But two ServeMux behaviors make it inconvenient to use +// directly in the test server: +// +// - Exact and wildcard patterns conflict when they cover the same +// request under different methods. ServeMux treats `GET /x` as +// matching both GET and HEAD, so it overlaps with `HEAD /{path...}` +// and panics at registration. Test fixtures register both kinds of +// routes side by side, so we keep exact paths in our own map and +// only delegate wildcards to ServeMux. Exact lookup runs first; +// misses fall through to ServeMux, which also lets a wildcard +// handler serve methods that the exact registration doesn't cover. +// +// - ServeMux panics on duplicate pattern registration. Router silently +// drops the later registration — first wins. Two callers rely on this: +// AddDefaultHandlers (libs/testserver/handlers.go) installs fallback +// handlers that any test stub for the same pattern can override, and +// startLocalServer (acceptance/internal/prepare_server.go) iterates +// test.toml stubs in reverse so leaf-directory configs register first +// and win over inherited parent stubs. +// +// Router also clears req.URL.RawPath before dispatch so percent-encoded +// slashes (%2F) match literal slashes in patterns; workspace file paths +// in tests routinely contain encoded slashes. +type Router struct { + mux *http.ServeMux + exact map[string]map[string]HandlerFunc + wildcard map[string]bool + + // Dispatch is invoked when a route matches. vars holds the path values for + // wildcard routes and is nil for exact routes. + Dispatch func(w http.ResponseWriter, r *http.Request, h HandlerFunc, vars map[string]string) + // NotFound is invoked when no route matches. + NotFound http.HandlerFunc +} + +func NewRouter() *Router { + r := &Router{ + mux: http.NewServeMux(), + exact: map[string]map[string]HandlerFunc{}, + wildcard: map[string]bool{}, + } + r.mux.HandleFunc("/", func(w http.ResponseWriter, req *http.Request) { + if r.NotFound != nil { + r.NotFound(w, req) + } + }) + return r +} + +// Handle registers a handler for method+path. First registration wins; +// duplicate (method, path) registrations are ignored. +func (r *Router) Handle(method, path string, handler HandlerFunc) { + if !strings.Contains(path, "{") { + if r.exact[path] == nil { + r.exact[path] = map[string]HandlerFunc{} + } + if _, ok := r.exact[path][method]; !ok { + r.exact[path][method] = handler + } + return + } + pattern := method + " " + path + if r.wildcard[pattern] { + return + } + r.wildcard[pattern] = true + names := wildcardNames(path) + r.mux.HandleFunc(pattern, func(w http.ResponseWriter, req *http.Request) { + vars := make(map[string]string, len(names)) + for _, name := range names { + vars[name] = req.PathValue(name) + } + if r.Dispatch != nil { + r.Dispatch(w, req, handler, vars) + } + }) +} + +// ServeHTTP routes a request to the registered handler, falling back to +// NotFound if no route matches. +func (r *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) { + // Force ServeMux to match against the decoded path; see the type doc. + req.URL.RawPath = "" + if methods, ok := r.exact[req.URL.Path]; ok { + if h, ok := methods[req.Method]; ok { + if r.Dispatch != nil { + r.Dispatch(w, req, h, nil) + } + return + } + } + r.mux.ServeHTTP(w, req) +} + +// wildcardNames extracts wildcard parameter names from a path pattern, +// e.g. "/api/{id}/files/{path...}" returns ["id", "path"]. +func wildcardNames(path string) []string { + var names []string + for part := range strings.SplitSeq(path, "/") { + if strings.HasPrefix(part, "{") && strings.HasSuffix(part, "}") { + name := part[1 : len(part)-1] + name = strings.TrimSuffix(name, "...") + names = append(names, name) + } + } + return names +} diff --git a/libs/testserver/router_test.go b/libs/testserver/router_test.go new file mode 100644 index 0000000000..9d2c8c603e --- /dev/null +++ b/libs/testserver/router_test.go @@ -0,0 +1,137 @@ +package testserver_test + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/databricks/cli/libs/testserver" + "github.com/stretchr/testify/assert" +) + +type capture struct { + handler string + vars map[string]string + notFound bool +} + +func newRouter(t *testing.T) (*testserver.Router, *capture) { + t.Helper() + c := &capture{} + r := testserver.NewRouter() + r.Dispatch = func(w http.ResponseWriter, req *http.Request, h testserver.HandlerFunc, vars map[string]string) { + c.vars = vars + c.handler = h(testserver.Request{}).(string) + } + r.NotFound = func(w http.ResponseWriter, req *http.Request) { + c.notFound = true + } + return r, c +} + +func handlerNamed(name string) testserver.HandlerFunc { + return func(req testserver.Request) any { return name } +} + +func TestRouterExactMatch(t *testing.T) { + r, c := newRouter(t) + r.Handle("GET", "/foo", handlerNamed("foo-get")) + + r.ServeHTTP(httptest.NewRecorder(), httptest.NewRequest(http.MethodGet, "/foo", nil)) + assert.Equal(t, "foo-get", c.handler) + assert.Nil(t, c.vars) +} + +func TestRouterWildcardMatch(t *testing.T) { + r, c := newRouter(t) + r.Handle("GET", "/items/{id}", handlerNamed("item-get")) + + r.ServeHTTP(httptest.NewRecorder(), httptest.NewRequest(http.MethodGet, "/items/42", nil)) + assert.Equal(t, "item-get", c.handler) + assert.Equal(t, map[string]string{"id": "42"}, c.vars) +} + +func TestRouterCatchAllWildcard(t *testing.T) { + r, c := newRouter(t) + r.Handle("GET", "/files/{path...}", handlerNamed("files-get")) + + r.ServeHTTP(httptest.NewRecorder(), httptest.NewRequest(http.MethodGet, "/files/a/b/c", nil)) + assert.Equal(t, "files-get", c.handler) + assert.Equal(t, map[string]string{"path": "a/b/c"}, c.vars) +} + +func TestRouterMultipleWildcards(t *testing.T) { + r, c := newRouter(t) + r.Handle("GET", "/items/{id}/files/{path...}", handlerNamed("nested")) + + r.ServeHTTP(httptest.NewRecorder(), httptest.NewRequest(http.MethodGet, "/items/42/files/a/b", nil)) + assert.Equal(t, "nested", c.handler) + assert.Equal(t, map[string]string{"id": "42", "path": "a/b"}, c.vars) +} + +func TestRouterExactBeforeWildcard(t *testing.T) { + r, c := newRouter(t) + r.Handle("GET", "/foo", handlerNamed("exact")) + r.Handle("HEAD", "/{path...}", handlerNamed("wildcard-head")) + + r.ServeHTTP(httptest.NewRecorder(), httptest.NewRequest(http.MethodGet, "/foo", nil)) + assert.Equal(t, "exact", c.handler) + + c.handler = "" + r.ServeHTTP(httptest.NewRecorder(), httptest.NewRequest(http.MethodHead, "/foo", nil)) + assert.Equal(t, "wildcard-head", c.handler) +} + +func TestRouterFirstRegistrationWins(t *testing.T) { + t.Run("exact", func(t *testing.T) { + r, c := newRouter(t) + r.Handle("GET", "/foo", handlerNamed("first")) + r.Handle("GET", "/foo", handlerNamed("second")) + + r.ServeHTTP(httptest.NewRecorder(), httptest.NewRequest(http.MethodGet, "/foo", nil)) + assert.Equal(t, "first", c.handler) + }) + + t.Run("wildcard", func(t *testing.T) { + r, c := newRouter(t) + r.Handle("GET", "/items/{id}", handlerNamed("first")) + r.Handle("GET", "/items/{id}", handlerNamed("second")) + + r.ServeHTTP(httptest.NewRecorder(), httptest.NewRequest(http.MethodGet, "/items/42", nil)) + assert.Equal(t, "first", c.handler) + }) +} + +func TestRouterNotFound(t *testing.T) { + r, c := newRouter(t) + r.ServeHTTP(httptest.NewRecorder(), httptest.NewRequest(http.MethodGet, "/missing", nil)) + assert.True(t, c.notFound) +} + +func TestRouterMethodNotAllowed(t *testing.T) { + t.Run("exact", func(t *testing.T) { + r, c := newRouter(t) + r.Handle("GET", "/foo", handlerNamed("foo-get")) + r.ServeHTTP(httptest.NewRecorder(), httptest.NewRequest(http.MethodPost, "/foo", nil)) + assert.True(t, c.notFound, "wrong method on exact path should hit NotFound") + assert.Empty(t, c.handler) + }) + + t.Run("wildcard", func(t *testing.T) { + r, c := newRouter(t) + r.Handle("GET", "/items/{id}", handlerNamed("item-get")) + r.ServeHTTP(httptest.NewRecorder(), httptest.NewRequest(http.MethodPost, "/items/42", nil)) + assert.True(t, c.notFound, "wrong method on wildcard path should hit NotFound") + assert.Empty(t, c.handler) + }) +} + +func TestRouterPercentEncodedSlash(t *testing.T) { + r, c := newRouter(t) + r.Handle("GET", "/files/{path...}", handlerNamed("files-get")) + + req := httptest.NewRequest(http.MethodGet, "/files/a%2Fb%2Fc", nil) + r.ServeHTTP(httptest.NewRecorder(), req) + assert.Equal(t, "files-get", c.handler) + assert.Equal(t, "a/b/c", c.vars["path"]) +} diff --git a/libs/testserver/server.go b/libs/testserver/server.go index da35473868..40556e5529 100644 --- a/libs/testserver/server.go +++ b/libs/testserver/server.go @@ -17,7 +17,6 @@ import ( "sync" "github.com/databricks/cli/internal/testutil" - "github.com/gorilla/mux" ) const testPidKey = "test-pid" @@ -39,7 +38,7 @@ func ExtractPidFromHeaders(headers http.Header) int { type Server struct { *httptest.Server - Router *mux.Router + *Router t testutil.TestingT @@ -84,7 +83,6 @@ func NewRequest(t testutil.TestingT, r *http.Request, fakeWorkspace *FakeWorkspa URL: r.URL, Headers: r.Header, Body: body, - Vars: mux.Vars(r), Workspace: fakeWorkspace, Context: r.Context(), } @@ -201,7 +199,7 @@ func getHeaders(value []byte) http.Header { } func New(t testutil.TestingT) *Server { - router := mux.NewRouter() + router := NewRouter() server := httptest.NewServer(router) t.Cleanup(server.Close) @@ -212,6 +210,7 @@ func New(t testutil.TestingT) *Server { fakeWorkspaces: map[string]*FakeWorkspace{}, fakeOidc: &FakeOidc{url: server.URL}, } + router.Dispatch = s.serve t.Cleanup(func() { for _, ws := range s.fakeWorkspaces { @@ -257,8 +256,7 @@ Response.Body = '' t.Errorf("Response write error: %s", err) } }) - router.NotFoundHandler = notFoundFunc - router.MethodNotAllowedHandler = notFoundFunc + router.NotFound = notFoundFunc // Register a default handler for the SDK's host metadata discovery endpoint. // The SDK resolves this during config initialization (as of v0.126.0) to @@ -290,48 +288,45 @@ func (s *Server) getWorkspaceForToken(token string) *FakeWorkspace { return s.fakeWorkspaces[token] } -type HandlerFunc func(req Request) any +func (s *Server) serve(w http.ResponseWriter, r *http.Request, handler HandlerFunc, vars map[string]string) { + // Each test uses unique DATABRICKS_TOKEN, we simulate each token having + // it's own fake fakeWorkspace to avoid interference between tests. + fakeWorkspace := s.getWorkspaceForToken(getToken(r)) -func (s *Server) Handle(method, path string, handler HandlerFunc) { - s.Router.HandleFunc(path, func(w http.ResponseWriter, r *http.Request) { - // Each test uses unique DATABRICKS_TOKEN, we simulate each token having - // it's own fake fakeWorkspace to avoid interference between tests. - fakeWorkspace := s.getWorkspaceForToken(getToken(r)) + request := NewRequest(s.t, r, fakeWorkspace) + request.Vars = vars - request := NewRequest(s.t, r, fakeWorkspace) - - if s.RequestCallback != nil { - s.RequestCallback(&request) - } + if s.RequestCallback != nil { + s.RequestCallback(&request) + } - var resp EncodedResponse + var resp EncodedResponse - if bytes.Contains(request.Body, []byte("INJECT_ERROR")) { - resp = EncodedResponse{ - StatusCode: 500, - Body: []byte("INJECTED"), - } - } else { - respAny := handler(request) - if respAny == nil && request.Context.Err() != nil { - return - } - resp = normalizeResponse(s.t, respAny) + if bytes.Contains(request.Body, []byte("INJECT_ERROR")) { + resp = EncodedResponse{ + StatusCode: 500, + Body: []byte("INJECTED"), } + } else { + respAny := handler(request) + if respAny == nil && request.Context.Err() != nil { + return + } + resp = normalizeResponse(s.t, respAny) + } - maps.Copy(w.Header(), resp.Headers) + maps.Copy(w.Header(), resp.Headers) - w.WriteHeader(resp.StatusCode) + w.WriteHeader(resp.StatusCode) - if s.ResponseCallback != nil { - s.ResponseCallback(&request, &resp) - } + if s.ResponseCallback != nil { + s.ResponseCallback(&request, &resp) + } - if _, err := w.Write(resp.Body); err != nil { - s.t.Errorf("Failed to write response: %s", err) - return - } - }).Methods(method) + if _, err := w.Write(resp.Body); err != nil { + s.t.Errorf("Failed to write response: %s", err) + return + } } func getToken(r *http.Request) string {