generated from allenai/python-package-template
-
Notifications
You must be signed in to change notification settings - Fork 495
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
c74e3d1
commit 6586744
Showing
1 changed file
with
186 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,186 @@ | ||
import os | ||
import csv | ||
import json | ||
import argparse | ||
import re | ||
import collections | ||
import random | ||
import sqlite3 | ||
|
||
from urllib.parse import urlparse | ||
from tqdm import tqdm | ||
|
||
def parse_pdf_hash(pretty_pdf_path: str) -> str: | ||
""" | ||
Given a string like "s3://ai2-s2-pdfs/4342/6a12ffc2ffa73f5258eb66095659beae9522.pdf-32", | ||
extract the hash ("43426a12ffc2ffa73f5258eb66095659beae9522"). | ||
Returns None if not found. | ||
""" | ||
pattern = r"s3://ai2-s2-pdfs/([a-f0-9]{4})/([a-f0-9]+)\.pdf-\d+" | ||
match = re.match(pattern, pretty_pdf_path) | ||
if match: | ||
return match.group(1) + match.group(2) | ||
return None | ||
|
||
def cache_athena_csv_to_db(athena_csv_path: str) -> str: | ||
""" | ||
Cache the Athena CSV file into an SQLite database. | ||
Returns the path to the SQLite database. | ||
""" | ||
db_path = athena_csv_path + ".db" | ||
|
||
if not os.path.exists(db_path): | ||
conn = sqlite3.connect(db_path) | ||
cursor = conn.cursor() | ||
|
||
cursor.execute("PRAGMA synchronous = OFF;") | ||
cursor.execute("PRAGMA journal_mode = MEMORY;") | ||
|
||
|
||
# Create the table | ||
cursor.execute(""" | ||
CREATE TABLE pdf_mapping ( | ||
pdf_hash TEXT PRIMARY KEY, | ||
uri TEXT | ||
) | ||
""") | ||
|
||
# Insert data from CSV in batches of 1000 rows | ||
with open(athena_csv_path, "r", encoding="utf-8") as f: | ||
reader = csv.DictReader(f) | ||
batch = [] | ||
for row in tqdm(reader): | ||
batch.append((row["distinct_pdf_hash"], row["uri"])) | ||
if len(batch) == 1000: | ||
cursor.executemany("INSERT INTO pdf_mapping (pdf_hash, uri) VALUES (?, ?)", batch) | ||
conn.commit() | ||
batch = [] | ||
|
||
# Insert remaining rows | ||
if batch: | ||
cursor.executemany("INSERT INTO pdf_mapping (pdf_hash, uri) VALUES (?, ?)", batch) | ||
conn.commit() | ||
|
||
conn.close() | ||
|
||
return db_path | ||
|
||
def get_uri_from_db(db_path: str, pdf_hash: str) -> str: | ||
""" | ||
Query the SQLite database to retrieve the URI for a given PDF hash. | ||
""" | ||
conn = sqlite3.connect(db_path) | ||
cursor = conn.cursor() | ||
cursor.execute("SELECT uri FROM pdf_mapping WHERE pdf_hash = ?", (pdf_hash,)) | ||
result = cursor.fetchone() | ||
conn.close() | ||
return result[0] if result else None | ||
|
||
def main(): | ||
parser = argparse.ArgumentParser( | ||
description="Review silver dataset and provide summary statistics based on source URL and also provide a few data samples for review." | ||
) | ||
parser.add_argument( | ||
"--input", | ||
type=str, | ||
default="openai_batch_data", | ||
help="Input folder, which is the output of the buildsilver.py script", | ||
) | ||
parser.add_argument( | ||
"--output", | ||
type=str, | ||
default="openai_batch_data_summary", | ||
help="Output destination (folder)", | ||
) | ||
parser.add_argument( | ||
"--athena-csv", | ||
type=str, | ||
default="/home/ubuntu/s2pdf_url_data/c974870d-3b06-4793-9a62-d46d38e2c8b2.csv", | ||
help="CSV file that maps pdf_hash to uri", | ||
) | ||
parser.add_argument( | ||
"--sample-size", | ||
type=int, | ||
default=20, | ||
help="How many sample rows to include in the sample CSV", | ||
) | ||
|
||
args = parser.parse_args() | ||
|
||
# Cache the Athena CSV into SQLite database | ||
db_path = cache_athena_csv_to_db(args.athena_csv) | ||
|
||
# Process input JSONL files | ||
all_rows = [] | ||
|
||
for filename in tqdm(os.listdir(args.input)): | ||
if filename.endswith(".jsonl"): | ||
filepath = os.path.join(args.input, filename) | ||
with open(filepath, "r", encoding="utf-8") as f: | ||
for line in f: | ||
line = line.strip() | ||
if not line: | ||
continue | ||
|
||
try: | ||
data = json.loads(line) | ||
except json.JSONDecodeError: | ||
print("Error parsing line") | ||
continue | ||
|
||
custom_id = data.get("custom_id") | ||
if not custom_id: | ||
print("No custom_id found") | ||
continue | ||
|
||
pdf_hash = parse_pdf_hash(custom_id) | ||
assert pdf_hash, f"Need to have a pdf_hash {custom_id}" | ||
|
||
uri = get_uri_from_db(db_path, pdf_hash) | ||
|
||
domain = None | ||
if uri: | ||
parsed = urlparse(uri) | ||
domain = parsed.netloc | ||
|
||
all_rows.append((custom_id, uri, domain)) | ||
|
||
# Write output CSVs | ||
os.makedirs(args.output, exist_ok=True) | ||
|
||
output_csv_path = os.path.join(args.output, "custom_id_to_url.csv") | ||
with open(output_csv_path, "w", encoding="utf-8", newline="") as f: | ||
writer = csv.writer(f) | ||
writer.writerow(["custom_id", "uri", "domain"]) | ||
for (cid, uri, domain) in all_rows: | ||
writer.writerow([cid, uri if uri else "", domain if domain else ""]) | ||
|
||
domain_counter = collections.Counter() | ||
for (_, _, domain) in all_rows: | ||
if domain: | ||
domain_counter[domain] += 1 | ||
|
||
most_common_domains = domain_counter.most_common(1000) | ||
domain_csv_path = os.path.join(args.output, "top_1000_domains.csv") | ||
with open(domain_csv_path, "w", encoding="utf-8", newline="") as f: | ||
writer = csv.writer(f) | ||
writer.writerow(["domain", "count"]) | ||
for domain, count in most_common_domains: | ||
writer.writerow([domain, count]) | ||
|
||
sample_size = min(args.sample_size, len(all_rows)) | ||
sample_rows = random.sample(all_rows, sample_size) if all_rows else [] | ||
sample_csv_path = os.path.join(args.output, "data_samples.csv") | ||
with open(sample_csv_path, "w", encoding="utf-8", newline="") as f: | ||
writer = csv.writer(f) | ||
writer.writerow(["custom_id", "uri", "domain"]) | ||
for (cid, uri, domain) in sample_rows: | ||
writer.writerow([cid, uri if uri else "", domain if domain else ""]) | ||
|
||
print(f"Summary files written to: {args.output}") | ||
print(f" - Full mapping: {output_csv_path}") | ||
print(f" - Top domains: {domain_csv_path}") | ||
print(f" - Samples: {sample_csv_path}") | ||
|
||
if __name__ == "__main__": | ||
main() |