From bc3929ee9d9175a7fa0083b40a150db457966738 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Mon, 24 Jun 2024 15:22:56 -0700 Subject: [PATCH 1/5] bugfix --- .../device_communicators/shm_broadcast.py | 63 ++++++++++--------- 1 file changed, 35 insertions(+), 28 deletions(-) diff --git a/vllm/distributed/device_communicators/shm_broadcast.py b/vllm/distributed/device_communicators/shm_broadcast.py index c44bd2f11ee8b..f938709b34d94 100644 --- a/vllm/distributed/device_communicators/shm_broadcast.py +++ b/vllm/distributed/device_communicators/shm_broadcast.py @@ -145,7 +145,6 @@ def __init__(self, buffer: ShmRingBuffer, reader_rank: int): @contextmanager def acquire_write(self): assert self._is_writer, "Only writers can acquire write" - start_index = self.current_idx start_time = time.time() n_warning = 1 while True: @@ -154,19 +153,21 @@ def acquire_write(self): written_flag = metadata_buffer[0] if written_flag and read_count != self.buffer.n_reader: # this block is written and not read by all readers - # try to write to the next block - self.current_idx = (self.current_idx + - 1) % self.buffer.max_chunks - if self.current_idx == start_index: - # no empty block found - if time.time( - ) - start_time > VLLM_RINGBUFFER_WARNING_INTERVAL * n_warning: # noqa - logger.warning( - "No available block found in %s second. ", - VLLM_RINGBUFFER_WARNING_INTERVAL) - n_warning += 1 - # wait for a while (0.1 us) - time.sleep(1e-7) + # for writers, `self.current_idx` is the next block to write + # if this block is not ready to write, + # we need to wait until it is read by all readers + + # wait for a while (0.1 us) + time.sleep(1e-7) + + # if we wait for a long time, we should warn the user + if time.time( + ) - start_time > VLLM_RINGBUFFER_WARNING_INTERVAL * n_warning: # noqa + logger.warning( + "No available block found in %s second. ", + VLLM_RINGBUFFER_WARNING_INTERVAL) + n_warning += 1 + continue # found a block that is either # (1) not written @@ -188,12 +189,13 @@ def acquire_write(self): metadata_buffer[i] = 0 # mark the block as written metadata_buffer[0] = 1 + self.current_idx = (self.current_idx + + 1) % self.buffer.max_chunks break @contextmanager def acquire_read(self): assert self._is_reader, "Only readers can acquire read" - start_index = self.current_idx start_time = time.time() n_warning = 1 while True: @@ -204,19 +206,22 @@ def acquire_read(self): # this block is either # (1) not written # (2) already read by this reader - # try to read the next block - self.current_idx = (self.current_idx + - 1) % self.buffer.max_chunks - if self.current_idx == start_index: - # no block found - if time.time( - ) - start_time > VLLM_RINGBUFFER_WARNING_INTERVAL * n_warning: # noqa - logger.warning( - "No available block found in %s second. ", - VLLM_RINGBUFFER_WARNING_INTERVAL) - n_warning += 1 - # wait for a while (0.1 us) - time.sleep(1e-7) + + # for readers, `self.current_idx` is the next block to read + # if this block is not ready, + # we need to wait until it is written + + # wait for a while (0.1 us) + time.sleep(1e-7) + + # if we wait for a long time, we should warn the user + if time.time( + ) - start_time > VLLM_RINGBUFFER_WARNING_INTERVAL * n_warning: # noqa + logger.warning( + "No available block found in %s second. ", + VLLM_RINGBUFFER_WARNING_INTERVAL) + n_warning += 1 + continue # found a block that is not read by this reader # let caller read from the buffer @@ -226,6 +231,8 @@ def acquire_read(self): # caller has read from the buffer # set the read flag metadata_buffer[self.reader_rank + 1] = 1 + self.current_idx = (self.current_idx + + 1) % self.buffer.max_chunks break def enqueue(self, obj): From 870cd22190993f26bcd2d65b7e3b4d6182ab7ffa Mon Sep 17 00:00:00 2001 From: youkaichao Date: Mon, 24 Jun 2024 15:27:00 -0700 Subject: [PATCH 2/5] add stress tests --- tests/distributed/test_shm_broadcast.py | 49 +++++++++++++++++-------- 1 file changed, 33 insertions(+), 16 deletions(-) diff --git a/tests/distributed/test_shm_broadcast.py b/tests/distributed/test_shm_broadcast.py index d92900ffce00b..2c2466f81bb8a 100644 --- a/tests/distributed/test_shm_broadcast.py +++ b/tests/distributed/test_shm_broadcast.py @@ -1,7 +1,9 @@ import multiprocessing import random import time +from typing import List +import numpy as np import torch.distributed as dist from vllm.distributed.device_communicators.shm_broadcast import ( @@ -9,6 +11,14 @@ from vllm.utils import update_environment_variables +def get_arrays(n: int, seed: int = 0) -> List[np.ndarray]: + np.random.seed(seed) + sizes = np.random.randint(1, 10_000, n) + # on average, each array will have 5k elements + # with int64, each array will have 40kb + return [np.random.randint(1, 100, i) for i in sizes] + + def distributed_run(fn, world_size): number_of_processes = world_size processes = [] @@ -47,24 +57,31 @@ def wrapped_fn(env): def worker_fn(): writer_rank = 2 broadcaster = ShmRingBufferIO.create_from_process_group( - dist.group.WORLD, 1024, 2, writer_rank) + dist.group.WORLD, 1024 * 1024, 2, writer_rank) + if dist.get_rank() == writer_rank: + seed = random.randint(0, 1000) + dist.broadcast_object_list([seed], writer_rank) + else: + recv = [None] + dist.broadcast_object_list(recv, writer_rank) + seed = recv[0] # type: ignore + dist.barrier() + # in case we find a race condition + # print the seed so that we can reproduce the error + print(f"Rank {dist.get_rank()} got seed {seed}") + # test broadcasting with about 400MB of data + N = 10_000 if dist.get_rank() == writer_rank: - time.sleep(random.random()) - broadcaster.broadcast_object(0) - time.sleep(random.random()) - broadcaster.broadcast_object({}) - time.sleep(random.random()) - broadcaster.broadcast_object([]) + arrs = get_arrays(N, seed) + for x in arrs: + broadcaster.broadcast_object(x) + time.sleep(random.random() / 1000) else: - time.sleep(random.random()) - a = broadcaster.broadcast_object(None) - time.sleep(random.random()) - b = broadcaster.broadcast_object(None) - time.sleep(random.random()) - c = broadcaster.broadcast_object(None) - assert a == 0 - assert b == {} - assert c == [] + arrs = get_arrays(N, seed) + for x in arrs: + y = broadcaster.broadcast_object(None) + assert np.array_equal(x, y) + time.sleep(random.random() / 1000) dist.barrier() From 004950d25d932760f653b0875814ec66a00c636f Mon Sep 17 00:00:00 2001 From: youkaichao Date: Mon, 24 Jun 2024 15:48:11 -0700 Subject: [PATCH 3/5] use time.monotonic() --- vllm/distributed/device_communicators/shm_broadcast.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/vllm/distributed/device_communicators/shm_broadcast.py b/vllm/distributed/device_communicators/shm_broadcast.py index f938709b34d94..0166673c891c1 100644 --- a/vllm/distributed/device_communicators/shm_broadcast.py +++ b/vllm/distributed/device_communicators/shm_broadcast.py @@ -145,7 +145,7 @@ def __init__(self, buffer: ShmRingBuffer, reader_rank: int): @contextmanager def acquire_write(self): assert self._is_writer, "Only writers can acquire write" - start_time = time.time() + start_time = time.monotonic() n_warning = 1 while True: with self.buffer.get_metadata(self.current_idx) as metadata_buffer: @@ -161,7 +161,7 @@ def acquire_write(self): time.sleep(1e-7) # if we wait for a long time, we should warn the user - if time.time( + if time.monotonic( ) - start_time > VLLM_RINGBUFFER_WARNING_INTERVAL * n_warning: # noqa logger.warning( "No available block found in %s second. ", @@ -196,7 +196,7 @@ def acquire_write(self): @contextmanager def acquire_read(self): assert self._is_reader, "Only readers can acquire read" - start_time = time.time() + start_time = time.monotonic() n_warning = 1 while True: with self.buffer.get_metadata(self.current_idx) as metadata_buffer: @@ -215,7 +215,7 @@ def acquire_read(self): time.sleep(1e-7) # if we wait for a long time, we should warn the user - if time.time( + if time.monotonic( ) - start_time > VLLM_RINGBUFFER_WARNING_INTERVAL * n_warning: # noqa logger.warning( "No available block found in %s second. ", From a3b9cf726236439a3d1ecc4d5cad3af11da571b1 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Tue, 25 Jun 2024 14:40:51 -0700 Subject: [PATCH 4/5] add comments --- vllm/distributed/device_communicators/shm_broadcast.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/vllm/distributed/device_communicators/shm_broadcast.py b/vllm/distributed/device_communicators/shm_broadcast.py index 0166673c891c1..637ca42d00168 100644 --- a/vllm/distributed/device_communicators/shm_broadcast.py +++ b/vllm/distributed/device_communicators/shm_broadcast.py @@ -157,7 +157,10 @@ def acquire_write(self): # if this block is not ready to write, # we need to wait until it is read by all readers - # wait for a while (0.1 us) + # wait for a while + # if we sleep for too short, it will consume too much CPU + # if we sleep for too long, it will slow down the writer + # 0.1 us is a good balance time.sleep(1e-7) # if we wait for a long time, we should warn the user @@ -211,7 +214,10 @@ def acquire_read(self): # if this block is not ready, # we need to wait until it is written - # wait for a while (0.1 us) + # wait for a while + # if we sleep for too short, it will consume too much CPU + # if we sleep for too long, it will slow down the reader + # 0.1 us is a good balance time.sleep(1e-7) # if we wait for a long time, we should warn the user From ac4cac874bb10bc17691634420a65891cf673ebe Mon Sep 17 00:00:00 2001 From: youkaichao Date: Tue, 25 Jun 2024 14:57:01 -0700 Subject: [PATCH 5/5] use global constant --- .../device_communicators/shm_broadcast.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/vllm/distributed/device_communicators/shm_broadcast.py b/vllm/distributed/device_communicators/shm_broadcast.py index 637ca42d00168..550271f881df5 100644 --- a/vllm/distributed/device_communicators/shm_broadcast.py +++ b/vllm/distributed/device_communicators/shm_broadcast.py @@ -14,6 +14,12 @@ VLLM_RINGBUFFER_WARNING_INTERVAL = envs.VLLM_RINGBUFFER_WARNING_INTERVAL +# time to wait if the queue is full or empty +# if we sleep for too short, it will consume too much CPU +# if we sleep for too long, it will slow down the writer/reader +# 0.1 us is a good balance +RINGBUFFER_SLEEP_INTERVAL = 1e-7 + logger = init_logger(__name__) @@ -158,10 +164,7 @@ def acquire_write(self): # we need to wait until it is read by all readers # wait for a while - # if we sleep for too short, it will consume too much CPU - # if we sleep for too long, it will slow down the writer - # 0.1 us is a good balance - time.sleep(1e-7) + time.sleep(RINGBUFFER_SLEEP_INTERVAL) # if we wait for a long time, we should warn the user if time.monotonic( @@ -215,10 +218,7 @@ def acquire_read(self): # we need to wait until it is written # wait for a while - # if we sleep for too short, it will consume too much CPU - # if we sleep for too long, it will slow down the reader - # 0.1 us is a good balance - time.sleep(1e-7) + time.sleep(RINGBUFFER_SLEEP_INTERVAL) # if we wait for a long time, we should warn the user if time.monotonic(