Skip to content

Commit 67fde6c

Browse files
committed
fix: add test for rerun args of different length
1 parent dbb57e2 commit 67fde6c

File tree

1 file changed

+50
-1
lines changed

1 file changed

+50
-1
lines changed

airflow-core/tests/unit/models/test_dagrun.py

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343
from airflow.providers.standard.operators.bash import BashOperator
4444
from airflow.providers.standard.operators.empty import EmptyOperator
4545
from airflow.providers.standard.operators.python import PythonOperator, ShortCircuitOperator
46-
from airflow.sdk import DAG, BaseOperator, setup, task, task_group, teardown
46+
from airflow.sdk import DAG, BaseOperator, get_current_context, setup, task, task_group, teardown
4747
from airflow.sdk.definitions.deadline import AsyncCallback, DeadlineAlert, DeadlineReference
4848
from airflow.serialization.serialized_objects import LazyDeserializedDAG, SerializedDAG
4949
from airflow.stats import Stats
@@ -2253,6 +2253,55 @@ def tg(x, y):
22532253
}
22542254

22552255

2256+
@pytest.mark.parametrize("rerun_length", [0, 1, 2, 3])
2257+
def test_mapped_task_rerun_with_different_length_of_args(session, dag_maker, rerun_length):
2258+
@task
2259+
def generate_mapping_args():
2260+
context = get_current_context()
2261+
if context["ti"].try_number == 0:
2262+
args = [i for i in range(2)]
2263+
else:
2264+
args = [i for i in range(rerun_length)]
2265+
return args
2266+
2267+
@task
2268+
def mapped_print_value(arg):
2269+
return arg
2270+
2271+
with dag_maker(session=session):
2272+
args = generate_mapping_args()
2273+
mapped_print_value.expand(arg=args)
2274+
2275+
# First Run
2276+
dr = dag_maker.create_dagrun()
2277+
dag_maker.run_ti("generate_mapping_args", dr)
2278+
2279+
decision = dr.task_instance_scheduling_decisions(session=session)
2280+
for ti in decision.schedulable_tis:
2281+
dag_maker.run_ti(ti.task_id, dr, map_index=ti.map_index)
2282+
2283+
clear_task_instances(dr.get_task_instances(), session=session)
2284+
2285+
# Second Run
2286+
ti = dr.get_task_instance(task_id="generate_mapping_args", session=session)
2287+
ti.try_number += 1
2288+
session.merge(ti)
2289+
dag_maker.run_ti("generate_mapping_args", dr)
2290+
2291+
# Check if the correct number of new mapped task instances are created/scheduled.
2292+
decision = dr.task_instance_scheduling_decisions(session=session)
2293+
assert len(decision.schedulable_tis) == rerun_length
2294+
assert all([ti.task_id == "mapped_print_value" for ti in decision.schedulable_tis])
2295+
2296+
for ti in decision.schedulable_tis:
2297+
dag_maker.run_ti(ti.task_id, dr, map_index=ti.map_index)
2298+
query = select(TI).filter_by(
2299+
dag_id=dr.dag_id, run_id=dr.run_id, task_id="mapped_print_value", state=TaskInstanceState.SUCCESS
2300+
)
2301+
success_tis = session.execute(query).all()
2302+
assert len(success_tis) == rerun_length
2303+
2304+
22562305
def test_operator_mapped_task_group_receives_value(dag_maker, session):
22572306
with dag_maker(session=session):
22582307

0 commit comments

Comments
 (0)