Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions default_client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ func TestConnectionReuse(t *testing.T) {

router := httprouter.New()
router.GET("/", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
fmt.Fprintf(w, "this is a test")
_, _ = fmt.Fprintf(w, "this is a test")
})
ts := httptest.NewServer(router)
defer ts.Close()
Expand All @@ -55,7 +55,7 @@ func TestConnectionReuse(t *testing.T) {
resp, err := client.Do(req)
require.Nil(t, err)
_, _ = io.Copy(io.Discard, resp.Body)
resp.Body.Close()
_ = resp.Body.Close()
}
}()
}
Expand Down
6 changes: 3 additions & 3 deletions do.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ func (c *Client) Do(req *Request) (*http.Response, error) {
// By default, we close the response body and return an error without
// returning the response
if resp != nil {
resp.Body.Close()
_ = resp.Body.Close()
}
c.closeIdleConnections()
return nil, fmt.Errorf("%s %s giving up after %d attempts: %w", req.Method, req.URL, retryMax+1, err)
Expand All @@ -139,7 +139,7 @@ func (c *Client) drainBody(req *Request, resp *http.Response) {
if err != nil {
req.Metrics.DrainErrors++
}
resp.Body.Close()
_ = resp.Body.Close()
}

const closeConnectionsCounter = 100
Expand Down Expand Up @@ -248,5 +248,5 @@ func (c *Client) wrapContextWithTrace(req *Request) {
}
req.TraceInfo = traceInfo

req.Request = req.Request.WithContext(httptrace.WithClientTrace(req.Request.Context(), trace))
req.Request = req.Request.WithContext(httptrace.WithClientTrace(req.Context(), trace))
}
109 changes: 101 additions & 8 deletions request.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"bytes"
"context"
"fmt"
"io"
"net/http"
"net/http/httptrace"
"net/http/httputil"
Expand Down Expand Up @@ -92,7 +93,7 @@ func (r *Request) WithContext(ctx context.Context) *Request {
// This function is not thread-safe; do not call it at the same time as another
// call, or at the same time this request is being used with Client.Do.
func (r *Request) BodyBytes() ([]byte, error) {
if r.Request.Body == nil {
if r.Body == nil {
return nil, nil
}
buf := new(bytes.Buffer)
Expand All @@ -103,6 +104,96 @@ func (r *Request) BodyBytes() ([]byte, error) {
return buf.Bytes(), nil
}

// SetBodyReader sets the request body and populates GetBody for retries.
//
// The provided body MUST be reusable (e.g. created via
// [readerutil.NewReusableReadCloser]).
// If the body is not reusable, retries and 307/308 redirects will send an
// empty body.
//
// This method does NOT set content length. The caller must set it manually.
//
// Prefer [SetBody], [SetBodyString], or [SetBodyStream] which handle
// reusability and content length automatically.
func (r *Request) SetBodyReader(body io.ReadCloser) {
r.Body = body
r.GetBody = func() (io.ReadCloser, error) {
return body, nil
}
}

// SetBody sets the request body from a byte slice.
// It creates a reusable reader, sets content length, and populates GetBody for
// retries.
func (r *Request) SetBody(body []byte) error {
bodyReader, err := readerutil.NewReusableReadCloser(body)
if err != nil {
return err
}

r.Body = bodyReader
r.ContentLength = int64(len(body))
r.GetBody = func() (io.ReadCloser, error) {
return readerutil.NewReusableReadCloser(body)
}

return nil
}

// SetBodyString sets the request body from a string.
// It creates a reusable reader, sets content length, and populates GetBody for
// retries.
func (r *Request) SetBodyString(body string) error {
return r.SetBody([]byte(body))
}

// SetBodyStream sets the request body from an [io.Reader].
//
// If bodySize is >= 0, it reads exactly that many bytes.
// If bodySize < 0, it reads bodyStream until io.EOF.
//
// It creates a reusable reader, calculates and sets content length, and
// populates GetBody for retries.
//
// If bodyStream implements [io.Closer], it is closed after the content is
// read into the reusable reader.
func (r *Request) SetBodyStream(bodyStream io.Reader, bodySize int64) error {
if closer, ok := bodyStream.(io.Closer); ok {
defer closer.Close()

// Wrap in NopCloser to prevent NewReusableReadCloser from closing it,
// since we are handling the close via defer.
bodyStream = io.NopCloser(bodyStream)
}

if bodySize >= 0 {
bodyStream = io.LimitReader(bodyStream, bodySize)
}

bodyReader, err := readerutil.NewReusableReadCloser(bodyStream)
if err != nil {
return err
}

r.SetBodyReader(bodyReader)

// If bodySize is provided, use it as ContentLength
if bodySize >= 0 {
r.ContentLength = bodySize
return nil
}

// Otherwise, calculate the length by reading the body
length, err := getLength(bodyReader)
if err == nil {
r.ContentLength = length
} else {
r.ContentLength = 0
}

return nil
}

// Update request URL with new changes of parameters if any
func (r *Request) Update() {
// Make a copy of the URL to avoid data races
Expand Down Expand Up @@ -192,7 +283,7 @@ func FromRequest(r *http.Request) (*Request, error) {
if err != nil {
return nil, err
}
r.Body = body
req.SetBodyReader(body)
req.ContentLength, err = getLength(body)
if err != nil {
return nil, err
Expand Down Expand Up @@ -252,21 +343,23 @@ func NewRequestFromURLWithContext(ctx context.Context, method string, urlx *urlu
return nil, err
}
urlx.Update()

httpReq.URL = urlx.URL
updateScheme(httpReq.URL)
// content-length and body should be assigned only
// if request has body
if bodyReader != nil {
httpReq.ContentLength = contentLength
httpReq.Body = bodyReader
}

request := &Request{
Request: httpReq,
URL: urlx,
Metrics: Metrics{},
}

// content-length and body should be assigned only
// if request has body
if bodyReader != nil {
request.SetBodyReader(bodyReader)
request.ContentLength = contentLength
}

return request, nil
}

Expand Down
Loading
Loading