diff --git a/pkg/agent/loop.go b/pkg/agent/loop.go index 8fd7328d10..a72f95bb16 100644 --- a/pkg/agent/loop.go +++ b/pkg/agent/loop.go @@ -99,7 +99,7 @@ func registerSharedTools( } // Web tools - if searchTool := tools.NewWebSearchTool(tools.WebSearchToolOptions{ + searchTool, err := tools.NewWebSearchTool(tools.WebSearchToolOptions{ BraveAPIKey: cfg.Tools.Web.Brave.APIKey, BraveMaxResults: cfg.Tools.Web.Brave.MaxResults, BraveEnabled: cfg.Tools.Web.Brave.Enabled, @@ -113,10 +113,18 @@ func registerSharedTools( PerplexityMaxResults: cfg.Tools.Web.Perplexity.MaxResults, PerplexityEnabled: cfg.Tools.Web.Perplexity.Enabled, Proxy: cfg.Tools.Web.Proxy, - }); searchTool != nil { + }) + if err != nil { + logger.ErrorCF("agent", "Failed to create web search tool", map[string]any{"error": err.Error()}) + } else if searchTool != nil { agent.Tools.Register(searchTool) } - agent.Tools.Register(tools.NewWebFetchToolWithProxy(50000, cfg.Tools.Web.Proxy)) + fetchTool, err := tools.NewWebFetchToolWithProxy(50000, cfg.Tools.Web.Proxy) + if err != nil { + logger.ErrorCF("agent", "Failed to create web fetch tool", map[string]any{"error": err.Error()}) + } else { + agent.Tools.Register(fetchTool) + } // Hardware tools (I2C, SPI) - Linux only, returns error on other platforms agent.Tools.Register(tools.NewI2CTool()) diff --git a/pkg/channels/line/line.go b/pkg/channels/line/line.go index 9fac2831c7..398f12e6b8 100644 --- a/pkg/channels/line/line.go +++ b/pkg/channels/line/line.go @@ -45,11 +45,13 @@ type replyTokenEntry struct { type LINEChannel struct { *channels.BaseChannel config config.LINEConfig - botUserID string // Bot's user ID - botBasicID string // Bot's basic ID (e.g. @216ru...) - botDisplayName string // Bot's display name for text-based mention detection - replyTokens sync.Map // chatID -> replyTokenEntry - quoteTokens sync.Map // chatID -> quoteToken (string) + infoClient *http.Client // for bot info lookups (short timeout) + apiClient *http.Client // for messaging API calls + botUserID string // Bot's user ID + botBasicID string // Bot's basic ID (e.g. @216ru...) + botDisplayName string // Bot's display name for text-based mention detection + replyTokens sync.Map // chatID -> replyTokenEntry + quoteTokens sync.Map // chatID -> quoteToken (string) ctx context.Context cancel context.CancelFunc } @@ -69,6 +71,8 @@ func NewLINEChannel(cfg config.LINEConfig, messageBus *bus.MessageBus) (*LINECha return &LINEChannel{ BaseChannel: base, config: cfg, + infoClient: &http.Client{Timeout: 10 * time.Second}, + apiClient: &http.Client{Timeout: 30 * time.Second}, }, nil } @@ -104,8 +108,7 @@ func (c *LINEChannel) fetchBotInfo() error { } req.Header.Set("Authorization", "Bearer "+c.config.ChannelAccessToken) - client := &http.Client{Timeout: 10 * time.Second} - resp, err := client.Do(req) + resp, err := c.infoClient.Do(req) if err != nil { return err } @@ -644,8 +647,7 @@ func (c *LINEChannel) callAPI(ctx context.Context, endpoint string, payload any) req.Header.Set("Content-Type", "application/json") req.Header.Set("Authorization", "Bearer "+c.config.ChannelAccessToken) - client := &http.Client{Timeout: 30 * time.Second} - resp, err := client.Do(req) + resp, err := c.apiClient.Do(req) if err != nil { return channels.ClassifyNetError(err) } diff --git a/pkg/channels/wecom/app.go b/pkg/channels/wecom/app.go index 7a23f9617c..292a71fd28 100644 --- a/pkg/channels/wecom/app.go +++ b/pkg/channels/wecom/app.go @@ -32,6 +32,7 @@ const ( type WeComAppChannel struct { *channels.BaseChannel config config.WeComAppConfig + client *http.Client accessToken string tokenExpiry time.Time tokenMu sync.RWMutex @@ -129,10 +130,18 @@ func NewWeComAppChannel(cfg config.WeComAppConfig, messageBus *bus.MessageBus) ( channels.WithReasoningChannelID(cfg.ReasoningChannelID), ) + // Client timeout must be >= the configured ReplyTimeout so the + // per-request context deadline is always the effective limit. + clientTimeout := 30 * time.Second + if d := time.Duration(cfg.ReplyTimeout) * time.Second; d > clientTimeout { + clientTimeout = d + } + ctx, cancel := context.WithCancel(context.Background()) return &WeComAppChannel{ BaseChannel: base, config: cfg, + client: &http.Client{Timeout: clientTimeout}, ctx: ctx, cancel: cancel, processedMsgs: make(map[string]bool), @@ -306,8 +315,7 @@ func (c *WeComAppChannel) uploadMedia(ctx context.Context, accessToken, mediaTyp } req.Header.Set("Content-Type", writer.FormDataContentType()) - client := &http.Client{Timeout: 30 * time.Second} - resp, err := client.Do(req) + resp, err := c.client.Do(req) if err != nil { return "", channels.ClassifyNetError(err) } @@ -364,8 +372,7 @@ func (c *WeComAppChannel) sendImageMessage(ctx context.Context, accessToken, use } req.Header.Set("Content-Type", "application/json") - client := &http.Client{Timeout: time.Duration(timeout) * time.Second} - resp, err := client.Do(req) + resp, err := c.client.Do(req) if err != nil { return channels.ClassifyNetError(err) } @@ -746,8 +753,7 @@ func (c *WeComAppChannel) sendTextMessage(ctx context.Context, accessToken, user } req.Header.Set("Content-Type", "application/json") - client := &http.Client{Timeout: time.Duration(timeout) * time.Second} - resp, err := client.Do(req) + resp, err := c.client.Do(req) if err != nil { return channels.ClassifyNetError(err) } diff --git a/pkg/channels/wecom/bot.go b/pkg/channels/wecom/bot.go index 39f84d55c4..0d0426c0dd 100644 --- a/pkg/channels/wecom/bot.go +++ b/pkg/channels/wecom/bot.go @@ -25,6 +25,7 @@ import ( type WeComBotChannel struct { *channels.BaseChannel config config.WeComConfig + client *http.Client ctx context.Context cancel context.CancelFunc processedMsgs map[string]bool // Message deduplication: msg_id -> processed @@ -93,10 +94,18 @@ func NewWeComBotChannel(cfg config.WeComConfig, messageBus *bus.MessageBus) (*We channels.WithReasoningChannelID(cfg.ReasoningChannelID), ) + // Client timeout must be >= the configured ReplyTimeout so the + // per-request context deadline is always the effective limit. + clientTimeout := 30 * time.Second + if d := time.Duration(cfg.ReplyTimeout) * time.Second; d > clientTimeout { + clientTimeout = d + } + ctx, cancel := context.WithCancel(context.Background()) return &WeComBotChannel{ BaseChannel: base, config: cfg, + client: &http.Client{Timeout: clientTimeout}, ctx: ctx, cancel: cancel, processedMsgs: make(map[string]bool), @@ -450,8 +459,7 @@ func (c *WeComBotChannel) sendWebhookReply(ctx context.Context, userID, content } req.Header.Set("Content-Type", "application/json") - client := &http.Client{Timeout: time.Duration(timeout) * time.Second} - resp, err := client.Do(req) + resp, err := c.client.Do(req) if err != nil { return channels.ClassifyNetError(err) } diff --git a/pkg/tools/web.go b/pkg/tools/web.go index 8ba2a723aa..834e7bfc7e 100644 --- a/pkg/tools/web.go +++ b/pkg/tools/web.go @@ -15,6 +15,14 @@ import ( const ( userAgent = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36" + + // HTTP client timeouts for web tool providers. + searchTimeout = 10 * time.Second // Brave, Tavily, DuckDuckGo + perplexityTimeout = 30 * time.Second // Perplexity (LLM-based, slower) + fetchTimeout = 60 * time.Second // WebFetchTool + + defaultMaxChars = 50000 + maxRedirects = 5 ) // Pre-compiled regexes for HTML text extraction @@ -74,6 +82,7 @@ type SearchProvider interface { type BraveSearchProvider struct { apiKey string proxy string + client *http.Client } func (p *BraveSearchProvider) Search(ctx context.Context, query string, count int) (string, error) { @@ -88,11 +97,7 @@ func (p *BraveSearchProvider) Search(ctx context.Context, query string, count in req.Header.Set("Accept", "application/json") req.Header.Set("X-Subscription-Token", p.apiKey) - client, err := createHTTPClient(p.proxy, 10*time.Second) - if err != nil { - return "", fmt.Errorf("failed to create HTTP client: %w", err) - } - resp, err := client.Do(req) + resp, err := p.client.Do(req) if err != nil { return "", fmt.Errorf("request failed: %w", err) } @@ -143,6 +148,7 @@ type TavilySearchProvider struct { apiKey string baseURL string proxy string + client *http.Client } func (p *TavilySearchProvider) Search(ctx context.Context, query string, count int) (string, error) { @@ -174,11 +180,7 @@ func (p *TavilySearchProvider) Search(ctx context.Context, query string, count i req.Header.Set("Content-Type", "application/json") req.Header.Set("User-Agent", userAgent) - client, err := createHTTPClient(p.proxy, 10*time.Second) - if err != nil { - return "", fmt.Errorf("failed to create HTTP client: %w", err) - } - resp, err := client.Do(req) + resp, err := p.client.Do(req) if err != nil { return "", fmt.Errorf("request failed: %w", err) } @@ -226,7 +228,8 @@ func (p *TavilySearchProvider) Search(ctx context.Context, query string, count i } type DuckDuckGoSearchProvider struct { - proxy string + proxy string + client *http.Client } func (p *DuckDuckGoSearchProvider) Search(ctx context.Context, query string, count int) (string, error) { @@ -239,11 +242,7 @@ func (p *DuckDuckGoSearchProvider) Search(ctx context.Context, query string, cou req.Header.Set("User-Agent", userAgent) - client, err := createHTTPClient(p.proxy, 10*time.Second) - if err != nil { - return "", fmt.Errorf("failed to create HTTP client: %w", err) - } - resp, err := client.Do(req) + resp, err := p.client.Do(req) if err != nil { return "", fmt.Errorf("request failed: %w", err) } @@ -322,6 +321,7 @@ func stripTags(content string) string { type PerplexitySearchProvider struct { apiKey string proxy string + client *http.Client } func (p *PerplexitySearchProvider) Search(ctx context.Context, query string, count int) (string, error) { @@ -356,11 +356,7 @@ func (p *PerplexitySearchProvider) Search(ctx context.Context, query string, cou req.Header.Set("Authorization", "Bearer "+p.apiKey) req.Header.Set("User-Agent", userAgent) - client, err := createHTTPClient(p.proxy, 30*time.Second) - if err != nil { - return "", fmt.Errorf("failed to create HTTP client: %w", err) - } - resp, err := client.Do(req) + resp, err := p.client.Do(req) if err != nil { return "", fmt.Errorf("request failed: %w", err) } @@ -415,43 +411,60 @@ type WebSearchToolOptions struct { Proxy string } -func NewWebSearchTool(opts WebSearchToolOptions) *WebSearchTool { +func NewWebSearchTool(opts WebSearchToolOptions) (*WebSearchTool, error) { var provider SearchProvider maxResults := 5 // Priority: Perplexity > Brave > Tavily > DuckDuckGo if opts.PerplexityEnabled && opts.PerplexityAPIKey != "" { - provider = &PerplexitySearchProvider{apiKey: opts.PerplexityAPIKey, proxy: opts.Proxy} + client, err := createHTTPClient(opts.Proxy, perplexityTimeout) + if err != nil { + return nil, fmt.Errorf("failed to create HTTP client for Perplexity: %w", err) + } + provider = &PerplexitySearchProvider{apiKey: opts.PerplexityAPIKey, proxy: opts.Proxy, client: client} if opts.PerplexityMaxResults > 0 { maxResults = opts.PerplexityMaxResults } } else if opts.BraveEnabled && opts.BraveAPIKey != "" { - provider = &BraveSearchProvider{apiKey: opts.BraveAPIKey, proxy: opts.Proxy} + client, err := createHTTPClient(opts.Proxy, searchTimeout) + if err != nil { + return nil, fmt.Errorf("failed to create HTTP client for Brave: %w", err) + } + provider = &BraveSearchProvider{apiKey: opts.BraveAPIKey, proxy: opts.Proxy, client: client} if opts.BraveMaxResults > 0 { maxResults = opts.BraveMaxResults } } else if opts.TavilyEnabled && opts.TavilyAPIKey != "" { + client, err := createHTTPClient(opts.Proxy, searchTimeout) + if err != nil { + return nil, fmt.Errorf("failed to create HTTP client for Tavily: %w", err) + } provider = &TavilySearchProvider{ apiKey: opts.TavilyAPIKey, baseURL: opts.TavilyBaseURL, proxy: opts.Proxy, + client: client, } if opts.TavilyMaxResults > 0 { maxResults = opts.TavilyMaxResults } } else if opts.DuckDuckGoEnabled { - provider = &DuckDuckGoSearchProvider{proxy: opts.Proxy} + client, err := createHTTPClient(opts.Proxy, searchTimeout) + if err != nil { + return nil, fmt.Errorf("failed to create HTTP client for DuckDuckGo: %w", err) + } + provider = &DuckDuckGoSearchProvider{proxy: opts.Proxy, client: client} if opts.DuckDuckGoMaxResults > 0 { maxResults = opts.DuckDuckGoMaxResults } } else { - return nil + return nil, nil } return &WebSearchTool{ provider: provider, maxResults: maxResults, - } + }, nil } func (t *WebSearchTool) Name() string { @@ -508,25 +521,34 @@ func (t *WebSearchTool) Execute(ctx context.Context, args map[string]any) *ToolR type WebFetchTool struct { maxChars int proxy string + client *http.Client } func NewWebFetchTool(maxChars int) *WebFetchTool { - if maxChars <= 0 { - maxChars = 50000 - } - return &WebFetchTool{ - maxChars: maxChars, - } + // createHTTPClient cannot fail with an empty proxy string. + tool, _ := NewWebFetchToolWithProxy(maxChars, "") + return tool } -func NewWebFetchToolWithProxy(maxChars int, proxy string) *WebFetchTool { +func NewWebFetchToolWithProxy(maxChars int, proxy string) (*WebFetchTool, error) { if maxChars <= 0 { - maxChars = 50000 + maxChars = defaultMaxChars + } + client, err := createHTTPClient(proxy, fetchTimeout) + if err != nil { + return nil, fmt.Errorf("failed to create HTTP client for web fetch: %w", err) + } + client.CheckRedirect = func(req *http.Request, via []*http.Request) error { + if len(via) >= maxRedirects { + return fmt.Errorf("stopped after %d redirects", maxRedirects) + } + return nil } return &WebFetchTool{ maxChars: maxChars, proxy: proxy, - } + client: client, + }, nil } func (t *WebFetchTool) Name() string { @@ -588,20 +610,7 @@ func (t *WebFetchTool) Execute(ctx context.Context, args map[string]any) *ToolRe req.Header.Set("User-Agent", userAgent) - client, err := createHTTPClient(t.proxy, 60*time.Second) - if err != nil { - return ErrorResult(fmt.Sprintf("failed to create HTTP client: %v", err)) - } - - // Configure redirect handling - client.CheckRedirect = func(req *http.Request, via []*http.Request) error { - if len(via) >= 5 { - return fmt.Errorf("stopped after 5 redirects") - } - return nil - } - - resp, err := client.Do(req) + resp, err := t.client.Do(req) if err != nil { return ErrorResult(fmt.Sprintf("request failed: %v", err)) } diff --git a/pkg/tools/web_test.go b/pkg/tools/web_test.go index 2cd79eb241..db3c08ba65 100644 --- a/pkg/tools/web_test.go +++ b/pkg/tools/web_test.go @@ -176,13 +176,19 @@ func TestWebTool_WebFetch_Truncation(t *testing.T) { // TestWebTool_WebSearch_NoApiKey verifies that no tool is created when API key is missing func TestWebTool_WebSearch_NoApiKey(t *testing.T) { - tool := NewWebSearchTool(WebSearchToolOptions{BraveEnabled: true, BraveAPIKey: ""}) + tool, err := NewWebSearchTool(WebSearchToolOptions{BraveEnabled: true, BraveAPIKey: ""}) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } if tool != nil { t.Errorf("Expected nil tool when Brave API key is empty") } // Also nil when nothing is enabled - tool = NewWebSearchTool(WebSearchToolOptions{}) + tool, err = NewWebSearchTool(WebSearchToolOptions{}) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } if tool != nil { t.Errorf("Expected nil tool when no provider is enabled") } @@ -190,7 +196,10 @@ func TestWebTool_WebSearch_NoApiKey(t *testing.T) { // TestWebTool_WebSearch_MissingQuery verifies error handling for missing query func TestWebTool_WebSearch_MissingQuery(t *testing.T) { - tool := NewWebSearchTool(WebSearchToolOptions{BraveEnabled: true, BraveAPIKey: "test-key", BraveMaxResults: 5}) + tool, err := NewWebSearchTool(WebSearchToolOptions{BraveEnabled: true, BraveAPIKey: "test-key", BraveMaxResults: 5}) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } ctx := context.Background() args := map[string]any{} @@ -438,7 +447,10 @@ func TestCreateHTTPClient_ProxyFromEnvironmentWhenConfigEmpty(t *testing.T) { } func TestNewWebFetchToolWithProxy(t *testing.T) { - tool := NewWebFetchToolWithProxy(1024, "http://127.0.0.1:7890") + tool, err := NewWebFetchToolWithProxy(1024, "http://127.0.0.1:7890") + if err != nil { + t.Fatalf("NewWebFetchToolWithProxy() error: %v", err) + } if tool.maxChars != 1024 { t.Fatalf("maxChars = %d, want %d", tool.maxChars, 1024) } @@ -446,7 +458,10 @@ func TestNewWebFetchToolWithProxy(t *testing.T) { t.Fatalf("proxy = %q, want %q", tool.proxy, "http://127.0.0.1:7890") } - tool = NewWebFetchToolWithProxy(0, "http://127.0.0.1:7890") + tool, err = NewWebFetchToolWithProxy(0, "http://127.0.0.1:7890") + if err != nil { + t.Fatalf("NewWebFetchToolWithProxy() error: %v", err) + } if tool.maxChars != 50000 { t.Fatalf("default maxChars = %d, want %d", tool.maxChars, 50000) } @@ -454,12 +469,15 @@ func TestNewWebFetchToolWithProxy(t *testing.T) { func TestNewWebSearchTool_PropagatesProxy(t *testing.T) { t.Run("perplexity", func(t *testing.T) { - tool := NewWebSearchTool(WebSearchToolOptions{ + tool, err := NewWebSearchTool(WebSearchToolOptions{ PerplexityEnabled: true, PerplexityAPIKey: "k", PerplexityMaxResults: 3, Proxy: "http://127.0.0.1:7890", }) + if err != nil { + t.Fatalf("NewWebSearchTool() error: %v", err) + } p, ok := tool.provider.(*PerplexitySearchProvider) if !ok { t.Fatalf("provider type = %T, want *PerplexitySearchProvider", tool.provider) @@ -470,12 +488,15 @@ func TestNewWebSearchTool_PropagatesProxy(t *testing.T) { }) t.Run("brave", func(t *testing.T) { - tool := NewWebSearchTool(WebSearchToolOptions{ + tool, err := NewWebSearchTool(WebSearchToolOptions{ BraveEnabled: true, BraveAPIKey: "k", BraveMaxResults: 3, Proxy: "http://127.0.0.1:7890", }) + if err != nil { + t.Fatalf("NewWebSearchTool() error: %v", err) + } p, ok := tool.provider.(*BraveSearchProvider) if !ok { t.Fatalf("provider type = %T, want *BraveSearchProvider", tool.provider) @@ -486,11 +507,14 @@ func TestNewWebSearchTool_PropagatesProxy(t *testing.T) { }) t.Run("duckduckgo", func(t *testing.T) { - tool := NewWebSearchTool(WebSearchToolOptions{ + tool, err := NewWebSearchTool(WebSearchToolOptions{ DuckDuckGoEnabled: true, DuckDuckGoMaxResults: 3, Proxy: "http://127.0.0.1:7890", }) + if err != nil { + t.Fatalf("NewWebSearchTool() error: %v", err) + } p, ok := tool.provider.(*DuckDuckGoSearchProvider) if !ok { t.Fatalf("provider type = %T, want *DuckDuckGoSearchProvider", tool.provider) @@ -542,12 +566,15 @@ func TestWebTool_TavilySearch_Success(t *testing.T) { })) defer server.Close() - tool := NewWebSearchTool(WebSearchToolOptions{ + tool, err := NewWebSearchTool(WebSearchToolOptions{ TavilyEnabled: true, TavilyAPIKey: "test-key", TavilyBaseURL: server.URL, TavilyMaxResults: 5, }) + if err != nil { + t.Fatalf("NewWebSearchTool() error: %v", err) + } ctx := context.Background() args := map[string]any{ diff --git a/pkg/utils/http_retry.go b/pkg/utils/http_retry.go index e90fa21294..135ea0ef52 100644 --- a/pkg/utils/http_retry.go +++ b/pkg/utils/http_retry.go @@ -37,6 +37,9 @@ func DoRequestWithRetry(client *http.Client, req *http.Request) (*http.Response, if i < maxRetries-1 { if err = sleepWithCtx(req.Context(), retryDelayUnit*time.Duration(i+1)); err != nil { + if resp != nil { + resp.Body.Close() + } return nil, fmt.Errorf("failed to sleep: %w", err) } } diff --git a/pkg/utils/http_retry_test.go b/pkg/utils/http_retry_test.go index 1c2dbe115e..d64cd5edaa 100644 --- a/pkg/utils/http_retry_test.go +++ b/pkg/utils/http_retry_test.go @@ -1,8 +1,11 @@ package utils import ( + "context" + "io" "net/http" "net/http/httptest" + "strings" "testing" "time" @@ -77,6 +80,91 @@ func TestDoRequestWithRetry(t *testing.T) { } } +func TestDoRequestWithRetry_ContextCancel(t *testing.T) { + // Use a long retry delay so cancellation always hits during sleepWithCtx. + retryDelayUnit = 10 * time.Second + t.Cleanup(func() { retryDelayUnit = time.Second }) + + bodyClosed := false + firstRoundTripDone := make(chan struct{}, 1) + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + w.Write([]byte("error")) + })) + defer server.Close() + + client := server.Client() + client.Timeout = 30 * time.Second + client.Transport = &bodyCloseTracker{ + rt: client.Transport, + onClose: func() { bodyClosed = true }, + // Signal after the first round-trip response is fully constructed on the client side. + onRoundTrip: func() { + select { + case firstRoundTripDone <- struct{}{}: + default: + } + }, + trackURL: server.URL, + } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + // Cancel the context after the first round-trip completes on the client side. + // This ensures client.Do has returned a valid resp (with body) and the retry + // loop is about to enter sleepWithCtx, where the cancel will be detected. + go func() { + <-firstRoundTripDone + cancel() + }() + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, server.URL, nil) + require.NoError(t, err) + + resp, err := DoRequestWithRetry(client, req) + if resp != nil { + resp.Body.Close() + } + require.Error(t, err, "expected error from context cancellation") + assert.Nil(t, resp, "expected nil response when context is canceled") + assert.True(t, bodyClosed, "expected resp.Body to be closed on context cancellation") +} + +// bodyCloseTracker wraps an http.RoundTripper and records when response bodies are closed. +type bodyCloseTracker struct { + rt http.RoundTripper + onClose func() + onRoundTrip func() // called after each successful round-trip + trackURL string +} + +func (t *bodyCloseTracker) RoundTrip(req *http.Request) (*http.Response, error) { + resp, err := t.rt.RoundTrip(req) + if err != nil { + return resp, err + } + if strings.HasPrefix(req.URL.String(), t.trackURL) { + resp.Body = &closeNotifier{ReadCloser: resp.Body, onClose: t.onClose} + if t.onRoundTrip != nil { + t.onRoundTrip() + } + } + return resp, nil +} + +// closeNotifier wraps an io.ReadCloser to detect Close calls. +type closeNotifier struct { + io.ReadCloser + onClose func() +} + +func (c *closeNotifier) Close() error { + c.onClose() + return c.ReadCloser.Close() +} + func TestDoRequestWithRetry_Delay(t *testing.T) { retryDelayUnit = time.Millisecond t.Cleanup(func() { retryDelayUnit = time.Second })