diff --git a/aws/awserr/error.go b/aws/awserr/error.go index 56fdfc2bfc7..99849c0e19c 100644 --- a/aws/awserr/error.go +++ b/aws/awserr/error.go @@ -138,8 +138,27 @@ type RequestFailure interface { RequestID() string } -// NewRequestFailure returns a new request error wrapper for the given Error -// provided. +// NewRequestFailure returns a wrapped error with additional information for +// request status code, and service requestID. +// +// Should be used to wrap all request which involve service requests. Even if +// the request failed without a service response, but had an HTTP status code +// that may be meaningful. func NewRequestFailure(err Error, statusCode int, reqID string) RequestFailure { return newRequestError(err, statusCode, reqID) } + +// UnmarshalError provides the interface for the SDK failing to unmarshal data. +type UnmarshalError interface { + awsError + Bytes() []byte +} + +// NewUnmarshalError returns an initialized UnmarshalError error wrapper adding +// the bytes that fail to unmarshal to the error. +func NewUnmarshalError(err error, msg string, bytes []byte) UnmarshalError { + return &unmarshalError{ + awsError: New("UnmarshalError", msg, err), + bytes: bytes, + } +} diff --git a/aws/awserr/types.go b/aws/awserr/types.go index 0202a008f5d..a2c5817c48f 100644 --- a/aws/awserr/types.go +++ b/aws/awserr/types.go @@ -1,6 +1,9 @@ package awserr -import "fmt" +import ( + "encoding/hex" + "fmt" +) // SprintError returns a string of the formatted error code. // @@ -119,6 +122,7 @@ type requestError struct { awsError statusCode int requestID string + bytes []byte } // newRequestError returns a wrapped error with additional information for @@ -170,6 +174,29 @@ func (r requestError) OrigErrs() []error { return []error{r.OrigErr()} } +type unmarshalError struct { + awsError + bytes []byte +} + +// Error returns the string representation of the error. +// Satisfies the error interface. +func (e unmarshalError) Error() string { + extra := hex.Dump(e.bytes) + return SprintError(e.Code(), e.Message(), extra, e.OrigErr()) +} + +// String returns the string representation of the error. +// Alias for Error to satisfy the stringer interface. +func (e unmarshalError) String() string { + return e.Error() +} + +// Bytes returns the bytes that failed to unmarshal. +func (e unmarshalError) Bytes() []byte { + return e.bytes +} + // An error list that satisfies the golang interface type errorList []error diff --git a/aws/credentials/ec2rolecreds/ec2_role_provider.go b/aws/credentials/ec2rolecreds/ec2_role_provider.go index 0ed791be641..43d4ed386ab 100644 --- a/aws/credentials/ec2rolecreds/ec2_role_provider.go +++ b/aws/credentials/ec2rolecreds/ec2_role_provider.go @@ -11,6 +11,7 @@ import ( "github.com/aws/aws-sdk-go/aws/client" "github.com/aws/aws-sdk-go/aws/credentials" "github.com/aws/aws-sdk-go/aws/ec2metadata" + "github.com/aws/aws-sdk-go/aws/request" "github.com/aws/aws-sdk-go/internal/sdkuri" ) @@ -142,7 +143,8 @@ func requestCredList(client *ec2metadata.EC2Metadata) ([]string, error) { } if err := s.Err(); err != nil { - return nil, awserr.New("SerializationError", "failed to read EC2 instance role from metadata service", err) + return nil, awserr.New(request.ErrCodeSerialization, + "failed to read EC2 instance role from metadata service", err) } return credsList, nil @@ -164,7 +166,7 @@ func requestCred(client *ec2metadata.EC2Metadata, credsName string) (ec2RoleCred respCreds := ec2RoleCredRespBody{} if err := json.NewDecoder(strings.NewReader(resp)).Decode(&respCreds); err != nil { return ec2RoleCredRespBody{}, - awserr.New("SerializationError", + awserr.New(request.ErrCodeSerialization, fmt.Sprintf("failed to decode %s EC2 instance role credentials", credsName), err) } diff --git a/aws/credentials/endpointcreds/provider.go b/aws/credentials/endpointcreds/provider.go index ace51313820..c2b2c5d65c3 100644 --- a/aws/credentials/endpointcreds/provider.go +++ b/aws/credentials/endpointcreds/provider.go @@ -39,6 +39,7 @@ import ( "github.com/aws/aws-sdk-go/aws/client/metadata" "github.com/aws/aws-sdk-go/aws/credentials" "github.com/aws/aws-sdk-go/aws/request" + "github.com/aws/aws-sdk-go/private/protocol/json/jsonutil" ) // ProviderName is the name of the credentials provider. @@ -174,7 +175,7 @@ func unmarshalHandler(r *request.Request) { out := r.Data.(*getCredentialsOutput) if err := json.NewDecoder(r.HTTPResponse.Body).Decode(&out); err != nil { - r.Error = awserr.New("SerializationError", + r.Error = awserr.New(request.ErrCodeSerialization, "failed to decode endpoint credentials", err, ) @@ -185,11 +186,15 @@ func unmarshalError(r *request.Request) { defer r.HTTPResponse.Body.Close() var errOut errorOutput - if err := json.NewDecoder(r.HTTPResponse.Body).Decode(&errOut); err != nil { - r.Error = awserr.New("SerializationError", - "failed to decode endpoint credentials", - err, + err := jsonutil.UnmarshalJSONError(&errOut, r.HTTPResponse.Body) + if err != nil { + r.Error = awserr.NewRequestFailure( + awserr.New(request.ErrCodeSerialization, + "failed to decode error message", err), + r.HTTPResponse.StatusCode, + r.RequestID, ) + return } // Response body format is not consistent between metadata endpoints. diff --git a/aws/csm/reporter.go b/aws/csm/reporter.go index 0b5571acfbf..d9aa5b062a4 100644 --- a/aws/csm/reporter.go +++ b/aws/csm/reporter.go @@ -96,7 +96,7 @@ func getMetricException(err awserr.Error) metricException { switch code { case "RequestError", - "SerializationError", + request.ErrCodeSerialization, request.CanceledErrorCode: return sdkException{ requestException{exception: code, message: msg}, diff --git a/aws/ec2metadata/api.go b/aws/ec2metadata/api.go index d57a1af5992..2c8d5f56d0e 100644 --- a/aws/ec2metadata/api.go +++ b/aws/ec2metadata/api.go @@ -82,7 +82,7 @@ func (c *EC2Metadata) GetInstanceIdentityDocument() (EC2InstanceIdentityDocument doc := EC2InstanceIdentityDocument{} if err := json.NewDecoder(strings.NewReader(resp)).Decode(&doc); err != nil { return EC2InstanceIdentityDocument{}, - awserr.New("SerializationError", + awserr.New(request.ErrCodeSerialization, "failed to decode EC2 instance identity document", err) } @@ -101,7 +101,7 @@ func (c *EC2Metadata) IAMInfo() (EC2IAMInfo, error) { info := EC2IAMInfo{} if err := json.NewDecoder(strings.NewReader(resp)).Decode(&info); err != nil { return EC2IAMInfo{}, - awserr.New("SerializationError", + awserr.New(request.ErrCodeSerialization, "failed to decode EC2 IAM info", err) } diff --git a/aws/ec2metadata/service.go b/aws/ec2metadata/service.go index f4438eae9c9..f0c1d31e756 100644 --- a/aws/ec2metadata/service.go +++ b/aws/ec2metadata/service.go @@ -123,7 +123,7 @@ func unmarshalHandler(r *request.Request) { defer r.HTTPResponse.Body.Close() b := &bytes.Buffer{} if _, err := io.Copy(b, r.HTTPResponse.Body); err != nil { - r.Error = awserr.New("SerializationError", "unable to unmarshal EC2 metadata respose", err) + r.Error = awserr.New(request.ErrCodeSerialization, "unable to unmarshal EC2 metadata respose", err) return } @@ -136,7 +136,7 @@ func unmarshalError(r *request.Request) { defer r.HTTPResponse.Body.Close() b := &bytes.Buffer{} if _, err := io.Copy(b, r.HTTPResponse.Body); err != nil { - r.Error = awserr.New("SerializationError", "unable to unmarshal EC2 metadata error respose", err) + r.Error = awserr.New(request.ErrCodeSerialization, "unable to unmarshal EC2 metadata error respose", err) return } diff --git a/aws/request/request_1_11_test.go b/aws/request/request_1_11_test.go index af6e910b1fc..e43370f8367 100644 --- a/aws/request/request_1_11_test.go +++ b/aws/request/request_1_11_test.go @@ -71,9 +71,9 @@ func TestSerializationErrConnectionReset_read(t *testing.T) { req.ApplyOptions(request.WithResponseReadTimeout(time.Second)) err := req.Send() if err == nil { - t.Error("Expected rror 'SerializationError', but received nil") + t.Error("Expected error 'SerializationError', but received nil") } - if aerr, ok := err.(awserr.Error); ok && aerr.Code() != "SerializationError" { + if aerr, ok := err.(awserr.Error); ok && aerr.Code() != request.ErrCodeSerialization { t.Errorf("Expected 'SerializationError', but received %q", aerr.Code()) } else if !ok { t.Errorf("Expected 'awserr.Error', but received %v", reflect.TypeOf(err)) diff --git a/aws/request/request_test.go b/aws/request/request_test.go index e4f125a8595..09c0fb3ae23 100644 --- a/aws/request/request_test.go +++ b/aws/request/request_test.go @@ -675,9 +675,9 @@ func TestSerializationErrConnectionReset_accept(t *testing.T) { req.ApplyOptions(request.WithResponseReadTimeout(time.Second)) err := req.Send() if err == nil { - t.Error("Expected rror 'SerializationError', but received nil") + t.Error("Expected error 'SerializationError', but received nil") } - if aerr, ok := err.(awserr.Error); ok && aerr.Code() != "SerializationError" { + if aerr, ok := err.(awserr.Error); ok && aerr.Code() != request.ErrCodeSerialization { t.Errorf("Expected 'SerializationError', but received %q", aerr.Code()) } else if !ok { t.Errorf("Expected 'awserr.Error', but received %v", reflect.TypeOf(err)) diff --git a/example/service/dynamodb/transactWriteItems/error_handler.go b/example/service/dynamodb/transactWriteItems/error_handler.go index 108ad723591..1ac48fd3a01 100644 --- a/example/service/dynamodb/transactWriteItems/error_handler.go +++ b/example/service/dynamodb/transactWriteItems/error_handler.go @@ -45,14 +45,15 @@ func TxAwareUnmarshalError(req *request.Request) { err := json.NewDecoder(req.HTTPResponse.Body).Decode(&jsonErr) if err == io.EOF { req.Error = awserr.NewRequestFailure( - awserr.New("SerializationError", req.HTTPResponse.Status, nil), + awserr.New(request.ErrCodeSerialization, req.HTTPResponse.Status, nil), req.HTTPResponse.StatusCode, req.RequestID, ) return } else if err != nil { req.Error = awserr.NewRequestFailure( - awserr.New("SerializationError", "failed decoding JSON RPC error response", err), + awserr.New(request.ErrCodeSerialization, + "failed decoding JSON RPC error response", err), req.HTTPResponse.StatusCode, req.RequestID, ) diff --git a/private/protocol/ec2query/build.go b/private/protocol/ec2query/build.go index 3104e6ce4c9..50c5ed76005 100644 --- a/private/protocol/ec2query/build.go +++ b/private/protocol/ec2query/build.go @@ -21,7 +21,8 @@ func Build(r *request.Request) { "Version": {r.ClientInfo.APIVersion}, } if err := queryutil.Parse(body, r.Params, true); err != nil { - r.Error = awserr.New("SerializationError", "failed encoding EC2 Query request", err) + r.Error = awserr.New(request.ErrCodeSerialization, + "failed encoding EC2 Query request", err) } if !r.IsPresigned() { diff --git a/private/protocol/ec2query/unmarshal.go b/private/protocol/ec2query/unmarshal.go index 5793c047373..105d732f9d3 100644 --- a/private/protocol/ec2query/unmarshal.go +++ b/private/protocol/ec2query/unmarshal.go @@ -4,7 +4,6 @@ package ec2query import ( "encoding/xml" - "io" "github.com/aws/aws-sdk-go/aws/awserr" "github.com/aws/aws-sdk-go/aws/request" @@ -28,7 +27,8 @@ func Unmarshal(r *request.Request) { err := xmlutil.UnmarshalXML(r.Data, decoder, "") if err != nil { r.Error = awserr.NewRequestFailure( - awserr.New("SerializationError", "failed decoding EC2 Query response", err), + awserr.New(request.ErrCodeSerialization, + "failed decoding EC2 Query response", err), r.HTTPResponse.StatusCode, r.RequestID, ) @@ -39,7 +39,11 @@ func Unmarshal(r *request.Request) { // UnmarshalMeta unmarshals response headers for the EC2 protocol. func UnmarshalMeta(r *request.Request) { - // TODO implement unmarshaling of request IDs + r.RequestID = r.HTTPResponse.Header.Get("X-Amzn-Requestid") + if r.RequestID == "" { + // Alternative version of request id in the header + r.RequestID = r.HTTPResponse.Header.Get("X-Amz-Request-Id") + } } type xmlErrorResponse struct { @@ -53,19 +57,21 @@ type xmlErrorResponse struct { func UnmarshalError(r *request.Request) { defer r.HTTPResponse.Body.Close() - resp := &xmlErrorResponse{} - err := xml.NewDecoder(r.HTTPResponse.Body).Decode(resp) - if err != nil && err != io.EOF { + var respErr xmlErrorResponse + err := xmlutil.UnmarshalXMLError(&respErr, r.HTTPResponse.Body) + if err != nil { r.Error = awserr.NewRequestFailure( - awserr.New("SerializationError", "failed decoding EC2 Query error response", err), + awserr.New(request.ErrCodeSerialization, + "failed to unmarshal error message", err), r.HTTPResponse.StatusCode, r.RequestID, ) - } else { - r.Error = awserr.NewRequestFailure( - awserr.New(resp.Code, resp.Message, nil), - r.HTTPResponse.StatusCode, - resp.RequestID, - ) + return } + + r.Error = awserr.NewRequestFailure( + awserr.New(respErr.Code, respErr.Message, nil), + r.HTTPResponse.StatusCode, + respErr.RequestID, + ) } diff --git a/private/protocol/ec2query/unmarshal_error_test.go b/private/protocol/ec2query/unmarshal_error_test.go new file mode 100644 index 00000000000..4dd0a7d01ab --- /dev/null +++ b/private/protocol/ec2query/unmarshal_error_test.go @@ -0,0 +1,82 @@ +// +build go1.8 + +package ec2query + +import ( + "io/ioutil" + "net/http" + "strings" + "testing" + + "github.com/aws/aws-sdk-go/aws/awserr" + "github.com/aws/aws-sdk-go/aws/request" +) + +func TestUnmarshalError(t *testing.T) { + cases := map[string]struct { + Request *request.Request + Code, Msg string + ReqID string + Status int + }{ + "ErrorResponse": { + Request: &request.Request{ + HTTPResponse: &http.Response{ + StatusCode: 400, + Header: http.Header{}, + Body: ioutil.NopCloser(strings.NewReader( + ` + + + codeAbc + msg123 + + + reqID123 + `)), + }, + }, + Code: "codeAbc", Msg: "msg123", + Status: 400, ReqID: "reqID123", + }, + "unknown tag": { + Request: &request.Request{ + HTTPResponse: &http.Response{ + StatusCode: 400, + Header: http.Header{}, + Body: ioutil.NopCloser(strings.NewReader( + ` + . + `)), + }, + }, + Code: request.ErrCodeSerialization, + Msg: "failed to unmarshal error message", + Status: 400, + }, + } + + for name, c := range cases { + t.Run(name, func(t *testing.T) { + r := c.Request + UnmarshalError(r) + if r.Error == nil { + t.Fatalf("expect error, got none") + } + + aerr := r.Error.(awserr.RequestFailure) + if e, a := c.Code, aerr.Code(); e != a { + t.Errorf("expect %v code, got %v", e, a) + } + if e, a := c.Msg, aerr.Message(); e != a { + t.Errorf("expect %q message, got %q", e, a) + } + if e, a := c.ReqID, aerr.RequestID(); e != a { + t.Errorf("expect %v request ID, got %v", e, a) + } + if e, a := c.Status, aerr.StatusCode(); e != a { + t.Errorf("expect %v status code, got %v", e, a) + } + }) + } +} diff --git a/private/protocol/json/jsonutil/unmarshal.go b/private/protocol/json/jsonutil/unmarshal.go index b11f3ee45b5..ea0da79a5e0 100644 --- a/private/protocol/json/jsonutil/unmarshal.go +++ b/private/protocol/json/jsonutil/unmarshal.go @@ -1,6 +1,7 @@ package jsonutil import ( + "bytes" "encoding/base64" "encoding/json" "fmt" @@ -9,9 +10,30 @@ import ( "time" "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/awserr" "github.com/aws/aws-sdk-go/private/protocol" ) +// UnmarshalJSONError unmarshal's the reader's JSON document into the passed in +// type. The value to unmarshal the json document into must be a pointer to the +// type. +func UnmarshalJSONError(v interface{}, stream io.Reader) error { + var errBuf bytes.Buffer + body := io.TeeReader(stream, &errBuf) + + err := json.NewDecoder(body).Decode(v) + if err != nil { + msg := "failed decoding error message" + if err == io.EOF { + msg = "error message missing" + err = nil + } + return awserr.NewUnmarshalError(err, msg, errBuf.Bytes()) + } + + return nil +} + // UnmarshalJSON reads a stream and unmarshals the results in object v. func UnmarshalJSON(v interface{}, stream io.Reader) error { var out interface{} diff --git a/private/protocol/jsonrpc/jsonrpc.go b/private/protocol/jsonrpc/jsonrpc.go index 36ceab088c0..bfedc9fd422 100644 --- a/private/protocol/jsonrpc/jsonrpc.go +++ b/private/protocol/jsonrpc/jsonrpc.go @@ -6,8 +6,6 @@ package jsonrpc //go:generate go run -tags codegen ../../../models/protocol_tests/generate.go ../../../models/protocol_tests/output/json.json unmarshal_test.go import ( - "encoding/json" - "io" "strings" "github.com/aws/aws-sdk-go/aws/awserr" @@ -37,7 +35,7 @@ func Build(req *request.Request) { if req.ParamsFilled() { buf, err = jsonutil.BuildJSON(req.Params) if err != nil { - req.Error = awserr.New("SerializationError", "failed encoding JSON RPC request", err) + req.Error = awserr.New(request.ErrCodeSerialization, "failed encoding JSON RPC request", err) return } } else { @@ -68,7 +66,7 @@ func Unmarshal(req *request.Request) { err := jsonutil.UnmarshalJSON(req.Data, req.HTTPResponse.Body) if err != nil { req.Error = awserr.NewRequestFailure( - awserr.New("SerializationError", "failed decoding JSON RPC response", err), + awserr.New(request.ErrCodeSerialization, "failed decoding JSON RPC response", err), req.HTTPResponse.StatusCode, req.RequestID, ) @@ -87,17 +85,11 @@ func UnmarshalError(req *request.Request) { defer req.HTTPResponse.Body.Close() var jsonErr jsonErrorResponse - err := json.NewDecoder(req.HTTPResponse.Body).Decode(&jsonErr) - if err == io.EOF { + err := jsonutil.UnmarshalJSONError(&jsonErr, req.HTTPResponse.Body) + if err != nil { req.Error = awserr.NewRequestFailure( - awserr.New("SerializationError", req.HTTPResponse.Status, nil), - req.HTTPResponse.StatusCode, - req.RequestID, - ) - return - } else if err != nil { - req.Error = awserr.NewRequestFailure( - awserr.New("SerializationError", "failed decoding JSON RPC error response", err), + awserr.New(request.ErrCodeSerialization, + "failed to unmarshal error message", err), req.HTTPResponse.StatusCode, req.RequestID, ) diff --git a/private/protocol/jsonrpc/unmarshal_err_test.go b/private/protocol/jsonrpc/unmarshal_err_test.go new file mode 100644 index 00000000000..a12a4107e08 --- /dev/null +++ b/private/protocol/jsonrpc/unmarshal_err_test.go @@ -0,0 +1,79 @@ +// +build go1.8 + +package jsonrpc + +import ( + "bytes" + "encoding/hex" + "io/ioutil" + "net/http" + "strings" + "testing" + + "github.com/aws/aws-sdk-go/aws/awserr" + "github.com/aws/aws-sdk-go/aws/request" +) + +func TestUnmarshalError_SerializationError(t *testing.T) { + cases := map[string]struct { + Request *request.Request + ExpectMsg string + ExpectBytes []byte + }{ + "empty body": { + Request: &request.Request{ + Data: &struct{}{}, + HTTPResponse: &http.Response{ + StatusCode: 400, + Header: http.Header{ + "X-Amzn-Requestid": []string{"abc123"}, + }, + Body: ioutil.NopCloser( + bytes.NewReader([]byte{}), + ), + }, + }, + ExpectMsg: "error message missing", + }, + "HTML body": { + Request: &request.Request{ + Data: &struct{}{}, + HTTPResponse: &http.Response{ + StatusCode: 400, + Header: http.Header{ + "X-Amzn-Requestid": []string{"abc123"}, + }, + Body: ioutil.NopCloser( + bytes.NewReader([]byte(``)), + ), + }, + }, + ExpectBytes: []byte(``), + ExpectMsg: "failed decoding", + }, + } + + for name, c := range cases { + t.Run(name, func(t *testing.T) { + req := c.Request + + UnmarshalError(req) + if req.Error == nil { + t.Fatal("expect error, got none") + } + + aerr := req.Error.(awserr.RequestFailure) + if e, a := request.ErrCodeSerialization, aerr.Code(); e != a { + t.Errorf("expect %v, got %v", e, a) + } + + uerr := aerr.OrigErr().(awserr.UnmarshalError) + if e, a := c.ExpectMsg, uerr.Message(); !strings.Contains(a, e) { + t.Errorf("Expect %q, in %q", e, a) + } + if e, a := c.ExpectBytes, uerr.Bytes(); !bytes.Equal(e, a) { + t.Errorf("expect:\n%v\nactual:\n%v", hex.Dump(e), hex.Dump(a)) + } + }) + } +} diff --git a/private/protocol/query/build.go b/private/protocol/query/build.go index 60e5b09d548..0cb99eb5796 100644 --- a/private/protocol/query/build.go +++ b/private/protocol/query/build.go @@ -21,7 +21,7 @@ func Build(r *request.Request) { "Version": {r.ClientInfo.APIVersion}, } if err := queryutil.Parse(body, r.Params, false); err != nil { - r.Error = awserr.New("SerializationError", "failed encoding Query request", err) + r.Error = awserr.New(request.ErrCodeSerialization, "failed encoding Query request", err) return } diff --git a/private/protocol/query/unmarshal.go b/private/protocol/query/unmarshal.go index 3495c73070b..f69c1efc93a 100644 --- a/private/protocol/query/unmarshal.go +++ b/private/protocol/query/unmarshal.go @@ -24,7 +24,7 @@ func Unmarshal(r *request.Request) { err := xmlutil.UnmarshalXML(r.Data, decoder, r.Operation.Name+"Result") if err != nil { r.Error = awserr.NewRequestFailure( - awserr.New("SerializationError", "failed decoding Query response", err), + awserr.New(request.ErrCodeSerialization, "failed decoding Query response", err), r.HTTPResponse.StatusCode, r.RequestID, ) diff --git a/private/protocol/query/unmarshal_error.go b/private/protocol/query/unmarshal_error.go index 46d354e826f..831b0110c54 100644 --- a/private/protocol/query/unmarshal_error.go +++ b/private/protocol/query/unmarshal_error.go @@ -2,73 +2,68 @@ package query import ( "encoding/xml" - "io/ioutil" + "fmt" "github.com/aws/aws-sdk-go/aws/awserr" "github.com/aws/aws-sdk-go/aws/request" + "github.com/aws/aws-sdk-go/private/protocol/xml/xmlutil" ) +// UnmarshalErrorHandler is a name request handler to unmarshal request errors +var UnmarshalErrorHandler = request.NamedHandler{Name: "awssdk.query.UnmarshalError", Fn: UnmarshalError} + type xmlErrorResponse struct { - XMLName xml.Name `xml:"ErrorResponse"` - Code string `xml:"Error>Code"` - Message string `xml:"Error>Message"` - RequestID string `xml:"RequestId"` + Code string `xml:"Error>Code"` + Message string `xml:"Error>Message"` + RequestID string `xml:"RequestId"` } -type xmlServiceUnavailableResponse struct { - XMLName xml.Name `xml:"ServiceUnavailableException"` +type xmlResponseError struct { + xmlErrorResponse } -// UnmarshalErrorHandler is a name request handler to unmarshal request errors -var UnmarshalErrorHandler = request.NamedHandler{Name: "awssdk.query.UnmarshalError", Fn: UnmarshalError} +func (e *xmlResponseError) UnmarshalXML(d *xml.Decoder, start xml.StartElement) error { + const svcUnavailableTagName = "ServiceUnavailableException" + const errorResponseTagName = "ErrorResponse" + + switch start.Name.Local { + case svcUnavailableTagName: + e.Code = svcUnavailableTagName + e.Message = "service is unavailable" + return d.Skip() + + case errorResponseTagName: + return d.DecodeElement(&e.xmlErrorResponse, &start) + + default: + return fmt.Errorf("unknown error response tag, %v", start) + } +} // UnmarshalError unmarshals an error response for an AWS Query service. func UnmarshalError(r *request.Request) { defer r.HTTPResponse.Body.Close() - bodyBytes, err := ioutil.ReadAll(r.HTTPResponse.Body) + var respErr xmlResponseError + err := xmlutil.UnmarshalXMLError(&respErr, r.HTTPResponse.Body) if err != nil { r.Error = awserr.NewRequestFailure( - awserr.New("SerializationError", "failed to read from query HTTP response body", err), + awserr.New(request.ErrCodeSerialization, + "failed to unmarshal error message", err), r.HTTPResponse.StatusCode, r.RequestID, ) return } - // First check for specific error - resp := xmlErrorResponse{} - decodeErr := xml.Unmarshal(bodyBytes, &resp) - if decodeErr == nil { - reqID := resp.RequestID - if reqID == "" { - reqID = r.RequestID - } - r.Error = awserr.NewRequestFailure( - awserr.New(resp.Code, resp.Message, nil), - r.HTTPResponse.StatusCode, - reqID, - ) - return - } - - // Check for unhandled error - servUnavailResp := xmlServiceUnavailableResponse{} - unavailErr := xml.Unmarshal(bodyBytes, &servUnavailResp) - if unavailErr == nil { - r.Error = awserr.NewRequestFailure( - awserr.New("ServiceUnavailableException", "service is unavailable", nil), - r.HTTPResponse.StatusCode, - r.RequestID, - ) - return + reqID := respErr.RequestID + if len(reqID) == 0 { + reqID = r.RequestID } - // Failed to retrieve any error message from the response body r.Error = awserr.NewRequestFailure( - awserr.New("SerializationError", - "failed to decode query XML error response", decodeErr), + awserr.New(respErr.Code, respErr.Message, nil), r.HTTPResponse.StatusCode, - r.RequestID, + reqID, ) } diff --git a/private/protocol/query/unmarshal_error_test.go b/private/protocol/query/unmarshal_error_test.go new file mode 100644 index 00000000000..5242a0cfd6f --- /dev/null +++ b/private/protocol/query/unmarshal_error_test.go @@ -0,0 +1,94 @@ +// +build go1.8 + +package query + +import ( + "io/ioutil" + "net/http" + "strings" + "testing" + + "github.com/aws/aws-sdk-go/aws/awserr" + "github.com/aws/aws-sdk-go/aws/request" +) + +func TestUnmarshalError(t *testing.T) { + cases := map[string]struct { + Request *request.Request + Code, Msg string + ReqID string + Status int + }{ + "ErrorResponse": { + Request: &request.Request{ + HTTPResponse: &http.Response{ + StatusCode: 400, + Header: http.Header{}, + Body: ioutil.NopCloser(strings.NewReader( + ` + + codeAbcmsg123 + + reqID123 + `)), + }, + }, + Code: "codeAbc", Msg: "msg123", + Status: 400, ReqID: "reqID123", + }, + "ServiceUnavailableException": { + Request: &request.Request{ + HTTPResponse: &http.Response{ + StatusCode: 502, + Header: http.Header{}, + Body: ioutil.NopCloser(strings.NewReader( + ` + else + `)), + }, + }, + Code: "ServiceUnavailableException", + Msg: "service is unavailable", + Status: 502, + }, + "unknown tag": { + Request: &request.Request{ + HTTPResponse: &http.Response{ + StatusCode: 400, + Header: http.Header{}, + Body: ioutil.NopCloser(strings.NewReader( + ` + . + `)), + }, + }, + Code: request.ErrCodeSerialization, + Msg: "failed to unmarshal error message", + Status: 400, + }, + } + + for name, c := range cases { + t.Run(name, func(t *testing.T) { + r := c.Request + UnmarshalError(r) + if r.Error == nil { + t.Fatalf("expect error, got none") + } + + aerr := r.Error.(awserr.RequestFailure) + if e, a := c.Code, aerr.Code(); e != a { + t.Errorf("expect %v code, got %v", e, a) + } + if e, a := c.Msg, aerr.Message(); e != a { + t.Errorf("expect %q message, got %q", e, a) + } + if e, a := c.ReqID, aerr.RequestID(); e != a { + t.Errorf("expect %v request ID, got %v", e, a) + } + if e, a := c.Status, aerr.StatusCode(); e != a { + t.Errorf("expect %v status code, got %v", e, a) + } + }) + } +} diff --git a/private/protocol/rest/build.go b/private/protocol/rest/build.go index b80f84fbb86..8460a26a139 100644 --- a/private/protocol/rest/build.go +++ b/private/protocol/rest/build.go @@ -137,7 +137,7 @@ func buildBody(r *request.Request, v reflect.Value) { case string: r.SetStringBody(reader) default: - r.Error = awserr.New("SerializationError", + r.Error = awserr.New(request.ErrCodeSerialization, "failed to encode REST request", fmt.Errorf("unknown payload type %s", payload.Type())) } @@ -152,7 +152,7 @@ func buildHeader(header *http.Header, v reflect.Value, name string, tag reflect. if err == errValueNotSet { return nil } else if err != nil { - return awserr.New("SerializationError", "failed to encode REST request", err) + return awserr.New(request.ErrCodeSerialization, "failed to encode REST request", err) } name = strings.TrimSpace(name) @@ -170,7 +170,7 @@ func buildHeaderMap(header *http.Header, v reflect.Value, tag reflect.StructTag) if err == errValueNotSet { continue } else if err != nil { - return awserr.New("SerializationError", "failed to encode REST request", err) + return awserr.New(request.ErrCodeSerialization, "failed to encode REST request", err) } keyStr := strings.TrimSpace(key.String()) @@ -186,7 +186,7 @@ func buildURI(u *url.URL, v reflect.Value, name string, tag reflect.StructTag) e if err == errValueNotSet { return nil } else if err != nil { - return awserr.New("SerializationError", "failed to encode REST request", err) + return awserr.New(request.ErrCodeSerialization, "failed to encode REST request", err) } u.Path = strings.Replace(u.Path, "{"+name+"}", value, -1) @@ -219,7 +219,7 @@ func buildQueryString(query url.Values, v reflect.Value, name string, tag reflec if err == errValueNotSet { return nil } else if err != nil { - return awserr.New("SerializationError", "failed to encode REST request", err) + return awserr.New(request.ErrCodeSerialization, "failed to encode REST request", err) } query.Set(name, str) } diff --git a/private/protocol/rest/unmarshal.go b/private/protocol/rest/unmarshal.go index 33fd53b126a..de021367da2 100644 --- a/private/protocol/rest/unmarshal.go +++ b/private/protocol/rest/unmarshal.go @@ -57,7 +57,7 @@ func unmarshalBody(r *request.Request, v reflect.Value) { defer r.HTTPResponse.Body.Close() b, err := ioutil.ReadAll(r.HTTPResponse.Body) if err != nil { - r.Error = awserr.New("SerializationError", "failed to decode REST response", err) + r.Error = awserr.New(request.ErrCodeSerialization, "failed to decode REST response", err) } else { payload.Set(reflect.ValueOf(b)) } @@ -65,7 +65,7 @@ func unmarshalBody(r *request.Request, v reflect.Value) { defer r.HTTPResponse.Body.Close() b, err := ioutil.ReadAll(r.HTTPResponse.Body) if err != nil { - r.Error = awserr.New("SerializationError", "failed to decode REST response", err) + r.Error = awserr.New(request.ErrCodeSerialization, "failed to decode REST response", err) } else { str := string(b) payload.Set(reflect.ValueOf(&str)) @@ -77,7 +77,7 @@ func unmarshalBody(r *request.Request, v reflect.Value) { case "io.ReadSeeker": b, err := ioutil.ReadAll(r.HTTPResponse.Body) if err != nil { - r.Error = awserr.New("SerializationError", + r.Error = awserr.New(request.ErrCodeSerialization, "failed to read response body", err) return } @@ -85,7 +85,7 @@ func unmarshalBody(r *request.Request, v reflect.Value) { default: io.Copy(ioutil.Discard, r.HTTPResponse.Body) defer r.HTTPResponse.Body.Close() - r.Error = awserr.New("SerializationError", + r.Error = awserr.New(request.ErrCodeSerialization, "failed to decode REST response", fmt.Errorf("unknown payload type %s", payload.Type())) } @@ -115,14 +115,14 @@ func unmarshalLocationElements(r *request.Request, v reflect.Value) { case "header": err := unmarshalHeader(m, r.HTTPResponse.Header.Get(name), field.Tag) if err != nil { - r.Error = awserr.New("SerializationError", "failed to decode REST response", err) + r.Error = awserr.New(request.ErrCodeSerialization, "failed to decode REST response", err) break } case "headers": prefix := field.Tag.Get("locationName") err := unmarshalHeaderMap(m, r.HTTPResponse.Header, prefix) if err != nil { - r.Error = awserr.New("SerializationError", "failed to decode REST response", err) + r.Error = awserr.New(request.ErrCodeSerialization, "failed to decode REST response", err) break } } diff --git a/private/protocol/restjson/restjson.go b/private/protocol/restjson/restjson.go index 8e88f3042aa..af4f6154d70 100644 --- a/private/protocol/restjson/restjson.go +++ b/private/protocol/restjson/restjson.go @@ -6,12 +6,11 @@ package restjson //go:generate go run -tags codegen ../../../models/protocol_tests/generate.go ../../../models/protocol_tests/output/rest-json.json unmarshal_test.go import ( - "encoding/json" - "io" "strings" "github.com/aws/aws-sdk-go/aws/awserr" "github.com/aws/aws-sdk-go/aws/request" + "github.com/aws/aws-sdk-go/private/protocol/json/jsonutil" "github.com/aws/aws-sdk-go/private/protocol/jsonrpc" "github.com/aws/aws-sdk-go/private/protocol/rest" ) @@ -59,17 +58,11 @@ func UnmarshalError(r *request.Request) { defer r.HTTPResponse.Body.Close() var jsonErr jsonErrorResponse - err := json.NewDecoder(r.HTTPResponse.Body).Decode(&jsonErr) - if err == io.EOF { + err := jsonutil.UnmarshalJSONError(&jsonErr, r.HTTPResponse.Body) + if err != nil { r.Error = awserr.NewRequestFailure( - awserr.New("SerializationError", r.HTTPResponse.Status, nil), - r.HTTPResponse.StatusCode, - r.RequestID, - ) - return - } else if err != nil { - r.Error = awserr.NewRequestFailure( - awserr.New("SerializationError", "failed decoding REST JSON error response", err), + awserr.New(request.ErrCodeSerialization, + "failed to unmarshal response error", err), r.HTTPResponse.StatusCode, r.RequestID, ) diff --git a/private/protocol/restjson/unmarshal_error_test.go b/private/protocol/restjson/unmarshal_error_test.go new file mode 100644 index 00000000000..d5125707f40 --- /dev/null +++ b/private/protocol/restjson/unmarshal_error_test.go @@ -0,0 +1,79 @@ +// +build go1.8 + +package restjson + +import ( + "bytes" + "encoding/hex" + "io/ioutil" + "net/http" + "strings" + "testing" + + "github.com/aws/aws-sdk-go/aws/awserr" + "github.com/aws/aws-sdk-go/aws/request" +) + +func TestUnmarshalError_SerializationError(t *testing.T) { + cases := map[string]struct { + Request *request.Request + ExpectMsg string + ExpectBytes []byte + }{ + "empty body": { + Request: &request.Request{ + Data: &struct{}{}, + HTTPResponse: &http.Response{ + StatusCode: 400, + Header: http.Header{ + "X-Amzn-Requestid": []string{"abc123"}, + }, + Body: ioutil.NopCloser( + bytes.NewReader([]byte{}), + ), + }, + }, + ExpectMsg: "error message missing", + }, + "HTML body": { + Request: &request.Request{ + Data: &struct{}{}, + HTTPResponse: &http.Response{ + StatusCode: 400, + Header: http.Header{ + "X-Amzn-Requestid": []string{"abc123"}, + }, + Body: ioutil.NopCloser( + bytes.NewReader([]byte(``)), + ), + }, + }, + ExpectBytes: []byte(``), + ExpectMsg: "failed decoding", + }, + } + + for name, c := range cases { + t.Run(name, func(t *testing.T) { + req := c.Request + + UnmarshalError(req) + if req.Error == nil { + t.Fatal("expect error, got none") + } + + aerr := req.Error.(awserr.RequestFailure) + if e, a := request.ErrCodeSerialization, aerr.Code(); e != a { + t.Errorf("expect %v, got %v", e, a) + } + + uerr := aerr.OrigErr().(awserr.UnmarshalError) + if e, a := c.ExpectMsg, uerr.Message(); !strings.Contains(a, e) { + t.Errorf("Expect %q, in %q", e, a) + } + if e, a := c.ExpectBytes, uerr.Bytes(); !bytes.Equal(e, a) { + t.Errorf("expect:\n%v\nactual:\n%v", hex.Dump(e), hex.Dump(a)) + } + }) + } +} diff --git a/private/protocol/restxml/restxml.go b/private/protocol/restxml/restxml.go index b0f4e245661..cf569645dc2 100644 --- a/private/protocol/restxml/restxml.go +++ b/private/protocol/restxml/restxml.go @@ -37,7 +37,8 @@ func Build(r *request.Request) { err := xmlutil.BuildXML(r.Params, xml.NewEncoder(&buf)) if err != nil { r.Error = awserr.NewRequestFailure( - awserr.New("SerializationError", "failed to encode rest XML request", err), + awserr.New(request.ErrCodeSerialization, + "failed to encode rest XML request", err), r.HTTPResponse.StatusCode, r.RequestID, ) @@ -55,7 +56,8 @@ func Unmarshal(r *request.Request) { err := xmlutil.UnmarshalXML(r.Data, decoder, "") if err != nil { r.Error = awserr.NewRequestFailure( - awserr.New("SerializationError", "failed to decode REST XML response", err), + awserr.New(request.ErrCodeSerialization, + "failed to decode REST XML response", err), r.HTTPResponse.StatusCode, r.RequestID, ) diff --git a/private/protocol/unmarshal_test.go b/private/protocol/unmarshal_test.go index f278f0eab14..bd3e68f06bb 100644 --- a/private/protocol/unmarshal_test.go +++ b/private/protocol/unmarshal_test.go @@ -76,7 +76,7 @@ func TestUnmarshalSeriaizationError(t *testing.T) { }, unmarshalFn: jsonrpc.Unmarshal, expectedError: awserr.NewRequestFailure( - awserr.New("SerializationError", "", nil), + awserr.New(request.ErrCodeSerialization, "", nil), 502, "", ), @@ -92,7 +92,7 @@ func TestUnmarshalSeriaizationError(t *testing.T) { }, unmarshalFn: ec2query.Unmarshal, expectedError: awserr.NewRequestFailure( - awserr.New("SerializationError", "", nil), + awserr.New(request.ErrCodeSerialization, "", nil), 111, "", ), @@ -111,7 +111,7 @@ func TestUnmarshalSeriaizationError(t *testing.T) { }, unmarshalFn: query.Unmarshal, expectedError: awserr.NewRequestFailure( - awserr.New("SerializationError", "", nil), + awserr.New(request.ErrCodeSerialization, "", nil), 1, "", ), @@ -127,7 +127,7 @@ func TestUnmarshalSeriaizationError(t *testing.T) { }, unmarshalFn: restjson.Unmarshal, expectedError: awserr.NewRequestFailure( - awserr.New("SerializationError", "", nil), + awserr.New(request.ErrCodeSerialization, "", nil), 123, "", ), @@ -143,7 +143,7 @@ func TestUnmarshalSeriaizationError(t *testing.T) { }, unmarshalFn: restxml.Unmarshal, expectedError: awserr.NewRequestFailure( - awserr.New("SerializationError", "", nil), + awserr.New(request.ErrCodeSerialization, "", nil), 456, "", ), diff --git a/private/protocol/xml/xmlutil/unmarshal.go b/private/protocol/xml/xmlutil/unmarshal.go index ff1ef6830b9..7108d380093 100644 --- a/private/protocol/xml/xmlutil/unmarshal.go +++ b/private/protocol/xml/xmlutil/unmarshal.go @@ -1,6 +1,7 @@ package xmlutil import ( + "bytes" "encoding/base64" "encoding/xml" "fmt" @@ -10,9 +11,27 @@ import ( "strings" "time" + "github.com/aws/aws-sdk-go/aws/awserr" "github.com/aws/aws-sdk-go/private/protocol" ) +// UnmarshalXMLError unmarshals the XML error from the stream into the value +// type specified. The value must be a pointer. If the message fails to +// unmarshal, the message content will be included in the returned error as a +// awserr.UnmarshalError. +func UnmarshalXMLError(v interface{}, stream io.Reader) error { + var errBuf bytes.Buffer + body := io.TeeReader(stream, &errBuf) + + err := xml.NewDecoder(body).Decode(v) + if err != nil && err != io.EOF { + return awserr.NewUnmarshalError(err, + "failed to unmarshal error message", errBuf.Bytes()) + } + + return nil +} + // UnmarshalXML deserializes an xml.Decoder into the container v. V // needs to match the shape of the XML expected to be decoded. // If the shape doesn't match unmarshaling will fail. diff --git a/service/route53/customizations.go b/service/route53/customizations.go index efe2d6e7c0a..7aca8722e99 100644 --- a/service/route53/customizations.go +++ b/service/route53/customizations.go @@ -33,7 +33,7 @@ func sanitizeURL(r *request.Request) { // Update Path so that it reflects the cleaned RawPath updated, err := url.Parse(r.HTTPRequest.URL.RawPath) if err != nil { - r.Error = awserr.New("SerializationError", "failed to clean Route53 URL", err) + r.Error = awserr.New(request.ErrCodeSerialization, "failed to clean Route53 URL", err) return } diff --git a/service/route53/unmarshal_error.go b/service/route53/unmarshal_error.go index 266e9a8ba43..b3b95a126e2 100644 --- a/service/route53/unmarshal_error.go +++ b/service/route53/unmarshal_error.go @@ -1,77 +1,106 @@ package route53 import ( - "bytes" "encoding/xml" - "io/ioutil" + "fmt" "github.com/aws/aws-sdk-go/aws/awserr" "github.com/aws/aws-sdk-go/aws/request" - "github.com/aws/aws-sdk-go/private/protocol/restxml" + "github.com/aws/aws-sdk-go/private/protocol/xml/xmlutil" ) -type baseXMLErrorResponse struct { - XMLName xml.Name -} +const errorRespTag = "ErrorResponse" +const invalidChangeTag = "InvalidChangeBatch" type standardXMLErrorResponse struct { - XMLName xml.Name `xml:"ErrorResponse"` - Code string `xml:"Error>Code"` - Message string `xml:"Error>Message"` - RequestID string `xml:"RequestId"` + Code string `xml:"Error>Code"` + Message string `xml:"Error>Message"` + RequestID string `xml:"RequestId"` +} + +func (e standardXMLErrorResponse) FillCommon(c *xmlErrorResponse) { + c.Code = e.Code + c.Message = e.Message + c.RequestID = e.RequestID } type invalidChangeBatchXMLErrorResponse struct { - XMLName xml.Name `xml:"InvalidChangeBatch"` - Messages []string `xml:"Messages>Message"` + Messages []string `xml:"Messages>Message"` + RequestID string `xml:"RequestId"` } -func unmarshalChangeResourceRecordSetsError(r *request.Request) { - defer r.HTTPResponse.Body.Close() +func (e invalidChangeBatchXMLErrorResponse) FillCommon(c *xmlErrorResponse) { + c.Code = invalidChangeTag + c.Message = "ChangeBatch errors occurred" + c.Messages = e.Messages + c.RequestID = e.RequestID +} - responseBody, err := ioutil.ReadAll(r.HTTPResponse.Body) +type xmlErrorResponse struct { + Code string + Message string + Messages []string + RequestID string +} - if err != nil { - r.Error = awserr.New("SerializationError", "failed to read Route53 XML error response", err) - return +func (e *xmlErrorResponse) UnmarshalXML(d *xml.Decoder, start xml.StartElement) error { + type commonFiller interface { + FillCommon(*xmlErrorResponse) } - baseError := &baseXMLErrorResponse{} + var errResp commonFiller + switch start.Name.Local { + case errorRespTag: + errResp = &standardXMLErrorResponse{} - if err := xml.Unmarshal(responseBody, baseError); err != nil { - r.Error = awserr.New("SerializationError", "failed to decode Route53 XML error response", err) - return - } + case invalidChangeTag: + errResp = &invalidChangeBatchXMLErrorResponse{} - switch baseError.XMLName.Local { - case "InvalidChangeBatch": - unmarshalInvalidChangeBatchError(r, responseBody) default: - r.HTTPResponse.Body = ioutil.NopCloser(bytes.NewReader(responseBody)) - restxml.UnmarshalError(r) + return fmt.Errorf("unknown error message, %v", start.Name.Local) } + + if err := d.DecodeElement(errResp, &start); err != nil { + return err + } + + errResp.FillCommon(e) + return nil } -func unmarshalInvalidChangeBatchError(r *request.Request, requestBody []byte) { - resp := &invalidChangeBatchXMLErrorResponse{} - err := xml.Unmarshal(requestBody, resp) +func unmarshalChangeResourceRecordSetsError(r *request.Request) { + defer r.HTTPResponse.Body.Close() + var errResp xmlErrorResponse + err := xmlutil.UnmarshalXMLError(&errResp, r.HTTPResponse.Body) if err != nil { - r.Error = awserr.New("SerializationError", "failed to decode query XML error response", err) + r.Error = awserr.NewRequestFailure( + awserr.New(request.ErrCodeSerialization, + "failed to unmarshal error message", err), + r.HTTPResponse.StatusCode, + r.RequestID, + ) return } - const errorCode = "InvalidChangeBatch" - errors := []error{} - - for _, msg := range resp.Messages { - errors = append(errors, awserr.New(errorCode, msg, nil)) + var baseErr awserr.Error + if len(errResp.Messages) != 0 { + var errs []error + for _, msg := range errResp.Messages { + errs = append(errs, awserr.New(invalidChangeTag, msg, nil)) + } + baseErr = awserr.NewBatchError(errResp.Code, errResp.Message, errs) + } else { + baseErr = awserr.New(errResp.Code, errResp.Message, nil) } + reqID := errResp.RequestID + if len(reqID) == 0 { + reqID = r.RequestID + } r.Error = awserr.NewRequestFailure( - awserr.NewBatchError(errorCode, "ChangeBatch errors occurred", errors), + baseErr, r.HTTPResponse.StatusCode, - r.RequestID, + reqID, ) - } diff --git a/service/route53/unmarshal_error_test.go b/service/route53/unmarshal_error_test.go index 750a937fddb..29329b41ebc 100644 --- a/service/route53/unmarshal_error_test.go +++ b/service/route53/unmarshal_error_test.go @@ -1,130 +1,113 @@ -package route53_test +// +build go1.8 + +package route53 import ( - "bytes" "io/ioutil" "net/http" + "strings" "testing" - "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/awserr" "github.com/aws/aws-sdk-go/aws/request" - "github.com/aws/aws-sdk-go/awstesting/unit" - "github.com/aws/aws-sdk-go/service/route53" ) -func makeClientWithResponse(response string) *route53.Route53 { - r := route53.New(unit.Session) - r.Handlers.Send.Clear() - r.Handlers.Send.PushBack(func(r *request.Request) { - body := ioutil.NopCloser(bytes.NewReader([]byte(response))) - r.HTTPResponse = &http.Response{ - ContentLength: int64(len(response)), - StatusCode: 400, - Status: "Bad Request", - Body: body, - } - }) +func TestUnmarshalInvalidChangeBatch(t *testing.T) { + const errorMessage = ` +Tried to create resource record set duplicate.example.com. type A, +but it already exists +` - return r -} + type batchError struct { + Code, Message string + } -func TestUnmarshalStandardError(t *testing.T) { - const errorResponse = ` + cases := map[string]struct { + Request *request.Request + Code, Message, RequestID string + StatusCode int + BatchErrors []batchError + }{ + "standard error": { + Request: &request.Request{ + HTTPResponse: &http.Response{ + StatusCode: 400, + Header: http.Header{}, + Body: ioutil.NopCloser(strings.NewReader( + ` - InvalidDomainName - The domain name is invalid + InvalidDomainName + The domain name is invalid 12345 - -` - - r := makeClientWithResponse(errorResponse) - - _, err := r.CreateHostedZone(&route53.CreateHostedZoneInput{ - CallerReference: aws.String("test"), - Name: aws.String("test_zone"), - }) - - if err == nil { - t.Error("expected error, but received none") - } - - if e, a := "InvalidDomainName", err.(awserr.Error).Code(); e != a { - t.Errorf("expected %s, but received %s", e, a) - } - - if e, a := "The domain name is invalid", err.(awserr.Error).Message(); e != a { - t.Errorf("expected %s, but received %s", e, a) - } -} - -func TestUnmarshalInvalidChangeBatch(t *testing.T) { - const errorMessage = ` -Tried to create resource record set duplicate.example.com. type A, -but it already exists -` - const errorResponse = ` +`)), + }, + }, + Code: "InvalidDomainName", Message: "The domain name is invalid", + StatusCode: 400, RequestID: "12345", + }, + "batched error": { + Request: &request.Request{ + HTTPResponse: &http.Response{ + StatusCode: 400, + Header: http.Header{}, + Body: ioutil.NopCloser(strings.NewReader( + ` - ` + errorMessage + ` + ` + errorMessage + ` - -` - - r := makeClientWithResponse(errorResponse) - - req := &route53.ChangeResourceRecordSetsInput{ - HostedZoneId: aws.String("zoneId"), - ChangeBatch: &route53.ChangeBatch{ - Changes: []*route53.Change{ - { - Action: aws.String("CREATE"), - ResourceRecordSet: &route53.ResourceRecordSet{ - Name: aws.String("domain"), - Type: aws.String("CNAME"), - TTL: aws.Int64(120), - ResourceRecords: []*route53.ResourceRecord{ - { - Value: aws.String("cname"), - }, - }, - }, + 12345 +`)), }, }, + Code: "InvalidChangeBatch", Message: "ChangeBatch errors occurred", + StatusCode: 400, RequestID: "12345", + BatchErrors: []batchError{ + {Code: "InvalidChangeBatch", Message: errorMessage}, + }, }, } - _, err := r.ChangeResourceRecordSets(req) - if err == nil { - t.Error("expected error, but received none") - } - - if reqErr, ok := err.(awserr.RequestFailure); ok { - if reqErr == nil { - t.Error("expected error, but received none") - } - - if e, a := 400, reqErr.StatusCode(); e != a { - t.Errorf("expected %d, but received %d", e, a) - } - } else { - t.Fatal("returned error is not a RequestFailure") - } - - if batchErr, ok := err.(awserr.BatchedErrors); ok { - errs := batchErr.OrigErrs() - if e, a := 1, len(errs); e != a { - t.Errorf("expected %d, but received %d", e, a) - } - if e, a := "InvalidChangeBatch", errs[0].(awserr.Error).Code(); e != a { - t.Errorf("expected %s, but received %s", e, a) - } - if e, a := errorMessage, errs[0].(awserr.Error).Message(); e != a { - t.Errorf("expected %s, but received %s", e, a) - } - } else { - t.Fatal("returned error is not a BatchedErrors") + for name, c := range cases { + t.Run(name, func(t *testing.T) { + unmarshalChangeResourceRecordSetsError(c.Request) + err := c.Request.Error + if err == nil { + t.Error("expected error, but received none") + } + + reqErr := err.(awserr.RequestFailure) + if e, a := c.StatusCode, reqErr.StatusCode(); e != a { + t.Errorf("expected %d status, got %d", e, a) + } + if e, a := c.Code, reqErr.Code(); e != a { + t.Errorf("expected %v code, got %v", e, a) + } + if e, a := c.Message, reqErr.Message(); e != a { + t.Errorf("expected %q message, got %q", e, a) + } + if e, a := c.RequestID, reqErr.RequestID(); e != a { + t.Errorf("expected %v request ID, got %v", e, a) + } + + batchErr := err.(awserr.BatchedErrors) + batchedErrs := batchErr.OrigErrs() + + if e, a := len(c.BatchErrors), len(batchedErrs); e != a { + t.Fatalf("expect %v batch errors, got %v", e, a) + } + + for i, ee := range c.BatchErrors { + bErr := batchedErrs[i].(awserr.Error) + if e, a := ee.Code, bErr.Code(); e != a { + t.Errorf("expect %v code, got %v", e, a) + } + if e, a := ee.Message, bErr.Message(); e != a { + t.Errorf("expect %v message, got %v", e, a) + } + } + }) } } diff --git a/service/s3/bucket_location.go b/service/s3/bucket_location.go index bc68a46acfa..9ba8a788720 100644 --- a/service/s3/bucket_location.go +++ b/service/s3/bucket_location.go @@ -80,7 +80,8 @@ func buildGetBucketLocation(r *request.Request) { out := r.Data.(*GetBucketLocationOutput) b, err := ioutil.ReadAll(r.HTTPResponse.Body) if err != nil { - r.Error = awserr.New("SerializationError", "failed reading response body", err) + r.Error = awserr.New(request.ErrCodeSerialization, + "failed reading response body", err) return } diff --git a/service/s3/statusok_error.go b/service/s3/statusok_error.go index fde3050f95b..f6a69aed11b 100644 --- a/service/s3/statusok_error.go +++ b/service/s3/statusok_error.go @@ -14,7 +14,7 @@ func copyMultipartStatusOKUnmarhsalError(r *request.Request) { b, err := ioutil.ReadAll(r.HTTPResponse.Body) if err != nil { r.Error = awserr.NewRequestFailure( - awserr.New("SerializationError", "unable to read response body", err), + awserr.New(request.ErrCodeSerialization, "unable to read response body", err), r.HTTPResponse.StatusCode, r.RequestID, ) @@ -31,7 +31,7 @@ func copyMultipartStatusOKUnmarhsalError(r *request.Request) { unmarshalError(r) if err, ok := r.Error.(awserr.Error); ok && err != nil { - if err.Code() == "SerializationError" { + if err.Code() == request.ErrCodeSerialization { r.Error = nil return } diff --git a/service/s3/unmarshal_error.go b/service/s3/unmarshal_error.go index 1db7e133baf..5b63fac72ff 100644 --- a/service/s3/unmarshal_error.go +++ b/service/s3/unmarshal_error.go @@ -11,6 +11,7 @@ import ( "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/awserr" "github.com/aws/aws-sdk-go/aws/request" + "github.com/aws/aws-sdk-go/private/protocol/xml/xmlutil" ) type xmlErrorResponse struct { @@ -42,29 +43,34 @@ func unmarshalError(r *request.Request) { return } - var errCode, errMsg string - // Attempt to parse error from body if it is known - resp := &xmlErrorResponse{} - err := xml.NewDecoder(r.HTTPResponse.Body).Decode(resp) - if err != nil && err != io.EOF { - errCode = "SerializationError" - errMsg = "failed to decode S3 XML error response" - } else { - errCode = resp.Code - errMsg = resp.Message + var errResp xmlErrorResponse + err := xmlutil.UnmarshalXMLError(&errResp, r.HTTPResponse.Body) + if err == io.EOF { + // Only capture the error if an unmarshal error occurs that is not EOF, + // because S3 might send an error without a error message which causes + // the XML unmarshal to fail with EOF. err = nil } + if err != nil { + r.Error = awserr.NewRequestFailure( + awserr.New(request.ErrCodeSerialization, + "failed to unmarshal error message", err), + r.HTTPResponse.StatusCode, + r.RequestID, + ) + return + } // Fallback to status code converted to message if still no error code - if len(errCode) == 0 { + if len(errResp.Code) == 0 { statusText := http.StatusText(r.HTTPResponse.StatusCode) - errCode = strings.Replace(statusText, " ", "", -1) - errMsg = statusText + errResp.Code = strings.Replace(statusText, " ", "", -1) + errResp.Message = statusText } r.Error = awserr.NewRequestFailure( - awserr.New(errCode, errMsg, err), + awserr.New(errResp.Code, errResp.Message, err), r.HTTPResponse.StatusCode, r.RequestID, ) diff --git a/service/simpledb/unmarshall_error.go b/service/simpledb/unmarshall_error.go index acc8a86eb7c..f64f6cc19f2 100644 --- a/service/simpledb/unmarshall_error.go +++ b/service/simpledb/unmarshall_error.go @@ -8,19 +8,45 @@ import ( "github.com/aws/aws-sdk-go/aws/awserr" "github.com/aws/aws-sdk-go/aws/request" + "github.com/aws/aws-sdk-go/private/protocol/xml/xmlutil" ) type xmlErrorDetail struct { Code string `xml:"Code"` Message string `xml:"Message"` } - -type xmlErrorResponse struct { +type xmlErrorMessage struct { XMLName xml.Name `xml:"Response"` Errors []xmlErrorDetail `xml:"Errors>Error"` RequestID string `xml:"RequestID"` } +type xmlErrorResponse struct { + Code string + Message string + RequestID string + OtherErrors []xmlErrorDetail +} + +func (r *xmlErrorResponse) UnmarshalXML(d *xml.Decoder, start xml.StartElement) error { + var errResp xmlErrorMessage + if err := d.DecodeElement(&errResp, &start); err != nil { + return err + } + + r.RequestID = errResp.RequestID + if len(errResp.Errors) == 0 { + r.Code = "MissingError" + r.Message = "missing error code in SimpleDB XML error response" + } else { + r.Code = errResp.Errors[0].Code + r.Message = errResp.Errors[0].Message + r.OtherErrors = errResp.Errors[1:] + } + + return nil +} + func unmarshalError(r *request.Request) { defer r.HTTPResponse.Body.Close() defer io.Copy(ioutil.Discard, r.HTTPResponse.Body) @@ -30,24 +56,32 @@ func unmarshalError(r *request.Request) { r.Error = awserr.NewRequestFailure( awserr.New(strings.Replace(r.HTTPResponse.Status, " ", "", -1), r.HTTPResponse.Status, nil), r.HTTPResponse.StatusCode, - "", + r.RequestID, ) return } - resp := &xmlErrorResponse{} - err := xml.NewDecoder(r.HTTPResponse.Body).Decode(resp) - if err != nil && err != io.EOF { - r.Error = awserr.New("SerializationError", "failed to decode SimpleDB XML error response", nil) - } else if len(resp.Errors) == 0 { - r.Error = awserr.New("MissingError", "missing error code in SimpleDB XML error response", nil) - } else { - // If there are multiple error codes, return only the first as the aws.Error interface only supports - // one error code. + var errResp xmlErrorResponse + err := xmlutil.UnmarshalXMLError(&errResp, r.HTTPResponse.Body) + if err != nil { r.Error = awserr.NewRequestFailure( - awserr.New(resp.Errors[0].Code, resp.Errors[0].Message, nil), + awserr.New(request.ErrCodeSerialization, "failed to unmarshal error message", err), r.HTTPResponse.StatusCode, - resp.RequestID, + r.RequestID, ) + return + } + + var otherErrs []error + for _, e := range errResp.OtherErrors { + otherErrs = append(otherErrs, awserr.New(e.Code, e.Message, nil)) } + + // If there are multiple error codes, return only the first as the + // aws.Error interface only supports one error code. + r.Error = awserr.NewRequestFailure( + awserr.NewBatchError(errResp.Code, errResp.Message, otherErrs), + r.HTTPResponse.StatusCode, + errResp.RequestID, + ) }