From d0cda4457b7271e2a4045f84120db23fc49dffe4 Mon Sep 17 00:00:00 2001 From: Andrew Lauder Date: Fri, 13 Dec 2024 13:36:28 -0800 Subject: [PATCH 1/3] start adding local model support to generate_model.py --- scripts/generate_model.py | 47 ++++++++++++++++++++++++++++++++++---- tests/test_text_decoder.py | 39 +++++++++++++++++++++++++------ 2 files changed, 74 insertions(+), 12 deletions(-) diff --git a/scripts/generate_model.py b/scripts/generate_model.py index 8ff08ad..2bfa4e3 100644 --- a/scripts/generate_model.py +++ b/scripts/generate_model.py @@ -35,8 +35,9 @@ def cli(): parser.add_argument( "--model-version", required=True, - help="Whisper model version string that matches Hugging Face model hub name, " - "e.g. openai/whisper-tiny.en", + help="Whisper model version string that can be either:\n" + "1. A Hugging Face model hub name (e.g. openai/whisper-tiny.en)\n" + "2. A local directory containing the model files" ) parser.add_argument( "--generate-quantized-variants", @@ -135,12 +136,16 @@ def upload_version(local_folder_path, model_version): # Dump required metadata before upload for filename in ["config.json", "generation_config.json"]: - with open(hf_hub_download(repo_id=model_version, - filename=filename), "r") as f: + if os.path.exists(model_version): # Local path + config_path = os.path.join(model_version, filename) + else: # HF hub path + config_path = hf_hub_download(repo_id=model_version, filename=filename) + + with open(config_path, "r") as f: model_file = json.load(f) with open(os.path.join(local_folder_path, filename), "w") as f: json.dump(model_file, f) - logger.info(f"Copied over {filename} from the original {model_version} repo") + logger.info(f"Copied over {filename} from the original model") # Get whisperkittools commit hash wkt_commit_hash = subprocess.run( @@ -262,3 +267,35 @@ def get_dir_size(root_dir): if not os.path.islink(path): size_in_mb += os.path.getsize(path) return size_in_mb / 1e6 + + +def load_whisper_model(model_path: str, torch_dtype=None): + """Load a Whisper model from either Hugging Face hub or local path + + Args: + model_path: Either a Hugging Face model ID or local directory path + torch_dtype: Optional torch dtype to load the model in + + Returns: + The loaded Whisper model + """ + from transformers import WhisperForConditionalGeneration + + try: + # First try loading as a local path + if os.path.exists(model_path): + return WhisperForConditionalGeneration.from_pretrained( + model_path, + local_files_only=True, + torch_dtype=torch_dtype + ) + # If not a valid path, try loading from HF hub + return WhisperForConditionalGeneration.from_pretrained( + model_path, + torch_dtype=torch_dtype + ) + except Exception as e: + raise ValueError( + f"Could not load model from '{model_path}'. " + "Make sure it is either a valid local path or Hugging Face model ID." + ) from e diff --git a/tests/test_text_decoder.py b/tests/test_text_decoder.py index a83a8c4..78af7b7 100644 --- a/tests/test_text_decoder.py +++ b/tests/test_text_decoder.py @@ -14,7 +14,7 @@ from argmaxtools import test_utils as argmaxtools_test_utils from argmaxtools.utils import get_fastest_device, get_logger from tqdm import tqdm -from transformers import AutoTokenizer +from transformers import AutoTokenizer, WhisperForConditionalGeneration from transformers.models.whisper import modeling_whisper from whisperkit import test_utils, text_decoder @@ -44,6 +44,36 @@ TEST_TOKEN_TIMESTAMPS = True +def load_whisper_model(model_path: str, torch_dtype=None): + """Load a Whisper model from either Hugging Face hub or local path + + Args: + model_path: Either a Hugging Face model ID or local directory path + torch_dtype: Optional torch dtype to load the model in + + Returns: + The loaded Whisper model + """ + try: + # First try loading as a local path + if os.path.exists(model_path): + return WhisperForConditionalGeneration.from_pretrained( + model_path, + local_files_only=True, + torch_dtype=torch_dtype + ) + # If not a valid path, try loading from HF hub + return WhisperForConditionalGeneration.from_pretrained( + model_path, + torch_dtype=torch_dtype + ) + except Exception as e: + raise ValueError( + f"Could not load model from '{model_path}'. " + "Make sure it is either a valid local path or Hugging Face model ID." + ) from e + + class TestWhisperTextDecoder(argmaxtools_test_utils.CoreMLTestsMixin, unittest.TestCase): @classmethod def setUpClass(cls): @@ -55,12 +85,7 @@ def setUpClass(cls): cls.test_output_names.pop(cls.test_output_names.index("alignment_heads_weights")) # Original model - orig_torch_model = ( - modeling_whisper.WhisperForConditionalGeneration.from_pretrained( - TEST_WHISPER_VERSION, - torch_dtype=TEST_TORCH_DTYPE, - ) - ) + orig_torch_model = load_whisper_model(TEST_WHISPER_VERSION, TEST_TORCH_DTYPE) cls.orig_torch_model = ( orig_torch_model.model.decoder.to(TEST_DEV).to(TEST_TORCH_DTYPE).eval() ) From 3529ebd8b7372da3c82b5080ffe7c22b5829d26c Mon Sep 17 00:00:00 2001 From: Andrew Lauder Date: Fri, 13 Dec 2024 14:58:08 -0800 Subject: [PATCH 2/3] local model generation --- scripts/generate_model.py | 4 +--- tests/test_audio_encoder.py | 17 +++++++++++++---- tests/test_text_decoder.py | 3 +++ 3 files changed, 17 insertions(+), 7 deletions(-) diff --git a/scripts/generate_model.py b/scripts/generate_model.py index 2bfa4e3..291b150 100644 --- a/scripts/generate_model.py +++ b/scripts/generate_model.py @@ -89,9 +89,7 @@ def cli(): args.test_model_version = args.model_version args.palettizer_tests = args.generate_quantized_variants args.context_prefill_tests = args.generate_decoder_context_prefill_data - args.persistent_cache_dir = os.path.join( - args.output_dir, args.model_version.replace("/", "_") - ) + args.persistent_cache_dir = args.output_dir if args.repo_path_suffix is not None: args.persistent_cache_dir += f"_{args.repo_path_suffix}" diff --git a/tests/test_audio_encoder.py b/tests/test_audio_encoder.py index 7add9af..d49a282 100644 --- a/tests/test_audio_encoder.py +++ b/tests/test_audio_encoder.py @@ -134,10 +134,19 @@ class TestWhisperMelSpectrogram( @classmethod def setUpClass(cls): - with open( - hf_hub_download(repo_id=TEST_WHISPER_VERSION, filename="config.json"), "r" - ) as f: - n_mels = json.load(f)["num_mel_bins"] + # Try loading config from local path first + config_path = os.path.join(TEST_WHISPER_VERSION, "config.json") + if os.path.exists(config_path): + logger.info(f"Loading config from local path: {config_path}") + with open(config_path, "r") as f: + n_mels = json.load(f)["num_mel_bins"] + else: + # Fall back to downloading from HF hub + logger.info(f"Loading config from Hugging Face hub: {TEST_WHISPER_VERSION}") + with open( + hf_hub_download(repo_id=TEST_WHISPER_VERSION, filename="config.json"), "r" + ) as f: + n_mels = json.load(f)["num_mel_bins"] logger.info( f"WhisperMelSpectrogram: n_mels={n_mels} for {TEST_WHISPER_VERSION}" diff --git a/tests/test_text_decoder.py b/tests/test_text_decoder.py index 78af7b7..d05aed9 100644 --- a/tests/test_text_decoder.py +++ b/tests/test_text_decoder.py @@ -54,15 +54,18 @@ def load_whisper_model(model_path: str, torch_dtype=None): Returns: The loaded Whisper model """ + logger.info(f"Attempting to load model from: {model_path}") try: # First try loading as a local path if os.path.exists(model_path): + logger.info(f"Loading model from local path: {model_path}") return WhisperForConditionalGeneration.from_pretrained( model_path, local_files_only=True, torch_dtype=torch_dtype ) # If not a valid path, try loading from HF hub + logger.info(f"Loading model from Hugging Face hub: {model_path}") return WhisperForConditionalGeneration.from_pretrained( model_path, torch_dtype=torch_dtype From 7138f5954fc5ae80a1f260000a6b3c051ba26beb Mon Sep 17 00:00:00 2001 From: Andrew Lauder Date: Tue, 31 Dec 2024 15:02:04 -0800 Subject: [PATCH 3/3] Remove unused/duplicate load_whisper_model method --- scripts/generate_model.py | 34 +--------------------------------- 1 file changed, 1 insertion(+), 33 deletions(-) diff --git a/scripts/generate_model.py b/scripts/generate_model.py index 291b150..d1832e6 100644 --- a/scripts/generate_model.py +++ b/scripts/generate_model.py @@ -264,36 +264,4 @@ def get_dir_size(root_dir): path = os.path.join(parent, f) if not os.path.islink(path): size_in_mb += os.path.getsize(path) - return size_in_mb / 1e6 - - -def load_whisper_model(model_path: str, torch_dtype=None): - """Load a Whisper model from either Hugging Face hub or local path - - Args: - model_path: Either a Hugging Face model ID or local directory path - torch_dtype: Optional torch dtype to load the model in - - Returns: - The loaded Whisper model - """ - from transformers import WhisperForConditionalGeneration - - try: - # First try loading as a local path - if os.path.exists(model_path): - return WhisperForConditionalGeneration.from_pretrained( - model_path, - local_files_only=True, - torch_dtype=torch_dtype - ) - # If not a valid path, try loading from HF hub - return WhisperForConditionalGeneration.from_pretrained( - model_path, - torch_dtype=torch_dtype - ) - except Exception as e: - raise ValueError( - f"Could not load model from '{model_path}'. " - "Make sure it is either a valid local path or Hugging Face model ID." - ) from e + return size_in_mb / 1e6 \ No newline at end of file