Skip to content

Commit

Permalink
support manifold in average_checkpoint.py
Browse files Browse the repository at this point in the history
Summary: use PathManager to support averaging checkpoints.

Reviewed By: myleott

Differential Revision: D20725346

fbshipit-source-id: 44b91f8652826da72c82087f8fbab7ae7d179423
  • Loading branch information
Weiyi Zheng authored and facebook-github-bot committed Mar 30, 2020
1 parent 90a8bda commit c8f26a1
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 9 deletions.
22 changes: 13 additions & 9 deletions scripts/average_checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
import os
import re

from fairseq.file_io import PathManager


def average_checkpoints(inputs):
"""Loads checkpoints from inputs and returns a model with averaged weights.
Expand All @@ -27,13 +29,14 @@ def average_checkpoints(inputs):
new_state = None
num_models = len(inputs)

for f in inputs:
state = torch.load(
f,
map_location=(
lambda s, _: torch.serialization.default_restore_location(s, 'cpu')
),
)
for fpath in inputs:
with PathManager.open(fpath, 'rb') as f:
state = torch.load(
f,
map_location=(
lambda s, _: torch.serialization.default_restore_location(s, 'cpu')
),
)
# Copies over the settings from the first checkpoint
if new_state is None:
new_state = state
Expand Down Expand Up @@ -74,7 +77,7 @@ def last_n_checkpoints(paths, n, update_based, upper_bound=None):
pt_regexp = re.compile(r'checkpoint_\d+_(\d+)\.pt')
else:
pt_regexp = re.compile(r'checkpoint(\d+)\.pt')
files = os.listdir(path)
files = PathManager.ls(path)

entries = []
for f in files:
Expand Down Expand Up @@ -135,7 +138,8 @@ def main():
print('averaging checkpoints: ', args.inputs)

new_state = average_checkpoints(args.inputs)
torch.save(new_state, args.output)
with PathManager.open(args.output, 'wb') as f:
torch.save(new_state, f)
print('Finished writing averaged checkpoint to {}.'.format(args.output))


Expand Down
19 changes: 19 additions & 0 deletions scripts/fb_average_checkpoints.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from fairseq.average_checkpoints import main
from fairseq.file_io import PathManager

# support fb specific path mananger
try:
from fvcore.fb.manifold import ManifoldPathHandler
PathManager.register_handler(ManifoldPathHandler(max_parallel=16, timeout_sec=1800))
except Exception:
pass


if __name__ == '__main__':
main()

0 comments on commit c8f26a1

Please # to comment.