Skip to content

Commit 77a1aba

Browse files
author
chenzhenyang
committed
Merge remote-tracking branch 'origin/main' into refactor-trainengine-api
2 parents 4a77121 + 3880f4a commit 77a1aba

87 files changed

Lines changed: 3381 additions & 1474 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,10 @@ our project just as you enjoy real-world milk tea (cheers).
2424
[multi-turn agentic rollout](https://inclusionai.github.io/AReaL/customization/agent.html)
2525
workflows within a single file, and smooth integration with
2626
[other agentic tooling frameworks](https://inclusionai.github.io/AReaL/tutorial/agentic_rl.html).
27-
- 🚀 **Scalability**: Through algorithm-system co-design, AReaL delivers **stable** fully
27+
- 📈 **Scalability**: Through algorithm-system co-design, AReaL delivers **stable** fully
2828
asynchronous RL training with **industry-leading speed**. AReaL seamlessly adapts to
2929
diverse computational environments, scaling from a single node to 1,000+ GPUs.
30-
- 🔪 **Cutting-Edge Performance**: AReaL produces state-of-the-art
30+
- **Cutting-Edge Performance**: AReaL produces state-of-the-art
3131
[math](/blog/AReaL_v0_2.md), [coding](/blog/AReaL_v0_3.md), and
3232
[search agents](https://github.com/inclusionAI/ASearcher) with exceptional
3333
capabilities.

areal/api/cli_args.py

Lines changed: 59 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,24 +4,25 @@
44
from dataclasses import MISSING as dataclass_missing
55
from dataclasses import asdict, dataclass, field, fields
66
from pathlib import Path
7-
from typing import Any, TypeVar
7+
from typing import TYPE_CHECKING, Any, TypeVar
88

99
import uvloop
1010
import yaml
1111
from hydra import compose as hydra_compose
1212
from hydra import initialize as hydra_init
1313
from hydra.core.global_hydra import GlobalHydra
1414
from omegaconf import MISSING, DictConfig, OmegaConf
15-
from transformers import PreTrainedTokenizerFast
1615

17-
from areal.platforms import current_platform
1816
from areal.utils import logging, name_resolve, pkg_version
1917
from areal.utils.constants import (
2018
PROX_LOGP_METHOD_RECOMPUTE,
2119
PROX_LOGP_METHODS_ALL,
2220
)
2321
from areal.utils.pkg_version import is_version_less
2422

23+
if TYPE_CHECKING:
24+
from transformers import PreTrainedTokenizerFast
25+
2526
uvloop.install()
2627

2728
logger = logging.getLogger("CLI args")
@@ -174,7 +175,7 @@ def new(self, **kwargs):
174175
args.update(kwargs)
175176
return GenerationHyperparameters(**args)
176177

177-
def new_with_stop_and_pad_token_ids(self, tokenizer: PreTrainedTokenizerFast):
178+
def new_with_stop_and_pad_token_ids(self, tokenizer: "PreTrainedTokenizerFast"):
178179
"""Create a new generation hyperparameters with stop and pad token ids added."""
179180
new_stop_token_ids = self.stop_token_ids.copy()
180181
if tokenizer.pad_token_id not in new_stop_token_ids:
@@ -183,8 +184,29 @@ def new_with_stop_and_pad_token_ids(self, tokenizer: PreTrainedTokenizerFast):
183184
new_stop_token_ids.append(tokenizer.eos_token_id)
184185
return self.new(stop_token_ids=new_stop_token_ids)
185186

186-
def to_openai_args_dict(
187+
def to_openai_completions_args_dict(
188+
self, exclude_args: list[str] | None = None
189+
) -> dict[str, Any]:
190+
return self.to_openai_args_dict(
191+
exclude_args=exclude_args, api_format="completions"
192+
)
193+
194+
def to_openai_responses_args_dict(
187195
self, exclude_args: list[str] | None = None
196+
) -> dict[str, Any]:
197+
return self.to_openai_args_dict(
198+
exclude_args=exclude_args, api_format="responses"
199+
)
200+
201+
def to_openai_agents_model_settings_dict(
202+
self, exclude_args: list[str] | None = None
203+
) -> dict[str, Any]:
204+
return self.to_openai_args_dict(
205+
exclude_args=exclude_args, api_format="openai-agents"
206+
)
207+
208+
def to_openai_args_dict(
209+
self, exclude_args: list[str] | None = None, api_format: str = "completions"
188210
) -> dict[str, Any]:
189211
"""Convert the generation hyperparameters to a dictionary of arguments for OpenAI client."""
190212
final_exclude_args = set(exclude_args) if exclude_args is not None else set()
@@ -195,14 +217,22 @@ def to_openai_args_dict(
195217
"top_k", # Not supported by OpenAI
196218
"stop_token_ids", # Not supported by OpenAI
197219
"lora_name", # Not supported by OpenAI
220+
"max_tokens", # deprecated by "completions", not used in "responses", should be `max_new_tokens` in "openai-agents"
198221
}
199222
)
200223
# TODO: move the excluded args into extra body, so they can be passed through the client request
201224

202-
mapping = {
203-
"n_samples": "n",
204-
"max_new_tokens": "max_completion_tokens",
205-
}
225+
mapping = {"n_samples": "n"}
226+
if api_format == "completions":
227+
mapping["max_new_tokens"] = "max_completion_tokens"
228+
elif api_format == "responses":
229+
mapping["max_new_tokens"] = "max_output_tokens"
230+
elif api_format == "openai-agents":
231+
# NOTE: max_tokens in openai-agents means `max_new_tokens` in sglang/vllm. This is not a bug
232+
mapping["max_new_tokens"] = "max_tokens"
233+
else:
234+
raise ValueError(f"Unsupported API format: {api_format}")
235+
206236
res = {}
207237
for k, v in asdict(self).items():
208238
if k in final_exclude_args:
@@ -224,7 +254,10 @@ def to_openai_args_dict(
224254
f"Unsupported arg for openai format: `{k}` with value {current_value}"
225255
)
226256
continue
227-
res[mapping.get(k, k)] = v
257+
key = mapping.get(k, k)
258+
if key in res:
259+
logger.warning(f"Overriding key: {key} from {k} with value: {v}")
260+
res[key] = v
228261

229262
return res
230263

@@ -624,6 +657,20 @@ class PPOActorConfig(TrainEngineConfig):
624657
metadata={"help": "KL divergence estimator", "choices": ["k1", "k2", "k3"]},
625658
)
626659

660+
# SAPO (Soft Adaptive Policy Optimization) - https://arxiv.org/abs/2511.20347
661+
use_sapo_loss: bool = field(
662+
default=False,
663+
metadata={"help": "Use SAPO loss (mutually exclusive with PPO clipping)"},
664+
)
665+
sapo_tau_pos: float = field(
666+
default=1.0,
667+
metadata={"help": "SAPO temperature for positive advantages"},
668+
)
669+
sapo_tau_neg: float = field(
670+
default=1.05,
671+
metadata={"help": "SAPO temperature for negative advantages"},
672+
)
673+
627674
# Asynchronous RL
628675
recompute_logprob: bool = field(
629676
default=False,
@@ -956,6 +1003,8 @@ def build_args(
9561003
args["lora_target_modules"] = [
9571004
x.replace("-linear", "") for x in args["lora_target_modules"]
9581005
]
1006+
from areal.platforms import current_platform
1007+
9591008
args = dict(
9601009
# Model and tokenizer
9611010
tokenizer_path=sglang_config.model_path,

areal/api/reward_api.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ def _cleanup_executor(cls, executor_key):
107107
if cls._instance_counts[executor_key] <= 0:
108108
if executor_key in cls._executors:
109109
executor = cls._executors.pop(executor_key)
110-
executor.shutdown(wait=True)
110+
executor.shutdown(wait=False, cancel_futures=True)
111111
logger.debug(
112112
f"ProcessPoolExecutor with {executor_key} workers shut down"
113113
)

areal/core/async_task_runner.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -234,18 +234,29 @@ def initialize(self, logger=None):
234234
self.thread.start()
235235
self._loop_ready.wait()
236236

237-
def destroy(self):
237+
def destroy(self, timeout: float = 30.0):
238238
"""Shutdown the task runner and wait for thread cleanup.
239239
240240
This method signals the background thread to exit and waits for
241241
it to complete. All pending tasks will be cancelled.
242+
243+
Parameters
244+
----------
245+
timeout : float, optional
246+
Maximum time in seconds to wait for thread to exit.
247+
Default is 30.0 seconds.
242248
"""
243249
self.exiting.set()
244250
self.paused.clear()
245251

246252
self._signal_new_input()
247253
if self.thread is not None:
248-
self.thread.join()
254+
self.thread.join(timeout=timeout)
255+
if self.thread.is_alive():
256+
if self.logger:
257+
self.logger.warning(
258+
f"Background thread did not exit within {timeout}s timeout."
259+
)
249260

250261
def register_shutdown_hook(self, hook: Callable[[], Awaitable[None]]) -> None:
251262
"""Register an async cleanup function to be called during shutdown.

areal/engine/ppo/actor.py

Lines changed: 58 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
dynamic_sampling,
2828
ppo_actor_loss_fn,
2929
reward_overlong_penalty,
30+
sapo_loss_fn,
3031
)
3132
from areal.utils.perf_tracer import trace_perf
3233

@@ -353,6 +354,10 @@ def ppo_update(self, data: dict[str, Any]) -> None:
353354
importance_sampling_level=self.config.importance_sampling_level,
354355
current_version=current_version,
355356
prox_logp_method=self.config.prox_logp_method,
357+
use_sapo_loss=self.config.use_sapo_loss,
358+
sapo_tau_pos=self.config.sapo_tau_pos,
359+
sapo_tau_neg=self.config.sapo_tau_neg,
360+
use_decoupled_loss=self.config.use_decoupled_loss,
356361
),
357362
loss_weight_fn=lambda x: x["loss_mask"].count_nonzero(),
358363
)
@@ -405,6 +410,10 @@ def grpo_loss_fn(
405410
importance_sampling_level: str = "token",
406411
current_version: int | None = None,
407412
prox_logp_method: str = PROX_LOGP_METHOD_RECOMPUTE,
413+
use_sapo_loss: bool = False,
414+
sapo_tau_pos: float = 1.0,
415+
sapo_tau_neg: float = 1.05,
416+
use_decoupled_loss: bool = False,
408417
vocab_min_logits: torch.Tensor | None = None,
409418
vocab_max_logits: torch.Tensor | None = None,
410419
):
@@ -431,19 +440,37 @@ def grpo_loss_fn(
431440
if m2_threshold is not None:
432441
loss_mask = _apply_m2po_masking(old_logp, prox_logp, loss_mask, m2_threshold)
433442

434-
loss, stat = ppo_actor_loss_fn(
435-
logprobs=logprobs,
436-
old_logprobs=old_logp,
437-
advantages=advantages,
438-
eps_clip=eps_clip,
439-
eps_clip_higher=eps_clip_higher,
440-
loss_mask=loss_mask,
441-
c_clip=c_clip,
442-
proximal_logprobs=prox_logp,
443-
behav_imp_weight_cap=behav_imp_weight_cap,
444-
importance_sampling_level=importance_sampling_level,
445-
cu_seqlens=input_data.get("cu_seqlens"),
446-
)
443+
# Use SAPO or PPO loss
444+
if use_sapo_loss:
445+
if use_decoupled_loss:
446+
raise ValueError(
447+
"SAPO is not compatible with `use_decoupled_loss=True`. "
448+
"Please set `actor.use_decoupled_loss=false` in your configuration."
449+
)
450+
loss, stat = sapo_loss_fn(
451+
logprobs=logprobs,
452+
old_logprobs=old_logp,
453+
advantages=advantages,
454+
tau_pos=sapo_tau_pos,
455+
tau_neg=sapo_tau_neg,
456+
loss_mask=loss_mask,
457+
importance_sampling_level=importance_sampling_level,
458+
cu_seqlens=input_data.get("cu_seqlens"),
459+
)
460+
else:
461+
loss, stat = ppo_actor_loss_fn(
462+
logprobs=logprobs,
463+
old_logprobs=old_logp,
464+
advantages=advantages,
465+
eps_clip=eps_clip,
466+
eps_clip_higher=eps_clip_higher,
467+
loss_mask=loss_mask,
468+
c_clip=c_clip,
469+
proximal_logprobs=prox_logp,
470+
behav_imp_weight_cap=behav_imp_weight_cap,
471+
importance_sampling_level=importance_sampling_level,
472+
cu_seqlens=input_data.get("cu_seqlens"),
473+
)
447474

448475
# Log training statistics
449476
stats_tracker.denominator(
@@ -483,14 +510,24 @@ def grpo_loss_fn(
483510
denominator="n_tokens",
484511
)
485512

486-
clip_mask = stat["clip_mask"]
487-
clipped_new_logp = torch.where(clip_mask, logprobs.detach(), 0.0)
488-
clipped_old_logp = torch.where(clip_mask, old_logp, 0.0)
489-
stats_tracker.stat(
490-
clipped_new_logp=clipped_new_logp,
491-
clipped_old_logp=clipped_old_logp,
492-
denominator="clipped_tokens",
493-
)
513+
# Log SAPO-specific statistics
514+
if use_sapo_loss:
515+
stats_tracker.stat(
516+
sapo_soft_gate=stat["sapo_soft_gate"],
517+
sapo_scaled_gate_pos=stat["sapo_scaled_gate_pos"],
518+
sapo_scaled_gate_neg=stat["sapo_scaled_gate_neg"],
519+
denominator="n_valid_tokens",
520+
)
521+
else:
522+
# Log clipping statistics (PPO only)
523+
clip_mask = stat["clip_mask"]
524+
clipped_new_logp = torch.where(clip_mask, logprobs.detach(), 0.0)
525+
clipped_old_logp = torch.where(clip_mask, old_logp, 0.0)
526+
stats_tracker.stat(
527+
clipped_new_logp=clipped_new_logp,
528+
clipped_old_logp=clipped_old_logp,
529+
denominator="clipped_tokens",
530+
)
494531

495532
# Log proximal approximation metrics
496533
compute_logp_mask = stat.get("behave_mask", loss_mask)

0 commit comments

Comments
 (0)