Skip to content

Commit

Permalink
Add weights_only=False to torch.load() for pytorch 2.6+ compatibility
Browse files Browse the repository at this point in the history
  • Loading branch information
aiknownc committed Feb 15, 2025
1 parent e5fcc45 commit a5a9083
Show file tree
Hide file tree
Showing 6 changed files with 12 additions and 12 deletions.
2 changes: 1 addition & 1 deletion audiocraft/data/jasco_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def __init__(self,
'saliency_files': self.saliency_files,
'trk2idx': self.trk2idx}, f"{chroma_root}/cache.pkl")
else:
tmp = torch.load(f"{chroma_root}/cache.pkl")
tmp = torch.load(f"{chroma_root}/cache.pkl", weights_only=False)
self.tracks = tmp['tracks']
self.saliency_files = tmp['saliency_files']
self.trk2idx = tmp['trk2idx']
Expand Down
6 changes: 3 additions & 3 deletions audiocraft/models/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,11 @@ def _get_state_dict(
assert isinstance(file_or_url_or_id, str)

if os.path.isfile(file_or_url_or_id):
return torch.load(file_or_url_or_id, map_location=device)
return torch.load(file_or_url_or_id, map_location=device, weights_only=False)

if os.path.isdir(file_or_url_or_id):
file = f"{file_or_url_or_id}/{filename}"
return torch.load(file, map_location=device)
return torch.load(file, map_location=device, weights_only=False)

elif file_or_url_or_id.startswith('https://'):
return torch.hub.load_state_dict_from_url(file_or_url_or_id, map_location=device, check_hash=True)
Expand All @@ -68,7 +68,7 @@ def _get_state_dict(
library_name="audiocraft",
library_version=audiocraft.__version__,
)
return torch.load(file, map_location=device)
return torch.load(file, map_location=device, weights_only=False)


def load_compression_model_ckpt(file_or_url_or_id: tp.Union[Path, str], cache_dir: tp.Optional[str] = None):
Expand Down
4 changes: 2 additions & 2 deletions audiocraft/utils/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def _get_cache_path(self, path: tp.Union[Path, str]):
def _get_full_embed_from_cache(cache: Path):
"""Loads full pre-computed embedding from the cache."""
try:
embed = torch.load(cache, 'cpu')
embed = torch.load(cache, 'cpu', weights_only=False)
except Exception as exc:
logger.error("Error loading %s: %r", cache, exc)
embed = None
Expand Down Expand Up @@ -279,7 +279,7 @@ def _load_one(self, index: int):
items = items[start: start + self.batch_size]
assert len(items) == self.batch_size
entries = []
entries = [torch.load(item.open(mode), 'cpu') for item in items] # type: ignore
entries = [torch.load(item.open(mode, weights_only=False), 'cpu') for item in items] # type: ignore
transposed = zip(*entries)
out = []
for part in transposed:
Expand Down
2 changes: 1 addition & 1 deletion audiocraft/utils/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def load_checkpoint(checkpoint_path: Path, is_sharded: bool = False) -> tp.Any:
rank0_checkpoint_path = checkpoint_path.parent / checkpoint_name(use_fsdp=False)
if rank0_checkpoint_path.exists():
check_sharded_checkpoint(checkpoint_path, rank0_checkpoint_path)
state = torch.load(checkpoint_path, 'cpu')
state = torch.load(checkpoint_path, 'cpu', weights_only=False)
logger.info("Checkpoint loaded from %s", checkpoint_path)
return state

Expand Down
6 changes: 3 additions & 3 deletions audiocraft/utils/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def export_encodec(checkpoint_path: tp.Union[Path, str], out_file: tp.Union[Path
"""Export only the best state from the given EnCodec checkpoint. This
should be used if you trained your own EnCodec model.
"""
pkg = torch.load(checkpoint_path, 'cpu')
pkg = torch.load(checkpoint_path, 'cpu', weights_only=False)
new_pkg = {
'best_state': pkg['best_state']['model'],
'xp.cfg': OmegaConf.to_yaml(pkg['xp.cfg']),
Expand All @@ -43,7 +43,7 @@ def export_pretrained_compression_model(pretrained_encodec: str, out_file: tp.Un
to the model used.
"""
if Path(pretrained_encodec).exists():
pkg = torch.load(pretrained_encodec)
pkg = torch.load(pretrained_encodec, weights_only=False)
assert 'best_state' in pkg
assert 'xp.cfg' in pkg
assert 'version' in pkg
Expand All @@ -61,7 +61,7 @@ def export_pretrained_compression_model(pretrained_encodec: str, out_file: tp.Un
def export_lm(checkpoint_path: tp.Union[Path, str], out_file: tp.Union[Path, str]):
"""Export only the best state from the given MusicGen or AudioGen checkpoint.
"""
pkg = torch.load(checkpoint_path, 'cpu')
pkg = torch.load(checkpoint_path, 'cpu', weights_only=False)
if pkg['fsdp_best_state']:
best_state = pkg['fsdp_best_state']['model']
else:
Expand Down
4 changes: 2 additions & 2 deletions audiocraft/utils/export_legacy.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def _clean_lm_cfg(cfg: DictConfig):


def export_encodec(checkpoint_path: tp.Union[Path, str], out_file: tp.Union[Path, str]):
pkg = torch.load(checkpoint_path, 'cpu')
pkg = torch.load(checkpoint_path, 'cpu', weights_only=False)
new_pkg = {
'best_state': pkg['ema']['state']['model'],
'xp.cfg': OmegaConf.to_yaml(pkg['xp.cfg']),
Expand All @@ -53,7 +53,7 @@ def export_encodec(checkpoint_path: tp.Union[Path, str], out_file: tp.Union[Path


def export_lm(checkpoint_path: tp.Union[Path, str], out_file: tp.Union[Path, str]):
pkg = torch.load(checkpoint_path, 'cpu')
pkg = torch.load(checkpoint_path, 'cpu', weights_only=False)
if pkg['fsdp_best_state']:
best_state = pkg['fsdp_best_state']['model']
else:
Expand Down

0 comments on commit a5a9083

Please # to comment.