diff --git a/scripts/generate_model.py b/scripts/generate_model.py index 8ff08ad..291b150 100644 --- a/scripts/generate_model.py +++ b/scripts/generate_model.py @@ -35,8 +35,9 @@ def cli(): parser.add_argument( "--model-version", required=True, - help="Whisper model version string that matches Hugging Face model hub name, " - "e.g. openai/whisper-tiny.en", + help="Whisper model version string that can be either:\n" + "1. A Hugging Face model hub name (e.g. openai/whisper-tiny.en)\n" + "2. A local directory containing the model files" ) parser.add_argument( "--generate-quantized-variants", @@ -88,9 +89,7 @@ def cli(): args.test_model_version = args.model_version args.palettizer_tests = args.generate_quantized_variants args.context_prefill_tests = args.generate_decoder_context_prefill_data - args.persistent_cache_dir = os.path.join( - args.output_dir, args.model_version.replace("/", "_") - ) + args.persistent_cache_dir = args.output_dir if args.repo_path_suffix is not None: args.persistent_cache_dir += f"_{args.repo_path_suffix}" @@ -135,12 +134,16 @@ def upload_version(local_folder_path, model_version): # Dump required metadata before upload for filename in ["config.json", "generation_config.json"]: - with open(hf_hub_download(repo_id=model_version, - filename=filename), "r") as f: + if os.path.exists(model_version): # Local path + config_path = os.path.join(model_version, filename) + else: # HF hub path + config_path = hf_hub_download(repo_id=model_version, filename=filename) + + with open(config_path, "r") as f: model_file = json.load(f) with open(os.path.join(local_folder_path, filename), "w") as f: json.dump(model_file, f) - logger.info(f"Copied over {filename} from the original {model_version} repo") + logger.info(f"Copied over {filename} from the original model") # Get whisperkittools commit hash wkt_commit_hash = subprocess.run( @@ -262,3 +265,35 @@ def get_dir_size(root_dir): if not os.path.islink(path): size_in_mb += os.path.getsize(path) return size_in_mb / 1e6 + + +def load_whisper_model(model_path: str, torch_dtype=None): + """Load a Whisper model from either Hugging Face hub or local path + + Args: + model_path: Either a Hugging Face model ID or local directory path + torch_dtype: Optional torch dtype to load the model in + + Returns: + The loaded Whisper model + """ + from transformers import WhisperForConditionalGeneration + + try: + # First try loading as a local path + if os.path.exists(model_path): + return WhisperForConditionalGeneration.from_pretrained( + model_path, + local_files_only=True, + torch_dtype=torch_dtype + ) + # If not a valid path, try loading from HF hub + return WhisperForConditionalGeneration.from_pretrained( + model_path, + torch_dtype=torch_dtype + ) + except Exception as e: + raise ValueError( + f"Could not load model from '{model_path}'. " + "Make sure it is either a valid local path or Hugging Face model ID." + ) from e diff --git a/tests/test_audio_encoder.py b/tests/test_audio_encoder.py index 7add9af..d49a282 100644 --- a/tests/test_audio_encoder.py +++ b/tests/test_audio_encoder.py @@ -134,10 +134,19 @@ class TestWhisperMelSpectrogram( @classmethod def setUpClass(cls): - with open( - hf_hub_download(repo_id=TEST_WHISPER_VERSION, filename="config.json"), "r" - ) as f: - n_mels = json.load(f)["num_mel_bins"] + # Try loading config from local path first + config_path = os.path.join(TEST_WHISPER_VERSION, "config.json") + if os.path.exists(config_path): + logger.info(f"Loading config from local path: {config_path}") + with open(config_path, "r") as f: + n_mels = json.load(f)["num_mel_bins"] + else: + # Fall back to downloading from HF hub + logger.info(f"Loading config from Hugging Face hub: {TEST_WHISPER_VERSION}") + with open( + hf_hub_download(repo_id=TEST_WHISPER_VERSION, filename="config.json"), "r" + ) as f: + n_mels = json.load(f)["num_mel_bins"] logger.info( f"WhisperMelSpectrogram: n_mels={n_mels} for {TEST_WHISPER_VERSION}" diff --git a/tests/test_text_decoder.py b/tests/test_text_decoder.py index a83a8c4..d05aed9 100644 --- a/tests/test_text_decoder.py +++ b/tests/test_text_decoder.py @@ -14,7 +14,7 @@ from argmaxtools import test_utils as argmaxtools_test_utils from argmaxtools.utils import get_fastest_device, get_logger from tqdm import tqdm -from transformers import AutoTokenizer +from transformers import AutoTokenizer, WhisperForConditionalGeneration from transformers.models.whisper import modeling_whisper from whisperkit import test_utils, text_decoder @@ -44,6 +44,39 @@ TEST_TOKEN_TIMESTAMPS = True +def load_whisper_model(model_path: str, torch_dtype=None): + """Load a Whisper model from either Hugging Face hub or local path + + Args: + model_path: Either a Hugging Face model ID or local directory path + torch_dtype: Optional torch dtype to load the model in + + Returns: + The loaded Whisper model + """ + logger.info(f"Attempting to load model from: {model_path}") + try: + # First try loading as a local path + if os.path.exists(model_path): + logger.info(f"Loading model from local path: {model_path}") + return WhisperForConditionalGeneration.from_pretrained( + model_path, + local_files_only=True, + torch_dtype=torch_dtype + ) + # If not a valid path, try loading from HF hub + logger.info(f"Loading model from Hugging Face hub: {model_path}") + return WhisperForConditionalGeneration.from_pretrained( + model_path, + torch_dtype=torch_dtype + ) + except Exception as e: + raise ValueError( + f"Could not load model from '{model_path}'. " + "Make sure it is either a valid local path or Hugging Face model ID." + ) from e + + class TestWhisperTextDecoder(argmaxtools_test_utils.CoreMLTestsMixin, unittest.TestCase): @classmethod def setUpClass(cls): @@ -55,12 +88,7 @@ def setUpClass(cls): cls.test_output_names.pop(cls.test_output_names.index("alignment_heads_weights")) # Original model - orig_torch_model = ( - modeling_whisper.WhisperForConditionalGeneration.from_pretrained( - TEST_WHISPER_VERSION, - torch_dtype=TEST_TORCH_DTYPE, - ) - ) + orig_torch_model = load_whisper_model(TEST_WHISPER_VERSION, TEST_TORCH_DTYPE) cls.orig_torch_model = ( orig_torch_model.model.decoder.to(TEST_DEV).to(TEST_TORCH_DTYPE).eval() )