generated from fastai/nbdev_template
-
Notifications
You must be signed in to change notification settings - Fork 2.3k
[OpenENV] Openenv rollout_func signature proposal #4344
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
kashif
wants to merge
14
commits into
huggingface:main
Choose a base branch
from
kashif:openenv-trainers
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+1,257
−157
Open
Changes from 3 commits
Commits
Show all changes
14 commits
Select commit
Hold shift + click to select a range
7c9c22a
initial online_dpo
kashif c995f66
refactor rollout_func to take trainer as argument
kashif 39bca3a
fix doc strings
kashif ecbf032
undo
kashif 506b9d8
added generate_async
kashif 15b56cf
call rollout in colocate mode
kashif f60d1d1
call openspiel sequentially
kashif de22709
generate logprobs
kashif 022505a
Add async wrapper for vLLM colocate mode
kashif 1397375
Merge branch 'main' into openenv-trainers
kashif aedd0d5
init
kashif 382c70f
wordle with collocate vllm support
kashif 79eed82
Merge branch 'main' into openenv-trainers
kashif 25b3342
fix typo
kashif File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,241 @@ | ||
| # Copyright 2020-2025 The HuggingFace Team. All rights reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| # ruff: noqa: T201 | ||
| import os | ||
| import subprocess | ||
| import sys | ||
| import time | ||
| from pathlib import Path | ||
|
|
||
| import requests | ||
| import torch | ||
| from datasets import load_dataset | ||
| from envs.echo_env import EchoEnv | ||
| from envs.echo_env.models import EchoAction | ||
| from transformers import AutoTokenizer | ||
|
|
||
| from trl import OnlineDPOConfig, OnlineDPOTrainer, RichProgressCallback | ||
| from trl.models import unwrap_model_for_generation | ||
|
|
||
|
|
||
| """ | ||
| Online DPO training with OpenEnv's Echo environment using the TRAINER'S MODEL for generation. | ||
| This example shows how to use a custom rollout function that: | ||
| 1. Generates completions using the trainer's model (no vLLM server needed!) | ||
| 2. Computes environment rewards from OpenEnv | ||
| 3. Returns both for training | ||
|
|
||
| Setup: | ||
|
|
||
| ```sh | ||
| pip install git+https://github.com/meta-pytorch/OpenEnv.git | ||
| ``` | ||
|
|
||
| Usage (single GPU - everything on one device!): | ||
|
|
||
| ```sh | ||
| python examples/scripts/openenv/echo_online_dpo_with_model.py | ||
| ``` | ||
| """ | ||
|
|
||
| ENV_URL = "http://127.0.0.1:8001" | ||
|
|
||
| print("⚡ Starting FastAPI server for Echo Environment...") | ||
| # Workaround if you can't run the env with Docker | ||
| work_dir = str(Path.cwd().parent.absolute()) | ||
| server_process = subprocess.Popen( | ||
| [sys.executable, "-m", "uvicorn", "envs.echo_env.server.app:app", "--host", "127.0.0.1", "--port", "8001"], | ||
| env={**os.environ, "PYTHONPATH": f"{work_dir}/src"}, | ||
| stdout=subprocess.PIPE, | ||
| stderr=subprocess.PIPE, | ||
| text=True, | ||
| cwd=work_dir, | ||
| ) | ||
|
|
||
| print("⏳ Waiting for server to start...") | ||
| time.sleep(5) | ||
|
|
||
| try: | ||
| response = requests.get(f"{ENV_URL}/health", timeout=2) | ||
| print("\n✅ Echo Environment server is running!") | ||
| except Exception as e: | ||
| print(f"\n❌ Server failed to start: {e}") | ||
| print("\n📋 Checking error output...") | ||
| server_process.poll() | ||
| if server_process.stderr: | ||
| stderr = server_process.stderr.read() | ||
| if stderr: | ||
| print(stderr) | ||
| raise | ||
|
|
||
|
|
||
| # Create HTTP client for Echo Environment | ||
| client = EchoEnv(base_url=f"{ENV_URL}") | ||
|
|
||
|
|
||
| def rollout_func_with_model(prompts: list[str], trainer: OnlineDPOTrainer) -> dict: | ||
| """ | ||
| Custom rollout function that generates completions using the trainer's model and computes environment rewards. | ||
|
|
||
| This function demonstrates the NEW signature that accepts a 'trainer' parameter, allowing direct access | ||
| to the model for generation without needing vLLM. | ||
|
|
||
| Args: | ||
| prompts: List of prompts to generate from | ||
| trainer: The OnlineDPOTrainer instance (provides access to model, accelerator, etc.) | ||
|
|
||
| Returns: | ||
| Dict containing prompt_ids, completion_ids, and env_reward | ||
| """ | ||
| if trainer is None: | ||
| raise ValueError( | ||
| "This rollout function requires the trainer parameter. " | ||
| "Make sure you're using a version of OnlineDPOTrainer that supports this feature." | ||
| ) | ||
|
|
||
| print(f"🎲 Generating completions for {len(prompts)} prompts using trainer's model...") | ||
|
|
||
| device = trainer.accelerator.device | ||
|
|
||
| # 1. Tokenize prompts | ||
| processing_class = trainer.processing_class | ||
| args = trainer.args | ||
| prompt_inputs = processing_class( | ||
| text=prompts, | ||
| return_tensors="pt", | ||
| padding=True, | ||
| padding_side="left", | ||
| max_length=args.max_length, | ||
| truncation=True, | ||
| add_special_tokens=False, | ||
| ) | ||
|
|
||
| # Move to device | ||
| prompt_inputs = {k: v.to(device) for k, v in prompt_inputs.items()} | ||
|
|
||
| # 2. Generate 2 completions per prompt using the trainer's model | ||
| all_prompt_ids = [] | ||
| all_completion_ids = [] | ||
| all_completions_text = [] | ||
|
|
||
| # Unwrap model for generation (handles FSDP, DeepSpeed, etc.) | ||
| with unwrap_model_for_generation( | ||
| trainer.model, trainer.accelerator, gather_deepspeed3_params=args.ds3_gather_for_generation | ||
| ) as unwrapped_model: | ||
| unwrapped_model.eval() | ||
| with torch.no_grad(): | ||
| for gen_idx in range(2): # OnlineDPO requires exactly 2 completions per prompt | ||
| print(f" Generation {gen_idx + 1}/2...") | ||
|
|
||
| # Generate | ||
| outputs = unwrapped_model.generate( | ||
| **prompt_inputs, | ||
| max_new_tokens=args.max_new_tokens, | ||
| temperature=args.temperature if args.temperature > 0 else 1.0, | ||
| top_p=args.top_p, | ||
| top_k=args.top_k if args.top_k is not None else 50, | ||
| do_sample=True if args.temperature > 0 else False, | ||
| pad_token_id=processing_class.pad_token_id, | ||
| eos_token_id=processing_class.eos_token_id, | ||
| ) | ||
|
|
||
| # Extract completions (remove prompt part) | ||
| prompt_length = prompt_inputs["input_ids"].shape[1] | ||
| completion_ids = outputs[:, prompt_length:] | ||
|
|
||
| # Decode completions | ||
| completions_text = processing_class.batch_decode(completion_ids, skip_special_tokens=True) | ||
|
|
||
| # Store results | ||
| for i in range(len(prompts)): | ||
| all_prompt_ids.append(prompt_inputs["input_ids"][i].tolist()) | ||
| all_completion_ids.append(completion_ids[i].tolist()) | ||
| all_completions_text.append(completions_text[i]) | ||
|
|
||
| unwrapped_model.train() | ||
|
|
||
| print(f" ✓ Generated {len(all_completions_text)} completions") | ||
|
|
||
| # 3. Step through the environment to get rewards for each completion | ||
| print("🌍 Computing environment rewards...") | ||
| env_result = client.reset() | ||
| env_rewards = [] | ||
| for msg in all_completions_text: | ||
| env_result = client.step(EchoAction(message=msg)) | ||
| env_rewards.append(env_result.reward) | ||
|
|
||
| print(f" ✓ Computed {len(env_rewards)} rewards") | ||
|
|
||
| # 4. Return results in the expected format | ||
| return { | ||
| "prompt_ids": all_prompt_ids, | ||
| "completion_ids": all_completion_ids, | ||
| "env_reward": env_rewards, # Extra field passed to reward function | ||
| } | ||
|
|
||
|
|
||
| def reward_from_env(completions, **kwargs): | ||
| """Reward function that uses the environment reward from kwargs.""" | ||
| env_rewards = kwargs.get("env_reward", []) | ||
| if env_rewards: | ||
| return [float(reward) for reward in env_rewards] | ||
| else: | ||
| # Fallback if env_reward is not available | ||
| return [0.0] * len(completions) | ||
|
|
||
|
|
||
| # Load dataset and tokenizer | ||
| dataset = load_dataset("trl-lib/ultrafeedback-prompt", split="train[:100]") # Small dataset for testing | ||
| tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct") | ||
|
|
||
| # Training configuration | ||
| training_args = OnlineDPOConfig( | ||
| output_dir="Qwen2.5-0.5B-OnlineDPO-Echo-ModelGen", | ||
| use_vllm=False, # ← No vLLM! Use trainer's model instead | ||
| logging_steps=1, | ||
| report_to="none", | ||
| num_train_epochs=1, | ||
| max_new_tokens=64, # Shorter for faster generation | ||
| max_length=512, # Max total sequence length | ||
| temperature=0.7, | ||
| gradient_accumulation_steps=2, | ||
| per_device_train_batch_size=1, | ||
| learning_rate=1e-5, | ||
| bf16=True, | ||
| ) | ||
|
|
||
| print("\n🏋️ Creating trainer...") | ||
| trainer = OnlineDPOTrainer( | ||
| model="Qwen/Qwen2.5-0.5B-Instruct", | ||
| processing_class=tokenizer, | ||
| reward_funcs=reward_from_env, | ||
| args=training_args, | ||
| train_dataset=dataset, | ||
| rollout_func=rollout_func_with_model, # ← Custom rollout with model access! | ||
| callbacks=[RichProgressCallback()], | ||
| ) | ||
|
|
||
| print("\n🚀 Starting training...") | ||
| print("=" * 80) | ||
| trainer.train() | ||
| print("=" * 80) | ||
|
|
||
| # Give time for background threads to finish | ||
| time.sleep(5) | ||
|
|
||
| print("\n🛑 Terminating Echo Environment server...") | ||
| server_process.terminate() | ||
|
|
||
| print("\n✅ Training complete!") |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When using the current
rollout_func, I found the definition of the vllm payload and how you need to userequestskinda messy/ low level. Mainly because I have to unpack and redefine the config (now trainer).Here's an idea: could we define a
generatefunction within the trainer that uses closures from the trainer scope to set configuration defaults. The user could then usegeneratefunction as a param of rollout_func without having to deal with the underlying vllm logic. They can then override any config params of the function.