From d4f3cfff4d46e742982895d9bb46054f5ba88fad Mon Sep 17 00:00:00 2001 From: Jake Poznanski Date: Wed, 27 Nov 2024 19:11:20 +0000 Subject: [PATCH] More reliable weka --- pdelfin/s3_utils.py | 82 +++++++++++++++++++++++++++++---------------- pdelfin/version.py | 2 +- 2 files changed, 54 insertions(+), 30 deletions(-) diff --git a/pdelfin/s3_utils.py b/pdelfin/s3_utils.py index 4223ca9..241652a 100644 --- a/pdelfin/s3_utils.py +++ b/pdelfin/s3_utils.py @@ -8,7 +8,7 @@ import time import requests import concurrent.futures -import hashlib # Added for MD5 hash computation +import hashlib from urllib.parse import urlparse from pathlib import Path @@ -81,6 +81,7 @@ def get_s3_bytes(s3_client, s3_path: str, start_index: Optional[int] = None, end return obj['Body'].read() + def get_s3_bytes_with_backoff(s3_client, pdf_s3_path, max_retries: int = 8, backoff_factor: int = 2): attempt = 0 @@ -106,6 +107,7 @@ def get_s3_bytes_with_backoff(s3_client, pdf_s3_path, max_retries: int = 8, back logger.error(f"Failed to get_s3_bytes for {pdf_s3_path} after {max_retries} retries.") raise Exception("Failed to get_s3_bytes after retries") + def put_s3_bytes(s3_client, s3_path: str, data: bytes): bucket, key = parse_s3_path(s3_path) @@ -160,6 +162,7 @@ def is_running_on_gcp(): except requests.RequestException: return False + def download_directory(model_choices: List[str], local_dir: str): """ Download the model to a specified local directory. @@ -242,7 +245,12 @@ def should_download(blob, local_file_path): return compare_hashes_gcs(blob, local_file_path) def download_blob(blob, local_file_path): - blob.download_to_filename(local_file_path) + try: + blob.download_to_filename(local_file_path) + logger.info(f"Successfully downloaded {blob.name} to {local_file_path}") + except Exception as e: + logger.error(f"Failed to download {blob.name} to {local_file_path}: {e}") + raise items = blobs elif storage_type in ('s3', 'weka'): @@ -272,21 +280,30 @@ def download_blob(blob, local_file_path): for page in pages: if 'Contents' in page: objects.extend(page['Contents']) + else: + logger.warning(f"No contents found in page: {page}") total_files = len(objects) logger.info(f"Found {total_files} files in {'Weka' if storage_type == 'weka' else 'S3'} bucket '{bucket_name}' with prefix '{prefix}'.") transfer_config = TransferConfig( multipart_threshold=8 * 1024 * 1024, multipart_chunksize=8 * 1024 * 1024, - max_concurrency=100, + max_concurrency=10, # Reduced for WekaFS compatibility use_threads=True ) def should_download(obj, local_file_path): - return compare_hashes_s3(obj, local_file_path) + return compare_hashes_s3(obj, local_file_path, storage_type) def download_blob(obj, local_file_path): - s3_client.download_file(bucket_name, obj['Key'], local_file_path, Config=transfer_config) + logger.info(f"Starting download of {obj['Key']} to {local_file_path}") + try: + with open(local_file_path, 'wb') as f: + s3_client.download_fileobj(bucket_name, obj['Key'], f, Config=transfer_config) + logger.info(f"Successfully downloaded {obj['Key']} to {local_file_path}") + except Exception as e: + logger.error(f"Failed to download {obj['Key']} to {local_file_path}: {e}") + raise items = objects else: @@ -307,8 +324,11 @@ def download_blob(obj, local_file_path): total_files -= 1 # Decrement total_files as we're skipping this file if total_files > 0: - for _ in tqdm(concurrent.futures.as_completed(futures), total=total_files, desc=f"Downloading from {storage_type.upper()}"): - pass + for future in tqdm(concurrent.futures.as_completed(futures), total=total_files, desc=f"Downloading from {storage_type.upper()}"): + try: + future.result() + except Exception as e: + logger.error(f"Error occurred during download: {e}") else: logger.info("All files are up-to-date. No downloads needed.") @@ -336,31 +356,35 @@ def compare_hashes_gcs(blob, local_file_path: str) -> bool: return True -def compare_hashes_s3(obj, local_file_path: str) -> bool: +def compare_hashes_s3(obj, local_file_path: str, storage_type: str) -> bool: """Compare MD5 hashes or sizes for S3 objects (including Weka).""" if os.path.exists(local_file_path): - etag = obj['ETag'].strip('"') - if '-' in etag: - remote_size = obj['Size'] - local_size = os.path.getsize(local_file_path) - if remote_size == local_size: - logger.info(f"File '{local_file_path}' size matches remote multipart file. Skipping download.") - return False - else: - logger.info(f"File '{local_file_path}' size differs from remote multipart file. Downloading.") - return True + if storage_type == 'weka': + return True else: - hash_md5 = hashlib.md5() - with open(local_file_path, "rb") as f: - for chunk in iter(lambda: f.read(8192), b""): - hash_md5.update(chunk) - local_md5 = hash_md5.hexdigest() - if etag == local_md5: - logger.info(f"File '{local_file_path}' already up-to-date. Skipping download.") - return False + etag = obj['ETag'].strip('"') + if '-' in etag: + # Multipart upload, compare sizes + remote_size = obj['Size'] + local_size = os.path.getsize(local_file_path) + if remote_size == local_size: + logger.info(f"File '{local_file_path}' size matches remote multipart file. Skipping download.") + return False + else: + logger.info(f"File '{local_file_path}' size differs from remote multipart file. Downloading.") + return True else: - logger.info(f"File '{local_file_path}' differs from remote. Downloading.") - return True + hash_md5 = hashlib.md5() + with open(local_file_path, "rb") as f: + for chunk in iter(lambda: f.read(8192), b""): + hash_md5.update(chunk) + local_md5 = hash_md5.hexdigest() + if etag == local_md5: + logger.info(f"File '{local_file_path}' already up-to-date. Skipping download.") + return False + else: + logger.info(f"File '{local_file_path}' differs from remote. Downloading.") + return True else: logger.info(f"File '{local_file_path}' does not exist locally. Downloading.") - return True \ No newline at end of file + return True diff --git a/pdelfin/version.py b/pdelfin/version.py index 4bde64c..a9ee47d 100644 --- a/pdelfin/version.py +++ b/pdelfin/version.py @@ -2,7 +2,7 @@ _MINOR = "1" # On main and in a nightly release the patch should be one ahead of the last # released build. -_PATCH = "49" +_PATCH = "50" # This is mainly for nightly builds which have the suffix ".dev$DATE". See # https://semver.org/#is-v123-a-semantic-version for the semantics. _SUFFIX = ""