diff --git a/CHANGELOG_PENDING.md b/CHANGELOG_PENDING.md index 9ddc9589b03..cc08e0dc0b3 100644 --- a/CHANGELOG_PENDING.md +++ b/CHANGELOG_PENDING.md @@ -10,11 +10,16 @@ * `aws/ec2metadata`: Add marketplaceProductCodes to EC2 Instance Identity Document ([#374](https://github.com/aws/aws-sdk-go-v2/pull/374)) * Adds `MarketplaceProductCodes` to the EC2 Instance Metadata's Identity Document. The ec2metadata client will now retrieve these values if they are available. * Related to: [aws/aws-sdk-go#2781](https://github.com/aws/aws-sdk-go/issues/2781) - +* `aws`: Adds configurations to the default retryer ([#375](https://github.com/aws/aws-sdk-go-v2/pull/375)) + * Provides more customization options for retryer by adding a constructor for default Retryer which accepts functional options. Adds NoOpRetryer to support no retry behavior. Exposes members of default retryer. + * Updates the underlying logic used by the default retryer to calculate jittered delay for retry. Handles int overflow for retry delay. + * Fixes [#370](https://github.com/aws/aws-sdk-go-v2/issues/370) + ### SDK Bugs * `aws`: Fixes bug in calculating throttled retry delay ([#373](https://github.com/aws/aws-sdk-go-v2/pull/373)) * The `Retry-After` duration specified in the request is now added to the Retry delay for throttled exception. Adds test for retry delays for throttled exceptions. Fixes bug where the throttled retry's math was off. * Fixes [#45](https://github.com/aws/aws-sdk-go-v2/issues/45) -* `aws` : Adds missing sdk error checking when seeking readers [#379](https://github.com/aws/aws-sdk-go-v2/pull/379). +* `aws` : Adds missing sdk error checking when seeking readers ([#379](https://github.com/aws/aws-sdk-go-v2/pull/379)) * Adds support for nonseekable io.Reader. Adds support for streamed payloads for unsigned body request. * Fixes [#371](https://github.com/aws/aws-sdk-go-v2/issues/371) + diff --git a/aws/client.go b/aws/client.go index 8e90d2e96d6..ffbfc3df357 100644 --- a/aws/client.go +++ b/aws/client.go @@ -63,13 +63,10 @@ func NewClient(cfg Config, metadata Metadata) *Client { retryer := cfg.Retryer if retryer == nil { - // TODO need better way of specifing default num retries - retryer = DefaultRetryer{NumMaxRetries: 3} + retryer = NewDefaultRetryer() } svc.Retryer = retryer - svc.AddDebugHandlers() - return svc } diff --git a/aws/default_retryer.go b/aws/default_retryer.go index bcf6fcce7c2..ae952a4d928 100644 --- a/aws/default_retryer.go +++ b/aws/default_retryer.go @@ -1,6 +1,7 @@ package aws import ( + "math" "math/rand" "strconv" "sync" @@ -8,21 +9,33 @@ import ( ) // DefaultRetryer implements basic retry logic using exponential backoff for -// most services. If you want to implement custom retry logic, implement the -// Retryer interface or create a structure type that composes this -// struct and override the specific methods. For example, to override only -// the MaxRetries method: -// -// type retryer struct { -// client.DefaultRetryer -// } -// -// // This implementation always has 100 max retries -// func (d retryer) MaxRetries() int { return 100 } +// most services. You can implement your own custom retryer by implementing +// retryer interface. type DefaultRetryer struct { - NumMaxRetries int + NumMaxRetries int + MinRetryDelay time.Duration + MinThrottleDelay time.Duration + MaxRetryDelay time.Duration + MaxThrottleDelay time.Duration } +const ( + // DefaultRetryerMaxNumRetries sets maximum number of retries + DefaultRetryerMaxNumRetries = 3 + + // DefaultRetryerMinRetryDelay sets minimum retry delay + DefaultRetryerMinRetryDelay = 30 * time.Millisecond + + // DefaultRetryerMinThrottleDelay sets minimum delay when throttled + DefaultRetryerMinThrottleDelay = 500 * time.Millisecond + + // DefaultRetryerMaxRetryDelay sets maximum retry delay + DefaultRetryerMaxRetryDelay = 300 * time.Second + + // DefaultRetryerMaxThrottleDelay sets maximum delay when throttled + DefaultRetryerMaxThrottleDelay = 300 * time.Second +) + // MaxRetries returns the number of maximum returns the service will use to make // an individual API func (d DefaultRetryer) MaxRetries() int { @@ -31,30 +44,63 @@ func (d DefaultRetryer) MaxRetries() int { var seededRand = rand.New(&lockedSource{src: rand.NewSource(time.Now().UnixNano())}) +// NewDefaultRetryer returns a retryer initialized with default values and optionally takes function +// to override values for default retryer. +func NewDefaultRetryer(opts ...func(d *DefaultRetryer)) DefaultRetryer { + d := DefaultRetryer{ + NumMaxRetries: DefaultRetryerMaxNumRetries, + MinRetryDelay: DefaultRetryerMinRetryDelay, + MinThrottleDelay: DefaultRetryerMinThrottleDelay, + MaxRetryDelay: DefaultRetryerMaxRetryDelay, + MaxThrottleDelay: DefaultRetryerMaxThrottleDelay, + } + + for _, opt := range opts { + opt(&d) + } + return d +} + // RetryRules returns the delay duration before retrying this request again +// +// Note: RetryRules method must be a value receiver so that the +// defaultRetryer is safe. func (d DefaultRetryer) RetryRules(r *Request) time.Duration { - // Set the upper limit of delay in retrying at ~five minutes - var minTime int64 = 30 + minDelay := d.MinRetryDelay var initialDelay time.Duration - throttle := d.shouldThrottle(r) if throttle { if delay, ok := getRetryAfterDelay(r); ok { initialDelay = delay } - - minTime = 500 + minDelay = d.MinThrottleDelay } retryCount := r.RetryCount - if throttle && retryCount > 8 { - retryCount = 8 - } else if retryCount > 12 { - retryCount = 12 + + maxDelay := d.MaxRetryDelay + if throttle { + maxDelay = d.MaxThrottleDelay + } + + var delay time.Duration + + // Logic to cap the retry count based on the minDelay provided + actualRetryCount := int(math.Log2(float64(minDelay))) + 1 + if actualRetryCount < 63-retryCount { + delay = time.Duration(1< maxDelay { + delay = getJitterDelay(maxDelay / 2) + } + } else { + delay = getJitterDelay(maxDelay / 2) } + return delay + initialDelay +} - delay := (1 << uint(retryCount)) * (seededRand.Int63n(minTime) + minTime) - return (time.Duration(delay) * time.Millisecond) + initialDelay +// getJitterDelay returns a jittered delay for retry +func getJitterDelay(duration time.Duration) time.Duration { + return time.Duration(seededRand.Int63n(int64(duration)) + int64(duration)) } // ShouldRetry returns true if the request should be retried. @@ -73,16 +119,18 @@ func (d DefaultRetryer) ShouldRetry(r *Request) bool { // ShouldThrottle returns true if the request should be throttled. func (d DefaultRetryer) shouldThrottle(r *Request) bool { - switch r.HTTPResponse.StatusCode { - case 429: - case 502: - case 503: - case 504: - default: - return r.IsErrorThrottle() + if r.HTTPResponse != nil { + switch r.HTTPResponse.StatusCode { + case 429: + case 502: + case 503: + case 504: + default: + return r.IsErrorThrottle() + } + return true } - - return true + return r.IsErrorThrottle() } // This will look in the Retry-After header, RFC 7231, for how long diff --git a/aws/default_retryer_test.go b/aws/default_retryer_test.go index 7bb485084de..2c35e10aa6b 100644 --- a/aws/default_retryer_test.go +++ b/aws/default_retryer_test.go @@ -56,7 +56,9 @@ func TestRetryThrottleStatusCodes(t *testing.T) { }, } - d := DefaultRetryer{NumMaxRetries: 10} + d := NewDefaultRetryer(func(d *DefaultRetryer) { + d.NumMaxRetries = 100 + }) for i, c := range cases { throttle := d.shouldThrottle(&c.r) retry := d.ShouldRetry(&c.r) @@ -71,7 +73,7 @@ func TestRetryThrottleStatusCodes(t *testing.T) { } } -func TestCanUseRetryAfter(t *testing.T) { +func TestGetRetryAfterDelay(t *testing.T) { cases := []struct { r Request e bool @@ -164,7 +166,9 @@ func TestGetRetryDelay(t *testing.T) { } func TestRetryDelay(t *testing.T) { - d := DefaultRetryer{100} + d := NewDefaultRetryer(func(d *DefaultRetryer) { + d.NumMaxRetries = 100 + }) r := Request{} for i := 0; i < 100; i++ { rTemp := r @@ -190,7 +194,7 @@ func TestRetryDelay(t *testing.T) { rTemp.RetryCount = 1 rTemp.HTTPResponse = &http.Response{StatusCode: 503, Header: http.Header{"Retry-After": []string{"300"}}} a := d.RetryRules(&rTemp) - if a < 5*time.Minute{ + if a < 5*time.Minute { t.Errorf("retry delay should not be less than retry-after duration, received %s for retrycount %d", a, 1) } } diff --git a/aws/http_request_retry_test.go b/aws/http_request_retry_test.go index 270d5f8ab09..e4c145a2729 100644 --- a/aws/http_request_retry_test.go +++ b/aws/http_request_retry_test.go @@ -25,7 +25,9 @@ func TestRequestCancelRetry(t *testing.T) { reqNum := 0 cfg := unit.Config() cfg.EndpointResolver = aws.ResolveWithEndpointURL("http://endpoint") - cfg.Retryer = aws.DefaultRetryer{NumMaxRetries: 10} + cfg.Retryer = aws.NewDefaultRetryer(func(d *aws.DefaultRetryer) { + d.NumMaxRetries = 10 + }) s := mock.NewMockClient(cfg) diff --git a/aws/no_op_retryer.go b/aws/no_op_retryer.go new file mode 100644 index 00000000000..e0294dabcae --- /dev/null +++ b/aws/no_op_retryer.go @@ -0,0 +1,24 @@ +package aws + +import "time" + +// NoOpRetryer provides a retryer that performs no retries. +// It should be used when we do not want retries to be performed. +type NoOpRetryer struct{} + +// MaxRetries returns the number of maximum returns the service will use to make +// an individual API; For NoOpRetryer the MaxRetries will always be zero. +func (d NoOpRetryer) MaxRetries() int { + return 0 +} + +// ShouldRetry will always return false for NoOpRetryer, as it should never retry. +func (d NoOpRetryer) ShouldRetry(_ *Request) bool { + return false +} + +// RetryRules returns the delay duration before retrying this request again; +// since NoOpRetryer does not retry, RetryRules always returns 0. +func (d NoOpRetryer) RetryRules(_ *Request) time.Duration { + return 0 +} diff --git a/aws/no_op_retryer_test.go b/aws/no_op_retryer_test.go new file mode 100644 index 00000000000..dec439393e9 --- /dev/null +++ b/aws/no_op_retryer_test.go @@ -0,0 +1,44 @@ +package aws + +import ( + "net/http" + "testing" + "time" +) + +func TestNoOpRetryer(t *testing.T) { + cases := []struct { + r Request + expectMaxRetries int + expectRetryDelay time.Duration + expectRetry bool + }{ + { + r: Request{ + HTTPResponse: &http.Response{StatusCode: 200}, + }, + expectMaxRetries: 0, + expectRetryDelay: 0, + expectRetry: false, + }, + } + + d := NoOpRetryer{} + for i, c := range cases { + maxRetries := d.MaxRetries() + retry := d.ShouldRetry(&c.r) + retryDelay := d.RetryRules(&c.r) + + if e, a := c.expectMaxRetries, maxRetries; e != a { + t.Errorf("%d: expected %v, but received %v for number of max retries", i, e, a) + } + + if e, a := c.expectRetry, retry; e != a { + t.Errorf("%d: expected %v, but received %v for should retry", i, e, a) + } + + if e, a := c.expectRetryDelay, retryDelay; e != a { + t.Errorf("%d: expected %v, but received %v as retry delay", i, e, a) + } + } +} diff --git a/aws/request_1_6_test.go b/aws/request_1_6_test.go index d9f7921f980..8a5c9b337db 100644 --- a/aws/request_1_6_test.go +++ b/aws/request_1_6_test.go @@ -21,7 +21,7 @@ func TestRequestInvalidEndpoint(t *testing.T) { cfg, aws.Metadata{}, cfg.Handlers, - aws.DefaultRetryer{}, + aws.NewDefaultRetryer(), &aws.Operation{}, nil, nil, diff --git a/aws/request_pagination_test.go b/aws/request_pagination_test.go index 99b4b1b9b36..e32bdd2f24b 100644 --- a/aws/request_pagination_test.go +++ b/aws/request_pagination_test.go @@ -58,7 +58,9 @@ func TestPagination(t *testing.T) { }, } - retryer := aws.DefaultRetryer{NumMaxRetries: 2} + retryer := aws.NewDefaultRetryer(func(d *aws.DefaultRetryer) { + d.NumMaxRetries = 2 + }) op := aws.Operation{ Name: "Operation", Paginator: &aws.Paginator{ @@ -160,7 +162,9 @@ func TestPaginationTruncation(t *testing.T) { } reqNum := 0 - retryer := aws.DefaultRetryer{NumMaxRetries: 2} + retryer := aws.NewDefaultRetryer(func(d *aws.DefaultRetryer) { + d.NumMaxRetries = 2 + }) ops := []aws.Operation{ { Name: "Operation", @@ -271,7 +275,9 @@ func BenchmarkPagination(b *testing.B) { {aws.String("3"), aws.String("")}, } - retryer := aws.DefaultRetryer{NumMaxRetries: 2} + retryer := aws.NewDefaultRetryer(func(d *aws.DefaultRetryer) { + d.NumMaxRetries = 2 + }) op := aws.Operation{ Name: "Operation", Paginator: &aws.Paginator{ @@ -339,7 +345,9 @@ func TestPaginationWithContextCancel(t *testing.T) { }, } - retryer := aws.DefaultRetryer{NumMaxRetries: 2} + retryer := aws.NewDefaultRetryer(func(d *aws.DefaultRetryer) { + d.NumMaxRetries = 2 + }) op := aws.Operation{ Name: "Operation", Paginator: &aws.Paginator{ @@ -350,7 +358,7 @@ func TestPaginationWithContextCancel(t *testing.T) { for _, c := range cases { input := c.input - inValues := []string{} + var inValues []string p := aws.Pager{ NewRequest: func(ctx context.Context) (*aws.Request, error) { h := defaults.Handlers() @@ -380,7 +388,7 @@ func TestPaginationWithContextCancel(t *testing.T) { ctx, cancelFn := context.WithCancel(context.Background()) cancelFn() - results := []*string{} + var results []*string for p.Next(ctx) { page := p.CurrentPage() output := page.(*mockOutput) diff --git a/aws/request_test.go b/aws/request_test.go index 92c4eca33be..c8e894621c1 100644 --- a/aws/request_test.go +++ b/aws/request_test.go @@ -87,7 +87,9 @@ func TestRequestRecoverRetry5xx(t *testing.T) { } cfg := unit.Config() - cfg.Retryer = aws.DefaultRetryer{NumMaxRetries: 10} + cfg.Retryer = aws.NewDefaultRetryer(func(d *aws.DefaultRetryer) { + d.NumMaxRetries = 10 + }) s := awstesting.NewClient(cfg) s.Handlers.Validate.Clear() @@ -126,7 +128,9 @@ func TestRequestRecoverRetry4xxRetryable(t *testing.T) { } cfg := unit.Config() - cfg.Retryer = aws.DefaultRetryer{NumMaxRetries: 10} + cfg.Retryer = aws.NewDefaultRetryer(func(d *aws.DefaultRetryer) { + d.NumMaxRetries = 10 + }) s := awstesting.NewClient(cfg) s.Handlers.Validate.Clear() @@ -154,7 +158,9 @@ func TestRequestRecoverRetry4xxRetryable(t *testing.T) { // test that retries don't occur for 4xx status codes with a response type that can't be retried func TestRequest4xxUnretryable(t *testing.T) { cfg := unit.Config() - cfg.Retryer = aws.DefaultRetryer{NumMaxRetries: 10} + cfg.Retryer = aws.NewDefaultRetryer(func(d *aws.DefaultRetryer) { + d.NumMaxRetries = 10 + }) s := awstesting.NewClient(cfg) @@ -193,7 +199,7 @@ func TestRequestExhaustRetries(t *testing.T) { orig := sdk.SleepWithContext defer func() { sdk.SleepWithContext = orig }() - delays := []time.Duration{} + var delays []time.Duration sdk.SleepWithContext = func(ctx context.Context, dur time.Duration) error { delays = append(delays, dur) return nil @@ -236,7 +242,7 @@ func TestRequestExhaustRetries(t *testing.T) { t.Errorf("expect %d retry count, got %d", e, a) } - expectDelays := []struct{ min, max time.Duration }{{30, 59}, {60, 118}, {120, 236}} + expectDelays := []struct{ min, max time.Duration }{{30, 60}, {60, 120}, {120, 240}} for i, v := range delays { min := expectDelays[i].min * time.Millisecond max := expectDelays[i].max * time.Millisecond @@ -266,7 +272,9 @@ func TestRequest_RecoverExpiredCreds(t *testing.T) { } cfg := unit.Config() - cfg.Retryer = aws.DefaultRetryer{NumMaxRetries: 10} + cfg.Retryer = aws.NewDefaultRetryer(func(d *aws.DefaultRetryer) { + d.NumMaxRetries = 10 + }) credsInvalidated := false credsProvider := func() aws.CredentialsProvider { @@ -394,7 +402,7 @@ func TestRequestThrottleRetries(t *testing.T) { orig := sdk.SleepWithContext defer func() { sdk.SleepWithContext = orig }() - delays := []time.Duration{} + var delays []time.Duration sdk.SleepWithContext = func(ctx context.Context, dur time.Duration) error { delays = append(delays, dur) return nil @@ -460,7 +468,9 @@ func TestRequestRecoverTimeoutWithNilBody(t *testing.T) { } cfg := unit.Config() - cfg.Retryer = aws.DefaultRetryer{NumMaxRetries: 10} + cfg.Retryer = aws.NewDefaultRetryer(func(d *aws.DefaultRetryer) { + d.NumMaxRetries = 10 + }) s := awstesting.NewClient(cfg) @@ -507,7 +517,9 @@ func TestRequestRecoverTimeoutWithNilResponse(t *testing.T) { } cfg := unit.Config() - cfg.Retryer = aws.DefaultRetryer{NumMaxRetries: 10} + cfg.Retryer = aws.NewDefaultRetryer(func(d *aws.DefaultRetryer) { + d.NumMaxRetries = 10 + }) s := awstesting.NewClient(cfg) @@ -572,7 +584,7 @@ func TestRequest_NoBody(t *testing.T) { cfg := unit.Config() cfg.Region = "mock-region" - cfg.Retryer = aws.DefaultRetryer{NumMaxRetries: 0} + cfg.Retryer = aws.NoOpRetryer{} cfg.EndpointResolver = aws.ResolveWithEndpointURL(server.URL) s := awstesting.NewClient(cfg) @@ -744,13 +756,17 @@ func TestSerializationErrConnectionReset(t *testing.T) { TargetPrefix: "Foo", } cfg := unit.Config() - cfg.Retryer = aws.DefaultRetryer{NumMaxRetries: 5} + cfg.Retryer = aws.NewDefaultRetryer(func(d *aws.DefaultRetryer) { + d.NumMaxRetries = 5 + }) req := aws.New( cfg, meta, handlers, - aws.DefaultRetryer{NumMaxRetries: 5}, + aws.NewDefaultRetryer(func(d *aws.DefaultRetryer) { + d.NumMaxRetries = 5 + }), op, &struct{}{}, &struct{}{}, @@ -895,7 +911,9 @@ func TestRequest_TemporaryRetry(t *testing.T) { defer server.Close() cfg := unit.Config() - cfg.Retryer = aws.DefaultRetryer{NumMaxRetries: 1} + cfg.Retryer = aws.NewDefaultRetryer(func(d *aws.DefaultRetryer) { + d.NumMaxRetries = 1 + }) cfg.HTTPClient = &http.Client{ Timeout: 100 * time.Millisecond, } diff --git a/aws/timeout_read_closer_benchmark_test.go b/aws/timeout_read_closer_benchmark_test.go index f25097dd0af..b34be170296 100644 --- a/aws/timeout_read_closer_benchmark_test.go +++ b/aws/timeout_read_closer_benchmark_test.go @@ -56,7 +56,9 @@ func BenchmarkTimeoutReadCloser(b *testing.B) { cfg, meta, handlers, - aws.DefaultRetryer{NumMaxRetries: 5}, + aws.NewDefaultRetryer(func(d *aws.DefaultRetryer) { + d.NumMaxRetries = 5 + }), op, &struct { Foo *string diff --git a/internal/awstesting/client.go b/internal/awstesting/client.go index 20ef4055deb..1cd260bd1dc 100644 --- a/internal/awstesting/client.go +++ b/internal/awstesting/client.go @@ -6,8 +6,5 @@ import ( // NewClient creates and initializes a generic service client for testing. func NewClient(cfg aws.Config) *aws.Client { - if cfg.Retryer == nil { - cfg.Retryer = aws.DefaultRetryer{NumMaxRetries: 3} - } return aws.NewClient(cfg, aws.Metadata{ServiceName: "mockService"}) } diff --git a/private/model/api/shape.go b/private/model/api/shape.go index fb20130e669..ad9d3af758f 100644 --- a/private/model/api/shape.go +++ b/private/model/api/shape.go @@ -30,11 +30,11 @@ type ShapeRef struct { Ignore bool XMLNamespace XMLInfo Payload string - IdempotencyToken bool `json:"idempotencyToken"` + IdempotencyToken bool `json:"idempotencyToken"` TimestampFormat string `json:"timestampFormat"` - JSONValue bool `json:"jsonvalue"` - Deprecated bool `json:"deprecated"` - HostLabel bool `json:"hostLabel"` + JSONValue bool `json:"jsonvalue"` + Deprecated bool `json:"deprecated"` + HostLabel bool `json:"hostLabel"` OrigShapeName string `json:"-"` @@ -84,7 +84,7 @@ type Shape struct { Streaming bool Location string LocationName string - IdempotencyToken bool `json:"idempotencyToken"` + IdempotencyToken bool `json:"idempotencyToken"` TimestampFormat string `json:"timestampFormat"` XMLNamespace XMLInfo Min float64 // optional Minimum length (string, list) or value (number) diff --git a/private/model/api/shape_marshal.go b/private/model/api/shape_marshal.go index 9ee9cfe5adb..964ea0d9c54 100644 --- a/private/model/api/shape_marshal.go +++ b/private/model/api/shape_marshal.go @@ -79,7 +79,7 @@ func getContentType(s *Shape) string { if s.API.Metadata.JSONVersion != "" && s.API.Metadata.Protocol == "json" { return fmt.Sprintf("application/x-amz-json-%s", s.API.Metadata.JSONVersion) } - if s.API.Metadata.Protocol == "json" || s.API.Metadata.Protocol == "rest-json" { + if s.API.Metadata.Protocol == "json" || s.API.Metadata.Protocol == "rest-json" { return "application/json" } return "" @@ -90,8 +90,7 @@ var marshalShapeTmpl = template.Must(template.New("marshalShapeTmpl").Funcs( "MarshalShapeRefGoCode": MarshalShapeRefGoCode, "nestedRefsByLocation": nestedRefsByLocation, "isShapeFieldsNested": isShapeFieldsNested, - "getContentType": getContentType, - + "getContentType": getContentType, }, ).Parse(` {{ define "encode shape" -}} @@ -207,15 +206,15 @@ func isShapeFieldsNested(loc string, s *Shape) bool { return loc == "Body" && len(s.LocationName) != 0 && s.API.Metadata.Protocol == "rest-xml" } -func QuotedFormatTime(s marshalShapeRef) string { - if (s.Ref.API.Metadata.Protocol == "json" || s.Ref.API.Metadata.Protocol == "rest-json") && s.Location() == "Body" { +func QuotedFormatTime(s marshalShapeRef) string { + if (s.Ref.API.Metadata.Protocol == "json" || s.Ref.API.Metadata.Protocol == "rest-json") && s.Location() == "Body" { return "true" } return "false" } var marshalShapeRefTmpl = template.Must(template.New("marshalShapeRefTmpl").Funcs(template.FuncMap{ - "Collection": Collection, + "Collection": Collection, "quotedFormatTime": QuotedFormatTime, }).Parse(` {{ define "encode field" -}} @@ -723,11 +722,11 @@ func (r marshalShapeRef) IsIdempotencyToken() bool { } func (r marshalShapeRef) TimeFormat() string { - if r.Ref.TimestampFormat!="" { - return fmt.Sprintf("%q",r.Ref.TimestampFormat) + if r.Ref.TimestampFormat != "" { + return fmt.Sprintf("%q", r.Ref.TimestampFormat) } - if r.Ref.Shape.TimestampFormat!= "" { - return fmt.Sprintf("%q",r.Ref.Shape.TimestampFormat) + if r.Ref.Shape.TimestampFormat != "" { + return fmt.Sprintf("%q", r.Ref.Shape.TimestampFormat) } switch r.Location() { diff --git a/private/protocol/ec2query/build_bench_test.go b/private/protocol/ec2query/build_bench_test.go index 54c3895f33b..875d0a33dcc 100644 --- a/private/protocol/ec2query/build_bench_test.go +++ b/private/protocol/ec2query/build_bench_test.go @@ -49,7 +49,7 @@ func BenchmarkEC2QueryBuild_Complex_ec2AuthorizeSecurityGroupEgress(b *testing.B IpProtocol: aws.String("String"), SourceSecurityGroupName: aws.String("String"), SourceSecurityGroupOwnerId: aws.String("String"), - ToPort: aws.Int64(1), + ToPort: aws.Int64(1), } benchEC2QueryBuild(b, "AuthorizeSecurityGroupEgress", params) diff --git a/private/protocol/fields.go b/private/protocol/fields.go index 41aa6f6703f..3ea555caf62 100644 --- a/private/protocol/fields.go +++ b/private/protocol/fields.go @@ -135,8 +135,8 @@ func (v JSONValue) MarshalValueBuf(b []byte) ([]byte, error) { // TimeValue provies encoding of time.Time for AWS protocols. type TimeValue struct { - V time.Time - Format string + V time.Time + Format string QuotedFormatTime bool } @@ -148,8 +148,8 @@ func (v TimeValue) MarshalValue() (string, error) { } if v.QuotedFormatTime { - format, err := FormatTime(v.Format, v.V) - return fmt.Sprintf("%q",format),err + format, err := FormatTime(v.Format, v.V) + return fmt.Sprintf("%q", format), err } return FormatTime(v.Format, v.V) diff --git a/private/protocol/json/jsonutil/unmarshal_bench_test.go b/private/protocol/json/jsonutil/unmarshal_bench_test.go index 291f5cd8b60..4c3f2c2ba5d 100644 --- a/private/protocol/json/jsonutil/unmarshal_bench_test.go +++ b/private/protocol/json/jsonutil/unmarshal_bench_test.go @@ -12,7 +12,7 @@ import ( ) var ( - simpleJSON = []byte(`{"FooEnum": "foo", "ListEnums": ["0", "1"]}`) + simpleJSON = []byte(`{"FooEnum": "foo", "ListEnums": ["0", "1"]}`) complexJSON = []byte(`{"Table":{"AttributeDefinitions":[{"AttributeName":"1","AttributeType":"N"}],"CreationDateTime":1.562054355238E9,"ItemCount":0,"KeySchema":[{"AttributeName":"1","KeyType":"HASH"}],"ProvisionedThroughput":{"NumberOfDecreasesToday":0,"ReadCapacityUnits":5,"WriteCapacityUnits":5},"TableArn":"arn:aws:dynamodb:us-west-2:183557167593:table/TestTable","TableId":"575d0be6-34e3-4843-838c-8e8e8d4ea2f7","TableName":"TestTable","TableSizeBytes":0,"TableStatus":"ACTIVE"}}`) ) diff --git a/service/dynamodb/customizations.go b/service/dynamodb/customizations.go index a2dcf2cf06f..9eb032ee433 100644 --- a/service/dynamodb/customizations.go +++ b/service/dynamodb/customizations.go @@ -5,25 +5,13 @@ import ( "hash/crc32" "io" "io/ioutil" - "math" "strconv" "time" "github.com/aws/aws-sdk-go-v2/aws" - client "github.com/aws/aws-sdk-go-v2/aws" - request "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/aws/awserr" ) -type retryer struct { - client.DefaultRetryer -} - -func (d retryer) RetryRules(r *request.Request) time.Duration { - delay := time.Duration(math.Pow(2, float64(r.RetryCount))) * 50 - return delay * time.Millisecond -} - func init() { initClient = func(c *Client) { if c.Config.Retryer == nil { @@ -45,11 +33,10 @@ func init() { } func setCustomRetryer(c *Client) { - c.Retryer = retryer{ - DefaultRetryer: client.DefaultRetryer{ - NumMaxRetries: 10, - }, - } + c.Retryer = aws.NewDefaultRetryer(func(d *aws.DefaultRetryer) { + d.NumMaxRetries = 10 + d.MinRetryDelay = 50 * time.Millisecond + }) } func drainBody(b io.ReadCloser, length int64) (out *bytes.Buffer, err error) { @@ -69,13 +56,13 @@ func drainBody(b io.ReadCloser, length int64) (out *bytes.Buffer, err error) { var disableCompressionHandler = aws.NamedHandler{Name: "dynamodb.DisableCompression", Fn: disableCompression} -func disableCompression(r *request.Request) { +func disableCompression(r *aws.Request) { r.HTTPRequest.Header.Set("Accept-Encoding", "identity") } var validateCRC32Handler = aws.NamedHandler{Name: "dynamodb.ValidateCRC32", Fn: validateCRC32} -func validateCRC32(r *request.Request) { +func validateCRC32(r *aws.Request) { if r.Error != nil { return // already have an error, no need to verify CRC } diff --git a/service/dynamodb/customizations_test.go b/service/dynamodb/customizations_test.go index ae23100c503..7025f9dd7cc 100644 --- a/service/dynamodb/customizations_test.go +++ b/service/dynamodb/customizations_test.go @@ -8,8 +8,6 @@ import ( "testing" "github.com/aws/aws-sdk-go-v2/aws" - client "github.com/aws/aws-sdk-go-v2/aws" - request "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/aws/awserr" "github.com/aws/aws-sdk-go-v2/internal/awstesting/unit" "github.com/aws/aws-sdk-go-v2/service/dynamodb" @@ -19,7 +17,9 @@ var db *dynamodb.Client func TestMain(m *testing.M) { cfg := unit.Config() - cfg.Retryer = aws.DefaultRetryer{NumMaxRetries: 2} + cfg.Retryer = aws.NewDefaultRetryer(func(d *aws.DefaultRetryer) { + d.NumMaxRetries = 2 + }) db = dynamodb.New(cfg) db.Handlers.Send.Clear() // mock sending @@ -27,13 +27,13 @@ func TestMain(m *testing.M) { os.Exit(m.Run()) } -func mockCRCResponse(svc *dynamodb.Client, status int, body, crc string) (req *request.Request) { +func mockCRCResponse(svc *dynamodb.Client, status int, body, crc string) (req *aws.Request) { header := http.Header{} header.Set("x-amz-crc32", crc) listReq := svc.ListTablesRequest(nil) req = listReq.Request - req.Handlers.Send.PushBack(func(*request.Request) { + req.Handlers.Send.PushBack(func(*aws.Request) { req.HTTPResponse = &http.Response{ ContentLength: int64(len(body)), StatusCode: status, @@ -57,7 +57,9 @@ func TestDefaultRetryRules(t *testing.T) { func TestCustomRetryRules(t *testing.T) { cfg := unit.Config() - cfg.Retryer = aws.DefaultRetryer{NumMaxRetries: 2} + cfg.Retryer = aws.NewDefaultRetryer(func(d *aws.DefaultRetryer) { + d.NumMaxRetries = 2 + }) svc := dynamodb.New(cfg) if e, a := 2, svc.Retryer.MaxRetries(); e != a { @@ -66,12 +68,14 @@ func TestCustomRetryRules(t *testing.T) { } type testCustomRetryer struct { - client.DefaultRetryer + aws.DefaultRetryer } func TestCustomRetry_FromConfig(t *testing.T) { cfg := unit.Config() - cfg.Retryer = testCustomRetryer{client.DefaultRetryer{NumMaxRetries: 9}} + cfg.Retryer = testCustomRetryer{aws.NewDefaultRetryer(func(d *aws.DefaultRetryer) { + d.NumMaxRetries = 9. + })} svc := dynamodb.New(cfg) @@ -138,7 +142,9 @@ func TestValidateCRC32DoesNotMatch(t *testing.T) { func TestValidateCRC32DoesNotMatchNoComputeChecksum(t *testing.T) { cfg := unit.Config() - cfg.Retryer = aws.DefaultRetryer{NumMaxRetries: 2} + cfg.Retryer = aws.NewDefaultRetryer(func(d *aws.DefaultRetryer) { + d.NumMaxRetries = 2 + }) svc := dynamodb.New(cfg) svc.DisableComputeChecksums = true diff --git a/service/ec2/customizations.go b/service/ec2/customizations.go index be2abb295b0..6830c59a9fe 100644 --- a/service/ec2/customizations.go +++ b/service/ec2/customizations.go @@ -4,19 +4,41 @@ import ( "time" "github.com/aws/aws-sdk-go-v2/aws" - request "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/internal/awsutil" ) +const ( + // customRetryerMaxNumRetries sets max number of retries + customRetryerMaxNumRetries = 3 + + // customRetryerMinRetryDelay sets min retry delay + customRetryerMinRetryDelay = 1 * time.Second + + // customRetryerMaxRetryDelay sets max retry delay + customRetryerMaxRetryDelay = 8 * time.Second +) + +// setRetryerConfig overrides the default Retryer values +func setRetryerConfig(d *aws.DefaultRetryer) { + d.NumMaxRetries = customRetryerMaxNumRetries + d.MinRetryDelay = customRetryerMinRetryDelay + d.MinThrottleDelay = customRetryerMinRetryDelay + d.MaxRetryDelay = customRetryerMaxRetryDelay + d.MaxThrottleDelay = customRetryerMaxRetryDelay +} + func init() { - initRequest = func(c *Client, r *request.Request) { + initRequest = func(c *Client, r *aws.Request) { if r.Operation.Name == opCopySnapshot { // fill the PresignedURL parameter r.Handlers.Build.PushFront(fillPresignedURL) } + if c.Config.Retryer == nil && (r.Operation.Name == opModifyNetworkInterfaceAttribute || r.Operation.Name == opAssignPrivateIpAddresses) { + r.Retryer = aws.NewDefaultRetryer(setRetryerConfig) + } } } -func fillPresignedURL(r *request.Request) { +func fillPresignedURL(r *aws.Request) { if !r.ParamsFilled() { return } @@ -49,7 +71,7 @@ func fillPresignedURL(r *request.Request) { metadata.SigningRegion = resolved.SigningRegion // Presign a CopySnapshot request with modified params - req := request.New(cfgCp, metadata, r.Handlers, r.Retryer, r.Operation, newParams, r.Data) + req := aws.New(cfgCp, metadata, r.Handlers, r.Retryer, r.Operation, newParams, r.Data) url, err := req.Presign(5 * time.Minute) // 5 minutes should be enough. if err != nil { // bubble error back up to original request r.Error = err diff --git a/service/ec2/retryer_test.go b/service/ec2/retryer_test.go new file mode 100644 index 00000000000..a7a8389c106 --- /dev/null +++ b/service/ec2/retryer_test.go @@ -0,0 +1,108 @@ +package ec2 + +import ( + "testing" + "time" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/internal/awstesting/unit" +) + +func TestCustomRetryRules(t *testing.T) { + + cfg := unit.Config() + svc := New(cfg) + + req := svc.ModifyNetworkInterfaceAttributeRequest(&ModifyNetworkInterfaceAttributeInput{ + NetworkInterfaceId: aws.String("foo"), + }) + + duration := req.Request.Retryer.RetryRules(req.Request) + if duration < time.Second*1 || duration > time.Second*2 { + t.Errorf("expected duration to be between 1s and 2s, but received %s", duration) + } + + req.Request.RetryCount = 15 + duration = req.Request.Retryer.RetryRules(req.Request) + + if duration < time.Second*4 || duration > time.Second*8 { + t.Errorf("expected duration to be between 4s and 8s, but received %s", duration) + } + +} + +func TestCustomRetryer_WhenRetrierSpecified(t *testing.T) { + svc := New(aws.Config{ + Region: "us-west-2", + Retryer: aws.NewDefaultRetryer(func(d *aws.DefaultRetryer) { + d.NumMaxRetries = 4 + d.MinThrottleDelay = 50 * time.Millisecond + d.MinRetryDelay = 10 * time.Millisecond + d.MaxThrottleDelay = 200 * time.Millisecond + d.MaxRetryDelay = 300 * time.Millisecond + }), + EndpointResolver: unit.Config().EndpointResolver, + }) + + if _, ok := svc.Client.Retryer.(aws.DefaultRetryer); !ok { + t.Error("expected default retryer, but received otherwise") + } + + req := svc.AssignPrivateIpAddressesRequest(&AssignPrivateIpAddressesInput{ + NetworkInterfaceId: aws.String("foo"), + }) + + d := req.Request.Retryer.(aws.DefaultRetryer) + + if d.NumMaxRetries != 4 { + t.Errorf("expected max retries to be %v, got %v", 4, d.NumMaxRetries) + } + + if d.MinRetryDelay != 10*time.Millisecond { + t.Errorf("expected min retry delay to be %v, got %v", "10 ms", d.MinRetryDelay) + } + + if d.MinThrottleDelay != 50*time.Millisecond { + t.Errorf("expected min throttle delay to be %v, got %v", "50 ms", d.MinThrottleDelay) + } + + if d.MaxRetryDelay != 300*time.Millisecond { + t.Errorf("expected max retry delay to be %v, got %v", "300 ms", d.MaxRetryDelay) + } + + if d.MaxThrottleDelay != 200*time.Millisecond { + t.Errorf("expected max throttle delay to be %v, got %v", "200 ms", d.MaxThrottleDelay) + } +} + +func TestCustomRetryer(t *testing.T) { + + cfg := unit.Config() + svc := New(cfg) + + req := svc.AssignPrivateIpAddressesRequest(&AssignPrivateIpAddressesInput{ + NetworkInterfaceId: aws.String("foo"), + }) + + d := req.Request.Retryer.(aws.DefaultRetryer) + + if d.NumMaxRetries != customRetryerMaxNumRetries { + t.Errorf("expected max retries to be %v, got %v", customRetryerMaxNumRetries, d.NumMaxRetries) + } + + if d.MinRetryDelay != customRetryerMinRetryDelay { + t.Errorf("expected min retry delay to be %v, got %v", customRetryerMinRetryDelay, d.MinRetryDelay) + } + + if d.MinThrottleDelay != customRetryerMinRetryDelay { + t.Errorf("expected min throttle delay to be %v, got %v", customRetryerMinRetryDelay, d.MinThrottleDelay) + } + + if d.MaxRetryDelay != customRetryerMaxRetryDelay { + t.Errorf("expected max retry delay to be %v, got %v", customRetryerMaxRetryDelay, d.MaxRetryDelay) + } + + if d.MaxThrottleDelay != customRetryerMaxRetryDelay { + t.Errorf("expected max throttle delay to be %v, got %v", customRetryerMaxRetryDelay, d.MaxThrottleDelay) + } +} diff --git a/service/kinesis/customizations_test.go b/service/kinesis/customizations_test.go index 6c1e0b05e80..ae18b6f9fb2 100644 --- a/service/kinesis/customizations_test.go +++ b/service/kinesis/customizations_test.go @@ -33,7 +33,9 @@ func TestKinesisGetRecordsCustomization(t *testing.T) { retryCount := 0 cfg := unit.Config() - cfg.Retryer = aws.DefaultRetryer{NumMaxRetries: 4} + cfg.Retryer = aws.NewDefaultRetryer(func(d *aws.DefaultRetryer) { + d.NumMaxRetries = 4 + }) svc := New(cfg) req := svc.GetRecordsRequest(&GetRecordsInput{ diff --git a/service/s3/s3crypto/cipher_util_test.go b/service/s3/s3crypto/cipher_util_test.go index c896ec8dbd7..6a7f73a775c 100644 --- a/service/s3/s3crypto/cipher_util_test.go +++ b/service/s3/s3crypto/cipher_util_test.go @@ -100,7 +100,7 @@ func TestCEKFactory(t *testing.T) { defer ts.Close() cfg := unit.Config() - cfg.Retryer = aws.DefaultRetryer{NumMaxRetries: 0} + cfg.Retryer = aws.NoOpRetryer{} cfg.EndpointResolver = aws.ResolveWithEndpointURL(ts.URL) cfg.Region = "us-west-2" @@ -156,7 +156,7 @@ func TestCEKFactoryNoCEK(t *testing.T) { defer ts.Close() cfg := unit.Config() - cfg.Retryer = aws.DefaultRetryer{NumMaxRetries: 0} + cfg.Retryer = aws.NoOpRetryer{} cfg.EndpointResolver = aws.ResolveWithEndpointURL(ts.URL) cfg.Region = "us-west-2" @@ -212,7 +212,7 @@ func TestCEKFactoryCustomEntry(t *testing.T) { defer ts.Close() cfg := unit.Config() - cfg.Retryer = aws.DefaultRetryer{NumMaxRetries: 0} + cfg.Retryer = aws.NoOpRetryer{} cfg.EndpointResolver = aws.ResolveWithEndpointURL(ts.URL) cfg.Region = "us-west-2" diff --git a/service/s3/s3crypto/decryption_client_test.go b/service/s3/s3crypto/decryption_client_test.go index 3303a21f5e4..2b1f5fa0411 100644 --- a/service/s3/s3crypto/decryption_client_test.go +++ b/service/s3/s3crypto/decryption_client_test.go @@ -30,16 +30,16 @@ func TestGetObjectGCM(t *testing.T) { defer ts.Close() cfg := unit.Config() - cfg.Retryer = aws.DefaultRetryer{NumMaxRetries: 0} + cfg.Retryer = aws.NoOpRetryer{} cfg.EndpointResolver = aws.ResolveWithEndpointURL(ts.URL) cfg.Region = "us-west-2" c := s3crypto.NewDecryptionClient(cfg) - c.S3Client.(*s3.Client).ForcePathStyle = true - if c == nil { - t.Error("expected non-nil value") + t.Fatalf("failed to create a new S3 crypto decryption client") } + + c.S3Client.(*s3.Client).ForcePathStyle = true input := &s3.GetObjectInput{ Key: aws.String("test"), Bucket: aws.String("test"), @@ -101,16 +101,16 @@ func TestGetObjectCBC(t *testing.T) { defer ts.Close() cfg := unit.Config() - cfg.Retryer = aws.DefaultRetryer{NumMaxRetries: 0} + cfg.Retryer = aws.NoOpRetryer{} cfg.EndpointResolver = aws.ResolveWithEndpointURL(ts.URL) cfg.Region = "us-west-2" c := s3crypto.NewDecryptionClient(cfg) - c.S3Client.(*s3.Client).ForcePathStyle = true - if c == nil { - t.Error("expected non-nil value") + t.Fatalf("failed to create a new S3 crypto decryption client") } + + c.S3Client.(*s3.Client).ForcePathStyle = true input := &s3.GetObjectInput{ Key: aws.String("test"), Bucket: aws.String("test"), @@ -170,16 +170,16 @@ func TestGetObjectCBC2(t *testing.T) { defer ts.Close() cfg := unit.Config() - cfg.Retryer = aws.DefaultRetryer{NumMaxRetries: 0} + cfg.Retryer = aws.NoOpRetryer{} cfg.EndpointResolver = aws.ResolveWithEndpointURL(ts.URL) cfg.Region = "us-west-2" c := s3crypto.NewDecryptionClient(cfg) - c.S3Client.(*s3.Client).ForcePathStyle = true - if c == nil { - t.Error("expected non-nil value") + t.Fatalf("failed to create a new S3 crypto decryption client") } + + c.S3Client.(*s3.Client).ForcePathStyle = true input := &s3.GetObjectInput{ Key: aws.String("test"), Bucket: aws.String("test"), diff --git a/service/s3/s3crypto/encryption_client_test.go b/service/s3/s3crypto/encryption_client_test.go index 46110b11307..c12d6154060 100644 --- a/service/s3/s3crypto/encryption_client_test.go +++ b/service/s3/s3crypto/encryption_client_test.go @@ -21,7 +21,7 @@ import ( func TestDefaultConfigValues(t *testing.T) { cfg := unit.Config() - cfg.Retryer = aws.DefaultRetryer{NumMaxRetries: 0} + cfg.Retryer = aws.NoOpRetryer{} cfg.Region = "us-west-2" svc := kms.New(cfg) @@ -49,15 +49,15 @@ func TestPutObject(t *testing.T) { cb := mockCipherBuilder{generator} cfg := unit.Config() - cfg.Retryer = aws.DefaultRetryer{NumMaxRetries: 0} + cfg.Retryer = aws.NoOpRetryer{} cfg.Region = "us-west-2" c := s3crypto.NewEncryptionClient(cfg, cb) - c.S3Client.(*s3.Client).ForcePathStyle = true - if c == nil { - t.Error("expected non-vil client value") + t.Fatalf("failed to create a new S3 crypto encryption client") } + + c.S3Client.(*s3.Client).ForcePathStyle = true input := &s3.PutObjectInput{ Key: aws.String("test"), Bucket: aws.String("test"), diff --git a/service/s3/s3crypto/kms_key_handler_test.go b/service/s3/s3crypto/kms_key_handler_test.go index 85136741a9f..9cf33a407c7 100644 --- a/service/s3/s3crypto/kms_key_handler_test.go +++ b/service/s3/s3crypto/kms_key_handler_test.go @@ -49,7 +49,7 @@ func TestKMSGenerateCipherData(t *testing.T) { })) cfg := unit.Config() - cfg.Retryer = aws.DefaultRetryer{NumMaxRetries: 0} + cfg.Retryer = aws.NoOpRetryer{} cfg.EndpointResolver = aws.ResolveWithEndpointURL(ts.URL) cfg.Region = "us-west-2" @@ -79,7 +79,7 @@ func TestKMSDecrypt(t *testing.T) { })) cfg := unit.Config() - cfg.Retryer = aws.DefaultRetryer{NumMaxRetries: 0} + cfg.Retryer = aws.NoOpRetryer{} cfg.EndpointResolver = aws.ResolveWithEndpointURL(ts.URL) cfg.Region = "us-west-2" @@ -106,7 +106,7 @@ func TestKMSDecryptBadJSON(t *testing.T) { })) cfg := unit.Config() - cfg.Retryer = aws.DefaultRetryer{NumMaxRetries: 0} + cfg.Retryer = aws.NoOpRetryer{} cfg.EndpointResolver = aws.ResolveWithEndpointURL(ts.URL) cfg.Region = "us-west-2" diff --git a/service/s3/s3manager/download.go b/service/s3/s3manager/download.go index d263f74450b..3c743f1baec 100644 --- a/service/s3/s3manager/download.go +++ b/service/s3/s3manager/download.go @@ -86,9 +86,6 @@ func NewDownloader(cfg aws.Config, options ...func(*Downloader)) *Downloader { Concurrency: DefaultDownloadConcurrency, Retryer: cfg.Retryer, } - if d.Retryer == nil { - d.Retryer = aws.DefaultRetryer{NumMaxRetries: 3} - } for _, option := range options { option(d) @@ -121,8 +118,6 @@ func NewDownloaderWithClient(svc s3iface.ClientAPI, options ...func(*Downloader) if s3Svc, ok := svc.(*s3.Client); ok { retryer = s3Svc.Retryer - } else { - retryer = aws.DefaultRetryer{NumMaxRetries: 3} } d := &Downloader{ diff --git a/service/s3/s3manager/download_test.go b/service/s3/s3manager/download_test.go index 3777411d5c9..fe1bb422d23 100644 --- a/service/s3/s3manager/download_test.go +++ b/service/s3/s3manager/download_test.go @@ -157,11 +157,18 @@ func dlLoggingSvcContentRangeTotalAny(data []byte, states []int) (*s3.Client, *[ func dlLoggingSvcWithErrReader(cases []testErrReader) (*s3.Client, *[]string) { var m sync.Mutex - names := []string{} + var names []string var index int cfg := unit.Config() - cfg.Retryer = aws.DefaultRetryer{NumMaxRetries: len(cases) - 1} + switch len(cases) - 1 { + case 0: // zero retries expected + cfg.Retryer = aws.NoOpRetryer{} + default: + cfg.Retryer = aws.NewDefaultRetryer(func(d *aws.DefaultRetryer) { + d.NumMaxRetries = len(cases) - 1 + }) + } svc := s3.New(cfg) diff --git a/service/s3/statusok_error_test.go b/service/s3/statusok_error_test.go index 8bf6521490f..34916837a1c 100644 --- a/service/s3/statusok_error_test.go +++ b/service/s3/statusok_error_test.go @@ -167,8 +167,7 @@ func newCopyTestSvc(errMsg string) *s3.Client { })) cfg := unit.Config() cfg.EndpointResolver = aws.ResolveWithEndpointURL(server.URL) - cfg.Retryer = aws.DefaultRetryer{NumMaxRetries: 0} - + cfg.Retryer = aws.NoOpRetryer{} svc := s3.New(cfg) svc.ForcePathStyle = true