Skip to content
This repository was archived by the owner on Jul 31, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions aws/ec2metadata/service.go
Original file line number Diff line number Diff line change
@@ -1,23 +1,32 @@
// Package ec2metadata provides the client for making API calls to the
// EC2 Metadata service.
//
// This package's client can be disabled completely by setting the environment
// variable "AWS_EC2_METADATA_DISABLED=true". This environment variable set to
// true instructs the SDK to disable the EC2 Metadata client. The client cannot
// be used while the environemnt variable is set to true, (case insensitive).
package ec2metadata

import (
"bytes"
"errors"
"io"
"net/http"
"os"
"strings"
"time"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/client"
"github.com/aws/aws-sdk-go/aws/client/metadata"
"github.com/aws/aws-sdk-go/aws/corehandlers"
"github.com/aws/aws-sdk-go/aws/request"
)

// ServiceName is the name of the service.
const ServiceName = "ec2metadata"
const disableServiceEnvVar = "AWS_EC2_METADATA_DISABLED"

// A EC2Metadata is an EC2 Metadata service Client.
type EC2Metadata struct {
Expand Down Expand Up @@ -75,6 +84,21 @@ func NewClient(cfg aws.Config, handlers request.Handlers, endpoint, signingRegio
svc.Handlers.Validate.Clear()
svc.Handlers.Validate.PushBack(validateEndpointHandler)

// Disable the EC2 Metadata service if the environment variable is set.
// This shortcirctes the service's functionality to always fail to send
// requests.
if strings.ToLower(os.Getenv(disableServiceEnvVar)) == "true" {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want to return an error or just remove that specific handler?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in this case returning error is preferred.

svc.Handlers.Send.SwapNamed(request.NamedHandler{
Name: corehandlers.SendHandler.Name,
Fn: func(r *request.Request) {
r.Error = awserr.New(
request.CanceledErrorCode,
"EC2 IMDS access disabled via "+disableServiceEnvVar+" env var",
nil)
},
})
}

// Add additional options to the service config
for _, option := range opts {
option(svc.Client)
Expand Down
62 changes: 52 additions & 10 deletions aws/ec2metadata/service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,30 @@ package ec2metadata_test
import (
"net/http"
"net/http/httptest"
"os"
"strings"
"sync"
"testing"
"time"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/ec2metadata"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/awstesting"
"github.com/aws/aws-sdk-go/awstesting/unit"
"github.com/stretchr/testify/assert"
)

func TestClientOverrideDefaultHTTPClientTimeout(t *testing.T) {
svc := ec2metadata.New(unit.Session)

assert.NotEqual(t, http.DefaultClient, svc.Config.HTTPClient)
assert.Equal(t, 5*time.Second, svc.Config.HTTPClient.Timeout)
if e, a := http.DefaultClient, svc.Config.HTTPClient; e == a {
t.Errorf("expect %v, not to equal %v", e, a)
}

if e, a := 5*time.Second, svc.Config.HTTPClient.Timeout; e != a {
t.Errorf("expect %v to be %v", e, a)
}
}

func TestClientNotOverrideDefaultHTTPClientTimeout(t *testing.T) {
Expand All @@ -28,18 +37,25 @@ func TestClientNotOverrideDefaultHTTPClientTimeout(t *testing.T) {

svc := ec2metadata.New(unit.Session)

assert.Equal(t, http.DefaultClient, svc.Config.HTTPClient)
if e, a := http.DefaultClient, svc.Config.HTTPClient; e != a {
t.Errorf("expect %v, got %v", e, a)
}

tr, ok := svc.Config.HTTPClient.Transport.(*http.Transport)
assert.True(t, ok)
assert.NotNil(t, tr)
assert.Nil(t, tr.Dial)
tr := svc.Config.HTTPClient.Transport.(*http.Transport)
if tr == nil {
t.Fatalf("expect transport not to be nil")
}
if tr.Dial != nil {
t.Errorf("expect dial to be nil, was not")
}
}

func TestClientDisableOverrideDefaultHTTPClientTimeout(t *testing.T) {
svc := ec2metadata.New(unit.Session, aws.NewConfig().WithEC2MetadataDisableTimeoutOverride(true))

assert.Equal(t, http.DefaultClient, svc.Config.HTTPClient)
if e, a := http.DefaultClient, svc.Config.HTTPClient; e != a {
t.Errorf("expect %v, got %v", e, a)
}
}

func TestClientOverrideDefaultHTTPClientTimeoutRace(t *testing.T) {
Expand All @@ -63,14 +79,40 @@ func TestClientOverrideDefaultHTTPClientTimeoutRaceWithTransport(t *testing.T) {
runEC2MetadataClients(t, cfg, 100)
}

func TestClientDisableIMDS(t *testing.T) {
env := awstesting.StashEnv()
defer awstesting.PopEnv(env)

os.Setenv("AWS_EC2_METADATA_DISABLED", "true")

svc := ec2metadata.New(unit.Session)
resp, err := svc.Region()
if err == nil {
t.Fatalf("expect error, got none")
}
if len(resp) != 0 {
t.Errorf("expect no response, got %v", resp)
}

aerr := err.(awserr.Error)
if e, a := request.CanceledErrorCode, aerr.Code(); e != a {
t.Errorf("expect %v error code, got %v", e, a)
}
if e, a := "AWS_EC2_METADATA_DISABLED", aerr.Message(); !strings.Contains(a, e) {
t.Errorf("expect %v in error message, got %v", e, a)
}
}

func runEC2MetadataClients(t *testing.T, cfg *aws.Config, atOnce int) {
var wg sync.WaitGroup
wg.Add(atOnce)
for i := 0; i < atOnce; i++ {
go func() {
svc := ec2metadata.New(unit.Session, cfg)
_, err := svc.Region()
assert.NoError(t, err)
if err != nil {
t.Fatalf("expect no error, got %v", err)
}
wg.Done()
}()
}
Expand Down