Skip to content

Commit aab9765

Browse files
authored
Add Trainer port option for client-server strategies (#198)
1 parent 91c85ae commit aab9765

File tree

2 files changed

+62
-1
lines changed

2 files changed

+62
-1
lines changed

agentlightning/trainer/trainer.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,9 @@ class Trainer(TrainerLegacy):
114114
or a dictionary with the initialization parameters for the exporter.
115115
Deprecated. Use [`adapter`][agentlightning.Trainer.adapter] instead."""
116116

117+
port: Optional[int]
118+
"""Port forwarded to [`ClientServerExecutionStrategy`][agentlightning.ClientServerExecutionStrategy]."""
119+
117120
def __init__(
118121
self,
119122
*,
@@ -126,6 +129,7 @@ def __init__(
126129
store: ComponentSpec[LightningStore] = None,
127130
runner: ComponentSpec[Runner[Any]] = None,
128131
strategy: ComponentSpec[ExecutionStrategy] = None,
132+
port: Optional[int] = None,
129133
algorithm: ComponentSpec[Algorithm] = None,
130134
llm_proxy: ComponentSpec[LLMProxy] = None,
131135
n_workers: Optional[int] = None,
@@ -139,6 +143,10 @@ def __init__(
139143
Each keyword accepts either a concrete instance, a class, a callable factory, a
140144
registry string, or a lightweight configuration dictionary (see
141145
[`build_component()`][agentlightning.trainer.init_utils.build_component]).
146+
147+
When ``port`` is provided it is forwarded to
148+
[`ClientServerExecutionStrategy`][agentlightning.ClientServerExecutionStrategy]
149+
instances constructed (or supplied) for the trainer.
142150
"""
143151
# Do not call super().__init__() here.
144152
# super().__init__() will call TrainerLegacy's initialization, which is not intended.
@@ -209,7 +217,13 @@ def __init__(
209217
self.store = self._make_store(store)
210218
self.runner = self._make_runner(runner)
211219

212-
self.strategy = self._make_strategy(strategy, n_runners=self.n_runners)
220+
self.port = port
221+
222+
self.strategy = self._make_strategy(
223+
strategy,
224+
n_runners=self.n_runners,
225+
port=port,
226+
)
213227
if hasattr(self.strategy, "n_runners"):
214228
strategy_runners = getattr(self.strategy, "n_runners")
215229
if isinstance(strategy_runners, int) and strategy_runners > 0:
@@ -284,13 +298,20 @@ def _make_strategy(
284298
strategy: ComponentSpec[ExecutionStrategy],
285299
*,
286300
n_runners: int,
301+
port: Optional[int] = None,
287302
) -> ExecutionStrategy:
288303
"""Resolve the execution strategy and seed defaults such as `n_runners`."""
289304
if isinstance(strategy, ExecutionStrategy):
305+
if port is not None and isinstance(strategy, ClientServerExecutionStrategy):
306+
strategy.server_port = port
290307
return strategy
291308
optional_defaults: Dict[str, Callable[[], Any]] = {"n_runners": lambda: n_runners}
309+
if port is not None:
310+
optional_defaults["server_port"] = lambda: port
292311

293312
def default_factory() -> ExecutionStrategy:
313+
if port is not None:
314+
return ClientServerExecutionStrategy(n_runners=n_runners, server_port=port)
294315
return ClientServerExecutionStrategy(n_runners=n_runners)
295316

296317
return build_component(

tests/trainer/test_trainer_init.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,46 @@ def test_trainer_with_client_server_strategy_dict() -> None:
7777
assert trainer.strategy.server_port == 9999
7878

7979

80+
def test_trainer_port_forwarded_to_client_server_strategy() -> None:
81+
"""Test that the top-level port argument configures the client-server strategy."""
82+
trainer = agl.Trainer(
83+
algorithm=agl.Baseline(),
84+
n_runners=4,
85+
port=8081,
86+
)
87+
88+
assert isinstance(trainer.strategy, agl.ClientServerExecutionStrategy)
89+
assert trainer.strategy.server_port == 8081
90+
91+
92+
def test_trainer_port_ignored_for_non_client_server_strategy() -> None:
93+
"""Test that port has no effect when using a non client-server strategy."""
94+
trainer = agl.Trainer(
95+
algorithm=agl.Baseline(),
96+
n_runners=1,
97+
port=8082,
98+
strategy="shm",
99+
)
100+
101+
assert isinstance(trainer.strategy, agl.SharedMemoryExecutionStrategy)
102+
assert not hasattr(trainer.strategy, "server_port")
103+
104+
105+
def test_trainer_port_overrides_existing_client_server_strategy() -> None:
106+
"""Test that provided port overrides an initialized client-server strategy."""
107+
strategy = agl.ClientServerExecutionStrategy(server_port=9000)
108+
109+
trainer = agl.Trainer(
110+
algorithm=agl.Baseline(),
111+
n_runners=1,
112+
strategy=strategy,
113+
port=9100,
114+
)
115+
116+
assert trainer.strategy is strategy
117+
assert trainer.strategy.server_port == 9100 # type: ignore
118+
119+
80120
def test_trainer_with_env_vars_for_execution_strategy(monkeypatch: pytest.MonkeyPatch) -> None:
81121
"""Test that execution strategy supports environment variables to override values."""
82122
algorithm = agl.Baseline()

0 commit comments

Comments
 (0)