Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

[core][rlhf] add colocate example for RLHF #12984

Merged
merged 7 commits into from
Feb 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ steps:
- tests/spec_decode/e2e/test_integration_dist_tp4
- tests/compile
- examples/offline_inference/rlhf.py
- examples/offline_inference/ray_placement.py
- examples/offline_inference/rlhf_colocate.py
commands:
- pytest -v -s distributed/test_utils.py
- pytest -v -s compile/test_basic_correctness.py
Expand All @@ -137,7 +137,7 @@ steps:
# TODO: create a dedicated test section for multi-GPU example tests
# when we have multiple distributed example tests
- python3 ../examples/offline_inference/rlhf.py
- RAY_DEDUP_LOGS=0 python3 ../examples/offline_inference/ray_placement.py
- RAY_DEDUP_LOGS=0 python3 ../examples/offline_inference/rlhf_colocate.py

- label: Metrics, Tracing Test # 10min
num_gpus: 2
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
# SPDX-License-Identifier: Apache-2.0
"""
a simple demonstration to show how to control
the placement of the vLLM workers with Ray.
The key is to set VLLM_RAY_PER_WORKER_GPUS and
VLLM_RAY_BUNDLE_INDICES properly.
a simple demonstration to show how to co-locate
vLLM worker with training actors on the same GPUs,
for RLHF-like applications.
The key points:
- Control the placement of the vLLM workers with Ray, by setting
VLLM_RAY_PER_WORKER_GPUS and VLLM_RAY_BUNDLE_INDICES properly.
- Use cuda-ipc to pass tensors, since NCCL does not work when we have
multiple processes on the same GPU.
"""
import os

import ray
import torch
from ray.util.placement_group import placement_group
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy

Expand All @@ -19,7 +24,33 @@ class MyWorker(Worker):

def report_device_id(self) -> str:
from vllm.platforms import current_platform
return current_platform.get_device_uuid(self.device.index)
self.device_uuid = current_platform.get_device_uuid(self.device.index)
return self.device_uuid

def update_weights_from_ipc_handles(self, ipc_handles):
handles = ipc_handles[self.device_uuid]
device_id = self.device.index
weights = []
for name, handle in handles.items():
func, args = handle
list_args = list(args)
# the key is to change device id to the current device id
# in case two processes have different CUDA_VISIBLE_DEVICES
list_args[6] = device_id
tensor = func(*list_args)
weights.append((name, tensor))
self.model_runner.model.load_weights(weights=weights)
torch.cuda.synchronize()

def check_weights_changed(self):
"""
Check if the weights are updated to 0.
"""
weights_updated = True
for name, p in self.model_runner.model.named_parameters():
weights_updated = weights_updated and torch.allclose(
p, torch.zeros_like(p))
Comment on lines +51 to +52
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be better to use p.abs()<1-e5? So that we don't need allocate a new zero tensor

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is just for demonstration, in the real use case, they don't need to check if the weights are zero. using torch.allclose here is more clear.

return weights_updated


class MyLLM(LLM):
Expand All @@ -40,12 +71,32 @@ def __init__(self, *args, bundle_indices: list, **kwargs):

class RayTrainingActor:

def report_device_id(self) -> str:
def __init__(self):
# ray will set CUDA_VISIBLE_DEVICES to the assigned GPUs
from transformers import AutoModelForCausalLM
self.model = AutoModelForCausalLM.from_pretrained("facebook/opt-125m")
self.model.to("cuda:0")
for name, p in self.model.named_parameters():
p.data.zero_()
torch.cuda.synchronize()
# the argument for get_device_uuid is the index
# of the GPU in the visible devices.
# ray will set CUDA_VISIBLE_DEVICES to the assigned GPUs
from vllm.platforms import current_platform
return current_platform.get_device_uuid(0)
self.device_uuid = current_platform.get_device_uuid(0)

def report_device_id(self) -> str:
return self.device_uuid

def get_weight_ipc_handles(self):
from torch.multiprocessing.reductions import reduce_tensor
data = {}
for name, p in self.model.named_parameters():
# the training actor might only have a subset of the weights
# and need to all-gather the weights from all the actors.
# for demonstration, here we assume all training actors have
# the full weights.
data[name] = reduce_tensor(p.detach())
return {self.device_uuid: data}


# ray manages 4 GPUs
Expand Down Expand Up @@ -78,6 +129,8 @@ def report_device_id(self) -> str:
),
)(RayTrainingActor).remote()
training_actors.append(training_actor)

for bundle_index, training_actor in enumerate(training_actors):
device_id = ray.get(training_actor.report_device_id.remote())
print(f"training actor {bundle_index} is on {device_id}")
training_actor_device_ids.append(device_id)
Expand Down Expand Up @@ -119,3 +172,18 @@ def report_device_id(self) -> str:
# the last two training actors should be
# on the same GPUs as the second inference engine
assert training_actor_device_ids[2:] == inference_engine_device_ids[1]

print("gather all the IPC handles from the training actors")
ipc_handles = {}
for actor in training_actors:
ipc_handles.update(ray.get(actor.get_weight_ipc_handles.remote()))

print("update the weights of the inference engines")
for llm in inference_engines:
ray.get(
llm.collective_rpc.remote("update_weights_from_ipc_handles",
args=(ipc_handles, )))
print("check if the weights are updated")
for llm in inference_engines:
assert ray.get(
llm.collective_rpc.remote("check_weights_changed", args=tuple()))