Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
141 changes: 87 additions & 54 deletions aws/waiter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@ import (
"testing"
"time"

"github.com/stretchr/testify/assert"

"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/aws/awserr"
"github.com/aws/aws-sdk-go-v2/internal/awstesting"
Expand All @@ -24,12 +22,21 @@ type mockClient struct {
}
type MockInput struct{}
type MockOutput struct {
States []*MockState
States []MockState
}
type MockState struct {
State *string
StatePtr *string
State StateType
}

type StateType string

const (
StateTypeStopping StateType = "stopping"
StateTypePending StateType = "pending"
StateTypeRunning StateType = "running"
)

func (c *mockClient) MockRequest(input *MockInput) (*aws.Request, *MockOutput) {
op := &aws.Operation{
Name: "Mock",
Expand Down Expand Up @@ -66,21 +73,21 @@ func TestWaiterPathAll(t *testing.T) {
reqNum := 0
resps := []*MockOutput{
{ // Request 1
States: []*MockState{
{State: aws.String("pending")},
{State: aws.String("pending")},
States: []MockState{
{State: StateTypePending},
{State: StateTypePending},
},
},
{ // Request 2
States: []*MockState{
{State: aws.String("running")},
{State: aws.String("pending")},
States: []MockState{
{State: StateTypeRunning},
{State: StateTypePending},
},
},
{ // Request 3
States: []*MockState{
{State: aws.String("running")},
{State: aws.String("running")},
States: []MockState{
{State: StateTypeRunning},
{State: StateTypeRunning},
},
},
}
Expand All @@ -91,7 +98,7 @@ func TestWaiterPathAll(t *testing.T) {
})
svc.Handlers.Unmarshal.PushBack(func(r *aws.Request) {
if reqNum >= len(resps) {
assert.Fail(t, "too many polling requests made")
t.Fatal("too many polling requests made")
return
}
r.Data = resps[reqNum]
Expand All @@ -114,9 +121,15 @@ func TestWaiterPathAll(t *testing.T) {
}

err := w.WaitWithContext(aws.BackgroundContext())
assert.NoError(t, err)
assert.Equal(t, 3, numBuiltReq)
assert.Equal(t, 3, reqNum)
if err != nil {
t.Errorf("expect no error, got %v", err)
}
if e, a := 3, numBuiltReq; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := 3, reqNum; e != a {
t.Errorf("expect %v, got %v", e, a)
}
}

func TestWaiterPath(t *testing.T) {
Expand All @@ -130,21 +143,21 @@ func TestWaiterPath(t *testing.T) {
reqNum := 0
resps := []*MockOutput{
{ // Request 1
States: []*MockState{
{State: aws.String("pending")},
{State: aws.String("pending")},
States: []MockState{
{State: StateTypePending},
{State: StateTypePending},
},
},
{ // Request 2
States: []*MockState{
{State: aws.String("running")},
{State: aws.String("pending")},
States: []MockState{
{State: StateTypeRunning},
{State: StateTypePending},
},
},
{ // Request 3
States: []*MockState{
{State: aws.String("running")},
{State: aws.String("running")},
States: []MockState{
{State: StateTypeRunning},
{State: StateTypeRunning},
},
},
}
Expand All @@ -155,7 +168,7 @@ func TestWaiterPath(t *testing.T) {
})
svc.Handlers.Unmarshal.PushBack(func(r *aws.Request) {
if reqNum >= len(resps) {
assert.Fail(t, "too many polling requests made")
t.Fatalf("too many polling requests made")
return
}
r.Data = resps[reqNum]
Expand All @@ -178,9 +191,15 @@ func TestWaiterPath(t *testing.T) {
}

err := w.WaitWithContext(aws.BackgroundContext())
assert.NoError(t, err)
assert.Equal(t, 3, numBuiltReq)
assert.Equal(t, 3, reqNum)
if err != nil {
t.Errorf("expect no error, got %v", err)
}
if e, a := 3, numBuiltReq; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := 3, reqNum; e != a {
t.Errorf("expect %v, got %v", e, a)
}
}

func TestWaiterFailure(t *testing.T) {
Expand All @@ -194,21 +213,21 @@ func TestWaiterFailure(t *testing.T) {
reqNum := 0
resps := []*MockOutput{
{ // Request 1
States: []*MockState{
{State: aws.String("pending")},
{State: aws.String("pending")},
States: []MockState{
{State: StateTypePending},
{State: StateTypePending},
},
},
{ // Request 2
States: []*MockState{
{State: aws.String("running")},
{State: aws.String("pending")},
States: []MockState{
{State: StateTypeRunning},
{State: StateTypePending},
},
},
{ // Request 3
States: []*MockState{
{State: aws.String("running")},
{State: aws.String("stopping")},
States: []MockState{
{State: StateTypeRunning},
{State: StateTypeStopping},
},
},
}
Expand All @@ -219,7 +238,7 @@ func TestWaiterFailure(t *testing.T) {
})
svc.Handlers.Unmarshal.PushBack(func(r *aws.Request) {
if reqNum >= len(resps) {
assert.Fail(t, "too many polling requests made")
t.Fatalf("too many polling requests made")
return
}
r.Data = resps[reqNum]
Expand Down Expand Up @@ -248,11 +267,21 @@ func TestWaiterFailure(t *testing.T) {
}

err := w.WaitWithContext(aws.BackgroundContext()).(awserr.Error)
assert.Error(t, err)
assert.Equal(t, aws.WaiterResourceNotReadyErrorCode, err.Code())
assert.Equal(t, "failed waiting for successful resource state", err.Message())
assert.Equal(t, 3, numBuiltReq)
assert.Equal(t, 3, reqNum)
if err == nil {
t.Fatalf("expect error, got none")
}
if e, a := aws.WaiterResourceNotReadyErrorCode, err.Code(); e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "failed waiting for successful resource state", err.Message(); e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := 3, numBuiltReq; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := 3, reqNum; e != a {
t.Errorf("expect %v, got %v", e, a)
}
}

func TestWaiterError(t *testing.T) {
Expand All @@ -266,19 +295,19 @@ func TestWaiterError(t *testing.T) {
reqNum := 0
resps := []*MockOutput{
{ // Request 1
States: []*MockState{
{State: aws.String("pending")},
{State: aws.String("pending")},
States: []MockState{
{State: StateTypePending},
{State: StateTypePending},
},
},
{ // Request 1, error case retry
},
{ // Request 2, error case failure
},
{ // Request 3
States: []*MockState{
{State: aws.String("running")},
{State: aws.String("running")},
States: []MockState{
{State: StateTypeRunning},
{State: StateTypeRunning},
},
},
}
Expand All @@ -303,7 +332,7 @@ func TestWaiterError(t *testing.T) {
})
svc.Handlers.Unmarshal.PushBack(func(r *aws.Request) {
if reqNum >= len(resps) {
assert.Fail(t, "too many polling requests made")
t.Fatalf("too many polling requests made")
return
}
r.Data = resps[reqNum]
Expand Down Expand Up @@ -401,8 +430,12 @@ func TestWaiterStatus(t *testing.T) {
}

err := w.WaitWithContext(aws.BackgroundContext())
assert.NoError(t, err)
assert.Equal(t, 3, reqNum)
if err != nil {
t.Errorf("expect no error, got %v", err)
}
if e, a := 3, reqNum; e != a {
t.Errorf("expect %v, got %v", e, a)
}
}

func TestWaiter_ApplyOptions(t *testing.T) {
Expand Down
6 changes: 6 additions & 0 deletions internal/awsutil/equal.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,5 +23,11 @@ func DeepEqual(a, b interface{}) bool {
return false
}

// Special casing for strings as typed enumerations are string aliases
// but are not deep equal.
if ra.Kind() == reflect.String && rb.Kind() == reflect.String {
return ra.String() == rb.String()
}

return reflect.DeepEqual(ra.Interface(), rb.Interface())
}
8 changes: 8 additions & 0 deletions internal/awsutil/equal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ import (
)

func TestDeepEqual(t *testing.T) {
type StringAlias string

cases := []struct {
a, b interface{}
equal bool
Expand All @@ -20,6 +22,12 @@ func TestDeepEqual(t *testing.T) {
{(*bool)(nil), (*bool)(nil), true},
{(*bool)(nil), (*string)(nil), false},
{nil, nil, true},
{StringAlias("abc"), "abc", true},
{StringAlias("abc"), "efg", false},
{StringAlias("abc"), aws.String("abc"), true},
{"abc", StringAlias("abc"), true},
{StringAlias("abc"), StringAlias("abc"), true},
{StringAlias("abc"), StringAlias("efg"), false},
}

for i, c := range cases {
Expand Down