Skip to content

Commit

Permalink
Make Prompts Reusable for Very Large Experements (#880)
Browse files Browse the repository at this point in the history
first commit
  • Loading branch information
Bslabe123 authored Nov 19, 2024
1 parent 4d63dc3 commit 5250044
Showing 1 changed file with 17 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,8 @@ async def on_request_end(session, trace_config_ctx, params):
gcs_client = None
gcs_bucket = None

def sample_requests(
def get_filtered_dataset(
dataset_path: str,
num_requests: int,
max_input_len: int,
max_output_len: int,
tokenizer: PreTrainedTokenizerBase,
Expand All @@ -64,12 +63,11 @@ def sample_requests(
if use_dummy_text:
dummy_prompt_token_ids = [0] * max_input_len
dummy_prompt = tokenizer.decode(dummy_prompt_token_ids)
dummy_requests = [(
dummy_prompt,
max_input_len,
max_output_len,
)] * num_requests
return dummy_requests
return [(
dummy_prompt,
max_input_len,
max_output_len,
)]

# Load the dataset.
with open(dataset_path) as f:
Expand Down Expand Up @@ -106,18 +104,15 @@ def sample_requests(
continue
filtered_dataset.append((prompt, prompt_len, output_len))

# Sample the requests.
sampled_requests = random.sample(filtered_dataset, num_requests)
return sampled_requests
return filtered_dataset


async def get_request(
async def generate_next_request(
input_requests: List[Tuple[str, int, int]],
request_rate: float,
) -> AsyncGenerator[Tuple[str, int, int], None]:
"""Gets request async."""
input_requests = iter(input_requests)
for request in input_requests:
request = random.choice(input_requests)
while True:
yield request

if request_rate == float("inf"):
Expand Down Expand Up @@ -383,17 +378,19 @@ async def benchmark(
model: str,
) -> Tuple[List[Tuple[int, int, float]], List[float], Dict[str, int]]:
"""Runs benchmark with asynchronous requests."""
input_requests = sample_requests(
input_requests = get_filtered_dataset(
args.dataset,
args.num_prompts,
args.max_input_length,
args.max_output_length,
tokenizer,
args.use_dummy_text,
)
benchmark_start_time = time.time()
tasks: List[asyncio.Task] = []
async for request in get_request(input_requests, args.request_rate):
prompts_sent: int = 0
async for request in generate_next_request(input_requests, args.request_rate):
if args.num_prompts <= prompts_sent:
break
prompt, prompt_len, output_len = request
if args.stream_request:
task = asyncio.create_task(
Expand Down Expand Up @@ -428,6 +425,7 @@ async def benchmark(
)
)
tasks.append(task)
prompts_sent += 1
results = await asyncio.gather(*tasks)
combined_latencies = []
combined_ttfts = []
Expand All @@ -442,7 +440,7 @@ async def benchmark(
combined_ttfts.append(ttft)

benchmark_duration = time.time() - benchmark_start_time
print_and_save_result(args, benchmark_duration, len(input_requests), model, combined_latencies, combined_ttfts, combined_errors)
print_and_save_result(args, benchmark_duration, prompts_sent, model, combined_latencies, combined_ttfts, combined_errors)
return combined_latencies, combined_ttfts, combined_errors

def save_json_results(args: argparse.Namespace, benchmark_result, server_metrics, model, errors):
Expand Down

0 comments on commit 5250044

Please # to comment.