Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 37 additions & 5 deletions src/lambda_codebase/account_processing/configure_account_alias.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,19 @@
AWS_PARTITION = os.getenv("AWS_PARTITION")


def delete_account_aliases(account, iam_client, current_aliases):
for alias in current_aliases:
LOGGER.info(
"Account %s, removing alias %s",
account.get('account_full_name'),
alias,
)
iam_client.delete_account_alias(AccountAlias=alias)


def create_account_alias(account, iam_client):
LOGGER.info(
"Ensuring Account: %s has alias %s",
"Adding alias to: %s alias %s",
account.get('account_full_name'),
account.get('alias'),
)
Expand All @@ -28,11 +38,33 @@ def create_account_alias(account, iam_client):
except iam_client.exceptions.EntityAlreadyExistsException as error:
LOGGER.error(
f"The account alias {account.get('alias')} already exists."
"The account alias must be unique across all Amazon Web Services products."
"Refer to https://docs.aws.amazon.com/IAM/latest/UserGuide/console_account-alias.html#AboutAccountAlias"
"The account alias must be unique across all Amazon Web Services "
"products. Refer to "
"https://docs.aws.amazon.com/IAM/latest/UserGuide/"
"console_account-alias.html#AboutAccountAlias"
)
raise error
return account


def ensure_account_has_alias(account, iam_client):
LOGGER.info(
"Ensuring Account: %s has alias %s",
account.get('account_full_name'),
account.get('alias'),
)
current_aliases = iam_client.list_account_aliases().get('AccountAliases')
if account.get('alias') in current_aliases:
LOGGER.info(
"Account: %s already has alias %s",
account.get('account_full_name'),
account.get('alias'),
)
return

# Since you can only have one alias per account, lets
# remove all old aliases (is at most one)
delete_account_aliases(account, iam_client, current_aliases)
create_account_alias(account, iam_client)


def lambda_handler(event, _):
Expand All @@ -43,7 +75,7 @@ def lambda_handler(event, _):
f"arn:{AWS_PARTITION}:iam::{account_id}:role/{ADF_ROLE_NAME}",
"adf_account_alias_config",
)
create_account_alias(event, role.client("iam"))
ensure_account_has_alias(event, role.client("iam"))
else:
LOGGER.info(
"Account: %s does not need an alias",
Expand Down
60 changes: 48 additions & 12 deletions src/lambda_codebase/account_processing/tests/test_account_alias.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,26 +6,62 @@
import boto3
from botocore.stub import Stubber
from botocore.exceptions import ClientError
from mock import Mock
from aws_xray_sdk import global_sdk_config
from ..configure_account_alias import create_account_alias
from ..configure_account_alias import (
create_account_alias,
ensure_account_has_alias,
)

global_sdk_config.set_sdk_enabled(False)


class SuccessTestCase(unittest.TestCase):
# pylint: disable=W0106
def test_account_alias(self):
@staticmethod
def test_account_alias_exists_already():
test_account = {"account_id": 123456789012, "alias": "MyCoolAlias"}
iam_client = boto3.client("iam")
stubber = Stubber(iam_client)
create_alias_response = {}
stubber.add_response(
"create_account_alias", create_alias_response, {"AccountAlias": "MyCoolAlias"}
),
stubber.activate()
iam_client = Mock()
iam_client.list_account_aliases.return_value = {
"AccountAliases": ["MyCoolAlias"],
}

ensure_account_has_alias(test_account, iam_client)
iam_client.list_account_aliases.assert_called_once_with()
iam_client.delete_account_alias.assert_not_called()
iam_client.create_account_alias.assert_not_called()

@staticmethod
def test_account_alias_another_alias_exists():
test_account = {"account_id": 123456789012, "alias": "MyCoolAlias"}
iam_client = Mock()
iam_client.list_account_aliases.return_value = {
"AccountAliases": ["AnotherCoolAlias"],
}

ensure_account_has_alias(test_account, iam_client)
iam_client.list_account_aliases.assert_called_once_with()
iam_client.delete_account_alias.assert_called_once_with(
AccountAlias='AnotherCoolAlias',
)
iam_client.create_account_alias.assert_called_once_with(
AccountAlias='MyCoolAlias',
)

response = create_account_alias(test_account, iam_client)
@staticmethod
def test_account_alias_no_aliases_yet():
test_account = {"account_id": 123456789012, "alias": "MyCoolAlias"}
iam_client = Mock()
iam_client.list_account_aliases.return_value = {
"AccountAliases": [],
}

ensure_account_has_alias(test_account, iam_client)
iam_client.list_account_aliases.assert_called_once_with()
iam_client.delete_account_alias.assert_not_called()
iam_client.create_account_alias.assert_called_once_with(
AccountAlias='MyCoolAlias',
)

self.assertEqual(response, test_account)

class FailureTestCase(unittest.TestCase):
# pylint: disable=W0106
Expand Down
12 changes: 6 additions & 6 deletions src/template.yml
Original file line number Diff line number Diff line change
Expand Up @@ -661,7 +661,7 @@ Resources:
"ErrorEquals": ["States.TaskFailed"],
"IntervalSeconds": 3,
"BackoffRate": 1.5,
"MaxAttempts": 30
"MaxAttempts": 10
}, {
"ErrorEquals": [
"Lambda.Unknown",
Expand All @@ -685,7 +685,7 @@ Resources:
"ErrorEquals": ["States.TaskFailed"],
"IntervalSeconds": 3,
"BackoffRate": 1.5,
"MaxAttempts": 30
"MaxAttempts": 10
}, {
"ErrorEquals": [
"Lambda.Unknown",
Expand Down Expand Up @@ -714,7 +714,7 @@ Resources:
"ErrorEquals": ["States.TaskFailed"],
"IntervalSeconds": 3,
"BackoffRate": 1.5,
"MaxAttempts": 30
"MaxAttempts": 10
}, {
"ErrorEquals": [
"Lambda.Unknown",
Expand All @@ -738,7 +738,7 @@ Resources:
"ErrorEquals": ["States.TaskFailed"],
"IntervalSeconds": 3,
"BackoffRate": 1.5,
"MaxAttempts": 30
"MaxAttempts": 10
}, {
"ErrorEquals": [
"Lambda.Unknown",
Expand All @@ -762,7 +762,7 @@ Resources:
"ErrorEquals": ["States.TaskFailed"],
"IntervalSeconds": 3,
"BackoffRate": 1.5,
"MaxAttempts": 30
"MaxAttempts": 10
}, {
"ErrorEquals": [
"Lambda.Unknown",
Expand Down Expand Up @@ -797,7 +797,7 @@ Resources:
"ErrorEquals": ["States.TaskFailed"],
"IntervalSeconds": 3,
"BackoffRate": 1.5,
"MaxAttempts": 30
"MaxAttempts": 10
}, {
"ErrorEquals": [
"Lambda.Unknown",
Expand Down