From e2bbd0eec9dc67c630325a745cdacb6adebccf3d Mon Sep 17 00:00:00 2001 From: Jake Poznanski Date: Tue, 10 Dec 2024 17:18:10 +0000 Subject: [PATCH] Adding some long context stats --- pdelfin/beakerpipeline.py | 44 ++++++++++++++++++++++++++++++++------- 1 file changed, 36 insertions(+), 8 deletions(-) diff --git a/pdelfin/beakerpipeline.py b/pdelfin/beakerpipeline.py index fdff656..6a62dba 100644 --- a/pdelfin/beakerpipeline.py +++ b/pdelfin/beakerpipeline.py @@ -717,6 +717,8 @@ def submit_beaker_job(args): def print_stats(args): + LONG_CONTEXT_THRESHOLD = 32768 + # Get total work items and completed items index_file_s3_path = os.path.join(args.workspace, "work_index_list.csv.zstd") output_glob = os.path.join(args.workspace, "results", "*.jsonl") @@ -741,20 +743,35 @@ def process_output_file(s3_path): total_fallback_pages = 0 processed_paths = set() + # Counters for long context docs within a single file + long_context_docs = 0 + long_context_tokens = 0 + for line in data.decode('utf-8').splitlines(): if line.strip(): doc = json.loads(line) doc_count += 1 - total_input_tokens += doc["metadata"].get("total-input-tokens", 0) - total_output_tokens += doc["metadata"].get("total-output-tokens", 0) - total_pages += doc["metadata"].get("pdf-total-pages", 0) - total_fallback_pages += doc["metadata"].get("total-fallback-pages", 0) + doc_input_tokens = doc["metadata"].get("total-input-tokens", 0) + doc_output_tokens = doc["metadata"].get("total-output-tokens", 0) + doc_pages = doc["metadata"].get("pdf-total-pages", 0) + doc_fallback_pages = doc["metadata"].get("total-fallback-pages", 0) + + total_input_tokens += doc_input_tokens + total_output_tokens += doc_output_tokens + total_pages += doc_pages + total_fallback_pages += doc_fallback_pages processed_paths.add(doc["metadata"]["Source-File"]) - - return doc_count, total_input_tokens, total_output_tokens, total_pages, total_fallback_pages, processed_paths + + # Check if this doc exceeds the long context threshold + if doc_output_tokens > LONG_CONTEXT_THRESHOLD: + long_context_docs += 1 + long_context_tokens += doc_output_tokens + + return (doc_count, total_input_tokens, total_output_tokens, total_pages, + total_fallback_pages, processed_paths, long_context_docs, long_context_tokens) except Exception as e: logger.warning(f"Error processing {s3_path}: {e}") - return 0, 0, 0, 0, 0, set() + return 0, 0, 0, 0, 0, set(), 0, 0 print("\nProcessing output files...") docs_total = 0 @@ -765,6 +782,10 @@ def process_output_file(s3_path): all_processed_paths = set() original_paths = set() + # Counters for long context documents across all files + long_context_docs_count = 0 + long_context_tokens_total = 0 + # First collect all original PDF paths for done_work_item in done_work_items: if match := re.search(r"output_(\w+).jsonl", done_work_item): @@ -775,13 +796,16 @@ def process_output_file(s3_path): futures = {executor.submit(process_output_file, item): item for item in done_work_items} for future in tqdm(as_completed(futures), total=len(futures)): - doc_count, input_tokens, output_tokens, pages, fallback_pages, processed_paths = future.result() + (doc_count, input_tokens, output_tokens, pages, fallback_pages, + processed_paths, long_context_docs, long_context_tokens) = future.result() docs_total += doc_count input_tokens_total += input_tokens output_tokens_total += output_tokens pages_total += pages fallback_pages_total += fallback_pages all_processed_paths.update(processed_paths) + long_context_docs_count += long_context_docs + long_context_tokens_total += long_context_tokens skipped_paths = original_paths - all_processed_paths @@ -803,6 +827,10 @@ def process_output_file(s3_path): print(f"Average output tokens per doc: {output_tokens_total/max(1,docs_total):,.1f}") print(f"Average output tokens per page: {output_tokens_total/max(1,pages_total):,.1f}") + # Print long context documents stats + print(f"\nLong Context Documents (>{LONG_CONTEXT_THRESHOLD} tokens): {long_context_docs_count:,}") + print(f"Total tokens in long context documents: {long_context_tokens_total:,}") + async def main(): parser = argparse.ArgumentParser(description='Manager for running millions of PDFs through a batch inference pipeline')