Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
203 changes: 101 additions & 102 deletions trl/experimental/openenv/echo.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import subprocess
import sys
import time
from contextlib import suppress
from pathlib import Path

import requests
Expand Down Expand Up @@ -66,113 +67,111 @@
server_process = subprocess.Popen(
[sys.executable, "-m", "uvicorn", "envs.echo_env.server.app:app", "--host", "0.0.0.0", "--port", "8001"],
env={**os.environ, "PYTHONPATH": f"{work_dir}/src"},
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
cwd=work_dir,
)

# Wait for server to start
print("⏳ Waiting for server to start...")
time.sleep(5)

# Check if server is running
try:
response = requests.get(f"{ENV_URL}/health", timeout=2)
print("\n✅ Echo Environment server is running!")
except Exception as e:
print(f"\n❌ Server failed to start: {e}")
print("\n📋 Checking error output...")
server_process.poll()
if server_process.stderr:
stderr = server_process.stderr.read()
if stderr:
print(stderr)
raise


# Create HTTP client for Echo Environment
client = EchoEnv(base_url=f"{ENV_URL}")


def rollout_func(prompts: list[str], images: list | None, args: GRPOConfig, processing_class) -> dict[str, list]:
"""
Custom rollout function that generates completions via vLLM server and computes environment rewards.

Args:
prompts: List of prompt strings to generate from
images: Optional images for vision models (not used in this example)
args: GRPOConfig containing all sampling parameters
processing_class: Tokenizer/processor for decoding completions

Returns:
Dict containing prompt_ids, completion_ids, logprobs, and env_reward
"""
# Make request to TRL's custom /generate/ endpoint
payload = {
"prompts": prompts,
"n": args.num_generations,
"temperature": args.temperature,
"top_p": args.top_p,
"top_k": -1 if args.top_k is None else args.top_k,
"min_p": 0.0 if args.min_p is None else args.min_p,
"max_tokens": args.max_completion_length,
"repetition_penalty": args.repetition_penalty,
}
response = requests.post(GEN_URL, json=payload)

if response.status_code != 200:
print(f"Error response: {response.text}")

response.raise_for_status()
result = response.json()

completions_text = processing_class.batch_decode(result["completion_ids"], skip_special_tokens=True)

# Flush env
env_result = client.reset()

env_rewards = []
for msg in completions_text:
env_result = client.step(EchoAction(message=msg))
env_rewards.append(env_result.reward)

result["env_reward"] = env_rewards

return result


dataset = load_dataset("trl-lib/ultrafeedback-prompt", split="train[:1000]")


def reward_from_env(completions, **kwargs):
"""Reward function that uses the environment reward."""
# Extract environment rewards from kwargs (propagated via extra_fields)
env_rewards = kwargs.get("env_reward", [])
if env_rewards:
return [float(reward) for reward in env_rewards]
else:
# Fallback if env_reward is not available
return [0.0] * len(completions)


training_args = GRPOConfig(
output_dir="scratch/Qwen2.5-0.5B-GRPO-Rollout",
vllm_mode="server",
use_vllm=True,
logging_steps=1,
report_to=["trackio", "wandb"],
num_train_epochs=1,
num_generations=8,
max_completion_length=2048,
per_device_train_batch_size=8,
gradient_accumulation_steps=4,
)
trainer = GRPOTrainer(
model="Qwen/Qwen2.5-0.5B-Instruct",
reward_funcs=reward_from_env,
args=training_args,
train_dataset=dataset,
rollout_func=rollout_func,
)
trainer.train()
# Check if server is running
try:
response = requests.get(f"{ENV_URL}/health", timeout=2)
print("\n✅ Echo Environment server is running!")
except Exception as e:
print(f"\n❌ Server failed to start: {e}")
raise

# Create HTTP client for Echo Environment
client = EchoEnv(base_url=f"{ENV_URL}")

def rollout_func(prompts: list[str], images: list | None, args: GRPOConfig, processing_class) -> dict[str, list]:
"""
Custom rollout function that generates completions via vLLM server and computes environment rewards.

Args:
prompts: List of prompt strings to generate from
images: Optional images for vision models (not used in this example)
args: GRPOConfig containing all sampling parameters
processing_class: Tokenizer/processor for decoding completions

Returns:
Dict containing prompt_ids, completion_ids, logprobs, and env_reward
"""
# Make request to TRL's custom /generate/ endpoint
payload = {
"prompts": prompts,
"n": args.num_generations,
"temperature": args.temperature,
"top_p": args.top_p,
"top_k": -1 if args.top_k is None else args.top_k,
"min_p": 0.0 if args.min_p is None else args.min_p,
"max_tokens": args.max_completion_length,
"repetition_penalty": args.repetition_penalty,
}
response = requests.post(GEN_URL, json=payload)

if response.status_code != 200:
print(f"Error response: {response.text}")

response.raise_for_status()
result = response.json()

completions_text = processing_class.batch_decode(result["completion_ids"], skip_special_tokens=True)

# Flush env
env_result = client.reset()

env_rewards = []
for msg in completions_text:
env_result = client.step(EchoAction(message=msg))
env_rewards.append(env_result.reward)

result["env_reward"] = env_rewards

return result

dataset = load_dataset("trl-lib/ultrafeedback-prompt", split="train[:1000]")

def reward_from_env(completions, **kwargs):
"""Reward function that uses the environment reward."""
# Extract environment rewards from kwargs (propagated via extra_fields)
env_rewards = kwargs.get("env_reward", [])
if env_rewards:
return [float(reward) for reward in env_rewards]
else:
# Fallback if env_reward is not available
return [0.0] * len(completions)

training_args = GRPOConfig(
output_dir="scratch/Qwen2.5-0.5B-GRPO-Rollout",
vllm_mode="server",
use_vllm=True,
logging_steps=1,
report_to=["trackio", "wandb"],
num_train_epochs=1,
num_generations=8,
max_completion_length=2048,
per_device_train_batch_size=8,
gradient_accumulation_steps=4,
)
trainer = GRPOTrainer(
model="Qwen/Qwen2.5-0.5B-Instruct",
reward_funcs=reward_from_env,
args=training_args,
train_dataset=dataset,
rollout_func=rollout_func,
)
trainer.train()
finally:
print("\n🛑 Stopping Echo Environment server...")
if server_process.poll() is None:
server_process.terminate()
try:
server_process.wait(timeout=5)
except subprocess.TimeoutExpired:
print("⚠️ Termination timeout reached. Forcing shutdown.")
server_process.kill()
with suppress(subprocess.TimeoutExpired):
server_process.wait(timeout=5)
Loading