diff --git a/.gitignore b/.gitignore index 5fa6202ffff..0bb5d34f4f2 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ + **/*.pt **/checkpoints **/wget-log diff --git a/docker/Dockerfile.rocm b/docker/Dockerfile.rocm new file mode 100644 index 00000000000..6a0d20e41cc --- /dev/null +++ b/docker/Dockerfile.rocm @@ -0,0 +1,45 @@ +# Build the docker in the repo dir: +# docker build -f docker/Dockerfile.rocm -t verl-rocm:03.04.2015 . +# docker images # you can find your built docker + + +FROM rocm/vllm:rocm6.2_mi300_ubuntu20.04_py3.9_vllm_0.6.4 + +# Set working directory +# WORKDIR $PWD/app + +# Set environment variables +ENV PYTORCH_ROCM_ARCH="gfx90a;gfx942" + +# Install vllm +RUN pip uninstall -y vllm && \ + rm -rf vllm && \ + git clone -b v0.6.3 https://github.com/vllm-project/vllm.git && \ + cd vllm && \ + MAX_JOBS=$(nproc) python3 setup.py install && \ + cd .. && \ + rm -rf vllm + +# Copy the entire project directory +COPY . . + +# Install dependencies +RUN pip install "tensordict<0.6" --no-deps && \ + pip install accelerate \ + codetiming \ + datasets \ + dill \ + hydra-core \ + liger-kernel \ + numpy \ + pandas \ + peft \ + "pyarrow>=15.0.0" \ + pylatexenc \ + "ray[data,train,tune,serve]" \ + torchdata \ + transformers \ + wandb \ + orjson \ + pybind11 && \ + pip install -e . --no-deps \ No newline at end of file diff --git a/docs/amd_tutorial/amd_build_dockerfile.md b/docs/amd_tutorial/amd_build_dockerfile.md new file mode 100644 index 00000000000..b067e843a70 --- /dev/null +++ b/docs/amd_tutorial/amd_build_dockerfile.md @@ -0,0 +1,170 @@ +# Setup + +## Dockerfile.rocm +```bash +# Build the docker in the repo dir: +# docker build -f docker/Dockerfile.rocm -t verl-rocm:03.04.2015 . +# docker images # you can find your built docker +# +FROM rocm/vllm:rocm6.2_mi300_ubuntu20.04_py3.9_vllm_0.6.4 + +# Set working directory +# WORKDIR $PWD/app + +# Set environment variables +ENV PYTORCH_ROCM_ARCH="gfx90a;gfx942" + +# Install vllm +RUN pip uninstall -y vllm && \ + rm -rf vllm && \ + git clone -b v0.6.3 https://github.com/vllm-project/vllm.git && \ + cd vllm && \ + MAX_JOBS=$(nproc) python3 setup.py install && \ + cd .. && \ + rm -rf vllm + +# Copy the entire project directory +COPY . . + +# Install dependencies +RUN pip install "tensordict<0.6" --no-deps && \ + pip install accelerate \ + codetiming \ + datasets \ + dill \ + hydra-core \ + liger-kernel \ + numpy \ + pandas \ + peft \ + "pyarrow>=15.0.0" \ + pylatexenc \ + "ray[data,train,tune,serve]" \ + torchdata \ + transformers \ + wandb \ + orjson \ + pybind11 && \ + pip install -e . --no-deps +``` + + +## Build the image: +```bash +docker build -t verl-rocm . +``` + +## Run the container +```bash +docker run --rm -it \ + --device /dev/dri \ + --device /dev/kfd \ + -p 8265:8265 \ + --group-add video \ + --cap-add SYS_PTRACE \ + --security-opt seccomp=unconfined \ + --privileged \ + -v $HOME/.ssh:/root/.ssh \ + -v $HOME:$HOME \ + --shm-size 128G \ + -w $PWD \ + verl-rocm \ + /bin/bash +``` + + +# Example + +## PPO +```bash +YOUR_PROJECT_NAME=r1-verl-ppo-upstream +YOUR_RUN_NAME=r1-training_ppo-upstream +# export HYDRA_FULL_ERROR=1 +export HIP_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 +export ROCR_VISIBLE_DEVICES=$HIP_VISIBLE_DEVICES +GPUS_PER_NODE=8 +MODEL_PATH=Qwen/Qwen2.5-0.5B-Instruct +python3 examples/data_preprocess/gsm8k.py --local_dir data/gsm8k +python3 -c "import transformers; transformers.pipeline('text-generation', model='$MODEL_PATH')" +PYTHONUNBUFFERED=1 python3 -m verl.trainer.main_ppo \ + data.train_files=data/gsm8k/train.parquet \ + data.val_files=data/gsm8k/test.parquet \ + data.train_batch_size=256 \ + data.val_batch_size=1312 \ + data.max_prompt_length=512 \ + data.max_response_length=256 \ + actor_rollout_ref.model.path=$MODEL_PATH \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.actor.ppo_mini_batch_size=64 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=8 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=1 \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.8 \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4 \ + critic.optim.lr=1e-5 \ + critic.model.path=$MODEL_PATH \ + critic.ppo_micro_batch_size_per_gpu=4 \ + algorithm.kl_ctrl.kl_coef=0.001 \ + trainer.logger=['console'] \ + trainer.project_name=$YOUR_PROJECT_NAME \ + trainer.experiment_name=$YOUR_RUN_NAME \ + +trainer.val_before_train=False \ + trainer.default_hdfs_dir=null \ + trainer.n_gpus_per_node=$GPUS_PER_NODE \ + trainer.nnodes=1 \ + trainer.save_freq=10 \ + trainer.test_freq=10 \ + trainer.total_epochs=15 #2>&1 | tee verl_demo.log +``` + + +## GRPO +```bash +YOUR_PROJECT_NAME=r1-verl-grpo-upstream +YOUR_RUN_NAME=r1-training_grpo-upstream +# export HYDRA_FULL_ERROR=1 +# export FSDP_VERBOSE=1 +export HIP_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 +export ROCR_VISIBLE_DEVICES=$HIP_VISIBLE_DEVICES +GPUS_PER_NODE=8 +MODEL_PATH=Qwen/Qwen2.5-0.5B-Instruct +# MODEL_PATH=Qwen/Qwen2-7B-Instruct +python3 examples/data_preprocess/gsm8k.py --local_dir data/gsm8k +python3 -c "import transformers; transformers.pipeline('text-generation', model='$MODEL_PATH')" +python3 -m verl.trainer.main_ppo \ + algorithm.adv_estimator=grpo \ + data.train_files=data/gsm8k/train.parquet \ + data.val_files=data/gsm8k/test.parquet \ + data.train_batch_size=1024 \ + data.val_batch_size=1312 \ + data.max_prompt_length=512 \ + data.max_response_length=1024 \ + actor_rollout_ref.model.path=$MODEL_PATH \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.actor.ppo_mini_batch_size=256 \ + actor_rollout_ref.actor.use_dynamic_bsz=True \ + actor_rollout_ref.actor.ppo_max_token_len_per_gpu=24000 \ + actor_rollout_ref.actor.use_kl_loss=True \ + actor_rollout_ref.actor.kl_loss_coef=0.001 \ + actor_rollout_ref.actor.kl_loss_type=low_var_kl \ + actor_rollout_ref.model.enable_gradient_checkpointing=Flase \ + actor_rollout_ref.actor.fsdp_config.param_offload=False \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ + actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.8 \ + actor_rollout_ref.rollout.n=5 \ + actor_rollout_ref.ref.fsdp_config.param_offload=False \ + algorithm.kl_ctrl.kl_coef=0.001 \ + trainer.critic_warmup=0 \ + trainer.logger=['console'] \ + trainer.project_name=$YOUR_PROJECT_NAME \ + trainer.experiment_name=$YOUR_RUN_NAME \ + trainer.n_gpus_per_node=$GPUS_PER_NODE \ + +trainer.val_before_train=False \ + trainer.nnodes=1 \ + trainer.save_freq=-1 \ + trainer.test_freq=10 \ + trainer.total_epochs=15 +``` \ No newline at end of file diff --git a/docs/amd_tutorial/amd_existing_docker.md b/docs/amd_tutorial/amd_existing_docker.md new file mode 100644 index 00000000000..49d7828be82 --- /dev/null +++ b/docs/amd_tutorial/amd_existing_docker.md @@ -0,0 +1,157 @@ +# Setup + +## Docker: +Find the docker here: https://hub.docker.com/r/rocm/vllm/tags (rocm6.2_mi300_ubuntu20.04_py3.9_vllm_0.6.4) +```bash +docker run --rm -it \ + --device /dev/dri \ + --device /dev/kfd \ + --network host \ + --ipc host \ + --group-add video \ + --cap-add SYS_PTRACE \ + --security-opt seccomp=unconfined \ + --privileged \ + -v /home/yushensu:/home/yushensu \ + -v $HOME/.ssh:/root/.ssh \ + --shm-size 128G \ + --name verl_vllm_upstream \ + -w $PWD \ + rocm/vllm:rocm6.2_mi300_ubuntu20.04_py3.9_vllm_0.6.4 \ + /bin/bash +``` + +## Build ROCM vLLM: +```bash +pip uninstall -y vllm +git clone -b v0.6.3 https://github.com/vllm-project/vllm.git +cd vllm +export PYTORCH_ROCM_ARCH="gfx90a;gfx942" +export MAX_JOBS=$(nproc) +# python3 setup.py develop # will not create src need to keep the repo +python3 setup.py install # will add src into py. You can delete the repo +cd .. +rm -rf vllm +``` + +## Install the require packages: +```bash +pip install "tensordict<0.6" --no-deps + +pip install accelerate \ + codetiming \ + datasets \ + dill \ + hydra-core \ + liger-kernel \ + numpy \ + pandas \ + peft \ + "pyarrow>=15.0.0" \ + pylatexenc \ + "ray[data,train,tune,serve]" \ + torchdata \ + transformers \ + wandb \ + orjson \ + pybind11 + +pip install -e . --no-deps +``` + + +# Example + +## PPO +```bash +YOUR_PROJECT_NAME=r1-verl-ppo-upstream +YOUR_RUN_NAME=r1-training_ppo-upstream +# export HYDRA_FULL_ERROR=1 +export HIP_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 +export ROCR_VISIBLE_DEVICES=$HIP_VISIBLE_DEVICES +GPUS_PER_NODE=8 +MODEL_PATH=Qwen/Qwen2.5-0.5B-Instruct +python3 examples/data_preprocess/gsm8k.py --local_dir data/gsm8k +python3 -c "import transformers; transformers.pipeline('text-generation', model='$MODEL_PATH')" +PYTHONUNBUFFERED=1 python3 -m verl.trainer.main_ppo \ + data.train_files=data/gsm8k/train.parquet \ + data.val_files=data/gsm8k/test.parquet \ + data.train_batch_size=256 \ + data.val_batch_size=1312 \ + data.max_prompt_length=512 \ + data.max_response_length=256 \ + actor_rollout_ref.model.path=$MODEL_PATH \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.actor.ppo_mini_batch_size=64 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=8 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=1 \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.8 \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4 \ + critic.optim.lr=1e-5 \ + critic.model.path=$MODEL_PATH \ + critic.ppo_micro_batch_size_per_gpu=4 \ + algorithm.kl_ctrl.kl_coef=0.001 \ + trainer.logger=['console'] \ + trainer.project_name=$YOUR_PROJECT_NAME \ + trainer.experiment_name=$YOUR_RUN_NAME \ + +trainer.val_before_train=False \ + trainer.default_hdfs_dir=null \ + trainer.n_gpus_per_node=$GPUS_PER_NODE \ + trainer.nnodes=1 \ + trainer.save_freq=10 \ + trainer.test_freq=10 \ + trainer.total_epochs=15 #2>&1 | tee verl_demo.log +``` + + +## GRPO +```bash +YOUR_PROJECT_NAME=r1-verl-grpo-upstream +YOUR_RUN_NAME=r1-training_grpo-upstream +# export HYDRA_FULL_ERROR=1 +# export FSDP_VERBOSE=1 +export HIP_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 +export ROCR_VISIBLE_DEVICES=$HIP_VISIBLE_DEVICES +GPUS_PER_NODE=8 +MODEL_PATH=Qwen/Qwen2.5-0.5B-Instruct +# MODEL_PATH=Qwen/Qwen2-7B-Instruct +python3 examples/data_preprocess/gsm8k.py --local_dir data/gsm8k +python3 -c "import transformers; transformers.pipeline('text-generation', model='$MODEL_PATH')" +python3 -m verl.trainer.main_ppo \ + algorithm.adv_estimator=grpo \ + data.train_files=data/gsm8k/train.parquet \ + data.val_files=data/gsm8k/test.parquet \ + data.train_batch_size=1024 \ + data.val_batch_size=1312 \ + data.max_prompt_length=512 \ + data.max_response_length=1024 \ + actor_rollout_ref.model.path=$MODEL_PATH \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.actor.ppo_mini_batch_size=256 \ + actor_rollout_ref.actor.use_dynamic_bsz=True \ + actor_rollout_ref.actor.ppo_max_token_len_per_gpu=24000 \ + actor_rollout_ref.actor.use_kl_loss=True \ + actor_rollout_ref.actor.kl_loss_coef=0.001 \ + actor_rollout_ref.actor.kl_loss_type=low_var_kl \ + actor_rollout_ref.model.enable_gradient_checkpointing=Flase \ + actor_rollout_ref.actor.fsdp_config.param_offload=False \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ + actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.8 \ + actor_rollout_ref.rollout.n=5 \ + actor_rollout_ref.ref.fsdp_config.param_offload=False \ + algorithm.kl_ctrl.kl_coef=0.001 \ + trainer.critic_warmup=0 \ + trainer.logger=['console'] \ + trainer.project_name=$YOUR_PROJECT_NAME \ + trainer.experiment_name=$YOUR_RUN_NAME \ + trainer.n_gpus_per_node=$GPUS_PER_NODE \ + +trainer.val_before_train=False \ + trainer.nnodes=1 \ + trainer.save_freq=-1 \ + trainer.test_freq=10 \ + trainer.total_epochs=15 +``` \ No newline at end of file diff --git a/verl/single_controller/base/worker.py b/verl/single_controller/base/worker.py index 92496fbfd2a..dbe4cc600fa 100644 --- a/verl/single_controller/base/worker.py +++ b/verl/single_controller/base/worker.py @@ -118,6 +118,19 @@ def _configure_before_init(self, register_center_name: str, rank: int): def __init__(self, cuda_visible_devices=None) -> None: # construct a meta from envrionment variable. Note that the import must be inside the class because it is executed remotely import os + + ### + # [SUPPORT AMD: torch] + import torch + ### + + ### + # [SUPPORT AMD: torch] + if "AMD" in torch.cuda.get_device_name(): + os.environ['CUDA_VISIBLE_DEVICES'] = os.environ.get('ROCR_VISIBLE_DEVICES') + os.environ['LOCAL_RANK'] = os.environ.get('RAY_LOCAL_RANK') + ### + world_size = int(os.environ['WORLD_SIZE']) rank = int(os.environ['RANK']) self._rank = rank @@ -129,6 +142,18 @@ def __init__(self, cuda_visible_devices=None) -> None: local_world_size = int(os.getenv("LOCAL_WORLD_SIZE", "1")) local_rank = int(os.getenv("LOCAL_RANK", "0")) + ### + # [SUPPORT AMD: torch] + if "AMD" in torch.cuda.get_device_name(): + self.local_rank = int(os.environ['LOCAL_RANK']) + ### + + ### + # [SUPPORT AMD: torch] + if "AMD" in torch.cuda.get_device_name(): + cuda_visible_devices = str(local_rank) + ### + store = { '_world_size': world_size, '_rank': rank, @@ -143,6 +168,13 @@ def __init__(self, cuda_visible_devices=None) -> None: meta = WorkerMeta(store=store) self._configure_with_meta(meta=meta) + ### + # [SUPPORT AMD: torch] + # torch.cuda.set_device(local_rank) + if "AMD" in torch.cuda.get_device_name(): + torch.cuda.set_device(int(cuda_visible_devices)) + ### + def _configure_with_meta(self, meta: WorkerMeta): """ This function should only be called inside by WorkerGroup diff --git a/verl/third_party/vllm/__init__.py b/verl/third_party/vllm/__init__.py index ab925c1a901..77847f039b0 100644 --- a/verl/third_party/vllm/__init__.py +++ b/verl/third_party/vllm/__init__.py @@ -47,6 +47,11 @@ def get_version(pkg): from .vllm_v_0_6_3.llm import LLM from .vllm_v_0_6_3.llm import LLMEngine from .vllm_v_0_6_3 import parallel_state +elif package_version == '0.6.3+rocm624': + vllm_version = '0.6.3' + from .vllm_v_0_6_3.llm import LLM + from .vllm_v_0_6_3.llm import LLMEngine + from .vllm_v_0_6_3 import parallel_state elif vs.parse(package_version) >= vs.parse('0.6.6.post2.dev252+g8027a724'): # From 0.6.6.post2 on, vllm supports SPMD inference # See https://github.com/vllm-project/vllm/pull/12071 diff --git a/verl/trainer/ppo/ray_trainer.py b/verl/trainer/ppo/ray_trainer.py index e39b75e330f..45ea331b095 100644 --- a/verl/trainer/ppo/ray_trainer.py +++ b/verl/trainer/ppo/ray_trainer.py @@ -335,6 +335,19 @@ def compute_timing_metrics(batch, timing_raw): } +def compute_throughout_metrics(batch, timing_raw, n_gpus): + total_num_tokens = sum(batch.meta_info['global_token_num']) + time = timing_raw["step"] + # estimated_flops, promised_flops = flops_function.estimate_flops(num_tokens, time) + # f'Actual TFLOPs/s/GPU​': estimated_flops/(n_gpus), + # f'Theoretical TFLOPs/s/GPU​': promised_flops, + return { + f'total_num_tokens': total_num_tokens, + f'time_per_step': time, + f'Tokens/Sec/GPU': total_num_tokens / (time * n_gpus), + } + + @contextmanager def _timer(name: str, timing_raw: Dict[str, float]): with Timer(name=name, logger=None) as timer: @@ -1021,6 +1034,11 @@ def fit(self): metrics.update(compute_data_metrics(batch=batch, use_critic=self.use_critic)) metrics.update(compute_timing_metrics(batch=batch, timing_raw=timing_raw)) + config = self.config + n_gpus = config.trainer.n_gpus_per_node * config.trainer.nnodes + # Implement actual tflpo and theoretical tflpo + metrics.update(compute_throughout_metrics(batch=batch, timing_raw=timing_raw, n_gpus=n_gpus)) + # TODO: make a canonical logger that supports various backend logger.log(data=metrics, step=self.global_steps) diff --git a/verl/utils/flops_counter.py b/verl/utils/flops_counter.py index fd61d51a6eb..40d5afea5f7 100644 --- a/verl/utils/flops_counter.py +++ b/verl/utils/flops_counter.py @@ -32,7 +32,10 @@ def unit_convert(number, level): device_name = torch.cuda.get_device_name() flops = float("inf") # INF flops for unkown gpu type - if "H100" in device_name or "H800" in device_name: + + if "MI300X" in device_name: + flops = 1336e12 + elif "H100" in device_name or "H800" in device_name: flops = 989e12 elif "A100" in device_name or "A800" in device_name: flops = 312e12 diff --git a/verl/workers/actor/dp_actor.py b/verl/workers/actor/dp_actor.py index a6cf99406a2..34337b5428b 100644 --- a/verl/workers/actor/dp_actor.py +++ b/verl/workers/actor/dp_actor.py @@ -264,11 +264,11 @@ def update_policy(self, data: DataProto): self.actor_optimizer.zero_grad() for data in micro_batches: + # Support all hardwares if isinstance(data, DataProto): - data = {**data.batch.cuda(), **data.non_tensor_batch} + data = {**data.batch.to(torch.cuda.current_device()), **data.non_tensor_batch} else: - data = data.cuda() # actor device is cpu when using offload - + data = data.to(torch.cuda.current_device()) # actor device is cpu when using offload responses = data['responses'] response_length = responses.size(1) attention_mask = data['attention_mask'] diff --git a/verl/workers/critic/dp_critic.py b/verl/workers/critic/dp_critic.py index 1d167e87891..68473363371 100644 --- a/verl/workers/critic/dp_critic.py +++ b/verl/workers/critic/dp_critic.py @@ -198,11 +198,11 @@ def update_critic(self, data: DataProto): self.critic_optimizer.zero_grad() for data in micro_batches: + #Support all devices if isinstance(data, DataProto): - data = {**data.batch.cuda(), **data.non_tensor_batch} + data = {**data.batch.to(torch.cuda.current_device()), **data.non_tensor_batch} else: - data = data.cuda() # critic device is cpu when using offload - + data = data.to(torch.cuda.current_device()) # critic device is cpu when using offload input_ids = data['input_ids'] responses = data['responses'] attention_mask = data['attention_mask'] diff --git a/verl/workers/fsdp_workers.py b/verl/workers/fsdp_workers.py index f01bd2816ea..ad07f628acd 100644 --- a/verl/workers/fsdp_workers.py +++ b/verl/workers/fsdp_workers.py @@ -416,7 +416,8 @@ def init_model(self): @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) def update_actor(self, data: DataProto): - data = data.to('cuda') + # Support all hardwares + data = data.to(torch.cuda.current_device()) assert self._is_actor if self._is_offload_param: @@ -424,7 +425,8 @@ def update_actor(self, data: DataProto): if self._is_offload_optimizer: load_fsdp_optimizer(optimizer=self.actor_optimizer, device_id=torch.cuda.current_device()) - data.batch = data.batch.cuda() + # Support all hardwares + data.batch = data.batch.to(torch.cuda.current_device()) log_gpu_memory_usage('Before update policy', logger=logger) @@ -459,13 +461,15 @@ def update_actor(self, data: DataProto): @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) def generate_sequences(self, prompts: DataProto): - prompts = prompts.to('cuda') + # Support all hardwares + prompts = prompts.to(torch.cuda.current_device()) assert self._is_rollout if self._is_offload_param: load_fsdp_model_to_gpu(self.actor_module_fsdp) - prompts.batch = prompts.batch.cuda() + # Support all hardwares + prompts.batch = prompts.batch.to(torch.cuda.current_device()) meta_info = { 'eos_token_id': self.generation_config.eos_token_id @@ -504,7 +508,9 @@ def compute_log_prob(self, data: DataProto): assert self._is_actor if self._is_offload_param: load_fsdp_model_to_gpu(self.actor_module_fsdp) - data = data.to('cuda') + + # Support all hardwares + data = data.to(torch.cuda.current_device()) # we should always recompute old_log_probs when it is HybridEngine data.meta_info['micro_batch_size'] = self.config.rollout.log_prob_micro_batch_size_per_gpu data.meta_info['max_token_len'] = self.config.rollout.log_prob_max_token_len_per_gpu @@ -537,7 +543,8 @@ def compute_log_prob(self, data: DataProto): def compute_ref_log_prob(self, data: DataProto): assert self._is_ref - data = data.to('cuda') + # Support all hardwares + data = data.to(torch.cuda.current_device()) micro_batch_size = self.config.ref.log_prob_micro_batch_size_per_gpu data.meta_info['micro_batch_size'] = micro_batch_size @@ -782,7 +789,9 @@ def init_model(self): @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) def compute_values(self, data: DataProto): - data = data.to('cuda') + + # Support all hardwares + data = data.to(torch.cuda.current_device()) if self._is_offload_param: load_fsdp_model_to_gpu(self.critic_module) @@ -804,7 +813,8 @@ def compute_values(self, data: DataProto): @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) def update_critic(self, data: DataProto): - data = data.to('cuda') + # Support all hardwares + data = data.to(torch.cuda.current_device()) if self._is_offload_param: load_fsdp_model_to_gpu(self.critic_module) if self._is_offload_optimizer: @@ -1103,11 +1113,13 @@ def _switch_chat_template(self, data: DataProto): def compute_rm_score(self, data: DataProto): import itertools from verl.utils.seqlen_balancing import rearrange_micro_batches, get_reverse_idx - data = data.to('cuda') + # Support all hardwares + data = data.to(torch.cuda.current_device()) if self._do_switch_chat_template: rm_data = self._switch_chat_template(data) - rm_data.batch = rm_data.batch.cuda() + # Support all hardwares + rm_data.batch = rm_data.batch.to(torch.cuda.current_device()) # perform forward computation with self.ulysses_sharding_manager: diff --git a/verl/workers/rollout/vllm_rollout/__init__.py b/verl/workers/rollout/vllm_rollout/__init__.py index 50eca90bacc..0d6d4c3d818 100644 --- a/verl/workers/rollout/vllm_rollout/__init__.py +++ b/verl/workers/rollout/vllm_rollout/__init__.py @@ -14,6 +14,11 @@ from importlib.metadata import version, PackageNotFoundError +### +# [SUPPORT AMD:] +import torch +### + def get_version(pkg): try: @@ -25,6 +30,17 @@ def get_version(pkg): package_name = 'vllm' package_version = get_version(package_name) +### +# package_version = get_version(package_name) +# [SUPPORT AMD:] +if "AMD" in torch.cuda.get_device_name(): + import re + package_version = version(package_name) + package_version = re.match(r'(\d+\.\d+\.?\d*)', package_version).group(1) +else: + package_version = get_version(package_name) +### + if package_version <= '0.6.3': vllm_mode = 'customized' from .vllm_rollout import vLLMRollout diff --git a/verl/workers/sharding_manager/fsdp_vllm.py b/verl/workers/sharding_manager/fsdp_vllm.py index b2ad76ceec7..9bedc8bd036 100644 --- a/verl/workers/sharding_manager/fsdp_vllm.py +++ b/verl/workers/sharding_manager/fsdp_vllm.py @@ -74,6 +74,7 @@ def __enter__(self): log_gpu_memory_usage('After state_dict() in sharding manager memory', logger=logger) # Copy, not share memory load_format = 'hf' if self.full_params else 'dtensor' + if vllm_version in ('0.4.2', '0.5.4', '0.6.3'): self.inference_engine.sync_model_weights(params, load_format=load_format) else: