Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
246 changes: 200 additions & 46 deletions connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
context2 "github.com/databricks/databricks-sql-go/internal/compat/context"
"github.com/databricks/databricks-sql-go/internal/config"
dbsqlerrint "github.com/databricks/databricks-sql-go/internal/errors"
"github.com/databricks/databricks-sql-go/internal/retry"
"github.com/databricks/databricks-sql-go/internal/rows"
"github.com/databricks/databricks-sql-go/internal/sentinel"
"github.com/databricks/databricks-sql-go/internal/thrift_protocol"
Expand Down Expand Up @@ -647,17 +648,21 @@ var _ driver.ConnBeginTx = (*conn)(nil)
var _ driver.NamedValueChecker = (*conn)(nil)

func Succeeded(response *http.Response) bool {
if response.StatusCode == 200 || response.StatusCode == 201 || response.StatusCode == 202 || response.StatusCode == 204 {
return true
}
return false
return statusInSuccessRange(response.StatusCode)
}

// statusInSuccessRange returns true for the 2xx status codes the staging
// HTTP path treats as success: 200 OK / 201 Created / 202 Accepted / 204
// No Content. Exposed separately from Succeeded so handlers can extend the
// accept set (e.g. REMOVE accepts 404 for idempotent-delete semantics).
func statusInSuccessRange(status int) bool {
return status == 200 || status == 201 || status == 202 || status == 204
}

