Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 15 additions & 15 deletions service/s3/s3crypto/strategy.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,13 @@ import (
"strings"

"github.com/aws/aws-sdk-go-v2/aws"
request "github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/aws/awserr"
"github.com/aws/aws-sdk-go-v2/service/s3"
)

// SaveStrategy is how the data's metadata wants to be saved
type SaveStrategy interface {
Save(Envelope, *request.Request) error
Save(Envelope, *aws.Request) error
}

// S3SaveStrategy will save the metadata to a separate instruction file in S3
Expand All @@ -25,7 +24,7 @@ type S3SaveStrategy struct {
}

// Save will save the envelope contents to s3.
func (strat S3SaveStrategy) Save(env Envelope, req *request.Request) error {
func (strat S3SaveStrategy) Save(env Envelope, req *aws.Request) error {
input := req.Params.(*s3.PutObjectInput)
b, err := json.Marshal(env)
if err != nil {
Expand All @@ -43,8 +42,7 @@ func (strat S3SaveStrategy) Save(env Envelope, req *request.Request) error {
instInput.Key = aws.String(*input.Key + strat.InstructionFileSuffix)
}

putReq := strat.Client.PutObjectRequest(&instInput)
_, err = putReq.Send()
_, err = strat.Client.PutObjectRequest(&instInput).Send()
return err
}

Expand All @@ -53,7 +51,7 @@ func (strat S3SaveStrategy) Save(env Envelope, req *request.Request) error {
type HeaderV2SaveStrategy struct{}

// Save will save the envelope to the request's header.
func (strat HeaderV2SaveStrategy) Save(env Envelope, req *request.Request) error {
func (strat HeaderV2SaveStrategy) Save(env Envelope, req *aws.Request) error {
input := req.Params.(*s3.PutObjectInput)
if input.Metadata == nil {
input.Metadata = map[string]string{}
Expand All @@ -64,15 +62,18 @@ func (strat HeaderV2SaveStrategy) Save(env Envelope, req *request.Request) error
input.Metadata[http.CanonicalHeaderKey(matDescHeader)] = env.MatDesc
input.Metadata[http.CanonicalHeaderKey(wrapAlgorithmHeader)] = env.WrapAlg
input.Metadata[http.CanonicalHeaderKey(cekAlgorithmHeader)] = env.CEKAlg
input.Metadata[http.CanonicalHeaderKey(tagLengthHeader)] = env.TagLen
input.Metadata[http.CanonicalHeaderKey(unencryptedMD5Header)] = env.UnencryptedMD5
input.Metadata[http.CanonicalHeaderKey(unencryptedContentLengthHeader)] = env.UnencryptedContentLen

if len(env.TagLen) > 0 {
input.Metadata[http.CanonicalHeaderKey(tagLengthHeader)] = env.TagLen
}
return nil
}

// LoadStrategy ...
type LoadStrategy interface {
Load(*request.Request) (Envelope, error)
Load(*aws.Request) (Envelope, error)
}

// S3LoadStrategy will load the instruction file from s3
Expand All @@ -82,23 +83,22 @@ type S3LoadStrategy struct {
}

// Load from a given instruction file suffix
func (load S3LoadStrategy) Load(req *request.Request) (Envelope, error) {
func (load S3LoadStrategy) Load(req *aws.Request) (Envelope, error) {
env := Envelope{}
if load.InstructionFileSuffix == "" {
load.InstructionFileSuffix = DefaultInstructionKeySuffix
}

input := req.Params.(*s3.GetObjectInput)
getReq := load.Client.GetObjectRequest(&s3.GetObjectInput{
out, err := load.Client.GetObjectRequest(&s3.GetObjectInput{
Key: aws.String(strings.Join([]string{*input.Key, load.InstructionFileSuffix}, "")),
Bucket: input.Bucket,
})
resp, err := getReq.Send()
}).Send()
if err != nil {
return env, err
}

b, err := ioutil.ReadAll(resp.Body)
b, err := ioutil.ReadAll(out.Body)
if err != nil {
return env, err
}
Expand All @@ -110,7 +110,7 @@ func (load S3LoadStrategy) Load(req *request.Request) (Envelope, error) {
type HeaderV2LoadStrategy struct{}

// Load from a given object's header
func (load HeaderV2LoadStrategy) Load(req *request.Request) (Envelope, error) {
func (load HeaderV2LoadStrategy) Load(req *aws.Request) (Envelope, error) {
env := Envelope{}
env.CipherKey = req.HTTPResponse.Header.Get(strings.Join([]string{metaHeader, keyV2Header}, "-"))
env.IV = req.HTTPResponse.Header.Get(strings.Join([]string{metaHeader, ivHeader}, "-"))
Expand All @@ -128,7 +128,7 @@ type defaultV2LoadStrategy struct {
suffix string
}

func (load defaultV2LoadStrategy) Load(req *request.Request) (Envelope, error) {
func (load defaultV2LoadStrategy) Load(req *aws.Request) (Envelope, error) {
if value := req.HTTPResponse.Header.Get(strings.Join([]string{metaHeader, keyV2Header}, "-")); value != "" {
strat := HeaderV2LoadStrategy{}
return strat.Load(req)
Expand Down
91 changes: 60 additions & 31 deletions service/s3/s3crypto/strategy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,44 +4,73 @@ import (
"reflect"
"testing"

request "github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/service/s3"
"github.com/aws/aws-sdk-go-v2/service/s3/s3crypto"
)

func TestHeaderV2SaveStrategy(t *testing.T) {
env := s3crypto.Envelope{
CipherKey: "Foo",
IV: "Bar",
MatDesc: "{}",
WrapAlg: s3crypto.KMSWrap,
CEKAlg: s3crypto.AESGCMNoPadding,
TagLen: "128",
UnencryptedMD5: "hello",
UnencryptedContentLen: "0",
}
params := &s3.PutObjectInput{}
req := &request.Request{
Params: params,
}
strat := s3crypto.HeaderV2SaveStrategy{}
err := strat.Save(env, req)
if err != nil {
t.Errorf("expected no error, but received %v", err)
cases := []struct {
env s3crypto.Envelope
expected map[string]string
}{
{
s3crypto.Envelope{
CipherKey: "Foo",
IV: "Bar",
MatDesc: "{}",
WrapAlg: s3crypto.KMSWrap,
CEKAlg: s3crypto.AESGCMNoPadding,
TagLen: "128",
UnencryptedMD5: "hello",
UnencryptedContentLen: "0",
},
map[string]string{
"X-Amz-Key-V2": "Foo",
"X-Amz-Iv": "Bar",
"X-Amz-Matdesc": "{}",
"X-Amz-Wrap-Alg": s3crypto.KMSWrap,
"X-Amz-Cek-Alg": s3crypto.AESGCMNoPadding,
"X-Amz-Tag-Len": "128",
"X-Amz-Unencrypted-Content-Md5": "hello",
"X-Amz-Unencrypted-Content-Length": "0",
},
},
{
s3crypto.Envelope{
CipherKey: "Foo",
IV: "Bar",
MatDesc: "{}",
WrapAlg: s3crypto.KMSWrap,
CEKAlg: s3crypto.AESGCMNoPadding,
UnencryptedMD5: "hello",
UnencryptedContentLen: "0",
},
map[string]string{
"X-Amz-Key-V2": "Foo",
"X-Amz-Iv": "Bar",
"X-Amz-Matdesc": "{}",
"X-Amz-Wrap-Alg": s3crypto.KMSWrap,
"X-Amz-Cek-Alg": s3crypto.AESGCMNoPadding,
"X-Amz-Unencrypted-Content-Md5": "hello",
"X-Amz-Unencrypted-Content-Length": "0",
},
},
}

expected := map[string]string{
"X-Amz-Key-V2": "Foo",
"X-Amz-Iv": "Bar",
"X-Amz-Matdesc": "{}",
"X-Amz-Wrap-Alg": s3crypto.KMSWrap,
"X-Amz-Cek-Alg": s3crypto.AESGCMNoPadding,
"X-Amz-Tag-Len": "128",
"X-Amz-Unencrypted-Content-Md5": "hello",
"X-Amz-Unencrypted-Content-Length": "0",
}
for _, c := range cases {
params := &s3.PutObjectInput{}
req := &aws.Request{
Params: params,
}
strat := s3crypto.HeaderV2SaveStrategy{}
err := strat.Save(c.env, req)
if err != nil {
t.Errorf("expected no error, but received %v", err)
}

if !reflect.DeepEqual(expected, params.Metadata) {
t.Errorf("expected %v, but received %v", expected, params.Metadata)
if !reflect.DeepEqual(c.expected, params.Metadata) {
t.Errorf("expected %v, but received %v", c.expected, params.Metadata)
}
}
}