diff --git a/app/integrations/aws/client.py b/app/integrations/aws/client.py index 28d51520a..89539bfee 100644 --- a/app/integrations/aws/client.py +++ b/app/integrations/aws/client.py @@ -1,12 +1,12 @@ from functools import wraps +import structlog import boto3 # type: ignore from botocore.client import BaseClient # type: ignore from botocore.exceptions import BotoCoreError, ClientError # type: ignore from core.config import settings -from core.logging import get_module_logger -logger = get_module_logger() +logger = structlog.get_logger() SYSTEM_ADMIN_PERMISSIONS = settings.aws.SYSTEM_ADMIN_PERMISSIONS VIEW_ONLY_PERMISSIONS = settings.aws.VIEW_ONLY_PERMISSIONS @@ -30,39 +30,36 @@ def wrapper(*args, **kwargs): try: return func(*args, **kwargs) except BotoCoreError as e: - logger.error( + log = logger.bind(module=func.__module__, function=func.__name__) + log.error( "boto_core_error", - module=func.__module__, - function=func.__name__, error=str(e), ) except ClientError as e: + log = logger.bind( + module=func.__module__, + function=func.__name__, + error_code=e.response["Error"]["Code"], + ) if e.response["Error"]["Code"] in THROTTLING_ERRS: - logger.info( + log.info( "aws_throttling_error", - module=func.__module__, - function=func.__name__, error=str(e), ) elif e.response["Error"]["Code"] in RESOURCE_NOT_FOUND_ERRS: - logger.warning( + log.warning( "aws_resource_not_found", - module=func.__module__, - function=func.__name__, error=str(e), ) else: - logger.error( + log.error( "aws_client_error", - module=func.__module__, - function=func.__name__, error=str(e), ) except Exception as e: # Catch-all for any other types of exceptions - logger.error( + log = logger.bind(module=func.__module__, function=func.__name__) + log.error( "unexpected_error", - module=func.__module__, - function=func.__name__, error=str(e), ) return False @@ -153,9 +150,8 @@ def execute_aws_api_call( if client_config is None: client_config = {"region_name": AWS_REGION} - logger.debug( - "aws_api_call_started", service=service_name, method=method, paginated=paginated - ) + log = logger.bind(service=service_name, method=method, paginated=paginated) + log.debug("aws_api_call_started") client = get_aws_service_client( service_name, @@ -173,21 +169,15 @@ def execute_aws_api_call( "ResponseMetadata" in results and results["ResponseMetadata"]["HTTPStatusCode"] != 200 ): - logger.error( + log.error( "aws_api_call_failed", - service=service_name, - method=method, status_code=results["ResponseMetadata"]["HTTPStatusCode"], ) raise RuntimeError( f"API call to {service_name}.{method} failed with status code {results['ResponseMetadata']['HTTPStatusCode']}" ) - logger.debug( - "aws_api_call_completed", - service=service_name, - method=method, - ) + log.debug("aws_api_call_completed") return results @@ -208,6 +198,9 @@ def paginator(client: BaseClient, operation, keys=None, **kwargs): """ paginator = client.get_paginator(operation) results = [] + log = logger.bind( + service=client.meta.service_model.service_name, operation=operation + ) for page in paginator.paginate(**kwargs): if keys is None: @@ -219,10 +212,8 @@ def paginator(client: BaseClient, operation, keys=None, **kwargs): results.append(value) else: if key == "ResponseMetadata" and value["HTTPStatusCode"] != 200: - logger.error( + log.error( "api_call_failed_during_pagination", - service=client.meta.service_model.service_name, - operation=operation, status_code=value["HTTPStatusCode"], ) raise RuntimeError( diff --git a/app/integrations/aws/client_next.py b/app/integrations/aws/client_next.py index 3e1754043..d077621cc 100644 --- a/app/integrations/aws/client_next.py +++ b/app/integrations/aws/client_next.py @@ -31,14 +31,15 @@ import time from typing import Any, List, Optional, Callable, cast +import structlog import boto3 # type: ignore from botocore.client import BaseClient # type: ignore from botocore.exceptions import BotoCoreError, ClientError # type: ignore from core.config import settings -from core.logging import get_module_logger + from infrastructure.operations.result import OperationResult -logger = get_module_logger() +logger = structlog.get_logger() AWS_REGION = settings.aws.AWS_REGION THROTTLING_ERRS = settings.aws.THROTTLING_ERRS diff --git a/app/integrations/aws/config.py b/app/integrations/aws/config.py index b9bfb431f..723ea5a70 100644 --- a/app/integrations/aws/config.py +++ b/app/integrations/aws/config.py @@ -1,8 +1,9 @@ +import structlog + from core.config import settings -from core.logging import get_module_logger from integrations.aws.client import execute_aws_api_call, handle_aws_api_errors -logger = get_module_logger() +logger = structlog.get_logger() AUDIT_ROLE_ARN = settings.aws.AUDIT_ROLE_ARN @@ -17,11 +18,12 @@ def describe_aggregate_compliance_by_config_rules(config_aggregator_name, filter Returns: list: A list of compliance objects """ - logger.debug( - "config_describe_aggregate_compliance_started", + log = logger.bind( + operation="describe_aggregate_compliance_by_config_rules", aggregator=config_aggregator_name, filter_keys=list(filters.keys()) if filters else [], ) + log.debug("config_describe_aggregate_compliance_started") params = { "ConfigurationAggregatorName": config_aggregator_name, @@ -37,10 +39,6 @@ def describe_aggregate_compliance_by_config_rules(config_aggregator_name, filter ) rule_count = len(response) if response else 0 - logger.debug( - "config_describe_aggregate_compliance_completed", - aggregator=config_aggregator_name, - rule_count=rule_count, - ) + log.debug("config_describe_aggregate_compliance_completed", rule_count=rule_count) return response if response else [] diff --git a/app/integrations/aws/cost_explorer.py b/app/integrations/aws/cost_explorer.py index 4c733d6a1..b6d0f3420 100644 --- a/app/integrations/aws/cost_explorer.py +++ b/app/integrations/aws/cost_explorer.py @@ -1,19 +1,20 @@ """Cost Explorer API integration.""" +import structlog from core.config import settings -from core.logging import get_module_logger from integrations.aws.client import execute_aws_api_call, handle_aws_api_errors -logger = get_module_logger() +logger = structlog.get_logger() ORG_ROLE_ARN = settings.aws.ORG_ROLE_ARN @handle_aws_api_errors def get_cost_and_usage(time_period, granularity, metrics, filter=None, group_by=None): - logger.debug( + log = logger.bind( + operation="get_cost_and_usage", granularity=granularity, metrics=str(metrics) + ) + log.debug( "cost_explorer_get_cost_and_usage_started", - granularity=granularity, - metrics=metrics, filter_present=filter is not None, group_by_present=group_by is not None, ) @@ -36,6 +37,6 @@ def get_cost_and_usage(time_period, granularity, metrics, filter=None, group_by= ) result_size = len(response.get("ResultsByTime", [])) if response else 0 - logger.debug("cost_explorer_get_cost_and_usage_completed", result_count=result_size) + log.debug("cost_explorer_get_cost_and_usage_completed", result_count=result_size) return response diff --git a/app/integrations/aws/dynamodb.py b/app/integrations/aws/dynamodb.py index 1d56db150..b6102ec58 100644 --- a/app/integrations/aws/dynamodb.py +++ b/app/integrations/aws/dynamodb.py @@ -1,13 +1,13 @@ """AWS DynamoDB API client""" +import structlog from core.config import settings -from core.logging import get_module_logger from integrations.aws.client import ( execute_aws_api_call, handle_aws_api_errors, ) -logger = get_module_logger() +logger = structlog.get_logger() client_config = dict( region_name=settings.aws.AWS_REGION, @@ -22,7 +22,8 @@ def query( TableName, **kwargs, ): - logger.debug("dynamodb_query_started", table=TableName) + log = logger.bind(operation="query", table=TableName) + log.debug("dynamodb_query_started") params = { "TableName": TableName, } @@ -31,9 +32,8 @@ def query( response = execute_aws_api_call( "dynamodb", "query", paginated=True, client_config=client_config, **params ) - logger.debug( + log.debug( "dynamodb_query_completed", - table=TableName, item_count=len(response) if response else 0, ) return response @@ -42,7 +42,8 @@ def query( @handle_aws_api_errors def scan(TableName, **kwargs): """Scan a DynamoDB table. Will return only the list of items found in the table that match the query.""" - logger.debug("dynamodb_scan_started", table=TableName) + log = logger.bind(operation="scan", table=TableName) + log.debug("dynamodb_scan_started") params = { "TableName": TableName, } @@ -56,9 +57,8 @@ def scan(TableName, **kwargs): client_config=client_config, **params, ) - logger.debug( + log.debug( "dynamodb_scan_completed", - table=TableName, item_count=len(response) if response else 0, ) return response @@ -66,7 +66,8 @@ def scan(TableName, **kwargs): @handle_aws_api_errors def put_item(TableName, **kwargs): - logger.debug("dynamodb_put_item_started", table=TableName) + log = logger.bind(operation="put_item", table=TableName) + log.debug("dynamodb_put_item_started") params = { "TableName": TableName, } @@ -75,7 +76,7 @@ def put_item(TableName, **kwargs): response = execute_aws_api_call( "dynamodb", "put_item", client_config=client_config, **params ) - logger.debug("dynamodb_put_item_completed", table=TableName) + log.debug("dynamodb_put_item_completed") return response @@ -85,7 +86,8 @@ def get_item(TableName, **kwargs) -> dict: Reference: https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/dynamodb.html#DynamoDB.Client.get_item """ - logger.debug("dynamodb_get_item_started", table=TableName) + log = logger.bind(operation="get_item", table=TableName) + log.debug("dynamodb_get_item_started") params = { "TableName": TableName, } @@ -95,16 +97,14 @@ def get_item(TableName, **kwargs) -> dict: "dynamodb", "get_item", client_config=client_config, **params ) if response.get("ResponseMetadata", {}).get("HTTPStatusCode") == 200: - logger.debug( + log.debug( "dynamodb_get_item_completed", - table=TableName, item_found=bool(response.get("Item")), ) return response.get("Item") else: - logger.warning( + log.warning( "dynamodb_get_item_failed", - table=TableName, status_code=response.get("ResponseMetadata", {}).get("HTTPStatusCode"), ) return None @@ -122,7 +122,8 @@ def update_item(TableName, **kwargs): dict: Response from the AWS API call Reference: https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/dynamodb.html#DynamoDB.Client.update_item """ - logger.debug("dynamodb_update_item_started", table=TableName) + log = logger.bind(operation="update_item", table=TableName) + log.debug("dynamodb_update_item_started") params = { "TableName": TableName, } @@ -132,19 +133,19 @@ def update_item(TableName, **kwargs): "dynamodb", "update_item", client_config=client_config, **params ) if response.get("ResponseMetadata", {}).get("HTTPStatusCode") == 200: - logger.debug("dynamodb_update_item_completed", table=TableName) + log.debug("dynamodb_update_item_completed") return response else: - logger.warning( + log.warning( "dynamodb_update_item_failed", - table=TableName, status_code=response.get("ResponseMetadata", {}).get("HTTPStatusCode"), ) @handle_aws_api_errors def delete_item(TableName, **kwargs): - logger.debug("dynamodb_delete_item_started", table=TableName) + log = logger.bind(operation="delete_item", table=TableName) + log.debug("dynamodb_delete_item_started") params = { "TableName": TableName, } @@ -153,17 +154,18 @@ def delete_item(TableName, **kwargs): response = execute_aws_api_call( "dynamodb", "delete_item", client_config=client_config, **params ) - logger.debug("dynamodb_delete_item_completed", table=TableName) + log.debug("dynamodb_delete_item_completed") return response @handle_aws_api_errors def list_tables(**kwargs): - logger.debug("dynamodb_list_tables_started") + log = logger.bind(operation="list_tables") + log.debug("dynamodb_list_tables_started") response = execute_aws_api_call( "dynamodb", "list_tables", client_config=client_config, **kwargs ) - logger.debug( + log.debug( "dynamodb_list_tables_completed", table_count=len(response.get("TableNames", [])), ) diff --git a/app/integrations/aws/dynamodb_next.py b/app/integrations/aws/dynamodb_next.py index bb04641c1..a4d5320f7 100644 --- a/app/integrations/aws/dynamodb_next.py +++ b/app/integrations/aws/dynamodb_next.py @@ -20,12 +20,12 @@ from typing import Any, Dict +import structlog from core.config import settings -from core.logging import get_module_logger from integrations.aws.client_next import execute_aws_api_call from infrastructure.operations.result import OperationResult -logger = get_module_logger() +logger = structlog.get_logger() AWS_REGION = settings.aws.AWS_REGION diff --git a/app/integrations/aws/guard_duty.py b/app/integrations/aws/guard_duty.py index cd8262c4e..d9fbea204 100644 --- a/app/integrations/aws/guard_duty.py +++ b/app/integrations/aws/guard_duty.py @@ -1,10 +1,10 @@ """AWS GuardDuty integration module.""" +import structlog from core.config import settings -from core.logging import get_module_logger from integrations.aws.client import execute_aws_api_call, handle_aws_api_errors -logger = get_module_logger() +logger = structlog.get_logger() LOGGING_ROLE_ARN = settings.aws.LOGGING_ROLE_ARN @@ -15,7 +15,8 @@ def list_detectors(): Returns: list: A list of detector objects. """ - logger.debug("guard_duty_list_detectors_started") + log = logger.bind(operation="list_detectors") + log.debug("guard_duty_list_detectors_started") response = execute_aws_api_call( "guardduty", "list_detectors", @@ -24,7 +25,7 @@ def list_detectors(): role_arn=LOGGING_ROLE_ARN, ) detector_count = len(response) if response else 0 - logger.debug("guard_duty_list_detectors_completed", detector_count=detector_count) + log.debug("guard_duty_list_detectors_completed", detector_count=detector_count) return response if response else [] @@ -39,9 +40,9 @@ def get_findings_statistics(detector_id, finding_criteria=None): Returns: dict: The findings statistics. """ - logger.debug( + log = logger.bind(operation="get_findings_statistics", detector_id=detector_id) + log.debug( "guard_duty_get_findings_statistics_started", - detector_id=detector_id, criteria_present=finding_criteria is not None, ) @@ -62,9 +63,8 @@ def get_findings_statistics(detector_id, finding_criteria=None): has_findings = bool( response.get("FindingStatistics", {}).get("CountBySeverity", {}) ) - logger.debug( + log.debug( "guard_duty_get_findings_statistics_completed", - detector_id=detector_id, has_findings=has_findings, ) diff --git a/app/integrations/aws/identity_store.py b/app/integrations/aws/identity_store.py index b4043555b..979e34797 100644 --- a/app/integrations/aws/identity_store.py +++ b/app/integrations/aws/identity_store.py @@ -1,7 +1,7 @@ """AWS Identity Store module""" +import structlog from core.config import settings -from core.logging import get_module_logger import pandas as pd from integrations.aws.client import execute_aws_api_call, handle_aws_api_errors @@ -10,7 +10,7 @@ INSTANCE_ID = settings.aws.INSTANCE_ID ROLE_ARN = settings.aws.ORG_ROLE_ARN -logger = get_module_logger() +logger = structlog.get_logger() def resolve_identity_store_id(kwargs): diff --git a/app/integrations/aws/identity_store_next.py b/app/integrations/aws/identity_store_next.py index 345fe5227..050ddafde 100644 --- a/app/integrations/aws/identity_store_next.py +++ b/app/integrations/aws/identity_store_next.py @@ -17,8 +17,8 @@ from concurrent.futures import ThreadPoolExecutor, as_completed from typing import Any, Dict, List, Optional, Mapping +import structlog from core.config import settings -from core.logging import get_module_logger from integrations.aws.client_next import execute_aws_api_call from infrastructure.operations.result import OperationResult from infrastructure.operations.status import OperationStatus @@ -27,7 +27,7 @@ AWS_IDENTITY_STORE_ID = settings.aws.INSTANCE_ID ROLE_ARN = settings.aws.ORG_ROLE_ARN -logger = get_module_logger() +logger = structlog.get_logger() # User Management Functions diff --git a/app/integrations/aws/lambdas.py b/app/integrations/aws/lambdas.py index 107c547da..2009471b4 100644 --- a/app/integrations/aws/lambdas.py +++ b/app/integrations/aws/lambdas.py @@ -1,7 +1,7 @@ -from core.logging import get_module_logger +import structlog from integrations.aws.client import execute_aws_api_call, handle_aws_api_errors -logger = get_module_logger() +logger = structlog.get_logger() @handle_aws_api_errors @@ -11,12 +11,13 @@ def list_functions(): Returns: list: A list of Lambda functions. """ - logger.debug("lambda_list_functions_started") + log = logger.bind(operation="list_functions") + log.debug("lambda_list_functions_started") response = execute_aws_api_call( "lambda", "list_functions", paginated=True, keys=["Functions"] ) function_count = len(response) if response else 0 - logger.debug("lambda_list_functions_completed", function_count=function_count) + log.debug("lambda_list_functions_completed", function_count=function_count) return response @@ -27,12 +28,13 @@ def list_layers(): Returns: list: A list of Lambda layers. """ - logger.debug("lambda_list_layers_started") + log = logger.bind(operation="list_layers") + log.debug("lambda_list_layers_started") response = execute_aws_api_call( "lambda", "list_layers", paginated=True, keys=["Layers"] ) layer_count = len(response) if response else 0 - logger.debug("lambda_list_layers_completed", layer_count=layer_count) + log.debug("lambda_list_layers_completed", layer_count=layer_count) return response @@ -47,20 +49,15 @@ def get_layer_version(layer_name, version_number): Returns: dict: The Lambda layer version. """ - logger.debug( - "lambda_get_layer_version_started", - layer_name=layer_name, - version=version_number, + log = logger.bind( + operation="get_layer_version", layer_name=layer_name, version=version_number ) + log.debug("lambda_get_layer_version_started") response = execute_aws_api_call( "lambda", "get_layer_version", LayerName=layer_name, VersionNumber=version_number, ) - logger.debug( - "lambda_get_layer_version_completed", - layer_name=layer_name, - version=version_number, - ) + log.debug("lambda_get_layer_version_completed") return response diff --git a/app/integrations/aws/organizations.py b/app/integrations/aws/organizations.py index 09c5ba5d2..bffab427a 100644 --- a/app/integrations/aws/organizations.py +++ b/app/integrations/aws/organizations.py @@ -1,10 +1,10 @@ +import structlog from core.config import settings -from core.logging import get_module_logger from integrations.aws.client import execute_aws_api_call, handle_aws_api_errors ORG_ROLE_ARN = settings.aws.ORG_ROLE_ARN -logger = get_module_logger() +logger = structlog.get_logger() @handle_aws_api_errors diff --git a/app/integrations/aws/security_hub.py b/app/integrations/aws/security_hub.py index c8dcab87f..145a51033 100644 --- a/app/integrations/aws/security_hub.py +++ b/app/integrations/aws/security_hub.py @@ -1,8 +1,8 @@ +import structlog from core.config import settings -from core.logging import get_module_logger from integrations.aws.client import execute_aws_api_call, handle_aws_api_errors -logger = get_module_logger() +logger = structlog.get_logger() LOGGING_ROLE_ARN = settings.aws.LOGGING_ROLE_ARN @@ -16,7 +16,8 @@ def get_findings(filters): Returns: list: A list of finding objects. """ - logger.debug("security_hub_get_findings_started", filter_keys=list(filters.keys())) + log = logger.bind(operation="get_findings", filter_keys=list(filters.keys())) + log.debug("security_hub_get_findings_started") response = execute_aws_api_call( "securityhub", "get_findings", @@ -25,5 +26,5 @@ def get_findings(filters): Filters=filters, ) finding_count = len(response) if response else 0 - logger.debug("security_hub_get_findings_completed", finding_count=finding_count) + log.debug("security_hub_get_findings_completed", finding_count=finding_count) return response diff --git a/app/integrations/aws/sqs.py b/app/integrations/aws/sqs.py index cdf560e15..935f8cf7e 100644 --- a/app/integrations/aws/sqs.py +++ b/app/integrations/aws/sqs.py @@ -1,7 +1,7 @@ -from core.logging import get_module_logger +import structlog from integrations.aws.client import execute_aws_api_call, handle_aws_api_errors -logger = get_module_logger() +logger = structlog.get_logger() @handle_aws_api_errors @@ -14,11 +14,8 @@ def get_queue_url(queue_name): Returns: str: The URL of the SQS queue. """ - logger.info( - "getting_queue_url", - service="sqs", - queue_name=queue_name, - ) + log = logger.bind(operation="get_queue_url", queue_name=queue_name) + log.info("getting_queue_url") if not queue_name: raise ValueError("Queue_name must not be empty") return execute_aws_api_call("sqs", "get_queue_url", QueueName=queue_name)[ @@ -38,13 +35,12 @@ def send_message(queue_url, message_body, message_group_id): Returns: dict: The response from the SQS service. """ - logger.info( - "sending_message", - service="sqs", + log = logger.bind( + operation="send_message", queue_url=queue_url, - message_body=message_body, message_group_id=message_group_id, ) + log.info("sending_message") return execute_aws_api_call( "sqs", "send_message", @@ -66,13 +62,13 @@ def receive_message(queue_url, max_number_of_messages=10, wait_time_seconds=10): Returns: list: A list of messages. """ - logger.info( - "receiving_message", - service="sqs", + log = logger.bind( + operation="receive_message", queue_url=queue_url, max_number_of_messages=max_number_of_messages, wait_time_seconds=wait_time_seconds, ) + log.info("receiving_message") return execute_aws_api_call( "sqs", "receive_message", diff --git a/app/integrations/aws/sso_admin.py b/app/integrations/aws/sso_admin.py index f5bc8a2f4..f3670c266 100644 --- a/app/integrations/aws/sso_admin.py +++ b/app/integrations/aws/sso_admin.py @@ -1,9 +1,9 @@ +import structlog from core.config import settings -from core.logging import get_module_logger from integrations.aws.client import execute_aws_api_call, handle_aws_api_errors -logger = get_module_logger() +logger = structlog.get_logger() ROLE_ARN = settings.aws.ORG_ROLE_ARN INSTANCE_ARN = settings.aws.INSTANCE_ARN diff --git a/app/tests/integrations/aws/test_legacy_aws_client.py b/app/tests/integrations/aws/test_legacy_aws_client.py index 62663f2eb..63617306f 100644 --- a/app/tests/integrations/aws/test_legacy_aws_client.py +++ b/app/tests/integrations/aws/test_legacy_aws_client.py @@ -15,18 +15,20 @@ def test_handle_aws_api_errors_catches_botocore_error(mock_logger): mock_func.__name__ = "mock_func_name" mock_func.__module__ = "mock_module" decorated_func = aws_client.handle_aws_api_errors(mock_func) - + mock_logger_bind = MagicMock() + mock_logger.bind.return_value = mock_logger_bind result = decorated_func() assert result is False mock_func.assert_called_once() - mock_logger.error.assert_called_once_with( + mock_logger.bind.assert_called_once_with( + module="mock_module", function="mock_func_name" + ) + mock_logger_bind.error.assert_called_once_with( "boto_core_error", - module="mock_module", - function="mock_func_name", error="An unspecified error occurred", ) - mock_logger.info.assert_not_called() + mock_logger_bind.info.assert_not_called() @patch("integrations.aws.client.logger") @@ -38,20 +40,25 @@ def test_handle_aws_api_errors_catches_client_error_resource_not_found(mock_logg ) mock_func.__name__ = "mock_func_name" mock_func.__module__ = "mock_module" + mock_bind_logger = MagicMock() + mock_logger.bind.return_value = mock_bind_logger decorated_func = aws_client.handle_aws_api_errors(mock_func) result = decorated_func() assert result is False mock_func.assert_called_once() - mock_logger.warning.assert_called_once_with( - "aws_resource_not_found", + mock_logger.bind.assert_called_once_with( module="mock_module", function="mock_func_name", + error_code="ResourceNotFoundException", + ) + mock_bind_logger.warning.assert_called_once_with( + "aws_resource_not_found", error="An error occurred (ResourceNotFoundException) when calling the operation_name operation: Unknown", ) - mock_logger.error.assert_not_called() - mock_logger.info.assert_not_called() + mock_bind_logger.error.assert_not_called() + mock_bind_logger.info.assert_not_called() @patch("integrations.aws.client.logger") @@ -64,19 +71,22 @@ def test_handle_aws_api_errors_catches_client_error_other(mock_logger): ) mock_func.__name__ = "mock_func_name" mock_func.__module__ = "mock_module" + mock_bind_logger = MagicMock() + mock_logger.bind.return_value = mock_bind_logger decorated_func = aws_client.handle_aws_api_errors(mock_func) result = decorated_func() assert result is False mock_func.assert_called_once() - mock_logger.error.assert_called_once_with( + mock_logger.bind.assert_called_once_with( + module="mock_module", function="mock_func_name", error_code="OtherError" + ) + mock_bind_logger.error.assert_called_once_with( "aws_client_error", - module="mock_module", - function="mock_func_name", error="An error occurred (OtherError) when calling the operation_name operation: An error occurred", ) - mock_logger.info.assert_not_called() + mock_bind_logger.info.assert_not_called() @patch("integrations.aws.client.logger") @@ -85,18 +95,20 @@ def test_handle_aws_api_errors_catches_exception(mock_logger): mock_func.__name__ = "mock_func_name" mock_func.__module__ = "mock_module" decorated_func = aws_client.handle_aws_api_errors(mock_func) - + mock_bind_logger = MagicMock() + mock_logger.bind.return_value = mock_bind_logger result = decorated_func() assert result is False mock_func.assert_called_once() - mock_logger.error.assert_called_once_with( + mock_logger.bind.assert_called_once_with( + module="mock_module", function="mock_func_name" + ) + mock_bind_logger.error.assert_called_once_with( "unexpected_error", - module="mock_module", - function="mock_func_name", error="Exception message", ) - mock_logger.info.assert_not_called() + mock_bind_logger.info.assert_not_called() def test_handle_aws_api_errors_passes_through_return_value(): @@ -199,6 +211,8 @@ def test_paginator_raises_exception_on_non_200_status(mock_logger): mock_client = MagicMock(spec=BaseClient) mock_paginator = MagicMock() mock_client.get_paginator.return_value = mock_paginator + mock_bound_logger = MagicMock() + mock_logger.bind.return_value = mock_bound_logger # Setup: bind() returns this mock # Add the meta attribute to the mock client mock_client.meta = MagicMock() @@ -227,10 +241,12 @@ def test_paginator_raises_exception_on_non_200_status(mock_logger): assert str(excinfo.value) == ( "API call to mock_service.operation failed with status code 500" ) - mock_logger.error.assert_called_once_with( + mock_logger.bind.assert_called_once_with( + service="mock_service", operation="operation" + ) + mock_bound_logger.error.assert_called_once() + mock_bound_logger.error.assert_called_once_with( "api_call_failed_during_pagination", - service="mock_service", - operation="operation", status_code=500, )