Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion airflow-core/src/airflow/serialization/schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,7 @@
"pool": { "type": "string", "default": "default_pool" },
"pool_slots": { "type": "number", "default": 1 },
"execution_timeout": { "$ref": "#/definitions/timedelta" },
"retry_delay": { "$ref": "#/definitions/timedelta" },
"retry_delay": { "$ref": "#/definitions/timedelta", "default": 300.0 },
"retry_exponential_backoff": { "type": "boolean", "default": false },
"max_retry_delay": { "$ref": "#/definitions/timedelta" },
"params": { "$ref": "#/definitions/params" },
Expand Down
15 changes: 11 additions & 4 deletions airflow-core/src/airflow/serialization/serialized_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -1291,7 +1291,7 @@ class SerializedBaseOperator(DAGNode, BaseSerialization):

resources: dict[str, Any] | None = None
retries: int = 0
retry_delay: datetime.timedelta
retry_delay: datetime.timedelta = datetime.timedelta(seconds=300)
retry_exponential_backoff: bool = False
run_as_user: str | None = None

Expand Down Expand Up @@ -2057,19 +2057,26 @@ def generate_client_defaults(cls) -> dict[str, Any]:
for k, v in OPERATOR_DEFAULTS.items():
if k not in cls.get_serialized_fields():
continue
# Exclude values that are the same as the schema defaults
if k in schema_defaults and schema_defaults[k] == v:
continue

# Exclude values that are None or empty collections
if v is None or v in [[], (), set(), {}]:
continue

# Check schema defaults first with raw value comparison (fast path)
if k in schema_defaults and schema_defaults[k] == v:
continue

# Use the existing serialize method to ensure consistent format
serialized_value = cls.serialize(v)
# Extract just the value part, consistent with serialize_to_json behavior
if isinstance(serialized_value, dict) and Encoding.TYPE in serialized_value:
serialized_value = serialized_value[Encoding.VAR]

# For cases where raw comparison failed but serialized values might match
# (e.g., timedelta vs float), check again with serialized value
if k in schema_defaults and schema_defaults[k] == serialized_value:
continue

client_defaults[k] = serialized_value

return client_defaults
Expand Down
68 changes: 56 additions & 12 deletions airflow-core/tests/unit/serialization/test_dag_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from __future__ import annotations

import contextlib
import copy
import dataclasses
import importlib
Expand Down Expand Up @@ -99,6 +100,43 @@
if TYPE_CHECKING:
from airflow.sdk.definitions.context import Context


@contextlib.contextmanager
def operator_defaults(overrides):
"""
Temporarily patches OPERATOR_DEFAULTS, restoring original values after context exit.

Example:
with operator_defaults({"retries": 2, "retry_delay": 200.0}):
# Test code with modified operator defaults
"""
from airflow.sdk.bases.operator import OPERATOR_DEFAULTS

original_values = {}
try:
# Store original values and apply overrides
for key, value in overrides.items():
original_values[key] = OPERATOR_DEFAULTS.get(key)
OPERATOR_DEFAULTS[key] = value

# Clear the cache to ensure fresh generation
SerializedBaseOperator.generate_client_defaults.cache_clear()

yield
finally:
# Cleanup: restore original values
for key, original_value in original_values.items():
if original_value is None and key in OPERATOR_DEFAULTS:
# Key didn't exist originally, remove it
del OPERATOR_DEFAULTS[key]
else:
# Restore original value
OPERATOR_DEFAULTS[key] = original_value

# Clear cache again to restore normal behavior
SerializedBaseOperator.generate_client_defaults.cache_clear()


AIRFLOW_REPO_ROOT_PATH = Path(airflow.__file__).parents[3]


