Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions aws/ec2metadata/service.go
Original file line number Diff line number Diff line change
@@ -1,19 +1,28 @@
// Package ec2metadata provides the client for making API calls to the
// EC2 Metadata service.
//
// This package's client can be disabled completely by setting the environment
// variable "AWS_EC2_METADATA_DISABLED=true". This environment variable set to
// true instructs the SDK to disable the EC2 Metadata client. The client cannot
// be used while the environemnt variable is set to true, (case insensitive).
package ec2metadata

import (
"bytes"
"errors"
"io"
"net/http"
"os"
"strings"

"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/aws/awserr"
"github.com/aws/aws-sdk-go-v2/aws/defaults"
)

// ServiceName is the name of the service.
const ServiceName = "ec2metadata"
const disableServiceEnvVar = "AWS_EC2_METADATA_DISABLED"

// A EC2Metadata is an EC2 Metadata service Client.
type EC2Metadata struct {
Expand Down Expand Up @@ -42,6 +51,21 @@ func New(config aws.Config) *EC2Metadata {
svc.Handlers.Validate.Clear()
svc.Handlers.Validate.PushBack(validateEndpointHandler)

// Disable the EC2 Metadata service if the environment variable is set.
// This shortcirctes the service's functionality to always fail to send
// requests.
if strings.ToLower(os.Getenv(disableServiceEnvVar)) == "true" {
svc.Handlers.Send.SwapNamed(aws.NamedHandler{
Name: defaults.SendHandler.Name,
Fn: func(r *aws.Request) {
r.Error = awserr.New(
aws.CanceledErrorCode,
"EC2 IMDS access disabled via "+disableServiceEnvVar+" env var",
nil)
},
})
}

return svc
}

Expand Down
37 changes: 37 additions & 0 deletions aws/ec2metadata/service_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
package ec2metadata_test

import (
"os"
"strings"
"testing"

"github.com/aws/aws-sdk-go-v2/aws/awserr"
"github.com/aws/aws-sdk-go-v2/aws/ec2metadata"
"github.com/aws/aws-sdk-go-v2/internal/awstesting/unit"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/awstesting"
)

func TestClientDisableIMDS(t *testing.T) {
env := awstesting.StashEnv()
defer awstesting.PopEnv(env)

os.Setenv("AWS_EC2_METADATA_DISABLED", "true")

svc := ec2metadata.New(unit.Config())
resp, err := svc.Region()
if err == nil {
t.Fatalf("expect error, got none")
}
if len(resp) != 0 {
t.Errorf("expect no response, got %v", resp)
}

aerr := err.(awserr.Error)
if e, a := request.CanceledErrorCode, aerr.Code(); e != a {
t.Errorf("expect %v error code, got %v", e, a)
}
if e, a := "AWS_EC2_METADATA_DISABLED", aerr.Message(); !strings.Contains(a, e) {
t.Errorf("expect %v in error message, got %v", e, a)
}
}