Skip to content

Commit 1f976d0

Browse files
dheerajturagakaxil
andauthored
Fix scheduler crash during 3.0 to 3.1 migration when retry_delay is None (#56202)
* Kaxil's suggestions * make default a float because some tests are complaining * Fix test * fixup! Fix test * fixup! fixup! Fix test --------- Co-authored-by: Kaxil Naik <[email protected]>
1 parent 9d4447d commit 1f976d0

File tree

4 files changed

+71
-17
lines changed

4 files changed

+71
-17
lines changed

airflow-core/src/airflow/serialization/schema.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -283,7 +283,7 @@
283283
"pool": { "type": "string", "default": "default_pool" },
284284
"pool_slots": { "type": "number", "default": 1 },
285285
"execution_timeout": { "$ref": "#/definitions/timedelta" },
286-
"retry_delay": { "$ref": "#/definitions/timedelta" },
286+
"retry_delay": { "$ref": "#/definitions/timedelta", "default": 300.0 },
287287
"retry_exponential_backoff": { "type": "boolean", "default": false },
288288
"max_retry_delay": { "$ref": "#/definitions/timedelta" },
289289
"params": { "$ref": "#/definitions/params" },

airflow-core/src/airflow/serialization/serialized_objects.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1291,7 +1291,7 @@ class SerializedBaseOperator(DAGNode, BaseSerialization):
12911291

12921292
resources: dict[str, Any] | None = None
12931293
retries: int = 0
1294-
retry_delay: datetime.timedelta
1294+
retry_delay: datetime.timedelta = datetime.timedelta(seconds=300)
12951295
retry_exponential_backoff: bool = False
12961296
run_as_user: str | None = None
12971297

@@ -2057,19 +2057,26 @@ def generate_client_defaults(cls) -> dict[str, Any]:
20572057
for k, v in OPERATOR_DEFAULTS.items():
20582058
if k not in cls.get_serialized_fields():
20592059
continue
2060-
# Exclude values that are the same as the schema defaults
2061-
if k in schema_defaults and schema_defaults[k] == v:
2062-
continue
20632060

20642061
# Exclude values that are None or empty collections
20652062
if v is None or v in [[], (), set(), {}]:
20662063
continue
20672064

2065+
# Check schema defaults first with raw value comparison (fast path)
2066+
if k in schema_defaults and schema_defaults[k] == v:
2067+
continue
2068+
20682069
# Use the existing serialize method to ensure consistent format
20692070
serialized_value = cls.serialize(v)
20702071
# Extract just the value part, consistent with serialize_to_json behavior
20712072
if isinstance(serialized_value, dict) and Encoding.TYPE in serialized_value:
20722073
serialized_value = serialized_value[Encoding.VAR]
2074+
2075+
# For cases where raw comparison failed but serialized values might match
2076+
# (e.g., timedelta vs float), check again with serialized value
2077+
if k in schema_defaults and schema_defaults[k] == serialized_value:
2078+
continue
2079+
20732080
client_defaults[k] = serialized_value
20742081

20752082
return client_defaults

airflow-core/tests/unit/serialization/test_dag_serialization.py

Lines changed: 56 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
from __future__ import annotations
2121

22+
import contextlib
2223
import copy
2324
import dataclasses
2425
import importlib
@@ -99,6 +100,43 @@
99100
if TYPE_CHECKING:
100101
from airflow.sdk.definitions.context import Context
101102

103+
104+
@contextlib.contextmanager
105+
def operator_defaults(overrides):
106+
"""
107+
Temporarily patches OPERATOR_DEFAULTS, restoring original values after context exit.
108+
109+
Example:
110+
with operator_defaults({"retries": 2, "retry_delay": 200.0}):
111+
# Test code with modified operator defaults
112+
"""
113+
from airflow.sdk.bases.operator import OPERATOR_DEFAULTS
114+
115+
original_values = {}
116+
try:
117+
# Store original values and apply overrides
118+
for key, value in overrides.items():
119+
original_values[key] = OPERATOR_DEFAULTS.get(key)
120+
OPERATOR_DEFAULTS[key] = value
121+
122+
# Clear the cache to ensure fresh generation
123+
SerializedBaseOperator.generate_client_defaults.cache_clear()
124+
125+
yield
126+
finally:
127+
# Cleanup: restore original values
128+
for key, original_value in original_values.items():
129+
if original_value is None and key in OPERATOR_DEFAULTS:
130+
# Key didn't exist originally, remove it
131+
del OPERATOR_DEFAULTS[key]
132+
else:
133+
# Restore original value
134+
OPERATOR_DEFAULTS[key] = original_value
135+
136+
# Clear cache again to restore normal behavior
137+
SerializedBaseOperator.generate_client_defaults.cache_clear()
138+
139+
102140
AIRFLOW_REPO_ROOT_PATH = Path(airflow.__file__).parents[3]
103141

104142

@@ -117,14 +155,13 @@
117155
VAR = Encoding.VAR
118156
serialized_simple_dag_ground_truth = {
119157
"__version": 3,
120-
"client_defaults": {"tasks": {"retry_delay": 300.0}},
121158
"dag": {
122159
"default_args": {
123160
"__type": "dict",
124161
"__var": {
125162
"depends_on_past": False,
126163
"retries": 1,
127-
"retry_delay": {"__type": "timedelta", "__var": 300.0},
164+
"retry_delay": {"__type": "timedelta", "__var": 240.0},
128165
"max_retry_delay": {"__type": "timedelta", "__var": 600.0},
129166
},
130167
},
@@ -165,7 +202,7 @@
165202
"__var": {
166203
"task_id": "bash_task",
167204
"retries": 1,
168-
"retry_delay": 300.0,
205+
"retry_delay": 240.0,
169206
"max_retry_delay": 600.0,
170207
"ui_color": "#f0ede4",
171208
"template_ext": [".sh", ".bash"],
@@ -224,7 +261,7 @@
224261
"__var": {
225262
"task_id": "custom_task",
226263
"retries": 1,
227-
"retry_delay": 300.0,
264+
"retry_delay": 240.0,
228265
"max_retry_delay": 600.0,
229266
"_operator_extra_links": {"Google Custom": "_link_CustomOpLink"},
230267
"template_fields": ["bash_command"],
@@ -294,7 +331,7 @@ def make_simple_dag():
294331
schedule=timedelta(days=1),
295332
default_args={
296333
"retries": 1,
297-
"retry_delay": timedelta(minutes=5),
334+
"retry_delay": timedelta(minutes=4),
298335
"max_retry_delay": timedelta(minutes=10),
299336
"depends_on_past": False,
300337
},
@@ -3072,7 +3109,7 @@ def test_handle_v1_serdag():
30723109
"__var": {
30733110
"depends_on_past": False,
30743111
"retries": 1,
3075-
"retry_delay": {"__type": "timedelta", "__var": 300.0},
3112+
"retry_delay": {"__type": "timedelta", "__var": 240.0},
30763113
"max_retry_delay": {"__type": "timedelta", "__var": 600.0},
30773114
"sla": {"__type": "timedelta", "__var": 100.0},
30783115
},
@@ -3110,7 +3147,7 @@ def test_handle_v1_serdag():
31103147
"__var": {
31113148
"task_id": "bash_task",
31123149
"retries": 1,
3113-
"retry_delay": 300.0,
3150+
"retry_delay": 240.0,
31143151
"max_retry_delay": 600.0,
31153152
"sla": 100.0,
31163153
"downstream_task_ids": [],
@@ -3173,7 +3210,7 @@ def test_handle_v1_serdag():
31733210
"__var": {
31743211
"task_id": "custom_task",
31753212
"retries": 1,
3176-
"retry_delay": 300.0,
3213+
"retry_delay": 240.0,
31773214
"max_retry_delay": 600.0,
31783215
"sla": 100.0,
31793216
"downstream_task_ids": [],
@@ -3383,7 +3420,7 @@ def test_handle_v2_serdag():
33833420
"__var": {
33843421
"depends_on_past": False,
33853422
"retries": 1,
3386-
"retry_delay": {"__type": "timedelta", "__var": 300.0},
3423+
"retry_delay": {"__type": "timedelta", "__var": 240.0},
33873424
"max_retry_delay": {"__type": "timedelta", "__var": 600.0},
33883425
},
33893426
},
@@ -3425,7 +3462,7 @@ def test_handle_v2_serdag():
34253462
"__var": {
34263463
"task_id": "bash_task",
34273464
"retries": 1,
3428-
"retry_delay": 300.0,
3465+
"retry_delay": 240.0,
34293466
"max_retry_delay": 600.0,
34303467
"downstream_task_ids": [],
34313468
"ui_color": "#f0ede4",
@@ -3491,7 +3528,7 @@ def test_handle_v2_serdag():
34913528
"__var": {
34923529
"task_id": "custom_task",
34933530
"retries": 1,
3494-
"retry_delay": 300.0,
3531+
"retry_delay": 240.0,
34953532
"max_retry_delay": 600.0,
34963533
"downstream_task_ids": [],
34973534
"_operator_extra_links": {"Google Custom": "_link_CustomOpLink"},
@@ -4004,8 +4041,9 @@ def test_apply_defaults_to_encoded_op_none_inputs(self):
40044041
result = SerializedBaseOperator._apply_defaults_to_encoded_op(encoded_op, None)
40054042
assert result == encoded_op
40064043

4044+
@operator_defaults({"retries": 2})
40074045
def test_multiple_tasks_share_client_defaults(self):
4008-
"""Test that multiple tasks can share the same client_defaults."""
4046+
"""Test that multiple tasks can share the same client_defaults when there are actually non-default values."""
40094047
with DAG(dag_id="test_dag") as dag:
40104048
BashOperator(task_id="task1", bash_command="echo 1")
40114049
BashOperator(task_id="task2", bash_command="echo 2")
@@ -4024,6 +4062,10 @@ def test_multiple_tasks_share_client_defaults(self):
40244062
deserialized_task1 = deserialized_dag.get_task("task1")
40254063
deserialized_task2 = deserialized_dag.get_task("task2")
40264064

4065+
# Both tasks should have retries=2 from client_defaults
4066+
assert deserialized_task1.retries == 2
4067+
assert deserialized_task2.retries == 2
4068+
40274069
# Both tasks should have the same default values from client_defaults
40284070
for field in client_defaults:
40294071
if hasattr(deserialized_task1, field) and hasattr(deserialized_task2, field):
@@ -4035,6 +4077,7 @@ def test_multiple_tasks_share_client_defaults(self):
40354077
class TestMappedOperatorSerializationAndClientDefaults:
40364078
"""Test MappedOperator serialization with client defaults and callback properties."""
40374079

4080+
@operator_defaults({"retry_delay": 200.0})
40384081
def test_mapped_operator_client_defaults_application(self):
40394082
"""Test that client_defaults are correctly applied to MappedOperator during deserialization."""
40404083
with DAG(dag_id="test_mapped_dag") as dag:
@@ -4099,6 +4142,7 @@ def test_mapped_operator_client_defaults_application(self):
40994142
),
41004143
],
41014144
)
4145+
@operator_defaults({"retry_delay": 200.0})
41024146
def test_mapped_operator_client_defaults_optimization(
41034147
self, task_config, dag_id, task_id, non_default_fields
41044148
):

scripts/in_container/run_schema_defaults_check.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828

2929
import json
3030
import sys
31+
from datetime import timedelta
3132
from pathlib import Path
3233
from typing import Any
3334

@@ -80,6 +81,8 @@ def get_server_side_operator_defaults() -> dict[str, Any]:
8081
if isinstance(default_value, (set, tuple)):
8182
# Convert to list since schema.json is pure JSON
8283
default_value = list(default_value)
84+
elif isinstance(default_value, timedelta):
85+
default_value = default_value.total_seconds()
8386
server_defaults[field_name] = default_value
8487

8588
return server_defaults

0 commit comments

Comments
 (0)