|
43 | 43 | from airflow.providers.standard.operators.bash import BashOperator |
44 | 44 | from airflow.providers.standard.operators.empty import EmptyOperator |
45 | 45 | from airflow.providers.standard.operators.python import PythonOperator, ShortCircuitOperator |
46 | | -from airflow.sdk import DAG, BaseOperator, setup, task, task_group, teardown |
| 46 | +from airflow.sdk import DAG, BaseOperator, get_current_context, setup, task, task_group, teardown |
47 | 47 | from airflow.sdk.definitions.deadline import AsyncCallback, DeadlineAlert, DeadlineReference |
48 | 48 | from airflow.serialization.serialized_objects import LazyDeserializedDAG, SerializedDAG |
49 | 49 | from airflow.stats import Stats |
@@ -2253,6 +2253,59 @@ def tg(x, y): |
2253 | 2253 | } |
2254 | 2254 |
|
2255 | 2255 |
|
| 2256 | +@pytest.mark.parametrize("rerun_length", [0, 1, 2, 3]) |
| 2257 | +def test_mapped_task_rerun_with_different_length_of_args(session, dag_maker, rerun_length): |
| 2258 | + @task |
| 2259 | + def generate_mapping_args(): |
| 2260 | + context = get_current_context() |
| 2261 | + if context["ti"].try_number == 0: |
| 2262 | + args = [i for i in range(2)] |
| 2263 | + else: |
| 2264 | + args = [i for i in range(rerun_length)] |
| 2265 | + return args |
| 2266 | + |
| 2267 | + @task |
| 2268 | + def mapped_print_value(arg): |
| 2269 | + return arg |
| 2270 | + |
| 2271 | + with dag_maker(session=session): |
| 2272 | + args = generate_mapping_args() |
| 2273 | + mapped_print_value.expand(arg=args) |
| 2274 | + |
| 2275 | + # First Run |
| 2276 | + dr = dag_maker.create_dagrun() |
| 2277 | + dag_maker.run_ti("generate_mapping_args", dr) |
| 2278 | + |
| 2279 | + decision = dr.task_instance_scheduling_decisions(session=session) |
| 2280 | + for ti in decision.schedulable_tis: |
| 2281 | + dag_maker.run_ti(ti.task_id, dr, map_index=ti.map_index) |
| 2282 | + |
| 2283 | + clear_task_instances(dr.get_task_instances(), session=session) |
| 2284 | + |
| 2285 | + # Second Run |
| 2286 | + ti = dr.get_task_instance(task_id="generate_mapping_args", session=session) |
| 2287 | + ti.try_number += 1 |
| 2288 | + session.merge(ti) |
| 2289 | + dag_maker.run_ti("generate_mapping_args", dr) |
| 2290 | + |
| 2291 | + # Check if the new mapped task instances are correctly scheduled |
| 2292 | + decision = dr.task_instance_scheduling_decisions(session=session) |
| 2293 | + assert len(decision.schedulable_tis) == rerun_length |
| 2294 | + assert all([ti.task_id == "mapped_print_value" for ti in decision.schedulable_tis]) |
| 2295 | + |
| 2296 | + # Check if mapped task rerun successfully |
| 2297 | + for ti in decision.schedulable_tis: |
| 2298 | + dag_maker.run_ti(ti.task_id, dr, map_index=ti.map_index) |
| 2299 | + query = select(TI).where( |
| 2300 | + TI.dag_id == dr.dag_id, |
| 2301 | + TI.run_id == dr.run_id, |
| 2302 | + TI.task_id == "mapped_print_value", |
| 2303 | + TI.state == TaskInstanceState.SUCCESS, |
| 2304 | + ) |
| 2305 | + success_tis = session.execute(query).all() |
| 2306 | + assert len(success_tis) == rerun_length |
| 2307 | + |
| 2308 | + |
2256 | 2309 | def test_operator_mapped_task_group_receives_value(dag_maker, session): |
2257 | 2310 | with dag_maker(session=session): |
2258 | 2311 |
|
|
0 commit comments