Skip to content

Commit

Permalink
Organization
Browse files Browse the repository at this point in the history
  • Loading branch information
jakep-allenai committed Nov 8, 2024
1 parent ee72b36 commit e5fb7c0
Showing 1 changed file with 52 additions and 53 deletions.
105 changes: 52 additions & 53 deletions pdelfin/beakerpipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,14 @@
pdf_s3 = boto3.client('s3')


def build_page_query(local_pdf_path: str, page: int, target_longest_image_dim: int, target_anchor_text_len: int, image_rotation: int=0) -> dict:
async def build_page_query(local_pdf_path: str, page: int, target_longest_image_dim: int, target_anchor_text_len: int, image_rotation: int=0) -> dict:
assert image_rotation in [0, 90, 180, 270], "Invalid image rotation provided in build_page_query"
image_base64 = render_pdf_to_base64png(local_pdf_path, page, target_longest_image_dim=target_longest_image_dim)

# Allow the page rendering to process in the background while we get the anchor text (which blocks the main thread)
image_base64 = asyncio.to_thread(render_pdf_to_base64png, local_pdf_path, page, target_longest_image_dim=target_longest_image_dim)
anchor_text = get_anchor_text(local_pdf_path, page, pdf_engine="pdfreport", target_length=target_anchor_text_len)

image_base64 = await image_base64
if image_rotation != 0:
image_bytes = base64.b64decode(image_base64)
with Image.open(BytesIO(image_bytes)) as img:
Expand All @@ -50,9 +54,6 @@ def build_page_query(local_pdf_path: str, page: int, target_longest_image_dim: i
# Encode the rotated image back to base64
image_base64 = base64.b64encode(buffered.getvalue()).decode('utf-8')


anchor_text = get_anchor_text(local_pdf_path, page, pdf_engine="pdfreport", target_length=target_anchor_text_len)

return {
"chat_messages": [
{
Expand Down Expand Up @@ -179,21 +180,64 @@ async def load_pdf_work_queue(args) -> asyncio.Queue:
return queue

async def process_pdf(args, pdf_s3_path):
await asyncio.sleep(1)
return f"pdf: {pdf_s3_path}"
URL = "http://localhost:30000/v1/chat/completions"

with tempfile.NamedTemporaryFile("wb+", suffix=".pdf") as tf:
# TODO Grab this file async
tf.write(get_s3_bytes(pdf_s3, s3_path))
tf.flush()

reader = PdfReader(tf.name)
page_data = []

for page_num in range(1, reader.get_num_pages() + 1):
query = await build_page_query(tf.name, page_num, args.target_longest_image_dim, args.target_anchor_text_len)

# TODO Url.post with the query as json_data
# if the result is a 200 then you can append it to the page_data




# TODO build dolma doc and return it, or return None if not possible
metadata = {
"Source-File": pdf.s3_path,
"pdf-total-pages": pdf.num_pages,
}
id_ = hashlib.sha1(document_text.encode()).hexdigest()

dolma_doc = {
"id": id_,
"text": document_text,
"source": "pdelfin",
"added": datetime.datetime.now().strftime("%Y-%m-%d"),
"created": datetime.datetime.now().strftime("%Y-%m-%d"),
"metadata": metadata,
"attributes": {
"pdf_page_numbers": pdf_page_spans
}
}

return dolma_doc




async def worker(args, queue):
while True:
[work_hash, pdfs] = await queue.get()

completed_pdfs = await asyncio.gather(*[process_pdf(args, pdf) for pdf in pdfs])
logger.info(f"Completed {completed_pdfs}")

# Take all the not None completed_pdfs and write them as a jsonl to the workspace output location
# under the proper work_hash location

queue.task_done()


async def sglang_server_task(args):
model_cache_dir = os.path.join(os.path.expanduser('~'), '.cache', 'pdelfin', 'model')
# TODO cache locally
#download_directory(args.model, model_cache_dir)

proc = await asyncio.create_subprocess_exec(
Expand Down Expand Up @@ -270,56 +314,11 @@ async def main():
if __name__ == "__main__":
asyncio.run(main())




# TODO
# If there is a beaker flag, then your job is to trigger this script with N replicas on beaker
# If not, then your job is to do the actual work

# Download the model from the best place available


# Register atexit function and signal handlers to guarantee process termination
# def terminate_processes():
# print("Terminating child processes...")
# sglang_process.terminate()
# try:
# sglang_process.wait(timeout=30)
# except subprocess.TimeoutExpired:
# print("Forcing termination of child processes.")
# sglang_process.kill()
# print("Child processes terminated.")

# atexit.register(terminate_processes)

# def signal_handler(sig, frame):
# terminate_processes()
# sys.exit(0)

# signal.signal(signal.SIGINT, signal_handler)
# signal.signal(signal.SIGTERM, signal_handler)


# logger.info(f"Remaining work items: {len(remaining_work_queue)}")

# TODO
# Spawn up to N workers to do:
# In a loop, take a random work item, read in the pdfs, queue in their requests
# Get results back, retry any failed pages
# Check periodically if that work is done in s3, if so, then abandon this work
# Save results back to s3 workspace output folder

# TODO
# Possible future addon, in beaker, discover other nodes on this same job
# Send them a message when you take a work item off the queue

# try:
# while True:
# time.sleep(1)

# if sglang_process.returncode is not None:
# logger.error(f"Sglang server exited with code {sglang_process.returncode} exiting.")
# except KeyboardInterrupt:
# logger.info("Got keyboard interrupt, exiting everything")
# sys.exit(1)

0 comments on commit e5fb7c0

Please # to comment.