Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Generate CoreML models from local transformer models or HF repos #28

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 43 additions & 8 deletions scripts/generate_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,9 @@ def cli():
parser.add_argument(
"--model-version",
required=True,
help="Whisper model version string that matches Hugging Face model hub name, "
"e.g. openai/whisper-tiny.en",
help="Whisper model version string that can be either:\n"
"1. A Hugging Face model hub name (e.g. openai/whisper-tiny.en)\n"
"2. A local directory containing the model files"
)
parser.add_argument(
"--generate-quantized-variants",
Expand Down Expand Up @@ -88,9 +89,7 @@ def cli():
args.test_model_version = args.model_version
args.palettizer_tests = args.generate_quantized_variants
args.context_prefill_tests = args.generate_decoder_context_prefill_data
args.persistent_cache_dir = os.path.join(
args.output_dir, args.model_version.replace("/", "_")
)
args.persistent_cache_dir = args.output_dir
if args.repo_path_suffix is not None:
args.persistent_cache_dir += f"_{args.repo_path_suffix}"

Expand Down Expand Up @@ -135,12 +134,16 @@ def upload_version(local_folder_path, model_version):

# Dump required metadata before upload
for filename in ["config.json", "generation_config.json"]:
with open(hf_hub_download(repo_id=model_version,
filename=filename), "r") as f:
if os.path.exists(model_version): # Local path
config_path = os.path.join(model_version, filename)
else: # HF hub path
config_path = hf_hub_download(repo_id=model_version, filename=filename)

with open(config_path, "r") as f:
model_file = json.load(f)
with open(os.path.join(local_folder_path, filename), "w") as f:
json.dump(model_file, f)
logger.info(f"Copied over {filename} from the original {model_version} repo")
logger.info(f"Copied over {filename} from the original model")

# Get whisperkittools commit hash
wkt_commit_hash = subprocess.run(
Expand Down Expand Up @@ -262,3 +265,35 @@ def get_dir_size(root_dir):
if not os.path.islink(path):
size_in_mb += os.path.getsize(path)
return size_in_mb / 1e6


def load_whisper_model(model_path: str, torch_dtype=None):
"""Load a Whisper model from either Hugging Face hub or local path

Args:
model_path: Either a Hugging Face model ID or local directory path
torch_dtype: Optional torch dtype to load the model in

Returns:
The loaded Whisper model
"""
from transformers import WhisperForConditionalGeneration

try:
# First try loading as a local path
if os.path.exists(model_path):
return WhisperForConditionalGeneration.from_pretrained(
model_path,
local_files_only=True,
torch_dtype=torch_dtype
)
# If not a valid path, try loading from HF hub
return WhisperForConditionalGeneration.from_pretrained(
model_path,
torch_dtype=torch_dtype
)
except Exception as e:
raise ValueError(
f"Could not load model from '{model_path}'. "
"Make sure it is either a valid local path or Hugging Face model ID."
) from e
17 changes: 13 additions & 4 deletions tests/test_audio_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,10 +134,19 @@ class TestWhisperMelSpectrogram(

@classmethod
def setUpClass(cls):
with open(
hf_hub_download(repo_id=TEST_WHISPER_VERSION, filename="config.json"), "r"
) as f:
n_mels = json.load(f)["num_mel_bins"]
# Try loading config from local path first
config_path = os.path.join(TEST_WHISPER_VERSION, "config.json")
if os.path.exists(config_path):
logger.info(f"Loading config from local path: {config_path}")
with open(config_path, "r") as f:
n_mels = json.load(f)["num_mel_bins"]
else:
# Fall back to downloading from HF hub
logger.info(f"Loading config from Hugging Face hub: {TEST_WHISPER_VERSION}")
with open(
hf_hub_download(repo_id=TEST_WHISPER_VERSION, filename="config.json"), "r"
) as f:
n_mels = json.load(f)["num_mel_bins"]

logger.info(
f"WhisperMelSpectrogram: n_mels={n_mels} for {TEST_WHISPER_VERSION}"
Expand Down
42 changes: 35 additions & 7 deletions tests/test_text_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from argmaxtools import test_utils as argmaxtools_test_utils
from argmaxtools.utils import get_fastest_device, get_logger
from tqdm import tqdm
from transformers import AutoTokenizer
from transformers import AutoTokenizer, WhisperForConditionalGeneration
from transformers.models.whisper import modeling_whisper

from whisperkit import test_utils, text_decoder
Expand Down Expand Up @@ -44,6 +44,39 @@
TEST_TOKEN_TIMESTAMPS = True


def load_whisper_model(model_path: str, torch_dtype=None):
"""Load a Whisper model from either Hugging Face hub or local path

Args:
model_path: Either a Hugging Face model ID or local directory path
torch_dtype: Optional torch dtype to load the model in

Returns:
The loaded Whisper model
"""
logger.info(f"Attempting to load model from: {model_path}")
try:
# First try loading as a local path
if os.path.exists(model_path):
logger.info(f"Loading model from local path: {model_path}")
return WhisperForConditionalGeneration.from_pretrained(
model_path,
local_files_only=True,
torch_dtype=torch_dtype
)
# If not a valid path, try loading from HF hub
logger.info(f"Loading model from Hugging Face hub: {model_path}")
return WhisperForConditionalGeneration.from_pretrained(
model_path,
torch_dtype=torch_dtype
)
except Exception as e:
raise ValueError(
f"Could not load model from '{model_path}'. "
"Make sure it is either a valid local path or Hugging Face model ID."
) from e


class TestWhisperTextDecoder(argmaxtools_test_utils.CoreMLTestsMixin, unittest.TestCase):
@classmethod
def setUpClass(cls):
Expand All @@ -55,12 +88,7 @@ def setUpClass(cls):
cls.test_output_names.pop(cls.test_output_names.index("alignment_heads_weights"))

# Original model
orig_torch_model = (
modeling_whisper.WhisperForConditionalGeneration.from_pretrained(
TEST_WHISPER_VERSION,
torch_dtype=TEST_TORCH_DTYPE,
)
)
orig_torch_model = load_whisper_model(TEST_WHISPER_VERSION, TEST_TORCH_DTYPE)
cls.orig_torch_model = (
orig_torch_model.model.decoder.to(TEST_DEV).to(TEST_TORCH_DTYPE).eval()
)
Expand Down