Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 22 additions & 31 deletions app/integrations/aws/client.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from functools import wraps

import structlog
import boto3 # type: ignore
from botocore.client import BaseClient # type: ignore
from botocore.exceptions import BotoCoreError, ClientError # type: ignore
from core.config import settings
from core.logging import get_module_logger

logger = get_module_logger()
logger = structlog.get_logger()

SYSTEM_ADMIN_PERMISSIONS = settings.aws.SYSTEM_ADMIN_PERMISSIONS
VIEW_ONLY_PERMISSIONS = settings.aws.VIEW_ONLY_PERMISSIONS
Expand All @@ -30,39 +30,36 @@ def wrapper(*args, **kwargs):
try:
return func(*args, **kwargs)
except BotoCoreError as e:
logger.error(
log = logger.bind(module=func.__module__, function=func.__name__)
log.error(
"boto_core_error",
module=func.__module__,
function=func.__name__,
error=str(e),
)
except ClientError as e:
log = logger.bind(
module=func.__module__,
function=func.__name__,
error_code=e.response["Error"]["Code"],
)
if e.response["Error"]["Code"] in THROTTLING_ERRS:
logger.info(
log.info(
"aws_throttling_error",
module=func.__module__,
function=func.__name__,
error=str(e),
)
elif e.response["Error"]["Code"] in RESOURCE_NOT_FOUND_ERRS:
logger.warning(
log.warning(
"aws_resource_not_found",
module=func.__module__,
function=func.__name__,
error=str(e),
)
else:
logger.error(
log.error(
"aws_client_error",
module=func.__module__,
function=func.__name__,
error=str(e),
)
except Exception as e: # Catch-all for any other types of exceptions
logger.error(
log = logger.bind(module=func.__module__, function=func.__name__)
log.error(
"unexpected_error",
module=func.__module__,
function=func.__name__,
error=str(e),
)
return False
Expand Down Expand Up @@ -153,9 +150,8 @@ def execute_aws_api_call(
if client_config is None:
client_config = {"region_name": AWS_REGION}

logger.debug(
"aws_api_call_started", service=service_name, method=method, paginated=paginated
)
log = logger.bind(service=service_name, method=method, paginated=paginated)
log.debug("aws_api_call_started")

client = get_aws_service_client(
service_name,
Expand All @@ -173,21 +169,15 @@ def execute_aws_api_call(
"ResponseMetadata" in results
and results["ResponseMetadata"]["HTTPStatusCode"] != 200
):
logger.error(
log.error(
"aws_api_call_failed",
service=service_name,
method=method,
status_code=results["ResponseMetadata"]["HTTPStatusCode"],
)
raise RuntimeError(
f"API call to {service_name}.{method} failed with status code {results['ResponseMetadata']['HTTPStatusCode']}"
)

logger.debug(
"aws_api_call_completed",
service=service_name,
method=method,
)
log.debug("aws_api_call_completed")

return results

Expand All @@ -208,6 +198,9 @@ def paginator(client: BaseClient, operation, keys=None, **kwargs):
"""
paginator = client.get_paginator(operation)
results = []
log = logger.bind(
service=client.meta.service_model.service_name, operation=operation
)

for page in paginator.paginate(**kwargs):
if keys is None:
Expand All @@ -219,10 +212,8 @@ def paginator(client: BaseClient, operation, keys=None, **kwargs):
results.append(value)
else:
if key == "ResponseMetadata" and value["HTTPStatusCode"] != 200:
logger.error(
log.error(
"api_call_failed_during_pagination",
service=client.meta.service_model.service_name,
operation=operation,
status_code=value["HTTPStatusCode"],
)
raise RuntimeError(
Expand Down
5 changes: 3 additions & 2 deletions app/integrations/aws/client_next.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,15 @@
import time
from typing import Any, List, Optional, Callable, cast

import structlog
import boto3 # type: ignore
from botocore.client import BaseClient # type: ignore
from botocore.exceptions import BotoCoreError, ClientError # type: ignore
from core.config import settings
from core.logging import get_module_logger

from infrastructure.operations.result import OperationResult

logger = get_module_logger()
logger = structlog.get_logger()

AWS_REGION = settings.aws.AWS_REGION
THROTTLING_ERRS = settings.aws.THROTTLING_ERRS
Expand Down
16 changes: 7 additions & 9 deletions app/integrations/aws/config.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import structlog

from core.config import settings
from core.logging import get_module_logger
from integrations.aws.client import execute_aws_api_call, handle_aws_api_errors

logger = get_module_logger()
logger = structlog.get_logger()
AUDIT_ROLE_ARN = settings.aws.AUDIT_ROLE_ARN


Expand All @@ -17,11 +18,12 @@ def describe_aggregate_compliance_by_config_rules(config_aggregator_name, filter
Returns:
list: A list of compliance objects
"""
logger.debug(
"config_describe_aggregate_compliance_started",
log = logger.bind(
operation="describe_aggregate_compliance_by_config_rules",
aggregator=config_aggregator_name,
filter_keys=list(filters.keys()) if filters else [],
)
log.debug("config_describe_aggregate_compliance_started")

params = {
"ConfigurationAggregatorName": config_aggregator_name,
Expand All @@ -37,10 +39,6 @@ def describe_aggregate_compliance_by_config_rules(config_aggregator_name, filter
)

rule_count = len(response) if response else 0
logger.debug(
"config_describe_aggregate_compliance_completed",
aggregator=config_aggregator_name,
rule_count=rule_count,
)
log.debug("config_describe_aggregate_compliance_completed", rule_count=rule_count)

return response if response else []
13 changes: 7 additions & 6 deletions app/integrations/aws/cost_explorer.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,20 @@
"""Cost Explorer API integration."""

import structlog
from core.config import settings
from core.logging import get_module_logger
from integrations.aws.client import execute_aws_api_call, handle_aws_api_errors

logger = get_module_logger()
logger = structlog.get_logger()
ORG_ROLE_ARN = settings.aws.ORG_ROLE_ARN


@handle_aws_api_errors
def get_cost_and_usage(time_period, granularity, metrics, filter=None, group_by=None):
logger.debug(
log = logger.bind(
operation="get_cost_and_usage", granularity=granularity, metrics=str(metrics)
)
log.debug(
"cost_explorer_get_cost_and_usage_started",
granularity=granularity,
metrics=metrics,
filter_present=filter is not None,
group_by_present=group_by is not None,
)
Expand All @@ -36,6 +37,6 @@ def get_cost_and_usage(time_period, granularity, metrics, filter=None, group_by=
)

result_size = len(response.get("ResultsByTime", [])) if response else 0
logger.debug("cost_explorer_get_cost_and_usage_completed", result_count=result_size)
log.debug("cost_explorer_get_cost_and_usage_completed", result_count=result_size)

return response
48 changes: 25 additions & 23 deletions app/integrations/aws/dynamodb.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
"""AWS DynamoDB API client"""

import structlog
from core.config import settings
from core.logging import get_module_logger
from integrations.aws.client import (
execute_aws_api_call,
handle_aws_api_errors,
)

logger = get_module_logger()
logger = structlog.get_logger()

client_config = dict(
region_name=settings.aws.AWS_REGION,
Expand All @@ -22,7 +22,8 @@ def query(
TableName,
**kwargs,
):
logger.debug("dynamodb_query_started", table=TableName)
log = logger.bind(operation="query", table=TableName)
log.debug("dynamodb_query_started")
params = {
"TableName": TableName,
}
Expand All @@ -31,9 +32,8 @@ def query(
response = execute_aws_api_call(
"dynamodb", "query", paginated=True, client_config=client_config, **params
)
logger.debug(
log.debug(
"dynamodb_query_completed",
table=TableName,
item_count=len(response) if response else 0,
)
return response
Expand All @@ -42,7 +42,8 @@ def query(
@handle_aws_api_errors
def scan(TableName, **kwargs):
"""Scan a DynamoDB table. Will return only the list of items found in the table that match the query."""
logger.debug("dynamodb_scan_started", table=TableName)
log = logger.bind(operation="scan", table=TableName)
log.debug("dynamodb_scan_started")
params = {
"TableName": TableName,
}
Expand All @@ -56,17 +57,17 @@ def scan(TableName, **kwargs):
client_config=client_config,
**params,
)
logger.debug(
log.debug(
"dynamodb_scan_completed",
table=TableName,
item_count=len(response) if response else 0,
)
return response


@handle_aws_api_errors
def put_item(TableName, **kwargs):
logger.debug("dynamodb_put_item_started", table=TableName)
log = logger.bind(operation="put_item", table=TableName)
log.debug("dynamodb_put_item_started")
params = {
"TableName": TableName,
}
Expand All @@ -75,7 +76,7 @@ def put_item(TableName, **kwargs):
response = execute_aws_api_call(
"dynamodb", "put_item", client_config=client_config, **params
)
logger.debug("dynamodb_put_item_completed", table=TableName)
log.debug("dynamodb_put_item_completed")
return response


Expand All @@ -85,7 +86,8 @@ def get_item(TableName, **kwargs) -> dict:

Reference: https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/dynamodb.html#DynamoDB.Client.get_item
"""
logger.debug("dynamodb_get_item_started", table=TableName)
log = logger.bind(operation="get_item", table=TableName)
log.debug("dynamodb_get_item_started")
params = {
"TableName": TableName,
}
Expand All @@ -95,16 +97,14 @@ def get_item(TableName, **kwargs) -> dict:
"dynamodb", "get_item", client_config=client_config, **params
)
if response.get("ResponseMetadata", {}).get("HTTPStatusCode") == 200:
logger.debug(
log.debug(
"dynamodb_get_item_completed",
table=TableName,
item_found=bool(response.get("Item")),
)
return response.get("Item")
else:
logger.warning(
log.warning(
"dynamodb_get_item_failed",
table=TableName,
status_code=response.get("ResponseMetadata", {}).get("HTTPStatusCode"),
)
return None
Expand All @@ -122,7 +122,8 @@ def update_item(TableName, **kwargs):
dict: Response from the AWS API call
Reference: https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/dynamodb.html#DynamoDB.Client.update_item
"""
logger.debug("dynamodb_update_item_started", table=TableName)
log = logger.bind(operation="update_item", table=TableName)
log.debug("dynamodb_update_item_started")
params = {
"TableName": TableName,
}
Expand All @@ -132,19 +133,19 @@ def update_item(TableName, **kwargs):
"dynamodb", "update_item", client_config=client_config, **params
)
if response.get("ResponseMetadata", {}).get("HTTPStatusCode") == 200:
logger.debug("dynamodb_update_item_completed", table=TableName)
log.debug("dynamodb_update_item_completed")
return response
else:
logger.warning(
log.warning(
"dynamodb_update_item_failed",
table=TableName,
status_code=response.get("ResponseMetadata", {}).get("HTTPStatusCode"),
)


@handle_aws_api_errors
def delete_item(TableName, **kwargs):
logger.debug("dynamodb_delete_item_started", table=TableName)
log = logger.bind(operation="delete_item", table=TableName)
log.debug("dynamodb_delete_item_started")
params = {
"TableName": TableName,
}
Expand All @@ -153,17 +154,18 @@ def delete_item(TableName, **kwargs):
response = execute_aws_api_call(
"dynamodb", "delete_item", client_config=client_config, **params
)
logger.debug("dynamodb_delete_item_completed", table=TableName)
log.debug("dynamodb_delete_item_completed")
return response


@handle_aws_api_errors
def list_tables(**kwargs):
logger.debug("dynamodb_list_tables_started")
log = logger.bind(operation="list_tables")
log.debug("dynamodb_list_tables_started")
response = execute_aws_api_call(
"dynamodb", "list_tables", client_config=client_config, **kwargs
)
logger.debug(
log.debug(
"dynamodb_list_tables_completed",
table_count=len(response.get("TableNames", [])),
)
Expand Down
4 changes: 2 additions & 2 deletions app/integrations/aws/dynamodb_next.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,12 @@

from typing import Any, Dict

import structlog
from core.config import settings
from core.logging import get_module_logger
from integrations.aws.client_next import execute_aws_api_call
from infrastructure.operations.result import OperationResult

logger = get_module_logger()
logger = structlog.get_logger()

AWS_REGION = settings.aws.AWS_REGION

Expand Down
Loading