SGLang model provider for Strands Agents SDK with Token-in/Token-out rollouts for on-policy agentic RL training (no retokenization drift).
This package is designed to make the serving-oriented agent scaffold Strands Agents SDK training-ready by exposing end-to-end, token-level rollouts from SGLang while reusing Strands’ customizable agent loop.
- Token-In/Token-Out rollouts (token IDs + logprobs/masks): no retokenization drift
- Strict, on-policy tool-call parsing: no heuristic repair or post-processing; tool calls are parsed exactly as generated by models
- Native SGLang
/generate: high-throughput, non-streaming rollouts
For RL environment integration, please refer to
strands-env
- Python 3.10+
- Strands Agents SDK
- SGLang server running with your model
- HuggingFace tokenizer for the model
pip install strands-sglang strands-agents-toolsOr install from source with development dependencies:
git clone https://github.com/horizon-rl/strands-sglang.git
cd strands-sglang
pip install -e ".[dev]"python -m sglang.launch_server \
--model-path Qwen/Qwen3.5-4B \
--port 30000 \
--host 0.0.0.0import asyncio
from transformers import AutoTokenizer
from strands import Agent
from strands_tools import calculator
from strands_sglang import SGLangClient, SGLangModel
from strands_sglang.tool_parsers import get_tool_parser
async def main():
client = SGLangClient(base_url="http://localhost:30000")
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3.5-4B")
model = SGLangModel(client=client, tokenizer=tokenizer, tool_parser=get_tool_parser("qwen_xml"))
agent = Agent(model=model, tools=[calculator])
result = await agent.invoke_async("What is 25 * 17?")
print(result)
# Access token data for RL training
print(f"Tokens: {model.token_manager.token_ids}")
print(f"Loss mask: {model.token_manager.loss_mask}")
print(f"Logprobs: {model.token_manager.logprobs}")
asyncio.run(main())For RL training with slime, SGLangModel eliminates the retokenization step, see an concrete example at slime/examples/strands_sglang:
import logging
from strands import Agent, tool
from strands_sglang import SGLangModel, ToolLimiter, decode_routed_experts, get_client_from_slime_args
from strands_sglang.tool_parsers import HermesToolParser
from slime.rollout.sglang_rollout import GenerateState
from slime.utils.types import Sample
SYSTEM_PROMPT = "..."
MAX_TOOL_ITERS = 5
MAX_TOOL_CALLS = None # No limit
@tool
def execute_python_code(code: str):
"""Execute Python code and return the output."""
...
async def generate(args, sample: Sample, sampling_params) -> Sample:
"""Generate with tokens captured during generation, no retokenization."""
state = GenerateState(args)
model = SGLangModel(
tokenizer=state.tokenizer,
client=get_client_from_slime_args(args), # this is lru-cached client
tool_parser=HermesToolParser(), # tool parsing for wrapped JSON tool calls
sampling_params=sampling_params,
return_routed_experts=True, # enable R3
)
tool_limiter = ToolLimiter(max_tool_iters=MAX_TOOL_ITERS, max_tool_calls=MAX_TOOL_CALLS)
agent = Agent(
model=model,
tools=[execute_python_code],
hooks=[tool_limiter],
callback_handler=None,
system_prompt=SYSTEM_PROMPT,
)
# Don't set --apply-chat-template in rollout args, it will make user prompt wrapped twice
prompt = sample.prompt if isinstance(sample.prompt, str) else sample.prompt[0]["content"]
try:
await agent.invoke_async(prompt)
sample.status = Sample.Status.COMPLETED
except Exception as e:
# Default all failed rollouts to TRUNCATED; customize your logic here if needed
sample.status = Sample.Status.TRUNCATED
logger.warning(f"TRUNCATED: {type(e).__name__}: {e}")
# Extract token trajectory from token_manager
tm = model.token_manager
prompt_len = len(tm.segments[0]) # system + user are first segment
sample.tokens = tm.token_ids
sample.loss_mask = tm.loss_mask[prompt_len:]
sample.rollout_log_probs = tm.logprobs[prompt_len:]
sample.response_length = len(sample.tokens) - prompt_len
sample.response = model.tokenizer.decode(sample.tokens[prompt_len:], skip_special_tokens=False)
# Record tool call stats for reward computation if needed
# Multiple parallel tool calls count as one tool_iter
sample.tool_iters = tool_limiter.tool_iter_count
sample.tool_calls = tool_limiter.tool_call_count
# Decode MoE routed experts for router replay (R3) — shape: (seq_len - 1, num_layers, top_k)
if model.routed_experts is not None:
# Recommend to wrap into asyncio.to_thread
sample.routed_experts = decode_routed_experts(
model.routed_experts,
seq_len=len(tm.token_ids),
num_layers=args.num_layers,
top_k=args.moe_router_topk,
)
model.reset()
agent.cleanup()
return sample# Unit tests
pytest tests/unit/ -v
# Integration tests (requires SGLang server)
pytest tests/integration/ -v --sglang-base-url=http://localhost:30000Contributions welcome! Install pre-commit hooks for code style and commit message validation:
pip install -e ".[dev]"
pre-commit install -t pre-commit -t commit-msgThis project uses Conventional Commits. Commit messages must follow the format:
<type>(<scope>): <description>
# Examples:
feat(client): add retry backoff configuration
fix(sglang): handle empty response from server
docs: update usage examples
Allowed types: feat, fix, docs, style, refactor, perf, test, build, ci, chore, revert
- agent-core-rl-toolkit - RL training toolkit with Bedrock AgentCore
- strands-vllm - Community vLLM provider for Strands Agents SDK
Apache License 2.0 - see LICENSE.