Skip to content

Commit 6d3e841

Browse files
authored
fix: allow mapped tasks to accept zero-length inputs on rerun (#56162)
* fix: allow mapped tasks to accept zero-length inputs on rerun * fix: add test for rerun args of different length * chore: revise comments to align with the changes * chore: add comments before the task state check * fix: replace legacy query syntax
1 parent b9d91c3 commit 6d3e841

File tree

2 files changed

+60
-2
lines changed

2 files changed

+60
-2
lines changed

airflow-core/src/airflow/models/dagrun.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1547,7 +1547,12 @@ def _expand_mapped_task_if_needed(ti: TI) -> Iterable[TI] | None:
15471547
)
15481548
)
15491549
revised_map_index_task_ids.add(schedulable.task.task_id)
1550-
ready_tis.append(schedulable)
1550+
1551+
# _revise_map_indexes_if_mapped might mark the current task as REMOVED
1552+
# after calculating mapped task length, so we need to re-check
1553+
# the task state to ensure it's still schedulable
1554+
if schedulable.state in SCHEDULEABLE_STATES:
1555+
ready_tis.append(schedulable)
15511556

15521557
# Check if any ti changed state
15531558
tis_filter = TI.filter_for_tis(old_states)

airflow-core/tests/unit/models/test_dagrun.py

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343
from airflow.providers.standard.operators.bash import BashOperator
4444
from airflow.providers.standard.operators.empty import EmptyOperator
4545
from airflow.providers.standard.operators.python import PythonOperator, ShortCircuitOperator
46-
from airflow.sdk import DAG, BaseOperator, setup, task, task_group, teardown
46+
from airflow.sdk import DAG, BaseOperator, get_current_context, setup, task, task_group, teardown
4747
from airflow.sdk.definitions.deadline import AsyncCallback, DeadlineAlert, DeadlineReference
4848
from airflow.serialization.serialized_objects import LazyDeserializedDAG, SerializedDAG
4949
from airflow.stats import Stats
@@ -2253,6 +2253,59 @@ def tg(x, y):
22532253
}
22542254

22552255

2256+
@pytest.mark.parametrize("rerun_length", [0, 1, 2, 3])
2257+
def test_mapped_task_rerun_with_different_length_of_args(session, dag_maker, rerun_length):
2258+
@task
2259+
def generate_mapping_args():
2260+
context = get_current_context()
2261+
if context["ti"].try_number == 0:
2262+
args = [i for i in range(2)]
2263+
else:
2264+
args = [i for i in range(rerun_length)]
2265+
return args
2266+
2267+
@task
2268+
def mapped_print_value(arg):
2269+
return arg
2270+
2271+
with dag_maker(session=session):
2272+
args = generate_mapping_args()
2273+
mapped_print_value.expand(arg=args)
2274+
2275+
# First Run
2276+
dr = dag_maker.create_dagrun()
2277+
dag_maker.run_ti("generate_mapping_args", dr)
2278+
2279+
decision = dr.task_instance_scheduling_decisions(session=session)
2280+
for ti in decision.schedulable_tis:
2281+
dag_maker.run_ti(ti.task_id, dr, map_index=ti.map_index)
2282+
2283+
clear_task_instances(dr.get_task_instances(), session=session)
2284+
2285+
# Second Run
2286+
ti = dr.get_task_instance(task_id="generate_mapping_args", session=session)
2287+
ti.try_number += 1
2288+
session.merge(ti)
2289+
dag_maker.run_ti("generate_mapping_args", dr)
2290+
2291+
# Check if the new mapped task instances are correctly scheduled
2292+
decision = dr.task_instance_scheduling_decisions(session=session)
2293+
assert len(decision.schedulable_tis) == rerun_length
2294+
assert all([ti.task_id == "mapped_print_value" for ti in decision.schedulable_tis])
2295+
2296+
# Check if mapped task rerun successfully
2297+
for ti in decision.schedulable_tis:
2298+
dag_maker.run_ti(ti.task_id, dr, map_index=ti.map_index)
2299+
query = select(TI).where(
2300+
TI.dag_id == dr.dag_id,
2301+
TI.run_id == dr.run_id,
2302+
TI.task_id == "mapped_print_value",
2303+
TI.state == TaskInstanceState.SUCCESS,
2304+
)
2305+
success_tis = session.execute(query).all()
2306+
assert len(success_tis) == rerun_length
2307+
2308+
22562309
def test_operator_mapped_task_group_receives_value(dag_maker, session):
22572310
with dag_maker(session=session):
22582311

0 commit comments

Comments
 (0)