diff --git a/docs/api-reference/marvin-fns-run.mdx b/docs/api-reference/marvin-fns-run.mdx index afd5d53fc..b585ab667 100644 --- a/docs/api-reference/marvin-fns-run.mdx +++ b/docs/api-reference/marvin-fns-run.mdx @@ -32,11 +32,13 @@ def run_stream(instructions: str | Sequence[UserContent], result_type: type[T] = ```python def run_tasks(tasks: list[Task[Any]], thread: Thread | str | None = None, raise_on_failure: bool = True, handlers: list[Handler | AsyncHandler] | None = None) -> list[Task[Any]] ``` +Run tasks either concurrently (if independent) or sequentially. ### `run_tasks_async` ```python def run_tasks_async(tasks: list[Task[Any]], thread: Thread | str | None = None, raise_on_failure: bool = True, handlers: list[Handler | AsyncHandler] | None = None) -> list[Task[Any]] | AsyncGenerator[Event, None] ``` +Run tasks either concurrently (if independent) or sequentially via orchestrator. ### `run_tasks_stream` ```python diff --git a/sandbox/prefect b/sandbox/prefect deleted file mode 160000 index 602ae7cdd..000000000 --- a/sandbox/prefect +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 602ae7cddb1540d0e6cbba66e17d2718ba571515 diff --git a/src/marvin/agents/actor.py b/src/marvin/agents/actor.py index f8b1e2671..1c1197964 100644 --- a/src/marvin/agents/actor.py +++ b/src/marvin/agents/actor.py @@ -76,7 +76,17 @@ def __enter__(self): def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any): """Reset the current actor in context.""" if self._tokens: # Only reset if we have tokens - _current_actor.reset(self._tokens.pop()) + try: + _current_actor.reset(self._tokens.pop()) + except ValueError as e: + # Token was created in a different async context (e.g., asyncio.gather) + # This happens when tasks run concurrently and is expected behavior + if "was created in a different Context" in str(e): + # This is the expected concurrent execution case - ignore safely + pass + else: + # Some other ValueError - re-raise it + raise @classmethod def get_current(cls) -> "Actor | None": diff --git a/src/marvin/engine/orchestrator.py b/src/marvin/engine/orchestrator.py index d7b8dfdef..872ed5233 100644 --- a/src/marvin/engine/orchestrator.py +++ b/src/marvin/engine/orchestrator.py @@ -143,7 +143,22 @@ async def run_once( if actor is None: actor = tasks[0].get_actor() - assigned_tasks = [t for t in tasks if actor is t.get_actor()] + # Get tasks assigned to this actor + potential_tasks = [t for t in tasks if actor is t.get_actor()] + + # For independent tasks, only assign one per turn to avoid EndTurn conflicts + if len(potential_tasks) > 1: + # Check if any tasks depend on each other + has_deps = any( + t2 in t1.depends_on or t1 in t2.depends_on + for t1 in potential_tasks + for t2 in potential_tasks + if t1 != t2 + ) + # If independent, process one at a time + assigned_tasks = [potential_tasks[0]] if not has_deps else potential_tasks + else: + assigned_tasks = potential_tasks # Mark tasks as running if they're pending for task in assigned_tasks: diff --git a/src/marvin/fns/run.py b/src/marvin/fns/run.py index e8fc2240b..2c3ddc609 100644 --- a/src/marvin/fns/run.py +++ b/src/marvin/fns/run.py @@ -15,19 +15,43 @@ T = TypeVar("T") +def _tasks_are_independent(tasks: list[Task[Any]]) -> bool: + """Check if tasks have no dependencies between each other.""" + for i, task1 in enumerate(tasks): + for j, task2 in enumerate(tasks): + if i != j: + # Check if task1 depends on task2 or vice versa + if task2 in task1.depends_on or task1 in task2.depends_on: + return False + # Check if they share subtasks or parent relationships + if task1.parent == task2 or task2.parent == task1: + return False + if task1 in task2.subtasks or task2 in task1.subtasks: + return False + return True + + async def run_tasks_async( tasks: list[Task[Any]], thread: Thread | str | None = None, raise_on_failure: bool = True, handlers: list[Handler | AsyncHandler] | None = None, ) -> list[Task[Any]] | AsyncGenerator[Event, None]: - orchestrator = Orchestrator( - tasks=tasks, - thread=thread, - handlers=handlers, - ) - await orchestrator.run(raise_on_failure=raise_on_failure) - return tasks + """Run tasks either concurrently (if independent) or sequentially via orchestrator.""" + # If we have multiple independent tasks, run them concurrently + if len(tasks) > 1 and _tasks_are_independent(tasks): + # Run independent tasks concurrently using asyncio.gather + await asyncio.gather(*[task.run_async() for task in tasks]) + return tasks + else: + # Use orchestrator for dependent tasks or single tasks + orchestrator = Orchestrator( + tasks=tasks, + thread=thread, + handlers=handlers, + ) + await orchestrator.run(raise_on_failure=raise_on_failure) + return tasks async def run_tasks_stream( @@ -79,6 +103,7 @@ def run_tasks( raise_on_failure: bool = True, handlers: list[Handler | AsyncHandler] | None = None, ) -> list[Task[Any]]: + """Run tasks either concurrently (if independent) or sequentially.""" return marvin.utilities.asyncio.run_sync( run_tasks_async( tasks=tasks, diff --git a/tests/test_concurrent_execution.py b/tests/test_concurrent_execution.py new file mode 100644 index 000000000..989b7f605 --- /dev/null +++ b/tests/test_concurrent_execution.py @@ -0,0 +1,191 @@ +"""Unit tests for concurrent task execution.""" + +import asyncio + +import pytest + +from marvin import Task +from marvin.agents.agent import Agent +from marvin.fns.run import _tasks_are_independent, run_tasks, run_tasks_async + + +class TestTaskIndependenceDetection: + """Test the independence detection logic.""" + + def test_independent_tasks(self): + """Test that truly independent tasks are detected as such.""" + task1 = Task("Say 'one'", result_type=str) + task2 = Task("Say 'two'", result_type=str) + task3 = Task("Say 'three'", result_type=str) + + assert _tasks_are_independent([task1, task2, task3]) + + def test_dependent_tasks_depends_on(self): + """Test that tasks with depends_on are not independent.""" + task1 = Task("Say 'one'", result_type=str) + task2 = Task("Say 'two'", result_type=str, depends_on=[task1]) + + assert not _tasks_are_independent([task1, task2]) + + def test_dependent_tasks_parent_child(self): + """Test that parent-child tasks are not independent.""" + parent = Task("Parent task", result_type=str) + child = Task("Child task", result_type=str) + parent.subtasks.add(child) # subtasks is a set, not list + child.parent = parent + + assert not _tasks_are_independent([parent, child]) + + def test_single_task_is_independent(self): + """Test that a single task is considered independent.""" + task = Task("Solo task", result_type=str) + assert _tasks_are_independent([task]) + + def test_empty_task_list(self): + """Test empty task list.""" + assert _tasks_are_independent([]) + + +class TestConcurrentExecution: + """Test actual concurrent execution behavior.""" + + @pytest.mark.asyncio + async def test_independent_tasks_run_without_errors(self): + """Test that independent tasks run without Multiple EndTurn warnings or errors.""" + task1 = Task("Say 'one'", result_type=str) + task2 = Task("Say 'two'", result_type=str) + task3 = Task("Say 'three'", result_type=str) + + # Should not raise any errors (no Multiple EndTurn warnings, no infinite loops) + results = await run_tasks_async([task1, task2, task3]) + + assert len(results) == 3 + assert all(task.is_successful() for task in results) + + # Verify results + result_values = [task.result for task in results] + assert set(result_values) == {"one", "two", "three"} + + @pytest.mark.asyncio + async def test_dependent_tasks_run_in_order(self): + """Test that dependent tasks run in correct order.""" + task1 = Task("Say 'A'", result_type=str) + task2 = Task("Say 'B'", result_type=str, depends_on=[task1]) + task3 = Task("Say 'C'", result_type=str, depends_on=[task2]) + + results = await run_tasks_async([task1, task2, task3]) + + assert len(results) == 3 + assert all(task.is_successful() for task in results) + + # Verify correct order + result_values = [task.result for task in results] + assert result_values == ["A", "B", "C"] + + def test_sync_run_tasks_independent(self): + """Test synchronous run_tasks with independent tasks.""" + task1 = Task("Say 'one'", result_type=str) + task2 = Task("Say 'two'", result_type=str) + + results = run_tasks([task1, task2]) + + assert len(results) == 2 + assert all(task.is_successful() for task in results) + assert set(t.result for t in results) == {"one", "two"} + + def test_sync_run_tasks_dependent(self): + """Test synchronous run_tasks with dependent tasks.""" + task1 = Task("Say 'A'", result_type=str) + task2 = Task("Say 'B'", result_type=str, depends_on=[task1]) + + results = run_tasks([task1, task2]) + + assert len(results) == 2 + assert all(task.is_successful() for task in results) + + # Verify correct order + result_values = [task.result for task in results] + assert result_values == ["A", "B"] + + +class TestAsyncioGatherCompatibility: + """Test that asyncio.gather works without ContextVar errors.""" + + @pytest.mark.asyncio + async def test_asyncio_gather_no_context_errors(self): + """Test that asyncio.gather doesn't throw ContextVar errors.""" + task1 = Task("Say 'async1'", result_type=str) + task2 = Task("Say 'async2'", result_type=str) + task3 = Task("Say 'async3'", result_type=str) + + # This should not raise ContextVar token errors + results = await asyncio.gather( + task1.run_async(), task2.run_async(), task3.run_async() + ) + + assert len(results) == 3 + assert set(results) == {"async1", "async2", "async3"} + + @pytest.mark.asyncio + async def test_mixed_execution_patterns(self): + """Test mixing run_tasks_async and asyncio.gather in same event loop.""" + # First batch via run_tasks_async + task1 = Task("Say 'batch1'", result_type=str) + task2 = Task("Say 'batch2'", result_type=str) + batch1_results = await run_tasks_async([task1, task2]) + + # Second batch via asyncio.gather + task3 = Task("Say 'gather1'", result_type=str) + task4 = Task("Say 'gather2'", result_type=str) + batch2_results = await asyncio.gather(task3.run_async(), task4.run_async()) + + # Both should work without errors + assert len(batch1_results) == 2 + assert all(task.is_successful() for task in batch1_results) + assert len(batch2_results) == 2 + assert set(batch2_results) == {"gather1", "gather2"} + + +class TestContextVarHandling: + """Test ContextVar token handling across async contexts.""" + + @pytest.mark.asyncio + async def test_actor_context_across_asyncio_gather(self): + """Test that Actor context management handles asyncio.gather correctly.""" + from marvin.agents.actor import _current_actor + + async def task_with_actor(name): + actor = Agent(name=f"Agent_{name}") + # This should not raise an error even with asyncio.gather + with actor: + assert _current_actor.get() == actor + await asyncio.sleep(0.1) # Simulate async work + # Context should be reset without errors + return name + + # Test that concurrent context management works + results = await asyncio.gather( + task_with_actor("1"), task_with_actor("2"), task_with_actor("3") + ) + + assert results == ["1", "2", "3"] + # Context should be None after all tasks complete + assert _current_actor.get() is None + + def test_actor_context_sequential(self): + """Test that Actor context works normally in sequential execution.""" + from marvin.agents.actor import _current_actor + + actor1 = Agent(name="Sequential_1") + actor2 = Agent(name="Sequential_2") + + # Test nested contexts work correctly + assert _current_actor.get() is None + + with actor1: + assert _current_actor.get() == actor1 + with actor2: + assert _current_actor.get() == actor2 + assert _current_actor.get() == actor1 + + assert _current_actor.get() is None