diff --git a/connector.go b/connector.go index 53908b4c..c31f4a15 100644 --- a/connector.go +++ b/connector.go @@ -63,7 +63,16 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) { CanUseMultipleCatalogs: &c.cfg.CanUseMultipleCatalogs, }) if err != nil { - return nil, dbsqlerrint.NewRequestError(ctx, fmt.Sprintf("error connecting: host=%s port=%d, httpPath=%s", c.cfg.Host, c.cfg.Port, c.cfg.HTTPPath), err) + return nil, dbsqlerrint.NewRequestError( + ctx, + fmt.Sprintf( + "error connecting: host=%s port=%d, httpPath=%s", + c.cfg.Host, + c.cfg.Port, + c.cfg.HTTPPath, + ), + err, + ) } conn := &conn{ @@ -74,7 +83,8 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) { } log := logger.WithContext(conn.id, driverctx.CorrelationIdFromContext(ctx), "") - log.Info().Msgf("connect: host=%s port=%d httpPath=%s serverProtocolVersion=0x%X", c.cfg.Host, c.cfg.Port, c.cfg.HTTPPath, session.ServerProtocolVersion) + log.Info(). + Msgf("connect: host=%s port=%d httpPath=%s serverProtocolVersion=0x%X", c.cfg.Host, c.cfg.Port, c.cfg.HTTPPath, session.ServerProtocolVersion) return conn, nil } @@ -241,7 +251,10 @@ func WithSessionParams(params map[string]string) ConnOption { func WithSkipTLSHostVerify() ConnOption { return func(c *config.Config) { if c.TLSConfig == nil { - c.TLSConfig = &tls.Config{MinVersion: tls.VersionTLS12, InsecureSkipVerify: true} // #nosec G402 + c.TLSConfig = &tls.Config{ + MinVersion: tls.VersionTLS12, + InsecureSkipVerify: true, + } // #nosec G402 } else { c.TLSConfig.InsecureSkipVerify = true // #nosec G402 } @@ -269,6 +282,12 @@ func WithCloudFetch(useCloudFetch bool) ConnOption { } } +func WithCloudFetchHttpClient(httpclient *http.Client) ConnOption { + return func(c *config.Config) { + c.UserConfig.CloudFetchConfig.HttpClient = httpclient + } +} + // WithMaxDownloadThreads sets up maximum download threads for cloud fetch. Default is 10. func WithMaxDownloadThreads(numThreads int) ConnOption { return func(c *config.Config) { diff --git a/connector_test.go b/connector_test.go index 57554b98..927e7c62 100644 --- a/connector_test.go +++ b/connector_test.go @@ -246,6 +246,42 @@ func TestNewConnector(t *testing.T) { require.True(t, ok) assert.False(t, coni.cfg.EnableMetricViewMetadata) }) + + t.Run("Connector test WithCloudFetchHTTPClient sets custom client", func(t *testing.T) { + host := "databricks-host" + accessToken := "token" + httpPath := "http-path" + customClient := &http.Client{Timeout: 5 * time.Second} + + con, err := NewConnector( + WithServerHostname(host), + WithAccessToken(accessToken), + WithHTTPPath(httpPath), + WithCloudFetchHttpClient(customClient), + ) + assert.Nil(t, err) + + coni, ok := con.(*connector) + require.True(t, ok) + assert.Equal(t, customClient, coni.cfg.UserConfig.CloudFetchConfig.HttpClient) + }) + + t.Run("Connector test WithCloudFetchHTTPClient with nil client is accepted", func(t *testing.T) { + host := "databricks-host" + accessToken := "token" + httpPath := "http-path" + + con, err := NewConnector( + WithServerHostname(host), + WithAccessToken(accessToken), + WithHTTPPath(httpPath), + ) + assert.Nil(t, err) + + coni, ok := con.(*connector) + require.True(t, ok) + assert.Nil(t, coni.cfg.UserConfig.CloudFetchConfig.HttpClient) + }) } type mockRoundTripper struct{} diff --git a/internal/config/config.go b/internal/config/config.go index 67437a9c..d1c30187 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -10,7 +10,6 @@ import ( "strings" "time" - dbsqlerr "github.com/databricks/databricks-sql-go/errors" "github.com/pkg/errors" "github.com/databricks/databricks-sql-go/auth" @@ -18,6 +17,7 @@ import ( "github.com/databricks/databricks-sql-go/auth/oauth/m2m" "github.com/databricks/databricks-sql-go/auth/oauth/u2m" "github.com/databricks/databricks-sql-go/auth/pat" + dbsqlerr "github.com/databricks/databricks-sql-go/errors" "github.com/databricks/databricks-sql-go/internal/cli_service" dbsqlerrint "github.com/databricks/databricks-sql-go/internal/errors" "github.com/databricks/databricks-sql-go/logger" @@ -198,7 +198,6 @@ func WithDefaults() *Config { ThriftProtocolVersion: cli_service.TProtocolVersion_SPARK_CLI_SERVICE_PROTOCOL_V8, ThriftDebugClientProtocol: false, } - } // ParseDSN constructs UserConfig and CloudFetchConfig by parsing DSN string supplied to `sql.Open()` @@ -209,14 +208,22 @@ func ParseDSN(dsn string) (UserConfig, error) { } parsedURL, err := url.Parse(fullDSN) if err != nil { - return UserConfig{}, dbsqlerrint.NewRequestError(context.TODO(), dbsqlerr.ErrInvalidDSNFormat, err) + return UserConfig{}, dbsqlerrint.NewRequestError( + context.TODO(), + dbsqlerr.ErrInvalidDSNFormat, + err, + ) } ucfg := UserConfig{}.WithDefaults() ucfg.Protocol = parsedURL.Scheme ucfg.Host = parsedURL.Hostname() port, err := strconv.Atoi(parsedURL.Port()) if err != nil { - return UserConfig{}, dbsqlerrint.NewRequestError(context.TODO(), dbsqlerr.ErrInvalidDSNPort, err) + return UserConfig{}, dbsqlerrint.NewRequestError( + context.TODO(), + dbsqlerr.ErrInvalidDSNPort, + err, + ) } ucfg.Port = port @@ -395,7 +402,11 @@ func (params *extractableParams) extractAsInt(key string) (int, bool, error) { if intString, ok := extractParam(key, params, false, true); ok { i, err := strconv.Atoi(intString) if err != nil { - return 0, true, dbsqlerrint.NewRequestError(context.TODO(), dbsqlerr.InvalidDSNFormat(key, intString, "int"), err) + return 0, true, dbsqlerrint.NewRequestError( + context.TODO(), + dbsqlerr.InvalidDSNFormat(key, intString, "int"), + err, + ) } return i, true, nil @@ -408,7 +419,11 @@ func (params *extractableParams) extractAsBool(key string) (bool, bool, error) { if boolString, ok := extractParam(key, params, false, true); ok { b, err := strconv.ParseBool(boolString) if err != nil { - return false, true, dbsqlerrint.NewRequestError(context.TODO(), dbsqlerr.InvalidDSNFormat(key, boolString, "bool"), err) + return false, true, dbsqlerrint.NewRequestError( + context.TODO(), + dbsqlerr.InvalidDSNFormat(key, boolString, "bool"), + err, + ) } return b, true, nil @@ -422,7 +437,12 @@ func (params *extractableParams) getNoCase(key string) (string, bool) { return extractParam(key, params, true, false) } -func extractParam(key string, params *extractableParams, ignoreCase bool, delValue bool) (string, bool) { +func extractParam( + key string, + params *extractableParams, + ignoreCase bool, + delValue bool, +) (string, bool) { if ignoreCase { key = strings.ToLower(key) } @@ -479,6 +499,7 @@ type CloudFetchConfig struct { MaxFilesInMemory int MinTimeToExpiry time.Duration CloudFetchSpeedThresholdMbps float64 // Minimum download speed in MBps before WARN logging (default: 0.1) + HttpClient *http.Client } func (cfg CloudFetchConfig) WithDefaults() CloudFetchConfig { diff --git a/internal/rows/arrowbased/batchloader.go b/internal/rows/arrowbased/batchloader.go index d26d8a4a..06b6361e 100644 --- a/internal/rows/arrowbased/batchloader.go +++ b/internal/rows/arrowbased/batchloader.go @@ -40,6 +40,7 @@ func NewCloudIPCStreamIterator( startRowOffset: startRowOffset, pendingLinks: NewQueue[cli_service.TSparkArrowResultLink](), downloadTasks: NewQueue[cloudFetchDownloadTask](), + httpClient: cfg.CloudFetchConfig.HttpClient, } for _, link := range files { @@ -140,6 +141,7 @@ type cloudIPCStreamIterator struct { startRowOffset int64 pendingLinks Queue[cli_service.TSparkArrowResultLink] downloadTasks Queue[cloudFetchDownloadTask] + httpClient *http.Client } var _ IPCStreamIterator = (*cloudIPCStreamIterator)(nil) @@ -162,6 +164,7 @@ func (bi *cloudIPCStreamIterator) Next() (io.Reader, error) { resultChan: make(chan cloudFetchDownloadTaskResult), minTimeToExpiry: bi.cfg.MinTimeToExpiry, speedThresholdMbps: bi.cfg.CloudFetchSpeedThresholdMbps, + httpClient: bi.httpClient, } task.Run() bi.downloadTasks.Enqueue(task) @@ -210,6 +213,7 @@ type cloudFetchDownloadTask struct { link *cli_service.TSparkArrowResultLink resultChan chan cloudFetchDownloadTaskResult speedThresholdMbps float64 + httpClient *http.Client } func (cft *cloudFetchDownloadTask) GetResult() (io.Reader, error) { @@ -252,7 +256,7 @@ func (cft *cloudFetchDownloadTask) Run() { cft.link.StartRowOffset, cft.link.RowCount, ) - data, err := fetchBatchBytes(cft.ctx, cft.link, cft.minTimeToExpiry, cft.speedThresholdMbps) + data, err := fetchBatchBytes(cft.ctx, cft.link, cft.minTimeToExpiry, cft.speedThresholdMbps, cft.httpClient) if err != nil { cft.resultChan <- cloudFetchDownloadTaskResult{data: nil, err: err} return @@ -300,6 +304,7 @@ func fetchBatchBytes( link *cli_service.TSparkArrowResultLink, minTimeToExpiry time.Duration, speedThresholdMbps float64, + httpClient *http.Client, ) (io.ReadCloser, error) { if isLinkExpired(link.ExpiryTime, minTimeToExpiry) { return nil, errors.New(dbsqlerr.ErrLinkExpired) @@ -317,9 +322,12 @@ func fetchBatchBytes( } } + if httpClient == nil { + httpClient = http.DefaultClient + } + startTime := time.Now() - client := http.DefaultClient - res, err := client.Do(req) + res, err := httpClient.Do(req) if err != nil { return nil, err } diff --git a/internal/rows/arrowbased/batchloader_test.go b/internal/rows/arrowbased/batchloader_test.go index b018eb6d..0b2aecb7 100644 --- a/internal/rows/arrowbased/batchloader_test.go +++ b/internal/rows/arrowbased/batchloader_test.go @@ -253,6 +253,103 @@ func TestCloudFetchIterator(t *testing.T) { assert.NotNil(t, err3) assert.ErrorContains(t, err3, fmt.Sprintf("%s %d", "HTTP error", http.StatusNotFound)) }) + + t.Run("should use custom HTTPClient when provided", func(t *testing.T) { + customClient := &http.Client{Timeout: 5 * time.Second} + requestCount := 0 + + handler = func(w http.ResponseWriter, r *http.Request) { + requestCount++ + w.WriteHeader(http.StatusOK) + _, err := w.Write(generateMockArrowBytes(generateArrowRecord())) + if err != nil { + panic(err) + } + } + + startRowOffset := int64(100) + + links := []*cli_service.TSparkArrowResultLink{ + { + FileLink: server.URL, + ExpiryTime: time.Now().Add(10 * time.Minute).Unix(), + StartRowOffset: startRowOffset, + RowCount: 1, + }, + } + + cfg := config.WithDefaults() + cfg.UseLz4Compression = false + cfg.MaxDownloadThreads = 1 + cfg.UserConfig.CloudFetchConfig.HttpClient = customClient + + bi, err := NewCloudBatchIterator( + context.Background(), + links, + startRowOffset, + cfg, + ) + assert.Nil(t, err) + + // Verify custom client is passed through the iterator chain + wrapper, ok := bi.(*batchIterator) + assert.True(t, ok) + cbi, ok := wrapper.ipcIterator.(*cloudIPCStreamIterator) + assert.True(t, ok) + assert.Equal(t, customClient, cbi.httpClient) + + // Fetch should work with custom client + sab1, nextErr := bi.Next() + assert.Nil(t, nextErr) + assert.NotNil(t, sab1) + assert.Greater(t, requestCount, 0) // Verify request was made + }) + + t.Run("should use http.DefaultClient when HTTPClient is nil", func(t *testing.T) { + handler = func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, err := w.Write(generateMockArrowBytes(generateArrowRecord())) + if err != nil { + panic(err) + } + } + + startRowOffset := int64(100) + + links := []*cli_service.TSparkArrowResultLink{ + { + FileLink: server.URL, + ExpiryTime: time.Now().Add(10 * time.Minute).Unix(), + StartRowOffset: startRowOffset, + RowCount: 1, + }, + } + + cfg := config.WithDefaults() + cfg.UseLz4Compression = false + cfg.MaxDownloadThreads = 1 + // HTTPClient is nil by default + + bi, err := NewCloudBatchIterator( + context.Background(), + links, + startRowOffset, + cfg, + ) + assert.Nil(t, err) + + // Verify nil client is passed through + wrapper, ok := bi.(*batchIterator) + assert.True(t, ok) + cbi, ok := wrapper.ipcIterator.(*cloudIPCStreamIterator) + assert.True(t, ok) + assert.Nil(t, cbi.httpClient) + + // Fetch should work (falls back to http.DefaultClient) + sab1, nextErr := bi.Next() + assert.Nil(t, nextErr) + assert.NotNil(t, sab1) + }) } func generateArrowRecord() arrow.Record {