diff --git a/ansible/ansible.cfg b/ansible/ansible.cfg index f8f1e4ff048..b83e0452068 100644 --- a/ansible/ansible.cfg +++ b/ansible/ansible.cfg @@ -124,7 +124,8 @@ connection_plugins = plugins/connection lookup_plugins = plugins/lookup # vars_plugins = /usr/share/ansible_plugins/vars_plugins filter_plugins = plugins/filter -callback_whitelist = profile_tasks +# Disable profile tasks callback to avoid possible deadlock +# callback_whitelist = profile_tasks # by default callbacks are not loaded for /bin/ansible, enable this if you # want, for example, a notification or logging callback to also apply to @@ -190,7 +191,7 @@ become_ask_pass=False # ssh arguments to use # Leaving off ControlPersist will result in poor performance, so use # paramiko on older platforms rather than removing it -ssh_args = -o ControlMaster=auto -o ControlPersist=7200s -o UserKnownHostsFile=/dev/null -o StrictHostKeyChecking=no -o ServerAliveInterval=30 -o ServerAliveCountMax=70 +ssh_args = -o ControlMaster=auto -o ControlPersist=7200s -o UserKnownHostsFile=/dev/null -o StrictHostKeyChecking=no -o ServerAliveInterval=30 -o ServerAliveCountMax=70 -o TCPKeepAlive=yes # The path to use for the ControlPath sockets. This defaults to diff --git a/tests/common/devices/base.py b/tests/common/devices/base.py index 9db40b94d7a..5644d0e0724 100644 --- a/tests/common/devices/base.py +++ b/tests/common/devices/base.py @@ -55,21 +55,20 @@ def __init__(self, ansible_adhoc, hostname, *args, **kwargs): def __getattr__(self, module_name): if self.host.has_module(module_name): - self.module_name = module_name - self.module = getattr(self.host, module_name) - - return self._run + def _run_wrapper(*module_args, **kwargs): + return self._run(module_name, *module_args, **kwargs) + return _run_wrapper raise AttributeError( "'%s' object has no attribute '%s'" % (self.__class__, module_name) ) - def _run(self, *module_args, **complex_args): + def _run(self, module_name, *module_args, **complex_args): previous_frame = inspect.currentframe().f_back filename, line_number, function_name, lines, index = inspect.getframeinfo(previous_frame) verbose = complex_args.pop('verbose', True) - + module = getattr(self.host, module_name) if verbose: logger.debug( "{}::{}#{}: [{}] AnsibleModule::{}, args={}, kwargs={}".format( @@ -77,7 +76,7 @@ def _run(self, *module_args, **complex_args): function_name, line_number, self.hostname, - self.module_name, + module_name, json.dumps(module_args, cls=AnsibleHostBase.CustomEncoder), json.dumps(complex_args, cls=AnsibleHostBase.CustomEncoder) ) @@ -89,7 +88,7 @@ def _run(self, *module_args, **complex_args): function_name, line_number, self.hostname, - self.module_name + module_name ) ) @@ -98,7 +97,7 @@ def _run(self, *module_args, **complex_args): if module_async: def run_module(module_args, complex_args): - return self.module(*module_args, **complex_args)[self.hostname] + return module(*module_args, **complex_args)[self.hostname] pool = ThreadPool() result = pool.apply_async(run_module, (module_args, complex_args)) return pool, result @@ -106,9 +105,9 @@ def run_module(module_args, complex_args): module_args = json.loads(json.dumps(module_args, cls=AnsibleHostBase.CustomEncoder)) complex_args = json.loads(json.dumps(complex_args, cls=AnsibleHostBase.CustomEncoder)) - adhoc_res: AdHocResult = self.module(*module_args, **complex_args) + adhoc_res: AdHocResult = module(*module_args, **complex_args) - if self.module_name == "meta": + if module_name == "meta": # The meta module is special in Ansible - it doesn't execute on remote hosts, it controls Ansible's behavior # There are no per-host ModuleResults contained within it return @@ -123,7 +122,7 @@ def run_module(module_args, complex_args): function_name, line_number, self.hostname, - self.module_name, json.dumps(hostname_res, cls=AnsibleHostBase.CustomEncoder) + module_name, json.dumps(hostname_res, cls=AnsibleHostBase.CustomEncoder) ) ) else: @@ -133,14 +132,14 @@ def run_module(module_args, complex_args): function_name, line_number, self.hostname, - self.module_name, + module_name, hostname_res.is_failed, hostname_res.get('rc', None) ) ) if (hostname_res.is_failed or 'exception' in hostname_res) and not module_ignore_errors: - raise RunAnsibleModuleFail("run module {} failed".format(self.module_name), hostname_res) + raise RunAnsibleModuleFail("run module {} failed".format(module_name), hostname_res) return hostname_res diff --git a/tests/common/devices/duthosts.py b/tests/common/devices/duthosts.py index a1165950dd1..b4a7ff03abf 100644 --- a/tests/common/devices/duthosts.py +++ b/tests/common/devices/duthosts.py @@ -19,9 +19,9 @@ class DutHosts(object): """ class _Nodes(list): """ Internal class representing a list of MultiAsicSonicHosts """ - def _run_on_nodes(self, *module_args, **complex_args): + def _run_on_nodes(self, module, *module_args, **complex_args): """ Delegate the call to each of the nodes, return the results in a dict.""" - return {node.hostname: getattr(node, self.attr)(*module_args, **complex_args) for node in self} + return {node.hostname: getattr(node, module)(*module_args, **complex_args) for node in self} def __getattr__(self, attr): """ To support calling ansible modules on a list of MultiAsicSonicHost @@ -32,8 +32,9 @@ def __getattr__(self, attr): a dictionary with key being the MultiAsicSonicHost's hostname, and value being the output of ansible module on that MultiAsicSonicHost """ - self.attr = attr - return self._run_on_nodes + def _run_on_nodes_wrapper(*module_args, **complex_args): + return self._run_on_nodes(attr, *module_args, **complex_args) + return _run_on_nodes_wrapper def __eq__(self, o): """ To support eq operator on the DUTs (nodes) in the testbed """ @@ -62,6 +63,12 @@ def __init__(self, ansible_adhoc, tbinfo, request, duts, target_hostname=None, i self.request = request self.duts = duts self.is_parallel_run = target_hostname is not None + # Initialize _nodes to None to avoid recursion in __getattr__ + self._nodes = None + self._nodes_for_parallel = None + self._supervisor_nodes = None + self._frontend_nodes = None + # TODO: Initialize the nodes in parallel using multi-threads? if self.is_parallel_run: self.parallel_run_stage = NON_INITIAL_CHECKS_STAGE diff --git a/tests/common/devices/multi_asic.py b/tests/common/devices/multi_asic.py index cc297403cfe..d22455f2b2b 100644 --- a/tests/common/devices/multi_asic.py +++ b/tests/common/devices/multi_asic.py @@ -123,10 +123,11 @@ def critical_services_tracking_list(self): def get_default_critical_services_list(self): return self._DEFAULT_SERVICES - def _run_on_asics(self, *module_args, **complex_args): + def _run_on_asics(self, multi_asic_attr, *module_args, **complex_args): """ Run an asible module on asics based on 'asic_index' keyword in complex_args Args: + multi_asic_attr: name of the ansible module to run module_args: other ansible module args passed from the caller complex_args: other ansible keyword args @@ -147,7 +148,7 @@ def _run_on_asics(self, *module_args, **complex_args): """ if "asic_index" not in complex_args: # Default ASIC/namespace - return getattr(self.sonichost, self.multi_asic_attr)(*module_args, **complex_args) + return getattr(self.sonichost, multi_asic_attr)(*module_args, **complex_args) else: asic_complex_args = copy.deepcopy(complex_args) asic_index = asic_complex_args.pop("asic_index") @@ -156,11 +157,11 @@ def _run_on_asics(self, *module_args, **complex_args): if self.sonichost.facts['num_asic'] == 1: if asic_index != 0: raise ValueError("Trying to run module '{}' against asic_index '{}' on a single asic dut '{}'" - .format(self.multi_asic_attr, asic_index, self.sonichost.hostname)) - return getattr(self.asic_instance(asic_index), self.multi_asic_attr)(*module_args, **asic_complex_args) + .format(multi_asic_attr, asic_index, self.sonichost.hostname)) + return getattr(self.asic_instance(asic_index), multi_asic_attr)(*module_args, **asic_complex_args) elif type(asic_index) == str and asic_index.lower() == "all": # All ASICs/namespace - return [getattr(asic, self.multi_asic_attr)(*module_args, **asic_complex_args) for asic in self.asics] + return [getattr(asic, multi_asic_attr)(*module_args, **asic_complex_args) for asic in self.asics] else: raise ValueError("Argument 'asic_index' must be an int or string 'all'.") @@ -357,8 +358,9 @@ def __getattr__(self, attr): """ sonic_asic_attr = getattr(SonicAsic, attr, None) if not attr.startswith("_") and sonic_asic_attr and callable(sonic_asic_attr): - self.multi_asic_attr = attr - return self._run_on_asics + def _run_on_asics_wrapper(*module_args, **complex_args): + return self._run_on_asics(attr, *module_args, **complex_args) + return _run_on_asics_wrapper else: return getattr(self.sonichost, attr) # For backward compatibility diff --git a/tests/common/plugins/parallel_fixture/README.md b/tests/common/plugins/parallel_fixture/README.md new file mode 100644 index 00000000000..60219b003dd --- /dev/null +++ b/tests/common/plugins/parallel_fixture/README.md @@ -0,0 +1,230 @@ +# Parallel Fixture Manager Design Document + +## 1. Overview + +The **Parallel Fixture Manager** is a pytest plugin designed to optimize test execution time by parallelizing the setup and teardown of fixtures. The sonic-mgmt fixture setup/teardowns often involves blocking I/O operations such as device configuration, service restarts, or waiting for convergence. By offloading these tasks to a thread pool, the manager allows multiple same-level fixtures to setup/teardown concurrently, which could reduce the overall test execution time. + +![test execution](images/pytest.jpg) + +## 2. Requirements + +The Parallel Fixture Manager is designed to address specific challenges in the SONiC testing infrastructure. The key requirements are: + +* **Test Fixture Setup/Teardown Parallelization** +* **Scope-Based Synchronization** + * The system must strictly enforce pytest scoping rules: + 1. All background tasks associated with a specific scope (Session, Module, Class, Function) in setup must complete successfully before the test runner proceeds to a narrower scope or executes the test function. + 2. All background tasks associated with a specific scope (Session, Module, Class, Function) in teardown must complete successfully before the test runner proceeds to a broader scope or finish the test execution. +* **Fail-Fast Reliability** + * The system must immediately detect the exception and abort the ongoing test setup to prevent cascading failures, resource wastage, and misleading test results. +* **Non-Intrusive Integration** + * The system must expose a minimal and intuitive API. Existing fixtures should be able to adopt parallel execution patterns with minimal code changes, preserving the standard pytest fixture structure. +* **Safe Termination & Cleanup** + * The system must handle interruptions and timeouts gracefully. It must ensure that background threads are properly terminated and resources are cleaned up, even in the event of a test failure or user interruption. + +## 3. Architecture + +### 3.1 Core Components + +* **`ParallelFixtureManager`**: The central thread pool controller exposed as a session-scoped fixture (`parallel_manager`). + * **Executor**: Uses `concurrent.futures.ThreadPoolExecutor` to execute tasks. + * **Monitor Thread**: A daemon thread (`_monitor_workers`) that polls active futures to log task execution status and any exception in worker thread. + * **Task Queues**: Maintains separate lists of futures for setup and teardown tasks, categorized by scope. +* **`TaskScope`**: Enum defining the lifecycle scopes: `SESSION`, `MODULE`, `CLASS`, and `FUNCTION`. +* **`Barriers`**: Autouse fixtures that enforce synchronization. They block the main thread until all background tasks for a specific scope are complete. + * Setup Barriers: + * `setup_barrier_session` + * `setup_barrier_module` + * `setup_barrier_class` + * `setup_barrier_function` + * Teardown Barriers: + * `teardown_barrier_session` + * `teardown_barrier_module` + * `teardown_barrier_class` + * `teardown_barrier_function` + +### 3.2 Execution Lifecycle + +The manager hooks into the pytest lifecycle to coordinate parallel execution: + +```mermaid +sequenceDiagram + participant Pytest + participant ParallelManager + participant ThreadPool + participant MonitorThread + participant Fixture + participant Barrier + + Note over Pytest,Barrier: Session Setup Phase + Pytest->>ParallelManager: Create (session scope) + ParallelManager->>ThreadPool: Initialize worker threads + ParallelManager->>MonitorThread: Start monitoring + + Pytest->>Fixture: Execute fixture (session scope) + Fixture->>ParallelManager: submit_setup_task(SESSION, func) + ParallelManager->>ThreadPool: Submit task to thread pool + ThreadPool-->>ParallelManager: Return future + Fixture-->>Pytest: Yield immediately + + Pytest->>Barrier: setup_barrier_session + Barrier->>ParallelManager: wait_for_setup_tasks(SESSION) + ParallelManager->>ThreadPool: Wait for all session tasks + ThreadPool-->>ParallelManager: Tasks complete + + Note over Pytest,Barrier: Module/Class/Function Scopes + Pytest->>Fixture: Execute fixture (module scope) + Fixture->>ParallelManager: submit_setup_task(MODULE, func) + ParallelManager->>ThreadPool: Submit to pool + Fixture-->>Pytest: Yield + + Pytest->>Barrier: setup_barrier_module + Barrier->>ParallelManager: wait_for_setup_tasks(MODULE) + ParallelManager->>ThreadPool: Wait for module tasks + + Note over Pytest,Barrier: Test Execution + Pytest->>ParallelManager: pytest_runtest_call hook + ParallelManager->>ThreadPool: Ensure all tasks complete + ParallelManager->>ThreadPool: Terminate executor + Pytest->>Pytest: Run test function + + Note over Pytest,Barrier: Teardown Phase + Pytest->>ParallelManager: pytest_runtest_teardown hook + ParallelManager->>ThreadPool: Reset and create new executor + ParallelManager->>MonitorThread: Restart monitoring + + Pytest->>Fixture: Teardown fixture + Fixture->>ParallelManager: submit_teardown_task(scope, func) + ParallelManager->>ThreadPool: Submit teardown task + Fixture-->>Pytest: Return + + Pytest->>Barrier: teardown_barrier_function + Barrier->>ParallelManager: wait_for_teardown_tasks(FUNCTION) + ParallelManager->>ThreadPool: Wait for function teardowns + + Pytest->>ParallelManager: pytest_runtest_logreport hook + ParallelManager->>ThreadPool: Terminate executor + ParallelManager->>MonitorThread: Stop monitoring + +``` + +#### Setup Phase + +1. **Submission**: Fixtures submit setup functions using `parallel_manager.submit_setup_task(scope, func, *args, **kwargs)`. +2. **Non-Blocking Return/Yield**: The fixture yields/Returns immediately, allowing pytest to proceed to the next fixture. +3. **Barrier Enforcement**: At the end of a scope (e.g., after all module-scoped fixtures have run), a barrier fixture waits for all submitted tasks of that scope to complete. + +#### Test Execution Phase + +1. **Pre-Test Wait**: Before the test function runs, the manager ensures all setup tasks are finished. +2. **Termination**: The setup executor is terminated to ensure a stable environment during the test. + +#### Teardown Phase + +1. **Restart**: The manager is restarted to handle teardown tasks. +2. **Submission**: Fixtures submit teardown functions using `parallel_manager.submit_teardown_task(scope, func, *args, **kwargs)`. +3. **Barrier Enforcement**: Teardown barriers wait for tasks to complete before moving to the next scope. + +## 4. Exception Handling Strategy + +The system implements a **Fail-Fast** strategy to detect exceptions in the background threads and fail the main Pytest thread timely, which helps prevent cascading failures and wasted execution time. + +* **Background Exception Logging**: The monitor thread detects and logs exceptions in worker threads as they happen. +* **Checkpoints**: + * **`pytest_fixture_setup`**: Before starting *any* fixture, the manager checks if a background task has failed. If so, it raises `ParallelTaskRuntimeError` immediately, aborting the test setup immediately. + * **Barriers**: When waiting at a barrier, exceptions from failed tasks are re-raised in the main thread. +* **Forced Termination**: In cases of interrupts or critical failures, `ctypes` is used to inject exceptions into worker threads to force immediate termination. + +## 5. Pytest Hooks Integration + +The plugin relies on several pytest hooks to orchestrate the flow: + +* **`pytest_runtest_setup`**: Dynamically reorders fixtures to ensure that barrier fixtures always execute **after** all other fixtures of the same scope. +* **`pytest_fixture_setup`**: Performs as exception handling checkpoint to interrupt the test execution in case of any exceptions are detected in the background threads. +* **`pytest_runtest_call`**: Acts as a final gate before the test runs, ensuring all setup tasks are done and terminating the setup executor. +* **`pytest_exception_interact`**: Handles exceptions during setup/call to terminate the manager gracefully. +* **`pytest_runtest_teardown`**: Restarts the parallel manager to prepare for the teardown phase. +* **`pytest_runtest_logreport`**: Terminates the parallel manager gracefully after teardown is complete. + +## 6. Deadlock Handling + +The Parallel Fixture Manager introduces multi-threading to the test execution environment. When combined with multi-processing (e.g., Ansible execution, `multiprocessing.Process`), this creates a risk of deadlocks, particularly involving logging locks. + +A common scenario is: +1. Thread A (main thread) calls a logging function and acquires the logging lock. +2. Thread B (parallel fixture manager worker) forks a new process (e.g., to run an Ansible task). +3. The child process inherits the memory state, including the held logging lock. +4. Since Thread A does not exist in the child process, the lock remains held indefinitely. +5. If the child process tries to log something, it attempts to acquire the lock and deadlocks. + +To prevent this, the framework (in `tests/common/helpers/parallel.py`) leverages `os.register_at_fork` hooks to: +* Acquire logging locks before forking. +* Release logging locks after forking (in both parent and child). +* Handle Ansible display locks similarly. + +This ensures that **locks are always in a released state within the child process immediately after forking**. + +```mermaid +sequenceDiagram + participant Main as Main Thread + participant Worker as Parallel Fixture Manager Worker + participant Logger as Logging Module + participant Handler as Log Handler + participant Fork as Fork Operation + participant Child as Child Process + + Note over Main,Child: Before Fix (Deadlock Scenario) + Worker->>Logger: Write log (acquires handler lock) + activate Handler + Main->>Fork: fork() called + Note over Fork: Lock state copied to child + Fork->>Child: Child process created + Note over Child: Child inherits locked handler + Child->>Handler: Attempt to log + Note over Child: DEADLOCK: Lock already held by Parallel Fixture Manager Worker + deactivate Handler + + Note over Main,Child: After Fix (Safe Fork) + Main->>Logger: _fix_logging_handler_fork_lock() + Logger->>Handler: Register at_fork handlers + Note over Handler: before=lock.acquire
after_in_parent=lock.release
after_in_child=lock.release + Main->>Fork: fork() called + Fork->>Handler: Execute before fork (acquire lock) + activate Handler + Fork->>Child: Child process created + Fork->>Handler: Execute after_in_parent (release lock) + deactivate Handler + Fork->>Child: Execute after_in_child (release lock) + Child->>Handler: Attempt to log + Note over Child: SUCCESS: Lock is free + Handler-->>Child: Log written successfully +``` + +## 7. Usage Example + +Fixtures interact with the parallel manager via the `parallel_manager` fixture. + +```python +import pytest +from tests.common.plugins.parallel_fixture import TaskScope + +@pytest.fixture(scope="module") +def heavy_initialization(parallel_manager): + def setup_logic(): + # Configure DUT + ... + + def teardown_logic(): + # Cleanup resources + ... + + # Submit setup task to run in background + future = parallel_manager.submit_setup_task(TaskScope.MODULE, setup_logic) + + # Yield immediately to let other fixtures start + yield + + # Submit teardown only if setup completed successfully + if parallel_manager.is_task_finished(future): + parallel_manager.submit_teardown_task(TaskScope.MODULE, teardown_logic) +``` diff --git a/tests/common/plugins/parallel_fixture/__init__.py b/tests/common/plugins/parallel_fixture/__init__.py new file mode 100644 index 00000000000..c8300e95b66 --- /dev/null +++ b/tests/common/plugins/parallel_fixture/__init__.py @@ -0,0 +1,699 @@ +import bisect +import contextlib +import ctypes +import enum +import functools +import logging +import pytest +import threading +import time +import traceback +import sys + +from concurrent.futures import CancelledError +from concurrent.futures import FIRST_EXCEPTION +from concurrent.futures import ALL_COMPLETED +from concurrent.futures import Future +from concurrent.futures import ThreadPoolExecutor +from concurrent.futures import TimeoutError +from concurrent.futures import wait + + +class TaskScope(enum.Enum): + """Defines the lifecycle scopes for parallel task.""" + SESSION = 0 + MODULE = 1 + CLASS = 2 + FUNCTION = 3 + + +class ParallelTaskRuntimeError(Exception): + pass + + +class ParallelTaskTerminatedError(Exception): + pass + + +def raise_async_exception(tid, exc_type): + """Injects an exception into the specified thread.""" + if not isinstance(tid, int): + raise TypeError("Thread ID must be an integer") + + res = ctypes.pythonapi.PyThreadState_SetAsyncExc(ctypes.c_long(tid), + ctypes.py_object(exc_type)) + if res == 0: + logging.warning("[Parallel Fixture] Thread %s not found when raising async exception", tid) + elif res > 1: + # Clear the exception to restore interpreter consistency + ctypes.pythonapi.PyThreadState_SetAsyncExc(ctypes.c_long(tid), None) + raise SystemError("PyThreadState_SetAsyncExc affected multiple threads") + + +_log_context = threading.local() +_original_log_factory = logging.getLogRecordFactory() + + +def _prefixed_log_factory(*args, **kwargs): + record = _original_log_factory(*args, **kwargs) + # Check if we are inside a parallel task wrapper + prefix = getattr(_log_context, "prefix", None) + if prefix: + # Prepend the prefix to the log message + # This handles standard logging.info("msg") calls + record.msg = f"{prefix} {record.msg}" + return record + + +# Apply the factory globally +logging.setLogRecordFactory(_prefixed_log_factory) + + +class ParallelFixtureManager(object): + + DEFAULT_WAIT_TIMEOUT = 180 + THREAD_POOL_POLLING_INTERVAL = 0.1 + + TASK_SCOPE_SESSION = TaskScope.SESSION + TASK_SCOPE_MODULE = TaskScope.MODULE + TASK_SCOPE_CLASS = TaskScope.CLASS + TASK_SCOPE_FUNCTION = TaskScope.FUNCTION + + class ParallelTaskFuture(Future): + """A Future subclass that supports timeout handling with thread interruption.""" + + @property + def default_result(self): + if hasattr(self, '_default_result'): + return self._default_result + return None + + @default_result.setter + def default_result(self, value): + self._default_result = value + + @property + def timeout(self): + if hasattr(self, '_timeout'): + return self._timeout + return None + + @timeout.setter + def timeout(self, value): + self._timeout = value + + def result(self, timeout=None, interrupt_when_timeout=False, + return_default_on_timeout=False): + try: + return super().result(timeout=timeout) + except TimeoutError: + task_name = self.task_name + if self.cancel(): + logging.warning("[Parallel Fixture] Task %s timed out and was cancelled.", task_name) + elif self.running() and interrupt_when_timeout: + task_context = getattr(self, 'task_context', None) + if task_context and hasattr(task_context, 'tid'): + tid = task_context.tid + if tid: + logging.warning( + "[Parallel Fixture] Task %s timed out. Interrupting thread %s.", + task_name, tid + ) + raise_async_exception(tid, ParallelTaskTerminatedError) + else: + logging.warning("[Parallel Fixture] Task %s timed out but TID not found.", task_name) + if return_default_on_timeout: + logging.info("[Parallel Fixture] Task %s returning default result on timeout: %s", + task_name, self.default_result) + return self.default_result + raise + + class ParallelTaskContext(object): + """Context information for a parallel task.""" + def __init__(self, tid=None, start_time=None, end_time=None, task_name=None): + self.tid = tid + self.start_time = start_time + self.end_time = end_time + self.task_name = task_name + + def __init__(self, worker_count): + self.terminated = False + self.worker_count = worker_count + self.executor = ThreadPoolExecutor(max_workers=worker_count) + + # Initialize buckets for all defined scopes + self.setup_futures = {scope: [] for scope in TaskScope} + self.teardown_futures = {scope: [] for scope in TaskScope} + self.current_scope = None + + # Start the background monitor thread + self.monitor_lock = threading.Lock() + self.active_futures = set() + self.done_futures = set() + self.is_monitor_running = True + self.monitor_thread = threading.Thread(target=self._monitor_workers, daemon=True) + self.monitor_thread.start() + + def _monitor_workers(self): + """Monitor thread pool tasks.""" + i = 0 + while True: + future_threads = {} + with self.monitor_lock: + done_futures = set() + for f in self.active_futures: + tid = f.task_context.tid + if tid is not None: + future_threads[tid] = f + if f.done(): + done_futures.add(f) + if f.exception(): + logging.info("[Parallel Fixture] Detect exception from task %s: %s", + f.task_name, f.exception()) + else: + logging.info("[Parallel Fixture] Detect task %s is done", f.task_name) + self.active_futures -= done_futures + self.done_futures |= done_futures + + if i % 100 == 0: + # Log the running task of each thread pool worker + # every 10 seconds + log_msg = ["[Parallel Fixture] Current worker threads status:"] + current_time = time.time() + if self.executor._threads: + current_threads = list(self.executor._threads) + current_threads.sort(key=lambda t: (len(t.name), t.name)) + for thread in current_threads: + if thread.is_alive(): + if thread.ident in future_threads: + start_time = future_threads[thread.ident].task_context.start_time + log_msg.append(f"Thread {thread.name}: " + f"{future_threads[thread.ident].task_name}, " + f"{current_time - start_time}s") + else: + log_msg.append(f"Thread {thread.name}: idle") + else: + log_msg.append(f"Thread {thread.name}: terminated") + else: + log_msg.append("No alive worker thread found.") + logging.info("\n".join(log_msg)) + + if not self.is_monitor_running: + break + + time.sleep(ParallelFixtureManager.THREAD_POOL_POLLING_INTERVAL) + i += 1 + + def _resolve_scope(self, scope): + """Ensure scope is a TaskScope Enum member.""" + if isinstance(scope, TaskScope): + return scope + try: + return TaskScope(scope) + except ValueError: + raise ValueError(f"Invalid scope '{scope}'. " + f"Must be one of {[e.value for e in TaskScope]}") + + def _cancel_futures(self, futures): + for future in futures: + future.cancel() + + def _wait_for_futures(self, futures, timeout, + wait_strategy=FIRST_EXCEPTION, reraise=True, + raise_timeout_error=True): + if not futures: + return + + # Wait for all futures to complete + done, not_done = wait(futures, timeout=timeout, return_when=wait_strategy) + + # Check for exceptions in completed tasks + for future in done: + if future.exception(): + # If any exception is raised, cancel the rest + self._cancel_futures(not_done) + if reraise: + raise ParallelTaskRuntimeError from future.exception() + + # Wait timeout, cancel the rest + if not_done: + # Attempt cancel to cleanup + self._cancel_futures(not_done) + if raise_timeout_error: + raise TimeoutError( + f"Parallel Tasks Timed Out! " + f"{len(not_done)} tasks failed to complete within {timeout}s: " + f"{[f.task_name for f in not_done]}" + ) + + def _format_task_name(self, func, *args, **kwargs): + task_name = func.__name__ + if args or kwargs: + task_name += f"({args}, {kwargs})" + return task_name + + def _wrap_task(self, func, task_context): + + @functools.wraps(func) + def wrapper(*args, **kwargs): + tid = threading.get_ident() + task_context.tid = tid + task_context.start_time = time.time() + current_thread = threading.current_thread().name + + prefix = f"[Parallel Fixture][{current_thread}][{task_context.task_name}]" + # Set thread-local context for logging module + _log_context.prefix = prefix + try: + return func(*args, **kwargs) + except Exception: + _, exc_value, exc_traceback = sys.exc_info() + logging.error("[Parallel Fixture] Task %s exception:\n%s", + task_context.task_name, + traceback.format_exc()) + raise exc_value.with_traceback(exc_traceback) + finally: + _log_context.prefix = None + task_context.end_time = time.time() + logging.debug("[Parallel Fixture] Task %s finished in %.2f seconds", + task_context.task_name, task_context.end_time - task_context.start_time) + + return wrapper + + def wait_for_tasks_completion(self, futures, timeout=DEFAULT_WAIT_TIMEOUT, + wait_strategy=ALL_COMPLETED, reraise=True): + """Block until all given tasks are done.""" + logging.debug("[Parallel Fixture] Waiting for tasks to finish, timeout: %s", timeout) + self._wait_for_futures(futures, timeout, wait_strategy, reraise) + + def submit_setup_task(self, scope, func, *args, **kwargs): + """Submit a setup task to the parallel fixture manager.""" + scope = self._resolve_scope(scope) + task_name = self._format_task_name(func, *args, **kwargs) + logging.info("[Parallel Fixture] Submit setup task (%s): %s", scope, task_name) + task_context = ParallelFixtureManager.ParallelTaskContext(task_name=task_name) + wrapped_func = self._wrap_task(func, task_context) + future = self.executor.submit(wrapped_func, *args, **kwargs) + future.__class__ = ParallelFixtureManager.ParallelTaskFuture + future.task_name = task_name + future.task_context = task_context + self.setup_futures[scope].append(future) + with self.monitor_lock: + self.active_futures.add(future) + return future + + def submit_teardown_task(self, scope, func, *args, **kwargs): + """Submit a teardown task to the parallel fixture manager.""" + scope = self._resolve_scope(scope) + task_name = self._format_task_name(func, *args, **kwargs) + logging.info("[Parallel Fixture] Submit teardown task (%s): %s", scope, task_name) + task_context = ParallelFixtureManager.ParallelTaskContext(task_name=task_name) + wrapped_func = self._wrap_task(func, task_context) + future = self.executor.submit(wrapped_func, *args, **kwargs) + future.__class__ = ParallelFixtureManager.ParallelTaskFuture + future.task_name = task_name + future.task_context = task_context + self.teardown_futures[scope].append(future) + with self.monitor_lock: + self.active_futures.add(future) + return future + + def wait_for_setup_tasks(self, scope, + timeout=DEFAULT_WAIT_TIMEOUT, + wait_strategy=FIRST_EXCEPTION, reraise=True): + """Block until all setup tasks in a specific scope are done.""" + logging.debug("[Parallel Fixture] Waiting for setup tasks to finish, scope: %s, timeout: %s", scope, timeout) + scope = self._resolve_scope(scope) + futures = self.setup_futures.get(scope, []) + self._wait_for_futures(futures, timeout, wait_strategy, reraise) + + def wait_for_teardown_tasks(self, scope, + timeout=DEFAULT_WAIT_TIMEOUT, + wait_strategy=FIRST_EXCEPTION, reraise=True): + """Block until all teardown tasks in a specific scope are done.""" + logging.debug("[Parallel Fixture] Waiting for teardown tasks to finish, scope: %s, timeout: %s", scope, timeout) + scope = self._resolve_scope(scope) + futures = self.teardown_futures.get(scope, []) + self._wait_for_futures(futures, timeout, wait_strategy, reraise) + + def terminate(self): + """Terminate the parallel fixture manager.""" + + if self.terminated: + return + + logging.info("[Parallel Fixture] Terminating parallel fixture manager") + + self.terminated = True + + # Stop the monitor + self.is_monitor_running = False + self.monitor_thread.join(10) + if self.monitor_thread.is_alive(): + logging.warning("[Parallel Fixture] Monitor thread failed to terminate.") + + # Cancel any pending futures + for future in self.active_futures: + future.cancel() + + # Force terminate the thread pool workers that are still running + running_futures = [future for future in self.active_futures if not future.done()] + logging.debug("[Parallel Fixture] Running tasks to be terminated: %s", [_.task_name for _ in running_futures]) + if running_futures: + logging.debug("[Parallel Fixture] Force interrupt thread pool workers") + running_futures_tids = [future.task_context.tid for future in running_futures + if future.task_context.tid is not None] + for thread in self.executor._threads: + if thread.is_alive() and thread.ident in running_futures_tids: + raise_async_exception(thread.ident, ParallelTaskTerminatedError) + + logging.debug("[Parallel Fixture] Current worker threads: %s", + [thread.is_alive() for thread in self.executor._threads]) + # Wait for all threads to terminate + self.executor.shutdown(wait=True) + logging.debug("[Parallel Fixture] Current worker threads: %s", + [thread.is_alive() for thread in self.executor._threads]) + + cancel_futures = [] + stopped_futures = [] + pending_futures = [] + done_futures = self.done_futures + for future in self.active_futures: + try: + exc = future.exception(0.1) + if isinstance(exc, ParallelTaskTerminatedError): + stopped_futures.append(future) + except CancelledError: + cancel_futures.append(future) + except TimeoutError: + # NOTE: should never hit this as all futures are either + # cancelled or stopped with ParallelTaskTerminatedError + pending_futures.append(future) + + logging.debug(f"[Parallel Fixture] The fixture manager is terminated:\n" + f"stopped tasks {[_.task_name for _ in stopped_futures]},\n" + f"canceled tasks {[_.task_name for _ in cancel_futures]},\n" + f"pending tasks {[_.task_name for _ in pending_futures]},\n" + f"done tasks {[(_.task_name, _.exception()) for _ in done_futures]}.") + + def reset(self): + """Reset the parallel fixture manager.""" + if not self.terminated: + raise RuntimeError("Cannot reset a running parallel fixture manager.") + + logging.info("[Parallel Fixture] Resetting parallel fixture manager") + # Reinitialize buckets for all defined scopes + self.setup_futures = {scope: [] for scope in TaskScope} + self.teardown_futures = {scope: [] for scope in TaskScope} + self.current_scope = None + + self.active_futures.clear() + self.done_futures.clear() + self.executor = ThreadPoolExecutor(max_workers=self.worker_count) + self.is_monitor_running = True + self.monitor_thread = threading.Thread(target=self._monitor_workers, daemon=True) + self.monitor_thread.start() + self.terminated = False + + def check_for_exception(self): + """Check done futures and re-raise any exception.""" + with self.monitor_lock: + for future in self.done_futures: + if future.exception(): + raise ParallelTaskRuntimeError from future.exception() + + def is_task_finished(self, future): + return future.done() and future.exception() is None + + def __del__(self): + self.terminate() + + +@contextlib.contextmanager +def log_function_call_duration(func_name): + start = time.time() + logging.debug("[Parallel Fixture] Start %s", func_name) + yield + logging.debug("[Parallel Fixture] End %s, duration %s", func_name, time.time() - start) + + +# ----------------------------------------------------------------- +# the parallel manager fixture +# ----------------------------------------------------------------- + + +_PARALLEL_MANAGER = None + + +@pytest.fixture(scope="session", autouse=True) +def parallel_manager(tbinfo): + dut_count = len(tbinfo.get("duts", [])) + worker_count = min(dut_count * 8, 16) + global _PARALLEL_MANAGER + _PARALLEL_MANAGER = ParallelFixtureManager(worker_count=worker_count) + _PARALLEL_MANAGER.current_scope = TaskScope.SESSION + return _PARALLEL_MANAGER + + +# ----------------------------------------------------------------- +# the setup barrier fixtures +# ----------------------------------------------------------------- + + +@pytest.fixture(scope="session", autouse=True) +def setup_barrier_session(parallel_manager): + """Barrier to wait for all session level setup tasks to finish.""" + with log_function_call_duration("setup_barrier_session"): + parallel_manager.wait_for_setup_tasks(TaskScope.SESSION) + parallel_manager.current_scope = TaskScope.MODULE + yield + return + + +@pytest.fixture(scope="module", autouse=True) +def setup_barrier_module(parallel_manager): + """Barrier to wait for all module level setup tasks to finish.""" + with log_function_call_duration("setup_barrier_module"): + parallel_manager.wait_for_setup_tasks(TaskScope.MODULE) + parallel_manager.current_scope = TaskScope.CLASS + yield + return + + +@pytest.fixture(scope="class", autouse=True) +def setup_barrier_class(parallel_manager): + """Barrier to wait for all class level setup tasks to finish.""" + with log_function_call_duration("setup_barrier_class"): + parallel_manager.wait_for_setup_tasks(TaskScope.CLASS) + parallel_manager.current_scope = TaskScope.FUNCTION + yield + return + + +@pytest.fixture(scope="function", autouse=True) +def setup_barrier_function(parallel_manager): + """Barrier to wait for all function level setup tasks to finish.""" + with log_function_call_duration("setup_barrier_function"): + parallel_manager.wait_for_setup_tasks(TaskScope.FUNCTION) + parallel_manager.current_scope = None + yield + return + + +# ----------------------------------------------------------------- +# the teardown barrier fixtures +# ----------------------------------------------------------------- + + +@pytest.fixture(scope="session", autouse=True) +def teardown_barrier_session(parallel_manager): + """Barrier to wait for all session level teardown tasks to finish.""" + yield + with log_function_call_duration("teardown_barrier_session"): + parallel_manager.wait_for_teardown_tasks(TaskScope.SESSION) + parallel_manager.current_scope = None + + +@pytest.fixture(scope="module", autouse=True) +def teardown_barrier_module(parallel_manager): + """Barrier to wait for all module level teardown tasks to finish.""" + yield + with log_function_call_duration("teardown_barrier_module"): + parallel_manager.wait_for_teardown_tasks(TaskScope.MODULE) + parallel_manager.current_scope = TaskScope.SESSION + + +@pytest.fixture(scope="class", autouse=True) +def teardown_barrier_class(parallel_manager): + """Barrier to wait for all class level teardown tasks to finish.""" + yield + with log_function_call_duration("teardown_barrier_class"): + parallel_manager.wait_for_teardown_tasks(TaskScope.CLASS) + parallel_manager.current_scope = TaskScope.MODULE + + +@pytest.fixture(scope="function", autouse=True) +def teardown_barrier_function(parallel_manager): + """Barrier to wait for all function level teardown tasks to finish.""" + yield + with log_function_call_duration("teardown_barrier_function"): + parallel_manager.wait_for_teardown_tasks(TaskScope.FUNCTION) + parallel_manager.current_scope = TaskScope.CLASS + + +# ----------------------------------------------------------------- +# pytest hooks +# ----------------------------------------------------------------- + + +@pytest.hookimpl(wrapper=True) +def pytest_runtest_setup(item): + """ + HOOK: Runs once BEFORE every fixture setup. + Reorder the setup/teardown barriers to ensure barriers should run + after ALL fixtures of the same-scope. + """ + logging.debug("[Parallel Fixture] Setup barrier fixtures") + + barriers = { + TaskScope.SESSION.value: ["teardown_barrier_session", + "setup_barrier_session"], + TaskScope.MODULE.value: ["teardown_barrier_module", + "setup_barrier_module"], + TaskScope.CLASS.value: ["teardown_barrier_class", + "setup_barrier_class"], + TaskScope.FUNCTION.value: ["teardown_barrier_function", + "setup_barrier_function"] + } + fixtureinfo = item._fixtureinfo + current_fixture_names = fixtureinfo.names_closure[:] + + logging.debug("[Parallel Fixture] Fixture order before:\n%s", current_fixture_names) + + for fixtures in barriers.values(): + for fixture in fixtures: + if fixture in current_fixture_names: + current_fixture_names.remove(fixture) + current_fixture_scopes = [] + for fixture in current_fixture_names: + fixture_defs = fixtureinfo.name2fixturedefs.get(fixture, []) + if not fixture_defs: + fixture_scope = current_fixture_scopes[-1] \ + if current_fixture_scopes else TaskScope.SESSION.value + else: + try: + fixture_scope = TaskScope[fixture_defs[0].scope.upper()].value + except Exception: + logging.debug("[Parallel Fixture] Unknown fixture scope for %r," + "default to previous scope", fixture_defs) + fixture_scope = current_fixture_scopes[-1] \ + if current_fixture_scopes else TaskScope.SESSION.value + current_fixture_scopes.append(fixture_scope) + + # NOTE: Inject the barriers to ensure they are running last + # in the fixtures of the same scope. + for scope, fixtures in barriers.items(): + for fixture in fixtures: + if fixture.startswith("setup"): + insert_pos = bisect.bisect_right(current_fixture_scopes, scope) + current_fixture_names.insert(insert_pos, fixture) + current_fixture_scopes.insert(insert_pos, scope) + if fixture.startswith("teardown"): + insert_pos = bisect.bisect_left(current_fixture_scopes, scope) + current_fixture_names.insert(insert_pos, fixture) + current_fixture_scopes.insert(insert_pos, scope) + + logging.debug("[Parallel Fixture] Fixture order after:\n%s", current_fixture_names) + fixtureinfo.names_closure[:] = current_fixture_names + + yield + return + + +@pytest.hookimpl(tryfirst=True) +def pytest_fixture_setup(fixturedef, request): + """ + HOOK: Runs BEFORE every fixture setup. + If a background task failed while the PREVIOUS fixture was running, + we catch it here and stop the next fixture from starting. + """ + if _PARALLEL_MANAGER: + logging.debug("[Parallel Fixture] Check for fixture exceptions before running %r", fixturedef) + _PARALLEL_MANAGER.check_for_exception() + + +@pytest.hookimpl(tryfirst=True) +def pytest_runtest_call(item): + """ + HOOK: Runs BEFORE the test function starts. + Happy path to terminate the parallel fixture manager. + All tasks should be done as those barrier fixtures should catch them + all. + """ + logging.debug("[Parallel Fixture] Wait for tasks to finish before test function") + parallel_manager = _PARALLEL_MANAGER + if parallel_manager: + try: + for scope in TaskScope: + parallel_manager.wait_for_setup_tasks(scope) + finally: + parallel_manager.terminate() + + +def pytest_exception_interact(call, report): + """ + HOOK: Runs WHEN an exception occurs. + Sad path to terminate the parallel fixture manager. + When a ParallelTaskRuntimeError is detected, tries to poll + the rest running tasks and terminate the parallel manager. + """ + parallel_manager = _PARALLEL_MANAGER + if parallel_manager and report.when == "setup": + reraise = not isinstance(call.excinfo.value, ParallelTaskRuntimeError) + logging.debug("[Parallel Fixture] Wait for tasks to finish after exception occurred in setup %s", + call.excinfo.value) + try: + for scope in TaskScope: + parallel_manager.wait_for_setup_tasks(scope, wait_strategy=ALL_COMPLETED, reraise=reraise) + finally: + parallel_manager.terminate() + + +def pytest_runtest_teardown(item, nextitem): + """ + HOOK: Runs once BEFORE all fixture teardown. + Reset the parallel manager. + """ + logging.debug("[Parallel Fixture] Reset parallel manager before teardown") + parallel_manager = _PARALLEL_MANAGER + if parallel_manager: + # If the manager wasn't terminated yet, it means pytest_runtest_call + # never ran (test was skipped or failed early during setup). + # Terminate it first before resetting. + if not parallel_manager.terminated: + logging.debug("[Parallel Fixture] Test was skipped or failed early, " + "waiting for tasks to finish before terminating") + try: + for scope in TaskScope: + parallel_manager.wait_for_setup_tasks(scope) + finally: + parallel_manager.terminate() + logging.debug("[Parallel Fixture] Test was skipped or failed early, " + "terminating parallel manager") + parallel_manager.terminate() + + parallel_manager.reset() + parallel_manager.current_scope = TaskScope.FUNCTION + + +def pytest_runtest_logreport(report): + """ + HOOK: Runs once AFTER all fixture setup/teardown. + Terminate the parallel manager. + """ + if report.when != "teardown": + return + logging.debug("[Parallel Fixture] Terminate parallel manager after teardown") + parallel_manager = _PARALLEL_MANAGER + if parallel_manager: + parallel_manager.terminate() diff --git a/tests/common/plugins/parallel_fixture/images/pytest.jpg b/tests/common/plugins/parallel_fixture/images/pytest.jpg new file mode 100644 index 00000000000..f066d793e6d Binary files /dev/null and b/tests/common/plugins/parallel_fixture/images/pytest.jpg differ diff --git a/tests/conftest.py b/tests/conftest.py index e9bdbe709e1..5b29990f84e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -119,7 +119,8 @@ 'tests.common.plugins.conditional_mark', 'tests.common.plugins.random_seed', 'tests.common.plugins.memory_utilization', - 'tests.common.fixtures.duthost_utils') + 'tests.common.fixtures.duthost_utils', + 'tests.common.plugins.parallel_fixture') # NOTE: This is to backport fix https://github.com/python/cpython/pull/126098