diff --git a/CHANGELOG_PENDING.md b/CHANGELOG_PENDING.md index cc08e0dc0b3..74a46fb4e84 100644 --- a/CHANGELOG_PENDING.md +++ b/CHANGELOG_PENDING.md @@ -14,6 +14,10 @@ * Provides more customization options for retryer by adding a constructor for default Retryer which accepts functional options. Adds NoOpRetryer to support no retry behavior. Exposes members of default retryer. * Updates the underlying logic used by the default retryer to calculate jittered delay for retry. Handles int overflow for retry delay. * Fixes [#370](https://github.com/aws/aws-sdk-go-v2/issues/370) +* `aws` : Refactors request retry behavior path logic ([#384](https://github.com/aws/aws-sdk-go-v2/pull/384)) + * Retry utilities now follow a consistent code path. aws.IsErrorRetryable is the primary entry point to determine if a request is retryable. + * Corrects sdk's behavior by not retrying errors with status code 501. Adds support for retrying the Kinesis API error, LimitExceededException. + * Fixes [#372](https://github.com/aws/aws-sdk-go-v2/issues/372), [#145](https://github.com/aws/aws-sdk-go-v2/issues/145) ### SDK Bugs * `aws`: Fixes bug in calculating throttled retry delay ([#373](https://github.com/aws/aws-sdk-go-v2/pull/373)) diff --git a/aws/default_retryer.go b/aws/default_retryer.go index ae952a4d928..88b6d21d9d5 100644 --- a/aws/default_retryer.go +++ b/aws/default_retryer.go @@ -66,23 +66,21 @@ func NewDefaultRetryer(opts ...func(d *DefaultRetryer)) DefaultRetryer { // Note: RetryRules method must be a value receiver so that the // defaultRetryer is safe. func (d DefaultRetryer) RetryRules(r *Request) time.Duration { + minDelay := d.MinRetryDelay + maxDelay := d.MaxRetryDelay + var initialDelay time.Duration - throttle := d.shouldThrottle(r) - if throttle { + isThrottle := r.IsErrorThrottle() + if isThrottle { if delay, ok := getRetryAfterDelay(r); ok { initialDelay = delay } minDelay = d.MinThrottleDelay - } - - retryCount := r.RetryCount - - maxDelay := d.MaxRetryDelay - if throttle { maxDelay = d.MaxThrottleDelay } + retryCount := r.RetryCount var delay time.Duration // Logic to cap the retry count based on the minDelay provided @@ -111,26 +109,7 @@ func (d DefaultRetryer) ShouldRetry(r *Request) bool { return *r.Retryable } - if r.HTTPResponse.StatusCode >= 500 { - return true - } - return r.IsErrorRetryable() || d.shouldThrottle(r) -} - -// ShouldThrottle returns true if the request should be throttled. -func (d DefaultRetryer) shouldThrottle(r *Request) bool { - if r.HTTPResponse != nil { - switch r.HTTPResponse.StatusCode { - case 429: - case 502: - case 503: - case 504: - default: - return r.IsErrorThrottle() - } - return true - } - return r.IsErrorThrottle() + return r.IsErrorRetryable() || r.IsErrorThrottle() } // This will look in the Retry-After header, RFC 7231, for how long diff --git a/aws/default_retryer_test.go b/aws/default_retryer_test.go index 2c35e10aa6b..f8a1132554c 100644 --- a/aws/default_retryer_test.go +++ b/aws/default_retryer_test.go @@ -60,7 +60,7 @@ func TestRetryThrottleStatusCodes(t *testing.T) { d.NumMaxRetries = 100 }) for i, c := range cases { - throttle := d.shouldThrottle(&c.r) + throttle := c.r.IsErrorThrottle() retry := d.ShouldRetry(&c.r) if e, a := c.expectThrottle, throttle; e != a { diff --git a/aws/defaults/handlers.go b/aws/defaults/handlers.go index 86e155ca35a..7711f9fbace 100644 --- a/aws/defaults/handlers.go +++ b/aws/defaults/handlers.go @@ -104,7 +104,7 @@ var SendHandler = aws.NamedHandler{ // TODO remove this complexity the SDK's built http.Request should // set Request.Body to nil, if there is no body to send. #318 - if aws.NoBody == r.HTTPRequest.Body { + if http.NoBody == r.HTTPRequest.Body { // Strip off the request body if the NoBody reader was used as a // place holder for a request body. This prevents the SDK from // making requests with a request body when it would be invalid @@ -158,9 +158,10 @@ func handleSendError(r *aws.Request, err error) { Body: ioutil.NopCloser(bytes.NewReader([]byte{})), } } - // Catch all other request errors. + + // Catch all request errors, and let the retryer determine + // if the error is retryable. r.Error = awserr.New("RequestError", "send request failed", err) - r.Retryable = aws.Bool(true) // network errors are retryable // Override the error with a context canceled error, if that was canceled. ctx := r.Context() @@ -183,34 +184,36 @@ var ValidateResponseHandler = aws.NamedHandler{Name: "core.ValidateResponseHandl // AfterRetryHandler performs final checks to determine if the request should // be retried and how long to delay. -var AfterRetryHandler = aws.NamedHandler{Name: "core.AfterRetryHandler", Fn: func(r *aws.Request) { - // If one of the other handlers already set the retry state - // we don't want to override it based on the service's state - if r.Retryable == nil || r.Config.EnforceShouldRetryCheck { - r.Retryable = aws.Bool(r.ShouldRetry(r)) - } +var AfterRetryHandler = aws.NamedHandler{ + Name: "core.AfterRetryHandler", + Fn: func(r *aws.Request) { + // If one of the other handlers already set the retry state + // we don't want to override it based on the service's state + if r.Retryable == nil || r.Config.EnforceShouldRetryCheck { + r.Retryable = aws.Bool(r.ShouldRetry(r)) + } - if r.WillRetry() { - r.RetryDelay = r.RetryRules(r) + if r.WillRetry() { + r.RetryDelay = r.RetryRules(r) - if err := sdk.SleepWithContext(r.Context(), r.RetryDelay); err != nil { - r.Error = awserr.New(aws.ErrCodeRequestCanceled, - "request context canceled", err) - r.Retryable = aws.Bool(false) - return - } + if err := sdk.SleepWithContext(r.Context(), r.RetryDelay); err != nil { + r.Error = awserr.New(aws.ErrCodeRequestCanceled, + "request context canceled", err) + r.Retryable = aws.Bool(false) + return + } - // when the expired token exception occurs the credentials - // need to be expired locally so that the next request to - // get credentials will trigger a credentials refresh. - if p, ok := r.Config.Credentials.(sdk.Invalidator); ok && r.IsErrorExpired() { - p.Invalidate() - } + // when the expired token exception occurs the credentials + // need to be expired locally so that the next request to + // get credentials will trigger a credentials refresh. + if p, ok := r.Config.Credentials.(sdk.Invalidator); ok && r.IsErrorExpired() { + p.Invalidate() + } - r.RetryCount++ - r.Error = nil - } -}} + r.RetryCount++ + r.Error = nil + } + }} // ValidateEndpointHandler is a request handler to validate a request had the // appropriate Region and Endpoint set. Will set r.Error if the endpoint or diff --git a/aws/defaults/handlers_1_8_test.go b/aws/defaults/handlers_1_8_test.go deleted file mode 100644 index e9d566e215d..00000000000 --- a/aws/defaults/handlers_1_8_test.go +++ /dev/null @@ -1,44 +0,0 @@ -// +build go1.8 - -package defaults_test - -import ( - "context" - "net/http" - "net/http/httptest" - "testing" - - "github.com/aws/aws-sdk-go-v2/aws" - "github.com/aws/aws-sdk-go-v2/aws/defaults" - "github.com/aws/aws-sdk-go-v2/service/s3" -) - -func TestSendHandler_HEADNoBody(t *testing.T) { - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusOK) - })) - - cfg := defaults.Config() - cfg.Region = "mock-region" - cfg.EndpointResolver = aws.ResolveWithEndpointURL(server.URL) - - svc := s3.New(cfg) - svc.ForcePathStyle = true - - req := svc.HeadObjectRequest(&s3.HeadObjectInput{ - Bucket: aws.String("bucketname"), - Key: aws.String("keyname"), - }) - - if e, a := aws.NoBody, req.HTTPRequest.Body; e != a { - t.Fatalf("expect %T request body, got %T", e, a) - } - - _, err := req.Send(context.Background()) - if err != nil { - t.Fatalf("expect no error, got %v", err) - } - if e, a := http.StatusOK, req.HTTPResponse.StatusCode; e != a { - t.Errorf("expect %d status code, got %d", e, a) - } -} diff --git a/aws/defaults/handlers_test.go b/aws/defaults/handlers_test.go index 8db6ea0f02c..61d3baf16cb 100644 --- a/aws/defaults/handlers_test.go +++ b/aws/defaults/handlers_test.go @@ -398,3 +398,33 @@ func TestBuildContentLength_WithBody(t *testing.T) { t.Errorf("expect no error, got %v", err) } } + +func TestSendHandler_HEADNoBody(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + cfg := defaults.Config() + cfg.Region = "mock-region" + cfg.EndpointResolver = aws.ResolveWithEndpointURL(server.URL) + + svc := s3.New(cfg) + svc.ForcePathStyle = true + + req := svc.HeadObjectRequest(&s3.HeadObjectInput{ + Bucket: aws.String("bucketname"), + Key: aws.String("keyname"), + }) + + if e, a := http.NoBody, req.HTTPRequest.Body; e != a { + t.Fatalf("expect %T request body, got %T", e, a) + } + + _, err := req.Send(context.Background()) + if err != nil { + t.Fatalf("expect no error, got %v", err) + } + if e, a := http.StatusOK, req.HTTPResponse.StatusCode; e != a { + t.Errorf("expect %d status code, got %d", e, a) + } +} diff --git a/aws/handlers.go b/aws/handlers.go index be2b4535d18..f464b9c9ab0 100644 --- a/aws/handlers.go +++ b/aws/handlers.go @@ -18,6 +18,7 @@ type Handlers struct { UnmarshalError HandlerList Retry HandlerList AfterRetry HandlerList + CompleteAttempt HandlerList Complete HandlerList } @@ -34,6 +35,7 @@ func (h *Handlers) Copy() Handlers { UnmarshalMeta: h.UnmarshalMeta.copy(), Retry: h.Retry.copy(), AfterRetry: h.AfterRetry.copy(), + CompleteAttempt: h.CompleteAttempt.copy(), Complete: h.Complete.copy(), } } @@ -50,6 +52,7 @@ func (h *Handlers) Clear() { h.ValidateResponse.Clear() h.Retry.Clear() h.AfterRetry.Clear() + h.CompleteAttempt.Clear() h.Complete.Clear() } @@ -172,6 +175,21 @@ func (l *HandlerList) SwapNamed(n NamedHandler) (swapped bool) { return swapped } +// Swap will swap out all handlers matching the name passed in. The matched +// handlers will be swapped in. True is returned if the handlers were swapped. +func (l *HandlerList) Swap(name string, replace NamedHandler) bool { + var swapped bool + + for i := 0; i < len(l.list); i++ { + if l.list[i].Name == name { + l.list[i] = replace + swapped = true + } + } + + return swapped +} + // SetBackNamed will replace the named handler if it exists in the handler list. // If the handler does not exist the handler will be added to the end of the list. func (l *HandlerList) SetBackNamed(n NamedHandler) { diff --git a/aws/http_request_retry_test.go b/aws/http_request_retry_test.go index e4c145a2729..e337a5e1330 100644 --- a/aws/http_request_retry_test.go +++ b/aws/http_request_retry_test.go @@ -1,11 +1,7 @@ -// +build go1.5 - package aws_test import ( "context" - "errors" - "fmt" "strings" "testing" "time" @@ -20,13 +16,10 @@ func TestRequestCancelRetry(t *testing.T) { restoreSleep := mockSleep() defer restoreSleep() - c := make(chan struct{}) - reqNum := 0 cfg := unit.Config() - cfg.EndpointResolver = aws.ResolveWithEndpointURL("http://endpoint") cfg.Retryer = aws.NewDefaultRetryer(func(d *aws.DefaultRetryer) { - d.NumMaxRetries = 10 + d.NumMaxRetries = 1 }) s := mock.NewMockClient(cfg) @@ -37,15 +30,14 @@ func TestRequestCancelRetry(t *testing.T) { s.Handlers.UnmarshalError.Clear() s.Handlers.Send.PushFront(func(r *aws.Request) { reqNum++ - r.Error = errors.New("net/http: request canceled") }) out := &testData{} + ctx, cancelFn := context.WithCancel(context.Background()) r := s.NewRequest(&aws.Operation{Name: "Operation"}, nil, out) - r.HTTPRequest.Cancel = c - close(c) + r.SetContext(ctx) + cancelFn() // cancelling the context associated with the request err := r.Send() - fmt.Println("request error", err) if e, a := "canceled", err.Error(); !strings.Contains(a, e) { t.Errorf("expect %q to be in %q", e, a) } diff --git a/aws/request.go b/aws/request.go index 59475eb12d1..afd9022f4ff 100644 --- a/aws/request.go +++ b/aws/request.go @@ -5,7 +5,6 @@ import ( "context" "fmt" "io" - "net" "net/http" "net/url" "reflect" @@ -40,6 +39,7 @@ type Request struct { Handlers Handlers Retryer + AttemptTime time.Time Time time.Time ExpireTime time.Duration Operation *Operation @@ -62,6 +62,15 @@ type Request struct { built bool + // Additional API error codes that should be retried. IsErrorRetryable + // will consider these codes in addition to its built in cases. + RetryErrorCodes []string + + // Additional API error codes that should be retried with throttle backoff + // delay. IsErrorThrottle will consider these codes in addition to its + // built in cases. + ThrottleErrorCodes []string + // Need to persist an intermediate body between the input Body and HTTP // request body because the HTTP Client's transport can maintain a reference // to the HTTP request's body after the client has returned. This value is @@ -228,12 +237,17 @@ func (r *Request) SetContext(ctx context.Context) { // WillRetry returns if the request's can be retried. func (r *Request) WillRetry() bool { - if !IsReaderSeekable(r.Body) && r.HTTPRequest.Body != NoBody { + if !IsReaderSeekable(r.Body) && r.HTTPRequest.Body != http.NoBody { return false } return r.Error != nil && BoolValue(r.Retryable) && r.RetryCount < r.MaxRetries() } +// fmtAttemptCount returns a formatted string with attempt count +func fmtAttemptCount(retryCount, maxRetries int) string { + return fmt.Sprintf("attempt %v/%v", retryCount, maxRetries) +} + // ParamsFilled returns if the request's parameters have been populated // and the parameters are valid. False is returned if no parameters are // provided or invalid. @@ -308,16 +322,15 @@ func (r *Request) PresignRequest(expireTime time.Duration) (string, http.Header, return r.HTTPRequest.URL.String(), r.SignedHeaderVals, nil } -func debugLogReqError(r *Request, stage string, retrying bool, err error) { +const ( + notRetrying = "not retrying" +) + +func debugLogReqError(r *Request, stage string, retryStr string, err error) { if !r.Config.LogLevel.Matches(LogDebugWithRequestErrors) { return } - retryStr := "not retrying" - if retrying { - retryStr = "will retry" - } - r.Config.Logger.Log(fmt.Sprintf("DEBUG: %s %s/%s failed, %s, error %v", stage, r.Metadata.ServiceName, r.Operation.Name, retryStr, err)) } @@ -336,12 +349,12 @@ func (r *Request) Build() error { if !r.built { r.Handlers.Validate.Run(r) if r.Error != nil { - debugLogReqError(r, "Validate Request", false, r.Error) + debugLogReqError(r, "Validate Request", notRetrying, r.Error) return r.Error } r.Handlers.Build.Run(r) if r.Error != nil { - debugLogReqError(r, "Build Request", false, r.Error) + debugLogReqError(r, "Build Request", notRetrying, r.Error) return r.Error } r.built = true @@ -357,7 +370,7 @@ func (r *Request) Build() error { func (r *Request) Sign() error { r.Build() if r.Error != nil { - debugLogReqError(r, "Build Request", false, r.Error) + debugLogReqError(r, "Build Request", notRetrying, r.Error) return r.Error } @@ -396,7 +409,7 @@ func (r *Request) getNextRequestBody() (body io.ReadCloser, err error) { } if l == 0 { - body = NoBody + body = http.NoBody } else if l > 0 { body = r.safeBody } else { @@ -411,7 +424,7 @@ func (r *Request) getNextRequestBody() (body io.ReadCloser, err error) { // implement Len() method. switch r.Operation.HTTPMethod { case "GET", "HEAD", "DELETE": - body = NoBody + body = http.NoBody default: body = r.safeBody } @@ -446,79 +459,90 @@ func (r *Request) Send() error { r.Handlers.Complete.Run(r) }() + if err := r.Error; err != nil { + return err + } + for { - if BoolValue(r.Retryable) { - if r.Config.LogLevel.Matches(LogDebugWithRequestRetries) { - r.Config.Logger.Log(fmt.Sprintf("DEBUG: Retrying Request %s/%s, attempt %d", - r.Metadata.ServiceName, r.Operation.Name, r.RetryCount)) - } - - // The previous http.Request will have a reference to the r.Body - // and the HTTP Client's Transport may still be reading from - // the request's body even though the Client's Do returned. - r.HTTPRequest = copyHTTPRequest(r.HTTPRequest, nil) - r.ResetBody() - - // Closing response body to ensure that no response body is leaked - // between retry attempts. - if r.HTTPResponse != nil && r.HTTPResponse.Body != nil { - r.HTTPResponse.Body.Close() - } - } + r.Error = nil + r.AttemptTime = time.Now() - r.Sign() - if r.Error != nil { - return r.Error + if err := r.Sign(); err != nil { + debugLogReqError(r, "Sign Request", notRetrying, err) + return err } - r.Retryable = nil - - r.Handlers.Send.Run(r) - if r.Error != nil { - if !shouldRetryCancel(r) { - return r.Error - } - - err := r.Error - r.Handlers.Retry.Run(r) - r.Handlers.AfterRetry.Run(r) - if r.Error != nil { - debugLogReqError(r, "Send Request", false, err) - return r.Error - } - debugLogReqError(r, "Send Request", true, err) - continue + if err := r.sendRequest(); err == nil { + return nil } - r.Handlers.UnmarshalMeta.Run(r) - r.Handlers.ValidateResponse.Run(r) - if r.Error != nil { - r.Handlers.UnmarshalError.Run(r) - err := r.Error - - r.Handlers.Retry.Run(r) - r.Handlers.AfterRetry.Run(r) - if r.Error != nil { - debugLogReqError(r, "Validate Response", false, err) - return r.Error - } - debugLogReqError(r, "Validate Response", true, err) - continue + r.Handlers.Retry.Run(r) + r.Handlers.AfterRetry.Run(r) + + if r.Error != nil || !BoolValue(r.Retryable) { + return r.Error } - r.Handlers.Unmarshal.Run(r) - if r.Error != nil { - err := r.Error - r.Handlers.Retry.Run(r) - r.Handlers.AfterRetry.Run(r) - if r.Error != nil { - debugLogReqError(r, "Unmarshal Response", false, err) - return r.Error - } - debugLogReqError(r, "Unmarshal Response", true, err) - continue + if err := r.prepareRetry(); err != nil { + r.Error = err + return err } + } +} - break +func (r *Request) prepareRetry() error { + if r.Config.LogLevel.Matches(LogDebugWithRequestRetries) { + r.Config.Logger.Log(fmt.Sprintf("DEBUG: Retrying Request %s/%s, attempt %d", + r.Metadata.ServiceName, r.Operation.Name, r.RetryCount)) + } + + // The previous http.Request will have a reference to the r.Body + // and the HTTP Client's Transport may still be reading from + // the request's body even though the Client's Do returned. + r.HTTPRequest = copyHTTPRequest(r.HTTPRequest, nil) + r.ResetBody() + if err := r.Error; err != nil { + return awserr.New(ErrCodeSerialization, + "failed to prepare body for retry", err) + + } + + // Closing response body to ensure that no response body is leaked + // between retry attempts. + if r.HTTPResponse != nil && r.HTTPResponse.Body != nil { + r.HTTPResponse.Body.Close() + } + + return nil +} + +func (r *Request) sendRequest() (sendErr error) { + defer r.Handlers.CompleteAttempt.Run(r) + + r.Retryable = nil + r.Handlers.Send.Run(r) + if r.Error != nil { + debugLogReqError(r, "Send Request", + fmtAttemptCount(r.RetryCount, r.MaxRetries()), + r.Error) + return r.Error + } + + r.Handlers.UnmarshalMeta.Run(r) + r.Handlers.ValidateResponse.Run(r) + if r.Error != nil { + r.Handlers.UnmarshalError.Run(r) + debugLogReqError(r, "Validate Response", + fmtAttemptCount(r.RetryCount, r.MaxRetries()), + r.Error) + return r.Error + } + + r.Handlers.Unmarshal.Run(r) + if r.Error != nil { + debugLogReqError(r, "Unmarshal Response", + fmtAttemptCount(r.RetryCount, r.MaxRetries()), + r.Error) + return r.Error } return nil @@ -544,32 +568,6 @@ func AddToUserAgent(r *Request, s string) { r.HTTPRequest.Header.Set("User-Agent", s) } -func shouldRetryCancel(r *Request) bool { - awsErr, ok := r.Error.(awserr.Error) - timeoutErr := false - errStr := r.Error.Error() - if ok { - if awsErr.Code() == ErrCodeRequestCanceled { - return false - } - err := awsErr.OrigErr() - netErr, netOK := err.(net.Error) - timeoutErr = netOK && netErr.Temporary() - if urlErr, ok := err.(*url.Error); !timeoutErr && ok { - errStr = urlErr.Err.Error() - } - } - - // There can be two types of canceled errors here. - // The first being a net.Error and the other being an error. - // If the request was timed out, we want to continue the retry - // process. Otherwise, return the canceled error. - return timeoutErr || - (errStr != "net/http: request canceled" && - errStr != "net/http: request canceled while waiting for connection") - -} - // SanitizeHostForHeader removes default port from host and updates request.Host func SanitizeHostForHeader(r *http.Request) { host := getHost(r) @@ -638,3 +636,24 @@ func isDefaultPort(scheme, port string) bool { return false } + +// ResetBody rewinds the request body back to its starting position, and +// set's the HTTP Request body reference. When the body is read prior +// to being sent in the HTTP request it will need to be rewound. +// +// ResetBody will automatically be called by the SDK's build handler, but if +// the request is being used directly ResetBody must be called before the request +// is Sent. SetStringBody, SetBufferBody, and SetReaderBody will automatically +// call ResetBody. +// +func (r *Request) ResetBody() { + body, err := r.getNextRequestBody() + if err != nil { + r.Error = awserr.New(ErrCodeSerialization, + "failed to reset request body", err) + return + } + + r.HTTPRequest.Body = body + r.HTTPRequest.GetBody = r.getNextRequestBody +} diff --git a/aws/request_1_5_test.go b/aws/request_1_5_test.go deleted file mode 100644 index f3675205400..00000000000 --- a/aws/request_1_5_test.go +++ /dev/null @@ -1,11 +0,0 @@ -// +build !go1.6 - -package aws_test - -import ( - "errors" - - "github.com/aws/aws-sdk-go-v2/aws/awserr" -) - -var errTimeout = awserr.New("foo", "bar", errors.New("net/http: request canceled Timeout")) diff --git a/aws/request_1_6_test.go b/aws/request_1_6_test.go deleted file mode 100644 index 8a5c9b337db..00000000000 --- a/aws/request_1_6_test.go +++ /dev/null @@ -1,49 +0,0 @@ -// +build go1.6 - -package aws_test - -import ( - "errors" - "testing" - - "github.com/aws/aws-sdk-go-v2/aws" - "github.com/aws/aws-sdk-go-v2/aws/awserr" - "github.com/aws/aws-sdk-go-v2/internal/awstesting/unit" -) - -// go version 1.4 and 1.5 do not return an error. Version 1.5 will url encode -// the uri while 1.4 will not -func TestRequestInvalidEndpoint(t *testing.T) { - cfg := unit.Config() - cfg.EndpointResolver = aws.ResolveWithEndpointURL("http://localhost:90 ") - - r := aws.New( - cfg, - aws.Metadata{}, - cfg.Handlers, - aws.NewDefaultRetryer(), - &aws.Operation{}, - nil, - nil, - ) - - if r.Error == nil { - t.Errorf("expect error, got none") - } -} - -type timeoutErr struct { - error -} - -var errTimeout = awserr.New("foo", "bar", &timeoutErr{ - errors.New("net/http: request canceled"), -}) - -func (e *timeoutErr) Timeout() bool { - return true -} - -func (e *timeoutErr) Temporary() bool { - return true -} diff --git a/aws/request_1_7.go b/aws/request_1_7.go deleted file mode 100644 index 6db88c8648f..00000000000 --- a/aws/request_1_7.go +++ /dev/null @@ -1,39 +0,0 @@ -// +build !go1.8 - -package aws - -import "io" - -// NoBody is an io.ReadCloser with no bytes. Read always returns EOF -// and Close always returns nil. It can be used in an outgoing client -// request to explicitly signal that a request has zero bytes. -// An alternative, however, is to simply set Request.Body to nil. -// -// Copy of Go 1.8 NoBody type from net/http/http.go -type noBody struct{} - -func (noBody) Read([]byte) (int, error) { return 0, io.EOF } -func (noBody) Close() error { return nil } -func (noBody) WriteTo(io.Writer) (int64, error) { return 0, nil } - -// NoBody is an empty reader that will trigger the Go HTTP client to not include -// and body in the HTTP request. -var NoBody = noBody{} - -// ResetBody rewinds the request body back to its starting position, and -// set's the HTTP Request body reference. When the body is read prior -// to being sent in the HTTP request it will need to be rewound. -// -// ResetBody will automatically be called by the SDK's build handler, but if -// the request is being used directly ResetBody must be called before the request -// is Sent. SetStringBody, SetBufferBody, and SetReaderBody will automatically -// call ResetBody. -func (r *Request) ResetBody() { - body, err := r.getNextRequestBody() - if err != nil { - r.Error = err - return - } - - r.HTTPRequest.Body = body -} diff --git a/aws/request_1_7_test.go b/aws/request_1_7_test.go deleted file mode 100644 index 1cdc0b34420..00000000000 --- a/aws/request_1_7_test.go +++ /dev/null @@ -1,24 +0,0 @@ -// +build !go1.8 - -package aws - -import ( - "net/http" - "strings" - "testing" -) - -func TestResetBody_WithEmptyBody(t *testing.T) { - r := Request{ - HTTPRequest: &http.Request{}, - } - - reader := strings.NewReader("") - r.Body = reader - - r.ResetBody() - - if a, e := r.HTTPRequest.Body, (noBody{}); a != e { - t.Errorf("expected request body to be set to reader, got %#v", r.HTTPRequest.Body) - } -} diff --git a/aws/request_1_8.go b/aws/request_1_8.go deleted file mode 100644 index 71ea240d549..00000000000 --- a/aws/request_1_8.go +++ /dev/null @@ -1,36 +0,0 @@ -// +build go1.8 - -package aws - -import ( - "net/http" - - "github.com/aws/aws-sdk-go-v2/aws/awserr" -) - -// NoBody is a http.NoBody reader instructing Go HTTP client to not include -// and body in the HTTP request. -var NoBody = http.NoBody - -// ResetBody rewinds the request body back to its starting position, and -// set's the HTTP Request body reference. When the body is read prior -// to being sent in the HTTP request it will need to be rewound. -// -// ResetBody will automatically be called by the SDK's build handler, but if -// the request is being used directly ResetBody must be called before the request -// is Sent. SetStringBody, SetBufferBody, and SetReaderBody will automatically -// call ResetBody. -// -// Will also set the Go 1.8's http.Request.GetBody member to allow retrying -// PUT/POST redirects. -func (r *Request) ResetBody() { - body, err := r.getNextRequestBody() - if err != nil { - r.Error = awserr.New(ErrCodeSerialization, - "failed to reset request body", err) - return - } - - r.HTTPRequest.Body = body - r.HTTPRequest.GetBody = r.getNextRequestBody -} diff --git a/aws/request_1_8_test.go b/aws/request_1_8_test.go deleted file mode 100644 index eaf2ded83fe..00000000000 --- a/aws/request_1_8_test.go +++ /dev/null @@ -1,84 +0,0 @@ -// +build go1.8 - -package aws_test - -import ( - "bytes" - "io" - "net/http" - "net/http/httptest" - "strings" - "testing" - - "github.com/aws/aws-sdk-go-v2/aws" - "github.com/aws/aws-sdk-go-v2/internal/awstesting" - "github.com/aws/aws-sdk-go-v2/internal/awstesting/unit" -) - -func TestResetBody_WithEmptyBody(t *testing.T) { - r := aws.Request{ - HTTPRequest: &http.Request{}, - } - - reader := strings.NewReader("") - r.Body = reader - - r.ResetBody() - - if a, e := r.HTTPRequest.Body, http.NoBody; a != e { - t.Errorf("expected request body to be set to reader, got %#v", - r.HTTPRequest.Body) - } -} - -func TestRequest_FollowPUTRedirects(t *testing.T) { - const bodySize = 1024 - - redirectHit := 0 - endpointHit := 0 - - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - switch r.URL.Path { - case "/redirect-me": - u := *r.URL - u.Path = "/endpoint" - w.Header().Set("Location", u.String()) - w.WriteHeader(307) - redirectHit++ - case "/endpoint": - b := bytes.Buffer{} - io.Copy(&b, r.Body) - r.Body.Close() - if e, a := bodySize, b.Len(); e != a { - t.Fatalf("expect %d body size, got %d", e, a) - } - endpointHit++ - default: - t.Fatalf("unexpected endpoint used, %q", r.URL.String()) - } - })) - defer server.Close() - - cfg := unit.Config() - cfg.EndpointResolver = aws.ResolveWithEndpointURL(server.URL) - - svc := awstesting.NewClient(cfg) - - req := svc.NewRequest(&aws.Operation{ - Name: "Operation", - HTTPMethod: "PUT", - HTTPPath: "/redirect-me", - }, &struct{}{}, &struct{}{}) - req.SetReaderBody(bytes.NewReader(make([]byte, bodySize))) - - err := req.Send() - if err != nil { - t.Errorf("expect no error, got %v", err) - } - if e, a := 1, redirectHit; e != a { - t.Errorf("expect %d redirect hits, got %d", e, a) - } - if e, a := 1, endpointHit; e != a { - t.Errorf("expect %d endpoint hits, got %d", e, a) - } -} diff --git a/aws/request_resetbody_test.go b/aws/request_resetbody_test.go index 164494a6345..cd219ecfd27 100644 --- a/aws/request_resetbody_test.go +++ b/aws/request_resetbody_test.go @@ -98,7 +98,7 @@ func TestResetBody_ExcludeEmptyUnseekableBodyByMethod(t *testing.T) { r.SetReaderBody(c.Body) - if a, e := r.HTTPRequest.Body == NoBody, c.IsNoBody; a != e { + if a, e := r.HTTPRequest.Body == http.NoBody, c.IsNoBody; a != e { t.Errorf("%d, expect body to be set to noBody(%t), but was %t", i, e, a) } } diff --git a/aws/request_retry_test.go b/aws/request_retry_test.go new file mode 100644 index 00000000000..ffeb6a52ba8 --- /dev/null +++ b/aws/request_retry_test.go @@ -0,0 +1,204 @@ +package aws + +import ( + "fmt" + "net" + "net/http" + "net/http/httptest" + "net/url" + "os" + "testing" + "time" + + "github.com/aws/aws-sdk-go-v2/aws/awserr" +) + +func newRequest(t *testing.T, url string) *http.Request { + r, err := http.NewRequest("GET", url, nil) + if err != nil { + t.Fatalf("can't forge request: %v", err) + } + return r +} + +func TestShouldRetryError_nil(t *testing.T) { + if shouldRetryError(nil) != true { + t.Error("shouldRetryError(nil) should return true") + } +} + +func TestShouldRetryError_timeout(t *testing.T) { + + tr := &http.Transport{} + defer tr.CloseIdleConnections() + client := http.Client{ + Timeout: time.Nanosecond, + Transport: tr, + } + + resp, err := client.Do(newRequest(t, "https://179.179.179.179/no/such/host")) + if resp != nil { + resp.Body.Close() + } + if err == nil { + t.Fatal("This should have failed.") + } + debugerr(t, err) + + if shouldRetryError(err) == false { + t.Errorf("this request timed out and should be retried") + } +} + +func TestShouldRetryError_cancelled(t *testing.T) { + tr := &http.Transport{} + defer tr.CloseIdleConnections() + client := http.Client{ + Transport: tr, + } + + cancelWait := make(chan bool) + srvrWait := make(chan bool) + srvr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + close(cancelWait) // Trigger the request cancel. + time.Sleep(100 * time.Millisecond) + + fmt.Fprintf(w, "Hello") + w.(http.Flusher).Flush() // send headers and some body + <-srvrWait // block forever + })) + defer srvr.Close() + defer close(srvrWait) + + r := newRequest(t, srvr.URL) + ch := make(chan struct{}) + r.Cancel = ch + + // Ensure the request has started, and client has started to receive bytes. + // This ensures the test is stable and does not run into timing with the + // request being canceled, before or after the http request is made. + go func() { + <-cancelWait + close(ch) // request is cancelled before anything + }() + + resp, err := client.Do(r) + if resp != nil { + resp.Body.Close() + } + if err == nil { + t.Fatal("This should have failed.") + } + + debugerr(t, err) + + if shouldRetryError(err) == true { + t.Errorf("this request was cancelled and should not be retried") + } +} + +func TestShouldRetry(t *testing.T) { + + syscallError := os.SyscallError{ + Err: ErrInvalidParams{}, + Syscall: "open", + } + + opError := net.OpError{ + Op: "dial", + Net: "tcp", + Source: net.Addr(nil), + Err: &syscallError, + } + + urlError := url.Error{ + Op: "Post", + URL: "https://localhost:52398", + Err: &opError, + } + origError := awserr.New("ErrorTestShouldRetry", "Test should retry when error received", &urlError).OrigErr() + if e, a := true, shouldRetryError(origError); e != a { + t.Errorf("Expected to return %v to retry when error occured, got %v instead", e, a) + } + +} + +func debugerr(t *testing.T, err error) { + t.Logf("Error, %v", err) + + switch err := err.(type) { + case temporary: + t.Logf("%s is a temporary error: %t", err, err.Temporary()) + return + case *url.Error: + t.Logf("err: %s, nested err: %#v", err, err.Err) + if operr, ok := err.Err.(*net.OpError); ok { + t.Logf("operr: %#v", operr) + } + debugerr(t, err.Err) + return + default: + return + } +} + +func TestRequest_retryCustomCodes(t *testing.T) { + cases := map[string]struct { + Code string + RetryErrorCodes []string + ThrottleErrorCodes []string + Retryable bool + Throttle bool + }{ + "retry code": { + Code: "RetryMePlease", + RetryErrorCodes: []string{ + "RetryMePlease", + "SomeOtherError", + }, + Retryable: true, + }, + "throttle code": { + Code: "AThrottleableError", + RetryErrorCodes: []string{ + "RetryMePlease", + "SomeOtherError", + }, + ThrottleErrorCodes: []string{ + "AThrottleableError", + "SomeOtherError", + }, + Throttle: true, + }, + "unknown code": { + Code: "UnknownCode", + RetryErrorCodes: []string{ + "RetryMePlease", + "SomeOtherError", + }, + Retryable: false, + }, + } + + for name, c := range cases { + t.Run(name, func(t *testing.T) { + req := Request{ + HTTPRequest: &http.Request{}, + HTTPResponse: &http.Response{}, + Error: awserr.New(c.Code, "some error", nil), + RetryErrorCodes: c.RetryErrorCodes, + ThrottleErrorCodes: c.ThrottleErrorCodes, + } + + retryable := req.IsErrorRetryable() + if e, a := c.Retryable, retryable; e != a { + t.Errorf("%s, expect %v retryable, got %v", name, e, a) + } + + throttle := req.IsErrorThrottle() + if e, a := c.Throttle, throttle; e != a { + t.Errorf("%s, expect %v throttle, got %v", name, e, a) + } + }) + } +} diff --git a/aws/request_test.go b/aws/request_test.go index 180c375fb56..abb6d63e1d2 100644 --- a/aws/request_test.go +++ b/aws/request_test.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "encoding/json" + "errors" "fmt" "io" "io/ioutil" @@ -82,7 +83,7 @@ func TestRequestRecoverRetry5xx(t *testing.T) { reqNum := 0 reqs := []http.Response{ {StatusCode: 500, Body: body(`{"__type":"UnknownError","message":"An error occurred."}`)}, - {StatusCode: 501, Body: body(`{"__type":"UnknownError","message":"An error occurred."}`)}, + {StatusCode: 502, Body: body(`{"__type":"UnknownError","message":"An error occurred."}`)}, {StatusCode: 200, Body: body(`{"data":"valid"}`)}, } @@ -159,7 +160,7 @@ func TestRequestRecoverRetry4xxRetryable(t *testing.T) { func TestRequest4xxUnretryable(t *testing.T) { cfg := unit.Config() cfg.Retryer = aws.NewDefaultRetryer(func(d *aws.DefaultRetryer) { - d.NumMaxRetries = 10 + d.NumMaxRetries = 1 }) s := awstesting.NewClient(cfg) @@ -169,7 +170,10 @@ func TestRequest4xxUnretryable(t *testing.T) { s.Handlers.UnmarshalError.PushBack(unmarshalError) s.Handlers.Send.Clear() // mock sending s.Handlers.Send.PushBack(func(r *aws.Request) { - r.HTTPResponse = &http.Response{StatusCode: 401, Body: body(`{"__type":"SignatureDoesNotMatch","message":"Signature does not match."}`)} + r.HTTPResponse = &http.Response{ + StatusCode: 401, + Body: body(`{"__type":"SignatureDoesNotMatch","message":"Signature does not match."}`), + } }) out := &testData{} r := s.NewRequest(&aws.Operation{Name: "Operation"}, nil, out) @@ -644,7 +648,7 @@ func TestIsSerializationErrorRetryable(t *testing.T) { Error: c.err, } if r.IsErrorRetryable() != c.expected { - t.Errorf("Case %d: expected %v, but received %v", i+1, c.expected, !c.expected) + t.Errorf("Case %d: expected %v, but received %v", i, c.expected, !c.expected) } } } @@ -882,11 +886,11 @@ func TestIsNoBodyReader(t *testing.T) { {ioutil.NopCloser(bytes.NewReader([]byte("abc"))), false}, {ioutil.NopCloser(bytes.NewReader(nil)), false}, {nil, false}, - {aws.NoBody, true}, + {http.NoBody, true}, } for i, c := range cases { - if e, a := c.expect, aws.NoBody == c.reader; e != a { + if e, a := c.expect, http.NoBody == c.reader; e != a { t.Errorf("%d, expect %t match, but was %t", i, e, a) } } @@ -1003,6 +1007,138 @@ func TestRequestBodySeekFails(t *testing.T) { } +func Test501NotRetrying(t *testing.T) { + reqNum := 0 + reqs := []http.Response{ + {StatusCode: 500, Body: body(`{"__type":"UnknownError","message":"An error occurred."}`)}, + {StatusCode: 501, Body: body(`{"__type":"NotImplemented","message":"An error occurred."}`)}, + {StatusCode: 200, Body: body(`{"data":"valid"}`)}, + } + + cfg := unit.Config() + cfg.Retryer = aws.NewDefaultRetryer(func(d *aws.DefaultRetryer) { + d.NumMaxRetries = 10 + }) + s := awstesting.NewClient(cfg) + s.Handlers.Validate.Clear() + s.Handlers.Unmarshal.PushBack(unmarshal) + s.Handlers.UnmarshalError.PushBack(unmarshalError) + s.Handlers.Send.Clear() // mock sending + s.Handlers.Send.PushBack(func(r *aws.Request) { + r.HTTPResponse = &reqs[reqNum] + reqNum++ + }) + out := &testData{} + r := s.NewRequest(&aws.Operation{Name: "Operation"}, nil, out) + err := r.Send() + if err == nil { + t.Fatal("expect error, but got none") + } + + aerr := err.(awserr.Error) + if e, a := "NotImplemented", aerr.Code(); e != a { + t.Errorf("expected error code %q, but received %q", e, a) + } + if e, a := 1, int(r.RetryCount); e != a { + t.Errorf("expect %d retry count, got %d", e, a) + } +} + +func TestRequestInvalidEndpoint(t *testing.T) { + cfg := unit.Config() + cfg.EndpointResolver = aws.ResolveWithEndpointURL("http://localhost:90 ") + + r := aws.New( + cfg, + aws.Metadata{}, + cfg.Handlers, + aws.NewDefaultRetryer(), + &aws.Operation{}, + nil, + nil, + ) + + if r.Error == nil { + t.Errorf("expect error, got none") + } +} + +func TestResetBody_WithEmptyBody(t *testing.T) { + r := aws.Request{ + HTTPRequest: &http.Request{}, + } + + reader := strings.NewReader("") + r.Body = reader + + r.ResetBody() + + if a, e := r.HTTPRequest.Body, http.NoBody; a != e { + t.Errorf("expected request body to be set to reader, got %#v", + r.HTTPRequest.Body) + } +} + +func TestRequest_FollowPUTRedirects(t *testing.T) { + const bodySize = 1024 + + redirectHit := 0 + endpointHit := 0 + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/redirect-me": + u := *r.URL + u.Path = "/endpoint" + w.Header().Set("Location", u.String()) + w.WriteHeader(307) + redirectHit++ + case "/endpoint": + b := bytes.Buffer{} + io.Copy(&b, r.Body) + r.Body.Close() + if e, a := bodySize, b.Len(); e != a { + t.Fatalf("expect %d body size, got %d", e, a) + } + endpointHit++ + default: + t.Fatalf("unexpected endpoint used, %q", r.URL.String()) + } + })) + defer server.Close() + + cfg := unit.Config() + cfg.EndpointResolver = aws.ResolveWithEndpointURL(server.URL) + + svc := awstesting.NewClient(cfg) + + req := svc.NewRequest(&aws.Operation{ + Name: "Operation", + HTTPMethod: "PUT", + HTTPPath: "/redirect-me", + }, &struct{}{}, &struct{}{}) + req.SetReaderBody(bytes.NewReader(make([]byte, bodySize))) + + err := req.Send() + if err != nil { + t.Errorf("expect no error, got %v", err) + } + if e, a := 1, redirectHit; e != a { + t.Errorf("expect %d redirect hits, got %d", e, a) + } + if e, a := 1, endpointHit; e != a { + t.Errorf("expect %d endpoint hits, got %d", e, a) + } +} + +type timeoutErr struct { + error +} + +var errTimeout = awserr.New("foo", "bar", &timeoutErr{ + errors.New("net/http: request canceled"), +}) + type stubSeekFail struct { Err error } diff --git a/aws/retryer.go b/aws/retryer.go index 4add53b6f1a..b4d329f2dba 100644 --- a/aws/retryer.go +++ b/aws/retryer.go @@ -1,6 +1,9 @@ package aws import ( + "net" + "net/url" + "strings" "time" "github.com/aws/aws-sdk-go-v2/aws/awserr" @@ -10,16 +13,21 @@ import ( // The default implementation used by most services is the client.DefaultRetryer // structure, which contains basic retry logic using exponential backoff. type Retryer interface { + + // RetryRules return the retry delay that should be used by the SDK before + // making another request attempt for the failed request. RetryRules(*Request) time.Duration + + // ShouldRetry returns if the failed request is retryable. + // + // Implementations may consider request attempt count when determining if a + // request is retryable, but the SDK will use MaxRetries to limit the + // number of attempts a request are made. ShouldRetry(*Request) bool - MaxRetries() int -} -// WithRetryer sets a config Retryer value to the given Config returning it -// for chaining. -func WithRetryer(cfg *Config, retryer Retryer) *Config { - cfg.Retryer = retryer - return cfg + // MaxRetries is the number of times a request may be retried before + // failing. + MaxRetries() int } // retryableCodes is a collection of service response codes which are retry-able @@ -74,10 +82,6 @@ var validParentCodes = map[string]struct{}{ ErrCodeRead: {}, } -type temporaryError interface { - Temporary() bool -} - func isNestedErrorRetryable(parentErr awserr.Error) bool { if parentErr == nil { return false @@ -96,7 +100,7 @@ func isNestedErrorRetryable(parentErr awserr.Error) bool { return isCodeRetryable(aerr.Code()) } - if t, ok := err.(temporaryError); ok { + if t, ok := err.(temporary); ok { return t.Temporary() || isErrConnectionReset(err) } @@ -106,32 +110,90 @@ func isNestedErrorRetryable(parentErr awserr.Error) bool { // IsErrorRetryable returns whether the error is retryable, based on its Code. // Returns false if error is nil. func IsErrorRetryable(err error) bool { - if err != nil { - if aerr, ok := err.(awserr.Error); ok { - return isCodeRetryable(aerr.Code()) || isNestedErrorRetryable(aerr) + if err == nil { + return false + } + return shouldRetryError(err) +} + +type temporary interface { + Temporary() bool +} + +func shouldRetryError(origErr error) bool { + switch err := origErr.(type) { + case awserr.Error: + if err.Code() == ErrCodeRequestCanceled { + return false + } + if isNestedErrorRetryable(err) { + return true + } + + origErr := err.OrigErr() + var shouldRetry bool + if origErr != nil { + shouldRetry := shouldRetryError(origErr) + if err.Code() == "RequestError" && !shouldRetry { + return false + } + } + if isCodeRetryable(err.Code()) { + return true + } + return shouldRetry + + case *url.Error: + if strings.Contains(err.Error(), "connection refused") { + // Refused connections should be retried as the service may not yet + // be running on the port. Go TCP dial considers refused + // connections as not temporary. + return true + } + // *url.Error only implements Temporary after golang 1.6 but since + // url.Error only wraps the error: + return shouldRetryError(err.Err) + + case temporary: + if netErr, ok := err.(*net.OpError); ok && netErr.Op == "dial" { + return true + } + // If the error is temporary, we want to allow continuation of the + // retry process + return err.Temporary() || isErrConnectionReset(origErr) + + case nil: + // `awserr.Error.OrigErr()` can be nil, meaning there was an error but + // because we don't know the cause, it is marked as retryable. See + // TestRequest4xxUnretryable for an example. + return true + + default: + switch err.Error() { + case "net/http: request canceled", + "net/http: request canceled while waiting for connection": + // known 1.5 error case when an http request is cancelled + return false } + // here we don't know the error; so we allow a retry. + return true } - return false } // IsErrorThrottle returns whether the error is to be throttled based on its code. // Returns false if error is nil. func IsErrorThrottle(err error) bool { - if err != nil { - if aerr, ok := err.(awserr.Error); ok { - return isCodeThrottle(aerr.Code()) - } + if aerr, ok := err.(awserr.Error); ok && aerr != nil { + return isCodeThrottle(aerr.Code()) } return false } -// IsErrorExpiredCreds returns whether the error code is a credential expiry error. -// Returns false if error is nil. +// IsErrorExpiredCreds returns whether the error code is a credential expiry +// error. Returns false if error is nil. func IsErrorExpiredCreds(err error) bool { - if err != nil { - if aerr, ok := err.(awserr.Error); ok { - return isCodeExpiredCreds(aerr.Code()) - } + if aerr, ok := err.(awserr.Error); ok && aerr != nil { + return isCodeExpiredCreds(aerr.Code()) } return false } @@ -141,6 +203,22 @@ func IsErrorExpiredCreds(err error) bool { // // Alias for the utility function IsErrorRetryable func (r *Request) IsErrorRetryable() bool { + if isErrCode(r.Error, r.RetryErrorCodes) { + return true + } + + // HTTP response status code 501 should not be retried. + // 501 represents Not Implemented which means the request method is not + // supported by the server and cannot be handled. + if r.HTTPResponse != nil { + // HTTP response status code 500 represents internal server error and + // should be retried without any throttle. + if r.HTTPResponse.StatusCode == 500 { + return true + } + + } + return IsErrorRetryable(r.Error) } @@ -149,9 +227,36 @@ func (r *Request) IsErrorRetryable() bool { // // Alias for the utility function IsErrorThrottle func (r *Request) IsErrorThrottle() bool { + if isErrCode(r.Error, r.ThrottleErrorCodes) { + return true + } + + if r.HTTPResponse != nil { + switch r.HTTPResponse.StatusCode { + case + 429, // error caused due to too many requests, thus retry should be throttled + 502, // Bad Gateway error should be throttled + 503, // caused when service is unavailable, thus retry should be throttled + 504: // error occurred due to gateway timeout, thus retry should be throttled + return true + } + } + return IsErrorThrottle(r.Error) } +func isErrCode(err error, codes []string) bool { + if aerr, ok := err.(awserr.Error); ok && aerr != nil { + for _, code := range codes { + if code == aerr.Code() { + return true + } + } + } + + return false +} + // IsErrorExpired returns whether the error code is a credential expiry error. // Returns false if the request has no Error set. // diff --git a/aws/retryer_test.go b/aws/retryer_test.go index 83dda31502d..7b69012c7b9 100644 --- a/aws/retryer_test.go +++ b/aws/retryer_test.go @@ -28,34 +28,34 @@ func (e mockTempError) Temporary() bool { func TestIsErrorRetryable(t *testing.T) { cases := []struct { - Err error - IsTemp bool + Err error + Retryable bool }{ { - Err: awserr.New(ErrCodeSerialization, "temporary error", mockTempError(true)), - IsTemp: true, + Err: awserr.New(ErrCodeSerialization, "temporary error", mockTempError(true)), + Retryable: true, }, { - Err: awserr.New(ErrCodeSerialization, "temporary error", mockTempError(false)), - IsTemp: false, + Err: awserr.New(ErrCodeSerialization, "temporary error", mockTempError(false)), + Retryable: false, }, { - Err: awserr.New(ErrCodeSerialization, "some error", errors.New("blah")), - IsTemp: false, + Err: awserr.New(ErrCodeSerialization, "some error", errors.New("blah")), + Retryable: false, }, { - Err: awserr.New("SomeError", "some error", nil), - IsTemp: false, + Err: awserr.New("SomeError", "some error", nil), + Retryable: false, }, { - Err: awserr.New("RequestError", "some error", nil), - IsTemp: true, + Err: awserr.New("RequestError", "some error", nil), + Retryable: true, }, } for i, c := range cases { retryable := IsErrorRetryable(c.Err) - if e, a := c.IsTemp, retryable; e != a { + if e, a := c.Retryable, retryable; e != a { t.Errorf("%d, expect %t temporary error, got %t", i, e, a) } } diff --git a/service/kinesis/customizations.go b/service/kinesis/customizations.go index 9690d74f86b..c51cf746304 100644 --- a/service/kinesis/customizations.go +++ b/service/kinesis/customizations.go @@ -3,20 +3,18 @@ package kinesis import ( "time" - request "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/aws" ) var readDuration = 5 * time.Second func init() { - ops := []string{ - opGetRecords, - } - initRequest = func(c *Client, r *request.Request) { - for _, operation := range ops { - if r.Operation.Name == operation { - r.ApplyOptions(request.WithResponseReadTimeout(readDuration)) - } + initRequest = func(c *Client, r *aws.Request) { + if r.Operation.Name == opGetRecords { + r.ApplyOptions(aws.WithResponseReadTimeout(readDuration)) } + + // Service specific error codes. + r.RetryErrorCodes = append(r.RetryErrorCodes, ErrCodeLimitExceededException) } } diff --git a/service/kinesis/customizations_test.go b/service/kinesis/customizations_test.go index ae18b6f9fb2..5228f8ecdbb 100644 --- a/service/kinesis/customizations_test.go +++ b/service/kinesis/customizations_test.go @@ -1,15 +1,18 @@ package kinesis import ( + "bytes" "context" + "fmt" "io" + "io/ioutil" "net/http" "testing" "time" "github.com/aws/aws-sdk-go-v2/aws" - request "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/aws/awserr" + "github.com/aws/aws-sdk-go-v2/aws/defaults" "github.com/aws/aws-sdk-go-v2/internal/awstesting/unit" ) @@ -42,7 +45,7 @@ func TestKinesisGetRecordsCustomization(t *testing.T) { ShardIterator: aws.String("foo"), }) req.Handlers.Send.Clear() - req.Handlers.Send.PushBack(func(r *request.Request) { + req.Handlers.Send.PushBack(func(r *aws.Request) { r.HTTPResponse = &http.Response{ StatusCode: 200, Header: http.Header{ @@ -54,13 +57,13 @@ func TestKinesisGetRecordsCustomization(t *testing.T) { r.HTTPResponse.Status = http.StatusText(r.HTTPResponse.StatusCode) retryCount++ }) - req.ApplyOptions(request.WithResponseReadTimeout(time.Second)) + req.ApplyOptions(aws.WithResponseReadTimeout(time.Second)) _, err := req.Send(context.Background()) if err == nil { t.Errorf("Expected error, but received nil") } else if v, ok := err.(awserr.Error); !ok { t.Errorf("Expected awserr.Error but received %v", err) - } else if v.Code() != request.ErrCodeResponseTimeout { + } else if v.Code() != aws.ErrCodeResponseTimeout { t.Errorf("Expected 'RequestTimeout' error, but received %s instead", v.Code()) } if retryCount != 5 { @@ -75,7 +78,7 @@ func TestKinesisGetRecordsNoTimeout(t *testing.T) { ShardIterator: aws.String("foo"), }) req.Handlers.Send.Clear() - req.Handlers.Send.PushBack(func(r *request.Request) { + req.Handlers.Send.PushBack(func(r *aws.Request) { r.HTTPResponse = &http.Response{ StatusCode: 200, Header: http.Header{ @@ -86,9 +89,55 @@ func TestKinesisGetRecordsNoTimeout(t *testing.T) { } r.HTTPResponse.Status = http.StatusText(r.HTTPResponse.StatusCode) }) - req.ApplyOptions(request.WithResponseReadTimeout(time.Second)) + req.ApplyOptions(aws.WithResponseReadTimeout(time.Second)) _, err := req.Send(context.Background()) if err != nil { t.Errorf("Expected no error, but received %v", err) } } + +func TestKinesisCustomRetryErrorCodes(t *testing.T) { + + cfg := unit.Config() + cfg.LogLevel = aws.LogLevel(aws.LogDebugWithHTTPBody) + cfg.Retryer = aws.NewDefaultRetryer(func(d *aws.DefaultRetryer) { + d.NumMaxRetries = 1 + }) + + svc := New(cfg) + svc.Handlers.Validate.Clear() + + const jsonErr = `{"__type":%q, "message":"some error message"}` + var reqCount int + resps := []*http.Response{ + { + StatusCode: 400, + Header: http.Header{}, + Body: ioutil.NopCloser(bytes.NewReader( + []byte(fmt.Sprintf(jsonErr, ErrCodeLimitExceededException)), + )), + }, + { + StatusCode: 200, + Header: http.Header{}, + Body: ioutil.NopCloser(bytes.NewReader([]byte{})), + }, + } + + req := svc.GetRecordsRequest(&GetRecordsInput{}) + req.Handlers.Send.Swap(defaults.SendHandler.Name, aws.NamedHandler{ + Name: "custom send handler", + Fn: func(r *aws.Request) { + r.HTTPResponse = resps[reqCount] + reqCount++ + }, + }) + + if _, err := req.Send(context.Background()); err != nil { + t.Fatalf("expect no error, got %v", err) + } + + if e, a := 2, reqCount; e != a { + t.Errorf("expect %v requests, got %v", e, a) + } +} diff --git a/service/sts/customizations.go b/service/sts/customizations.go new file mode 100644 index 00000000000..06f47229df6 --- /dev/null +++ b/service/sts/customizations.go @@ -0,0 +1,9 @@ +package sts + +import "github.com/aws/aws-sdk-go-v2/aws" + +func init() { + initRequest = func(c *Client, r *aws.Request) { + r.RetryErrorCodes = append(r.RetryErrorCodes, ErrCodeIDPCommunicationErrorException) + } +} diff --git a/service/sts/customizations_test.go b/service/sts/customizations_test.go index 0a11bc16198..6e57dfcc2ea 100644 --- a/service/sts/customizations_test.go +++ b/service/sts/customizations_test.go @@ -1,9 +1,15 @@ package sts_test import ( + "bytes" + "context" + "fmt" + "io/ioutil" + "net/http" "testing" "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/aws/defaults" "github.com/aws/aws-sdk-go-v2/aws/endpoints" "github.com/aws/aws-sdk-go-v2/internal/awstesting/unit" "github.com/aws/aws-sdk-go-v2/service/sts" @@ -49,3 +55,47 @@ func TestUnsignedRequest_AssumeRoleWithWebIdentity(t *testing.T) { t.Errorf("expect %v, got %v", e, a) } } + +func TestSTSCustomRetryErrorCodes(t *testing.T) { + cfg := unit.Config() + cfg.Retryer = aws.NewDefaultRetryer(func(d *aws.DefaultRetryer) { + d.NumMaxRetries = 1 + }) + + svc := sts.New(cfg) + svc.Handlers.Validate.Clear() + + const xmlErr = `%ssome error message` + var reqCount int + resps := []*http.Response{ + { + StatusCode: 400, + Header: http.Header{}, + Body: ioutil.NopCloser(bytes.NewReader( + []byte(fmt.Sprintf(xmlErr, sts.ErrCodeIDPCommunicationErrorException)), + )), + }, + { + StatusCode: 200, + Header: http.Header{}, + Body: ioutil.NopCloser(bytes.NewReader([]byte{})), + }, + } + + req := svc.AssumeRoleWithWebIdentityRequest(&sts.AssumeRoleWithWebIdentityInput{}) + req.Handlers.Send.Swap(defaults.SendHandler.Name, aws.NamedHandler{ + Name: "custom send handler", + Fn: func(r *aws.Request) { + r.HTTPResponse = resps[reqCount] + reqCount++ + }, + }) + + if _, err := req.Send(context.Background()); err != nil { + t.Fatalf("expect no error, got %v", err) + } + + if e, a := 2, reqCount; e != a { + t.Errorf("expect %v requests, got %v", e, a) + } +}