From 5250044000f7b9916d3c5a6ca6809c7f88c8e354 Mon Sep 17 00:00:00 2001 From: Brendan Slabe Date: Tue, 19 Nov 2024 20:29:06 +0100 Subject: [PATCH] Make Prompts Reusable for Very Large Experements (#880) first commit --- .../container/benchmark_serving.py | 36 +++++++++---------- 1 file changed, 17 insertions(+), 19 deletions(-) diff --git a/benchmarks/benchmark/tools/profile-generator/container/benchmark_serving.py b/benchmarks/benchmark/tools/profile-generator/container/benchmark_serving.py index f99a0a357..bb19efa4b 100644 --- a/benchmarks/benchmark/tools/profile-generator/container/benchmark_serving.py +++ b/benchmarks/benchmark/tools/profile-generator/container/benchmark_serving.py @@ -52,9 +52,8 @@ async def on_request_end(session, trace_config_ctx, params): gcs_client = None gcs_bucket = None -def sample_requests( +def get_filtered_dataset( dataset_path: str, - num_requests: int, max_input_len: int, max_output_len: int, tokenizer: PreTrainedTokenizerBase, @@ -64,12 +63,11 @@ def sample_requests( if use_dummy_text: dummy_prompt_token_ids = [0] * max_input_len dummy_prompt = tokenizer.decode(dummy_prompt_token_ids) - dummy_requests = [( - dummy_prompt, - max_input_len, - max_output_len, - )] * num_requests - return dummy_requests + return [( + dummy_prompt, + max_input_len, + max_output_len, + )] # Load the dataset. with open(dataset_path) as f: @@ -106,18 +104,15 @@ def sample_requests( continue filtered_dataset.append((prompt, prompt_len, output_len)) - # Sample the requests. - sampled_requests = random.sample(filtered_dataset, num_requests) - return sampled_requests + return filtered_dataset - -async def get_request( +async def generate_next_request( input_requests: List[Tuple[str, int, int]], request_rate: float, ) -> AsyncGenerator[Tuple[str, int, int], None]: """Gets request async.""" - input_requests = iter(input_requests) - for request in input_requests: + request = random.choice(input_requests) + while True: yield request if request_rate == float("inf"): @@ -383,9 +378,8 @@ async def benchmark( model: str, ) -> Tuple[List[Tuple[int, int, float]], List[float], Dict[str, int]]: """Runs benchmark with asynchronous requests.""" - input_requests = sample_requests( + input_requests = get_filtered_dataset( args.dataset, - args.num_prompts, args.max_input_length, args.max_output_length, tokenizer, @@ -393,7 +387,10 @@ async def benchmark( ) benchmark_start_time = time.time() tasks: List[asyncio.Task] = [] - async for request in get_request(input_requests, args.request_rate): + prompts_sent: int = 0 + async for request in generate_next_request(input_requests, args.request_rate): + if args.num_prompts <= prompts_sent: + break prompt, prompt_len, output_len = request if args.stream_request: task = asyncio.create_task( @@ -428,6 +425,7 @@ async def benchmark( ) ) tasks.append(task) + prompts_sent += 1 results = await asyncio.gather(*tasks) combined_latencies = [] combined_ttfts = [] @@ -442,7 +440,7 @@ async def benchmark( combined_ttfts.append(ttft) benchmark_duration = time.time() - benchmark_start_time - print_and_save_result(args, benchmark_duration, len(input_requests), model, combined_latencies, combined_ttfts, combined_errors) + print_and_save_result(args, benchmark_duration, prompts_sent, model, combined_latencies, combined_ttfts, combined_errors) return combined_latencies, combined_ttfts, combined_errors def save_json_results(args: argparse.Namespace, benchmark_result, server_metrics, model, errors):