Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 31 additions & 27 deletions src/lambda_codebase/account_processing/process_account_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import tempfile
import logging
from typing import Any, TypedDict
import re
import yaml

from yaml.error import YAMLError
Expand Down Expand Up @@ -42,6 +43,7 @@ class AccountFileData(TypedDict):
Class used to return YAML account file data and its related
metadata like the execution_id of the CodePipeline that uploaded it.
"""

content: Any
execution_id: str

Expand All @@ -65,8 +67,8 @@ def get_file_from_s3(
try:
LOGGER.debug(
"Reading YAML from S3: %s from %s",
s3_object_location.get('object_key'),
s3_object_location.get('bucket_name'),
s3_object_location.get("object_key"),
s3_object_location.get("bucket_name"),
)
s3_object = s3_resource.Object(**s3_object_location)
object_adf_version = s3_object.metadata.get(
Expand All @@ -80,12 +82,9 @@ def get_file_from_s3(
s3_object_location,
object_adf_version,
)
return {
"content": {},
"execution_id": ""
}
return {"content": {}, "execution_id": ""}

with tempfile.TemporaryFile(mode='w+b') as file_pointer:
with tempfile.TemporaryFile(mode="w+b") as file_pointer:
s3_object.download_fileobj(file_pointer)

# Move pointer to the start of the file
Expand All @@ -98,16 +97,16 @@ def get_file_from_s3(
except ClientError as error:
LOGGER.error(
"Failed to download %s from %s, due to %s",
s3_object_location.get('object_key'),
s3_object_location.get('bucket_name'),
s3_object_location.get("object_key"),
s3_object_location.get("bucket_name"),
error,
)
raise
except YAMLError as yaml_error:
LOGGER.error(
"Failed to parse YAML file: %s from %s, due to %s",
s3_object_location.get('object_key'),
s3_object_location.get('bucket_name'),
s3_object_location.get("object_key"),
s3_object_location.get("bucket_name"),
yaml_error,
)
raise
Expand All @@ -129,19 +128,23 @@ def process_account(account_lookup, account):


def process_account_list(all_accounts, accounts_in_file):
account_lookup = {
account["Name"]: account["Id"] for account in all_accounts
}
processed_accounts = list(map(
lambda account: process_account(
account_lookup=account_lookup,
account=account,
),
accounts_in_file
))
account_lookup = {account["Name"]: account["Id"] for account in all_accounts}
processed_accounts = list(
map(
lambda account: process_account(
account_lookup=account_lookup,
account=account,
),
accounts_in_file,
)
)
return processed_accounts


def sanitize_account_name_for_snf(account_name):
return re.sub("[^a-zA-Z0-9_]", "_", account_name[:30])


def start_executions(
sfn_client,
processed_account_list,
Expand All @@ -158,14 +161,14 @@ def start_executions(
run_id,
)
for account in processed_account_list:
full_account_name = account.get('account_full_name', 'no-account-name')
full_account_name = account.get("account_full_name", "no-account-name")
# AWS Step Functions supports max 80 characters.
# Since the run_id equals 49 characters plus the dash, we have 30
# characters available. To ensure we don't run over, lets use a
# truncated version instead:
truncated_account_name = full_account_name[:30]
sfn_execution_name = f"{truncated_account_name}-{run_id}"

sfn_execution_name = (
f"{sanitize_account_name_for_snf(full_account_name)}-{run_id}"
)
LOGGER.debug(
"Payload for %s: %s",
sfn_execution_name,
Expand All @@ -182,8 +185,9 @@ def lambda_handler(event, context):
"""Main Lambda Entry point"""
LOGGER.debug(
"Processing event: %s",
json.dumps(event, indent=2) if LOGGER.isEnabledFor(logging.DEBUG)
else "--data-hidden--"
json.dumps(event, indent=2)
if LOGGER.isEnabledFor(logging.DEBUG)
else "--data-hidden--",
)
sfn_client = boto3.client("stepfunctions")
s3_resource = boto3.resource("s3")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,12 @@
Tests the account file processing lambda
"""
import unittest
from ..process_account_files import process_account, process_account_list, get_details_from_event
from ..process_account_files import (
process_account,
process_account_list,
get_details_from_event,
sanitize_account_name_for_snf,
)


class SuccessTestCase(unittest.TestCase):
Expand All @@ -20,7 +25,7 @@ def test_process_account_when_account_exists(self):
"account_full_name": "myTestAccountName",
"account_id": 123456789012,
"needs_created": False,
}
},
)

def test_process_account_when_account_does_not_exist(self):
Expand All @@ -35,7 +40,7 @@ def test_process_account_when_account_does_not_exist(self):
"alias": "MyCoolAlias",
"account_full_name": "myTestAccountName",
"needs_created": True,
}
},
)

def test_process_account_list(self):
Expand All @@ -59,6 +64,43 @@ def test_process_account_list(self):
],
)

def test_get_sanitize_account_name(self):
self.assertEqual(
sanitize_account_name_for_snf("myTestAccountName"), "myTestAccountName"
)
self.assertEqual(
sanitize_account_name_for_snf(
"thisIsALongerAccountNameForTestingTruncatedNames"
),
"thisIsALongerAccountNameForTes",
)
self.assertEqual(
sanitize_account_name_for_snf(
"thisIsALongerAccountName ForTestingTruncatedNames"
),
"thisIsALongerAccountName_ForTe",
)
self.assertEqual(
sanitize_account_name_for_snf("this accountname <has illegal> chars"),
"this_accountname__has_illegal_",
)
self.assertEqual(
sanitize_account_name_for_snf("this accountname \\has illegal"),
"this_accountname__has_illegal",
)
self.assertEqual(
sanitize_account_name_for_snf("^startswithanillegalchar"),
"_startswithanillegalchar",
)
self.assertEqual(
len(
sanitize_account_name_for_snf(
"ReallyLongAccountNameThatShouldBeTruncatedBecauseItsTooLong"
)
),
30,
)


class FailureTestCase(unittest.TestCase):
# pylint: disable=W0106
Expand All @@ -67,6 +109,5 @@ def test_event_parsing(self):
with self.assertRaises(ValueError) as _error:
get_details_from_event(sample_event)
self.assertEqual(
str(_error.exception),
"No S3 Event details present in event trigger"
str(_error.exception), "No S3 Event details present in event trigger"
)