Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 22 additions & 1 deletion agentlightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,9 @@ class Trainer(TrainerLegacy):
or a dictionary with the initialization parameters for the exporter.
Deprecated. Use [`adapter`][agentlightning.Trainer.adapter] instead."""

port: Optional[int]
"""Port forwarded to [`ClientServerExecutionStrategy`][agentlightning.ClientServerExecutionStrategy]."""

def __init__(
self,
*,
Expand All @@ -126,6 +129,7 @@ def __init__(
store: ComponentSpec[LightningStore] = None,
runner: ComponentSpec[Runner[Any]] = None,
strategy: ComponentSpec[ExecutionStrategy] = None,
port: Optional[int] = None,
algorithm: ComponentSpec[Algorithm] = None,
llm_proxy: ComponentSpec[LLMProxy] = None,
n_workers: Optional[int] = None,
Expand All @@ -139,6 +143,10 @@ def __init__(
Each keyword accepts either a concrete instance, a class, a callable factory, a
registry string, or a lightweight configuration dictionary (see
[`build_component()`][agentlightning.trainer.init_utils.build_component]).

When ``port`` is provided it is forwarded to
[`ClientServerExecutionStrategy`][agentlightning.ClientServerExecutionStrategy]
instances constructed (or supplied) for the trainer.
"""
# Do not call super().__init__() here.
# super().__init__() will call TrainerLegacy's initialization, which is not intended.
Expand Down Expand Up @@ -209,7 +217,13 @@ def __init__(
self.store = self._make_store(store)
self.runner = self._make_runner(runner)

self.strategy = self._make_strategy(strategy, n_runners=self.n_runners)
self.port = port

self.strategy = self._make_strategy(
strategy,
n_runners=self.n_runners,
port=port,
)
if hasattr(self.strategy, "n_runners"):
strategy_runners = getattr(self.strategy, "n_runners")
if isinstance(strategy_runners, int) and strategy_runners > 0:
Expand Down Expand Up @@ -284,13 +298,20 @@ def _make_strategy(
strategy: ComponentSpec[ExecutionStrategy],
*,
n_runners: int,
port: Optional[int] = None,
) -> ExecutionStrategy:
"""Resolve the execution strategy and seed defaults such as `n_runners`."""
if isinstance(strategy, ExecutionStrategy):
if port is not None and isinstance(strategy, ClientServerExecutionStrategy):
strategy.server_port = port
return strategy
optional_defaults: Dict[str, Callable[[], Any]] = {"n_runners": lambda: n_runners}
if port is not None:
optional_defaults["server_port"] = lambda: port

def default_factory() -> ExecutionStrategy:
if port is not None:
return ClientServerExecutionStrategy(n_runners=n_runners, server_port=port)
return ClientServerExecutionStrategy(n_runners=n_runners)

Comment on lines +313 to 316
Copy link

Copilot AI Oct 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The conditional logic duplicates the ClientServerExecutionStrategy instantiation. Simplify by constructing kwargs conditionally and creating the strategy once.

Suggested change
if port is not None:
return ClientServerExecutionStrategy(n_runners=n_runners, server_port=port)
return ClientServerExecutionStrategy(n_runners=n_runners)
kwargs = {"n_runners": n_runners}
if port is not None:
kwargs["server_port"] = port
return ClientServerExecutionStrategy(**kwargs)

Copilot uses AI. Check for mistakes.
return build_component(
Expand Down
40 changes: 40 additions & 0 deletions tests/trainer/test_trainer_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,46 @@ def test_trainer_with_client_server_strategy_dict() -> None:
assert trainer.strategy.server_port == 9999


def test_trainer_port_forwarded_to_client_server_strategy() -> None:
"""Test that the top-level port argument configures the client-server strategy."""
trainer = agl.Trainer(
algorithm=agl.Baseline(),
n_runners=4,
port=8081,
)

assert isinstance(trainer.strategy, agl.ClientServerExecutionStrategy)
assert trainer.strategy.server_port == 8081


def test_trainer_port_ignored_for_non_client_server_strategy() -> None:
"""Test that port has no effect when using a non client-server strategy."""
trainer = agl.Trainer(
algorithm=agl.Baseline(),
n_runners=1,
port=8082,
strategy="shm",
)

assert isinstance(trainer.strategy, agl.SharedMemoryExecutionStrategy)
assert not hasattr(trainer.strategy, "server_port")


def test_trainer_port_overrides_existing_client_server_strategy() -> None:
"""Test that provided port overrides an initialized client-server strategy."""
strategy = agl.ClientServerExecutionStrategy(server_port=9000)

trainer = agl.Trainer(
algorithm=agl.Baseline(),
n_runners=1,
strategy=strategy,
port=9100,
)

assert trainer.strategy is strategy
assert trainer.strategy.server_port == 9100 # type: ignore


def test_trainer_with_env_vars_for_execution_strategy(monkeypatch: pytest.MonkeyPatch) -> None:
"""Test that execution strategy supports environment variables to override values."""
algorithm = agl.Baseline()
Expand Down
Loading