Expand All @@ -117,14 +155,13 @@
VAR = Encoding.VAR
serialized_simple_dag_ground_truth = {
"__version": 3,
"client_defaults": {"tasks": {"retry_delay": 300.0}},
"dag": {
"default_args": {
"__type": "dict",
"__var": {
"depends_on_past": False,
"retries": 1,
"retry_delay": {"__type": "timedelta", "__var": 300.0},
"retry_delay": {"__type": "timedelta", "__var": 240.0},
"max_retry_delay": {"__type": "timedelta", "__var": 600.0},
},
},
Expand Down Expand Up @@ -165,7 +202,7 @@
"__var": {
"task_id": "bash_task",
"retries": 1,
"retry_delay": 300.0,
"retry_delay": 240.0,
"max_retry_delay": 600.0,
"ui_color": "#f0ede4",
"template_ext": [".sh", ".bash"],
Expand Down Expand Up @@ -224,7 +261,7 @@
"__var": {
"task_id": "custom_task",
"retries": 1,
"retry_delay": 300.0,
"retry_delay": 240.0,
"max_retry_delay": 600.0,
"_operator_extra_links": {"Google Custom": "_link_CustomOpLink"},
"template_fields": ["bash_command"],
Expand Down Expand Up @@ -294,7 +331,7 @@ def make_simple_dag():
schedule=timedelta(days=1),
default_args={
"retries": 1,
"retry_delay": timedelta(minutes=5),
"retry_delay": timedelta(minutes=4),
"max_retry_delay": timedelta(minutes=10),
"depends_on_past": False,
},
Expand Down Expand Up @@ -3072,7 +3109,7 @@ def test_handle_v1_serdag():
"__var": {
"depends_on_past": False,
"retries": 1,
"retry_delay": {"__type": "timedelta", "__var": 300.0},
"retry_delay": {"__type": "timedelta", "__var": 240.0},
"max_retry_delay": {"__type": "timedelta", "__var": 600.0},
"sla": {"__type": "timedelta", "__var": 100.0},
},
Expand Down Expand Up @@ -3110,7 +3147,7 @@ def test_handle_v1_serdag():
"__var": {
"task_id": "bash_task",
"retries": 1,
"retry_delay": 300.0,
"retry_delay": 240.0,
"max_retry_delay": 600.0,
"sla": 100.0,
"downstream_task_ids": [],
Expand Down Expand Up @@ -3173,7 +3210,7 @@ def test_handle_v1_serdag():
"__var": {
"task_id": "custom_task",
"retries": 1,
"retry_delay": 300.0,
"retry_delay": 240.0,
"max_retry_delay": 600.0,
"sla": 100.0,
"downstream_task_ids": [],
Expand Down Expand Up @@ -3383,7 +3420,7 @@ def test_handle_v2_serdag():
"__var": {
"depends_on_past": False,
"retries": 1,
"retry_delay": {"__type": "timedelta", "__var": 300.0},
"retry_delay": {"__type": "timedelta", "__var": 240.0},
"max_retry_delay": {"__type": "timedelta", "__var": 600.0},
},
},
Expand Down Expand Up @@ -3425,7 +3462,7 @@ def test_handle_v2_serdag():
"__var": {
"task_id": "bash_task",
"retries": 1,
"retry_delay": 300.0,
"retry_delay": 240.0,
"max_retry_delay": 600.0,
"downstream_task_ids": [],
"ui_color": "#f0ede4",
Expand Down Expand Up @@ -3491,7 +3528,7 @@ def test_handle_v2_serdag():
"__var": {
"task_id": "custom_task",
"retries": 1,
"retry_delay": 300.0,
"retry_delay": 240.0,
"max_retry_delay": 600.0,
"downstream_task_ids": [],
"_operator_extra_links": {"Google Custom": "_link_CustomOpLink"},
Expand Down Expand Up @@ -4004,8 +4041,9 @@ def test_apply_defaults_to_encoded_op_none_inputs(self):
result = SerializedBaseOperator._apply_defaults_to_encoded_op(encoded_op, None)
assert result == encoded_op

@operator_defaults({"retries": 2})
def test_multiple_tasks_share_client_defaults(self):
"""Test that multiple tasks can share the same client_defaults."""
"""Test that multiple tasks can share the same client_defaults when there are actually non-default values."""
with DAG(dag_id="test_dag") as dag:
BashOperator(task_id="task1", bash_command="echo 1")
BashOperator(task_id="task2", bash_command="echo 2")
Expand All @@ -4024,6 +4062,10 @@ def test_multiple_tasks_share_client_defaults(self):
deserialized_task1 = deserialized_dag.get_task("task1")
deserialized_task2 = deserialized_dag.get_task("task2")

# Both tasks should have retries=2 from client_defaults
assert deserialized_task1.retries == 2
assert deserialized_task2.retries == 2

# Both tasks should have the same default values from client_defaults
for field in client_defaults:
if hasattr(deserialized_task1, field) and hasattr(deserialized_task2, field):
Expand All @@ -4035,6 +4077,7 @@ def test_multiple_tasks_share_client_defaults(self):
class TestMappedOperatorSerializationAndClientDefaults:
"""Test MappedOperator serialization with client defaults and callback properties."""

@operator_defaults({"retry_delay": 200.0})
def test_mapped_operator_client_defaults_application(self):
"""Test that client_defaults are correctly applied to MappedOperator during deserialization."""
with DAG(dag_id="test_mapped_dag") as dag:
Expand Down Expand Up @@ -4099,6 +4142,7 @@ def test_mapped_operator_client_defaults_application(self):
),
],
)
@operator_defaults({"retry_delay": 200.0})
def test_mapped_operator_client_defaults_optimization(
self, task_config, dag_id, task_id, non_default_fields
):
Expand Down
3 changes: 3 additions & 0 deletions scripts/in_container/run_schema_defaults_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

import json
import sys
from datetime import timedelta
from pathlib import Path
from typing import Any

Expand Down Expand Up @@ -80,6 +81,8 @@ def get_server_side_operator_defaults() -> dict[str, Any]:
if isinstance(default_value, (set, tuple)):
# Convert to list since schema.json is pure JSON
default_value = list(default_value)
elif isinstance(default_value, timedelta):
default_value = default_value.total_seconds()
server_defaults[field_name] = default_value

return server_defaults
Expand Down
Loading