44from dataclasses import MISSING as dataclass_missing
55from dataclasses import asdict , dataclass , field , fields
66from pathlib import Path
7- from typing import Any , TypeVar
7+ from typing import TYPE_CHECKING , Any , TypeVar
88
99import uvloop
1010import yaml
1111from hydra import compose as hydra_compose
1212from hydra import initialize as hydra_init
1313from hydra .core .global_hydra import GlobalHydra
1414from omegaconf import MISSING , DictConfig , OmegaConf
15- from transformers import PreTrainedTokenizerFast
1615
17- from areal .platforms import current_platform
1816from areal .utils import logging , name_resolve , pkg_version
1917from areal .utils .constants import (
2018 PROX_LOGP_METHOD_RECOMPUTE ,
2119 PROX_LOGP_METHODS_ALL ,
2220)
2321from areal .utils .pkg_version import is_version_less
2422
23+ if TYPE_CHECKING :
24+ from transformers import PreTrainedTokenizerFast
25+
2526uvloop .install ()
2627
2728logger = logging .getLogger ("CLI args" )
@@ -174,7 +175,7 @@ def new(self, **kwargs):
174175 args .update (kwargs )
175176 return GenerationHyperparameters (** args )
176177
177- def new_with_stop_and_pad_token_ids (self , tokenizer : PreTrainedTokenizerFast ):
178+ def new_with_stop_and_pad_token_ids (self , tokenizer : " PreTrainedTokenizerFast" ):
178179 """Create a new generation hyperparameters with stop and pad token ids added."""
179180 new_stop_token_ids = self .stop_token_ids .copy ()
180181 if tokenizer .pad_token_id not in new_stop_token_ids :
@@ -183,8 +184,29 @@ def new_with_stop_and_pad_token_ids(self, tokenizer: PreTrainedTokenizerFast):
183184 new_stop_token_ids .append (tokenizer .eos_token_id )
184185 return self .new (stop_token_ids = new_stop_token_ids )
185186
186- def to_openai_args_dict (
187+ def to_openai_completions_args_dict (
188+ self , exclude_args : list [str ] | None = None
189+ ) -> dict [str , Any ]:
190+ return self .to_openai_args_dict (
191+ exclude_args = exclude_args , api_format = "completions"
192+ )
193+
194+ def to_openai_responses_args_dict (
187195 self , exclude_args : list [str ] | None = None
196+ ) -> dict [str , Any ]:
197+ return self .to_openai_args_dict (
198+ exclude_args = exclude_args , api_format = "responses"
199+ )
200+
201+ def to_openai_agents_model_settings_dict (
202+ self , exclude_args : list [str ] | None = None
203+ ) -> dict [str , Any ]:
204+ return self .to_openai_args_dict (
205+ exclude_args = exclude_args , api_format = "openai-agents"
206+ )
207+
208+ def to_openai_args_dict (
209+ self , exclude_args : list [str ] | None = None , api_format : str = "completions"
188210 ) -> dict [str , Any ]:
189211 """Convert the generation hyperparameters to a dictionary of arguments for OpenAI client."""
190212 final_exclude_args = set (exclude_args ) if exclude_args is not None else set ()
@@ -195,14 +217,22 @@ def to_openai_args_dict(
195217 "top_k" , # Not supported by OpenAI
196218 "stop_token_ids" , # Not supported by OpenAI
197219 "lora_name" , # Not supported by OpenAI
220+ "max_tokens" , # deprecated by "completions", not used in "responses", should be `max_new_tokens` in "openai-agents"
198221 }
199222 )
200223 # TODO: move the excluded args into extra body, so they can be passed through the client request
201224
202- mapping = {
203- "n_samples" : "n" ,
204- "max_new_tokens" : "max_completion_tokens" ,
205- }
225+ mapping = {"n_samples" : "n" }
226+ if api_format == "completions" :
227+ mapping ["max_new_tokens" ] = "max_completion_tokens"
228+ elif api_format == "responses" :
229+ mapping ["max_new_tokens" ] = "max_output_tokens"
230+ elif api_format == "openai-agents" :
231+ # NOTE: max_tokens in openai-agents means `max_new_tokens` in sglang/vllm. This is not a bug
232+ mapping ["max_new_tokens" ] = "max_tokens"
233+ else :
234+ raise ValueError (f"Unsupported API format: { api_format } " )
235+
206236 res = {}
207237 for k , v in asdict (self ).items ():
208238 if k in final_exclude_args :
@@ -224,7 +254,10 @@ def to_openai_args_dict(
224254 f"Unsupported arg for openai format: `{ k } ` with value { current_value } "
225255 )
226256 continue
227- res [mapping .get (k , k )] = v
257+ key = mapping .get (k , k )
258+ if key in res :
259+ logger .warning (f"Overriding key: { key } from { k } with value: { v } " )
260+ res [key ] = v
228261
229262 return res
230263
@@ -624,6 +657,20 @@ class PPOActorConfig(TrainEngineConfig):
624657 metadata = {"help" : "KL divergence estimator" , "choices" : ["k1" , "k2" , "k3" ]},
625658 )
626659
660+ # SAPO (Soft Adaptive Policy Optimization) - https://arxiv.org/abs/2511.20347
661+ use_sapo_loss : bool = field (
662+ default = False ,
663+ metadata = {"help" : "Use SAPO loss (mutually exclusive with PPO clipping)" },
664+ )
665+ sapo_tau_pos : float = field (
666+ default = 1.0 ,
667+ metadata = {"help" : "SAPO temperature for positive advantages" },
668+ )
669+ sapo_tau_neg : float = field (
670+ default = 1.05 ,
671+ metadata = {"help" : "SAPO temperature for negative advantages" },
672+ )
673+
627674 # Asynchronous RL
628675 recompute_logprob : bool = field (
629676 default = False ,
@@ -956,6 +1003,8 @@ def build_args(
9561003 args ["lora_target_modules" ] = [
9571004 x .replace ("-linear" , "" ) for x in args ["lora_target_modules" ]
9581005 ]
1006+ from areal .platforms import current_platform
1007+
9591008 args = dict (
9601009 # Model and tokenizer
9611010 tokenizer_path = sglang_config .model_path ,
0 commit comments