func (c *conn) handleStagingPut(ctx context.Context, presignedUrl string, headers map[string]string, localFile string) dbsqlerr.DBError {
if localFile == "" {
return dbsqlerrint.NewDriverError(ctx, "cannot perform PUT without specifying a local_file", nil)
}
client := &http.Client{}

dat, err := os.Open(localFile) //nolint:gosec // localFile is provided by the application, not user input
if err != nil {
Expand All @@ -669,73 +674,222 @@ func (c *conn) handleStagingPut(ctx context.Context, presignedUrl string, header
if err != nil {
return dbsqlerrint.NewDriverError(ctx, "error reading local file info", err)
}

req, _ := http.NewRequest("PUT", presignedUrl, dat)
req.ContentLength = info.Size() // backend actually requires content length to be known

for k, v := range headers {
req.Header.Set(k, v)
}
res, err := client.Do(req)
if err != nil {
return dbsqlerrint.NewDriverError(ctx, "error sending http request", err)
size := info.Size()

// Each retry attempt needs a fresh request because http.Client.Do consumes
// the request body. Rewind the *os.File between attempts so the server
// receives the full payload on every retry, not just attempt 1.
//
// Wrap the file in io.NopCloser so http.Client.Do can't close it — by
// default it closes any body that implements io.Closer, which would break
// the Seek on the next retry. The outer defer dat.Close() above owns the
// file's lifecycle.
reqFactory := func(attempt int) (*http.Request, error) {
if attempt > 0 {
if _, seekErr := dat.Seek(0, io.SeekStart); seekErr != nil {
return nil, seekErr
}
}
req, reqErr := http.NewRequestWithContext(ctx, http.MethodPut, presignedUrl, io.NopCloser(dat))
if reqErr != nil {
return nil, reqErr
}
req.ContentLength = size // backend actually requires content length to be known
for k, v := range headers {
req.Header.Set(k, v)
}
return req, nil
}
defer res.Body.Close() //nolint:errcheck
content, err := io.ReadAll(res.Body)

if err != nil || !Succeeded(res) {
return dbsqlerrint.NewDriverError(ctx, fmt.Sprintf("staging operation over HTTP was unsuccessful: %d-%s", res.StatusCode, content), nil)
if _, err := c.doStagingRequestWithRetry(ctx, reqFactory); err != nil {
return err
}
return nil

}

func (c *conn) handleStagingGet(ctx context.Context, presignedUrl string, headers map[string]string, localFile string) dbsqlerr.DBError {
if localFile == "" {
return dbsqlerrint.NewDriverError(ctx, "cannot perform GET without specifying a local_file", nil)
}
client := &http.Client{}
req, _ := http.NewRequest("GET", presignedUrl, nil)

for k, v := range headers {
req.Header.Set(k, v)
reqFactory := func(_ int) (*http.Request, error) {
req, reqErr := http.NewRequestWithContext(ctx, http.MethodGet, presignedUrl, nil)
if reqErr != nil {
return nil, reqErr
}
for k, v := range headers {
req.Header.Set(k, v)
}
return req, nil
}
res, err := client.Do(req)

content, err := c.doStagingRequestWithRetry(ctx, reqFactory)
if err != nil {
return dbsqlerrint.NewDriverError(ctx, "error sending http request", err)
return err
}
defer res.Body.Close() //nolint:errcheck
content, err := io.ReadAll(res.Body)
if writeErr := os.WriteFile(localFile, content, 0644); writeErr != nil { //nolint:gosec
return dbsqlerrint.NewDriverError(ctx, "error writing local file", writeErr)
}
return nil
}

if err != nil || !Succeeded(res) {
return dbsqlerrint.NewDriverError(ctx, fmt.Sprintf("staging operation over HTTP was unsuccessful: %d-%s", res.StatusCode, content), nil)
func (c *conn) handleStagingRemove(ctx context.Context, presignedUrl string, headers map[string]string) dbsqlerr.DBError {
reqFactory := func(_ int) (*http.Request, error) {
req, reqErr := http.NewRequestWithContext(ctx, http.MethodDelete, presignedUrl, nil)
if reqErr != nil {
return nil, reqErr
}
for k, v := range headers {
req.Header.Set(k, v)
}
return req, nil
}

err = os.WriteFile(localFile, content, 0644) //nolint:gosec
if err != nil {
return dbsqlerrint.NewDriverError(ctx, "error writing local file", err)
// Treat 404 as success on REMOVE: DELETE is idempotent, and a 404 means
// the object is already absent — which is the post-condition the caller
// asked for. This also avoids spurious failures when a successful DELETE
// returns a transient 5xx mid-response and the retry sees 404 from the
// server having already applied the original request.
acceptStatus := func(status int) bool {
return statusInSuccessRange(status) || status == http.StatusNotFound
}

if _, err := c.doStagingRequestWithRetryAccept(ctx, reqFactory, acceptStatus); err != nil {
return err
}
return nil
}

func (c *conn) handleStagingRemove(ctx context.Context, presignedUrl string, headers map[string]string) dbsqlerr.DBError {
client := &http.Client{}
req, _ := http.NewRequest("DELETE", presignedUrl, nil)
for k, v := range headers {
req.Header.Set(k, v)
// maxStagingErrorBodyBytes bounds the response body bytes included in
// terminal staging error messages. Proxies and misconfigured backends can
// return multi-MB error bodies; truncating keeps the driver error readable
// without dropping the typical S3 XML error code that fits well under 512B.
const maxStagingErrorBodyBytes = 512

// doStagingRequestWithRetry executes a staging HTTP request with retry on
// transient object-storage failures (ES-1911239). Wraps
// doStagingRequestWithRetryAccept with the default success predicate (2xx
// from statusInSuccessRange / Succeeded).
func (c *conn) doStagingRequestWithRetry(ctx context.Context, reqFactory func(attempt int) (*http.Request, error)) ([]byte, dbsqlerr.DBError) {
return c.doStagingRequestWithRetryAccept(ctx, reqFactory, statusInSuccessRange)
}

// doStagingRequestWithRetryAccept is the generalized staging retry helper
// used by all three handleStaging* methods. Mirrors the CloudFetch retry
// path (ES-1892645) in semantics — same retryable status set, same
// exponential-backoff-with-jitter schedule, same RetryMax/RetryWaitMin/
// RetryWaitMax config knobs — so behavior is consistent across the driver's
// two object-storage code paths.
//
// reqFactory must return a fresh *http.Request on each call. Attempt 0 is
// the initial request; attempt N>0 is a retry. The PUT path uses this to
// rewind the file body between attempts; other staging paths just construct
// a new request each time.
//
// acceptStatus reports whether a given HTTP status code should be treated
// as success. Most handlers pass statusInSuccessRange. The REMOVE handler
// extends this to also accept 404 (idempotent-delete semantics).
//
// On success returns the response body bytes. On terminal failure (non-
// retryable status, exhausted retries, or context cancellation) returns a
// dbsqlerr.DBError describing the final state.
func (c *conn) doStagingRequestWithRetryAccept(
ctx context.Context,
reqFactory func(attempt int) (*http.Request, error),
acceptStatus func(status int) bool,
) ([]byte, dbsqlerr.DBError) {
retryMax := c.cfg.RetryMax
if retryMax < 0 {
retryMax = 0
}
res, err := client.Do(req)
if err != nil {
return dbsqlerrint.NewDriverError(ctx, "error sending http request", err)
client := &http.Client{}

var (
lastErr error
lastStatus int
lastBody []byte
lastRetryAfter string
)

for attempt := 0; attempt <= retryMax; attempt++ {
if attempt > 0 {
wait := retry.Backoff(attempt, c.cfg.RetryWaitMin, c.cfg.RetryWaitMax, lastRetryAfter)
logger.Debug().Msgf(
"staging: retrying HTTP request (attempt %d/%d) in %v; lastStatus=%d lastErr=%v",
attempt, retryMax, wait, lastStatus, lastErr,
)
t := time.NewTimer(wait)
select {
case <-ctx.Done():
if !t.Stop() {
<-t.C
}
return nil, dbsqlerrint.NewDriverError(ctx, "staging operation cancelled during retry backoff", ctx.Err())
case <-t.C:
}
}

req, reqErr := reqFactory(attempt)
if reqErr != nil {
return nil, dbsqlerrint.NewDriverError(ctx, "error building staging http request", reqErr)
}

res, err := client.Do(req)
if err != nil {
// Caller cancellation is terminal; otherwise treat transport
// errors (TCP RST, TLS timeout, etc.) as transient.
if ctx.Err() != nil {
return nil, dbsqlerrint.NewDriverError(ctx, "error sending http request", ctx.Err())
}
lastErr = err
lastStatus = 0
lastRetryAfter = ""
continue
}

body, readErr := io.ReadAll(res.Body)
res.Body.Close() //nolint:errcheck,gosec // G104: close after drain

if readErr != nil {
if ctx.Err() != nil {
return nil, dbsqlerrint.NewDriverError(ctx, "error reading http response", ctx.Err())
}
lastErr = readErr
lastStatus = 0
lastRetryAfter = ""
continue
}

if acceptStatus(res.StatusCode) {
return body, nil
}

lastStatus = res.StatusCode
lastErr = nil
lastBody = body
lastRetryAfter = res.Header.Get("Retry-After")

if !retry.IsRetryableStatus(res.StatusCode) {
return nil, dbsqlerrint.NewDriverError(ctx, fmt.Sprintf("staging operation over HTTP was unsuccessful: %d-%s", res.StatusCode, truncateErrorBody(body)), nil)
}
}
defer res.Body.Close() //nolint:errcheck
content, err := io.ReadAll(res.Body)

if err != nil || !Succeeded(res) {
return dbsqlerrint.NewDriverError(ctx, fmt.Sprintf("staging operation over HTTP was unsuccessful: %d-%s, nil", res.StatusCode, content), nil)
if lastStatus != 0 {
// lastErr is nil here by construction: the HTTP-status branch above
// explicitly clears it on every iteration. The status code and body
// are captured in msg, so there's no underlying error to wrap.
return nil, dbsqlerrint.NewDriverError(ctx, fmt.Sprintf("staging operation over HTTP was unsuccessful: %d-%s (after %d retries)", lastStatus, truncateErrorBody(lastBody), retryMax), nil)
}
return nil, dbsqlerrint.NewDriverError(ctx, fmt.Sprintf("staging operation HTTP request failed: %v (after %d retries)", lastErr, retryMax), lastErr)
}

return nil
// truncateErrorBody caps b at maxStagingErrorBodyBytes for inclusion in error
// messages, appending an indicator when truncation occurred.
func truncateErrorBody(b []byte) string {
if len(b) <= maxStagingErrorBodyBytes {
return string(b)
}
return fmt.Sprintf("%s... (%d bytes total, truncated)", b[:maxStagingErrorBodyBytes], len(b))
}

func localPathIsAllowed(stagingAllowedLocalPaths []string, localFile string) bool {
Expand Down
Loading
Loading