diff --git a/service/rds/customizations.go b/service/rds/customizations.go index d3023d1f764..d412fb282ba 100644 --- a/service/rds/customizations.go +++ b/service/rds/customizations.go @@ -29,6 +29,8 @@ func fillPresignedURL(r *request.Request) { fns := map[string]func(r *request.Request){ opCopyDBSnapshot: copyDBSnapshotPresign, opCreateDBInstanceReadReplica: createDBInstanceReadReplicaPresign, + opCopyDBClusterSnapshot: copyDBClusterSnapshotPresign, + opCreateDBCluster: createDBClusterPresign, } if !r.ParamsFilled() { return @@ -41,7 +43,7 @@ func fillPresignedURL(r *request.Request) { func copyDBSnapshotPresign(r *request.Request) { originParams := r.Params.(*CopyDBSnapshotInput) - if originParams.PreSignedUrl != nil || originParams.DestinationRegion != nil { + if originParams.SourceRegion == nil || originParams.PreSignedUrl != nil || originParams.DestinationRegion != nil { return } @@ -53,7 +55,7 @@ func copyDBSnapshotPresign(r *request.Request) { func createDBInstanceReadReplicaPresign(r *request.Request) { originParams := r.Params.(*CreateDBInstanceReadReplicaInput) - if originParams.PreSignedUrl != nil || originParams.DestinationRegion != nil { + if originParams.SourceRegion == nil || originParams.PreSignedUrl != nil || originParams.DestinationRegion != nil { return } @@ -62,6 +64,30 @@ func createDBInstanceReadReplicaPresign(r *request.Request) { originParams.PreSignedUrl = presignURL(r, originParams.SourceRegion, newParams) } +func copyDBClusterSnapshotPresign(r *request.Request) { + originParams := r.Params.(*CopyDBClusterSnapshotInput) + + if originParams.SourceRegion == nil || originParams.PreSignedUrl != nil || originParams.DestinationRegion != nil { + return + } + + originParams.DestinationRegion = r.Config.Region + newParams := awsutil.CopyOf(r.Params).(*CopyDBClusterSnapshotInput) + originParams.PreSignedUrl = presignURL(r, originParams.SourceRegion, newParams) +} + +func createDBClusterPresign(r *request.Request) { + originParams := r.Params.(*CreateDBClusterInput) + + if originParams.SourceRegion == nil || originParams.PreSignedUrl != nil || originParams.DestinationRegion != nil { + return + } + + originParams.DestinationRegion = r.Config.Region + newParams := awsutil.CopyOf(r.Params).(*CreateDBClusterInput) + originParams.PreSignedUrl = presignURL(r, originParams.SourceRegion, newParams) +} + // presignURL will presign the request by using SoureRegion to sign with. SourceRegion is not // sent to the service, and is only used to not have the SDKs parsing ARNs. func presignURL(r *request.Request, sourceRegion *string, newParams interface{}) *string { diff --git a/service/rds/customizations_test.go b/service/rds/customizations_test.go index df12d9bdc38..62db13af74c 100644 --- a/service/rds/customizations_test.go +++ b/service/rds/customizations_test.go @@ -5,6 +5,7 @@ import ( "io/ioutil" "net/url" "testing" + "time" "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/request" @@ -79,3 +80,26 @@ func TestPresignWithPresignSet(t *testing.T) { assert.Regexp(t, `presignedURL`, u) } } + +func TestPresignWithSourceNotSet(t *testing.T) { + reqs := map[string]*request.Request{} + svc := New(unit.Session, &aws.Config{Region: aws.String("us-west-2")}) + + assert.NotPanics(t, func() { + // Doesn't panic on nil input + req, _ := svc.CopyDBSnapshotRequest(nil) + req.Sign() + }) + + reqs[opCopyDBSnapshot], _ = svc.CopyDBSnapshotRequest(&CopyDBSnapshotInput{ + SourceDBSnapshotIdentifier: aws.String("foo"), + TargetDBSnapshotIdentifier: aws.String("bar"), + }) + + for _, req := range reqs { + _, err := req.Presign(5 * time.Minute) + if err != nil { + t.Fatal(err) + } + } +}