Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/api-reference/marvin-fns-run.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,13 @@ def run_stream(instructions: str | Sequence[UserContent], result_type: type[T] =
```python
def run_tasks(tasks: list[Task[Any]], thread: Thread | str | None = None, raise_on_failure: bool = True, handlers: list[Handler | AsyncHandler] | None = None) -> list[Task[Any]]
```
Run tasks either concurrently (if independent) or sequentially.

### `run_tasks_async`
```python
def run_tasks_async(tasks: list[Task[Any]], thread: Thread | str | None = None, raise_on_failure: bool = True, handlers: list[Handler | AsyncHandler] | None = None) -> list[Task[Any]] | AsyncGenerator[Event, None]
```
Run tasks either concurrently (if independent) or sequentially via orchestrator.

### `run_tasks_stream`
```python
Expand Down
1 change: 0 additions & 1 deletion sandbox/prefect
Submodule prefect deleted from 602ae7
12 changes: 11 additions & 1 deletion src/marvin/agents/actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,17 @@ def __enter__(self):
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any):
"""Reset the current actor in context."""
if self._tokens: # Only reset if we have tokens
_current_actor.reset(self._tokens.pop())
try:
_current_actor.reset(self._tokens.pop())
except ValueError as e:
# Token was created in a different async context (e.g., asyncio.gather)
# This happens when tasks run concurrently and is expected behavior
if "was created in a different Context" in str(e):
# This is the expected concurrent execution case - ignore safely
pass
else:
# Some other ValueError - re-raise it
raise

@classmethod
def get_current(cls) -> "Actor | None":
Expand Down
17 changes: 16 additions & 1 deletion src/marvin/engine/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,22 @@ async def run_once(
if actor is None:
actor = tasks[0].get_actor()

assigned_tasks = [t for t in tasks if actor is t.get_actor()]
# Get tasks assigned to this actor
potential_tasks = [t for t in tasks if actor is t.get_actor()]

# For independent tasks, only assign one per turn to avoid EndTurn conflicts
if len(potential_tasks) > 1:
# Check if any tasks depend on each other
has_deps = any(
t2 in t1.depends_on or t1 in t2.depends_on
for t1 in potential_tasks
for t2 in potential_tasks
if t1 != t2
)
# If independent, process one at a time
assigned_tasks = [potential_tasks[0]] if not has_deps else potential_tasks
else:
assigned_tasks = potential_tasks

# Mark tasks as running if they're pending
for task in assigned_tasks:
Expand Down
39 changes: 32 additions & 7 deletions src/marvin/fns/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,43 @@
T = TypeVar("T")


def _tasks_are_independent(tasks: list[Task[Any]]) -> bool:
"""Check if tasks have no dependencies between each other."""
for i, task1 in enumerate(tasks):
for j, task2 in enumerate(tasks):
if i != j:
# Check if task1 depends on task2 or vice versa
if task2 in task1.depends_on or task1 in task2.depends_on:
return False
# Check if they share subtasks or parent relationships
if task1.parent == task2 or task2.parent == task1:
return False
if task1 in task2.subtasks or task2 in task1.subtasks:
return False
return True


async def run_tasks_async(
tasks: list[Task[Any]],
thread: Thread | str | None = None,
raise_on_failure: bool = True,
handlers: list[Handler | AsyncHandler] | None = None,
) -> list[Task[Any]] | AsyncGenerator[Event, None]:
orchestrator = Orchestrator(
tasks=tasks,
thread=thread,
handlers=handlers,
)
await orchestrator.run(raise_on_failure=raise_on_failure)
return tasks
"""Run tasks either concurrently (if independent) or sequentially via orchestrator."""
# If we have multiple independent tasks, run them concurrently
if len(tasks) > 1 and _tasks_are_independent(tasks):
# Run independent tasks concurrently using asyncio.gather
await asyncio.gather(*[task.run_async() for task in tasks])
return tasks
else:
# Use orchestrator for dependent tasks or single tasks
orchestrator = Orchestrator(
tasks=tasks,
thread=thread,
handlers=handlers,
)
await orchestrator.run(raise_on_failure=raise_on_failure)
return tasks


async def run_tasks_stream(
Expand Down Expand Up @@ -79,6 +103,7 @@ def run_tasks(
raise_on_failure: bool = True,
handlers: list[Handler | AsyncHandler] | None = None,
) -> list[Task[Any]]:
"""Run tasks either concurrently (if independent) or sequentially."""
return marvin.utilities.asyncio.run_sync(
run_tasks_async(
tasks=tasks,
Expand Down
191 changes: 191 additions & 0 deletions tests/test_concurrent_execution.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
"""Unit tests for concurrent task execution."""

import asyncio

import pytest

from marvin import Task
from marvin.agents.agent import Agent
from marvin.fns.run import _tasks_are_independent, run_tasks, run_tasks_async


class TestTaskIndependenceDetection:
"""Test the independence detection logic."""

def test_independent_tasks(self):
"""Test that truly independent tasks are detected as such."""
task1 = Task("Say 'one'", result_type=str)
task2 = Task("Say 'two'", result_type=str)
task3 = Task("Say 'three'", result_type=str)

assert _tasks_are_independent([task1, task2, task3])

def test_dependent_tasks_depends_on(self):
"""Test that tasks with depends_on are not independent."""
task1 = Task("Say 'one'", result_type=str)
task2 = Task("Say 'two'", result_type=str, depends_on=[task1])

assert not _tasks_are_independent([task1, task2])

def test_dependent_tasks_parent_child(self):
"""Test that parent-child tasks are not independent."""
parent = Task("Parent task", result_type=str)
child = Task("Child task", result_type=str)
parent.subtasks.add(child) # subtasks is a set, not list
child.parent = parent

assert not _tasks_are_independent([parent, child])

def test_single_task_is_independent(self):
"""Test that a single task is considered independent."""
task = Task("Solo task", result_type=str)
assert _tasks_are_independent([task])

def test_empty_task_list(self):
"""Test empty task list."""
assert _tasks_are_independent([])


class TestConcurrentExecution:
"""Test actual concurrent execution behavior."""

@pytest.mark.asyncio
async def test_independent_tasks_run_without_errors(self):
"""Test that independent tasks run without Multiple EndTurn warnings or errors."""
task1 = Task("Say 'one'", result_type=str)
task2 = Task("Say 'two'", result_type=str)
task3 = Task("Say 'three'", result_type=str)

# Should not raise any errors (no Multiple EndTurn warnings, no infinite loops)
results = await run_tasks_async([task1, task2, task3])

assert len(results) == 3
assert all(task.is_successful() for task in results)

# Verify results
result_values = [task.result for task in results]
assert set(result_values) == {"one", "two", "three"}

@pytest.mark.asyncio
async def test_dependent_tasks_run_in_order(self):
"""Test that dependent tasks run in correct order."""
task1 = Task("Say 'A'", result_type=str)
task2 = Task("Say 'B'", result_type=str, depends_on=[task1])
task3 = Task("Say 'C'", result_type=str, depends_on=[task2])

results = await run_tasks_async([task1, task2, task3])

assert len(results) == 3
assert all(task.is_successful() for task in results)

# Verify correct order
result_values = [task.result for task in results]
assert result_values == ["A", "B", "C"]

def test_sync_run_tasks_independent(self):
"""Test synchronous run_tasks with independent tasks."""
task1 = Task("Say 'one'", result_type=str)
task2 = Task("Say 'two'", result_type=str)

results = run_tasks([task1, task2])

assert len(results) == 2
assert all(task.is_successful() for task in results)
assert set(t.result for t in results) == {"one", "two"}

def test_sync_run_tasks_dependent(self):
"""Test synchronous run_tasks with dependent tasks."""
task1 = Task("Say 'A'", result_type=str)
task2 = Task("Say 'B'", result_type=str, depends_on=[task1])

results = run_tasks([task1, task2])

assert len(results) == 2
assert all(task.is_successful() for task in results)

# Verify correct order
result_values = [task.result for task in results]
assert result_values == ["A", "B"]


class TestAsyncioGatherCompatibility:
"""Test that asyncio.gather works without ContextVar errors."""

@pytest.mark.asyncio
async def test_asyncio_gather_no_context_errors(self):
"""Test that asyncio.gather doesn't throw ContextVar errors."""
task1 = Task("Say 'async1'", result_type=str)
task2 = Task("Say 'async2'", result_type=str)
task3 = Task("Say 'async3'", result_type=str)

# This should not raise ContextVar token errors
results = await asyncio.gather(
task1.run_async(), task2.run_async(), task3.run_async()
)

assert len(results) == 3
assert set(results) == {"async1", "async2", "async3"}

@pytest.mark.asyncio
async def test_mixed_execution_patterns(self):
"""Test mixing run_tasks_async and asyncio.gather in same event loop."""
# First batch via run_tasks_async
task1 = Task("Say 'batch1'", result_type=str)
task2 = Task("Say 'batch2'", result_type=str)
batch1_results = await run_tasks_async([task1, task2])

# Second batch via asyncio.gather
task3 = Task("Say 'gather1'", result_type=str)
task4 = Task("Say 'gather2'", result_type=str)
batch2_results = await asyncio.gather(task3.run_async(), task4.run_async())

# Both should work without errors
assert len(batch1_results) == 2
assert all(task.is_successful() for task in batch1_results)
assert len(batch2_results) == 2
assert set(batch2_results) == {"gather1", "gather2"}


class TestContextVarHandling:
"""Test ContextVar token handling across async contexts."""

@pytest.mark.asyncio
async def test_actor_context_across_asyncio_gather(self):
"""Test that Actor context management handles asyncio.gather correctly."""
from marvin.agents.actor import _current_actor

async def task_with_actor(name):
actor = Agent(name=f"Agent_{name}")
# This should not raise an error even with asyncio.gather
with actor:
assert _current_actor.get() == actor
await asyncio.sleep(0.1) # Simulate async work
# Context should be reset without errors
return name

# Test that concurrent context management works
results = await asyncio.gather(
task_with_actor("1"), task_with_actor("2"), task_with_actor("3")
)

assert results == ["1", "2", "3"]
# Context should be None after all tasks complete
assert _current_actor.get() is None

def test_actor_context_sequential(self):
"""Test that Actor context works normally in sequential execution."""
from marvin.agents.actor import _current_actor

actor1 = Agent(name="Sequential_1")
actor2 = Agent(name="Sequential_2")

# Test nested contexts work correctly
assert _current_actor.get() is None

with actor1:
assert _current_actor.get() == actor1
with actor2:
assert _current_actor.get() == actor2
assert _current_actor.get() == actor1

assert _current_actor.get() is None