Skip to content

Commit

Permalink
Add rm_pt.py helper script for removing checkpoint files
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: facebookresearch/fairseq#681

Differential Revision: D15147107

fbshipit-source-id: 4452c98059586a4d748868a7659329285a76d5ef
  • Loading branch information
Myle Ott authored and yzpang committed Feb 19, 2021
1 parent 2ac15cc commit 9cee2b2
Show file tree
Hide file tree
Showing 5 changed files with 139 additions and 3 deletions.
6 changes: 6 additions & 0 deletions scripts/average_checkpoints.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,10 @@
#!/usr/bin/env python3
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.

import argparse
import collections
Expand Down
2 changes: 0 additions & 2 deletions scripts/build_sym_alignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#

"""
Use this script in order to build symmetric alignments for your translation
dataset.
Expand Down
1 change: 0 additions & 1 deletion scripts/read_binarized.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#

import argparse

Expand Down
133 changes: 133 additions & 0 deletions scripts/rm_pt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
#!/usr/bin/env python
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.

import argparse
import os
import re
import shutil
import sys


pt_regexp = re.compile(r'checkpoint(\d+|_\d+_\d+|_[a-z]+)\.pt')
pt_regexp_epoch_based = re.compile(r'checkpoint(\d+)\.pt')
pt_regexp_update_based = re.compile(r'checkpoint_\d+_(\d+)\.pt')


def last_n_checkpoints(files, n):
entries = []
for f in files:
m = pt_regexp_epoch_based.fullmatch(f)
if m is not None:
entries.append((int(m.group(1)), m.group(0)))
else:
m = pt_regexp_update_based.fullmatch(f)
if m is not None:
entries.append((int(m.group(1)), m.group(0)))
return [x[1] for x in sorted(entries, reverse=True)[:n]]


def every_n_checkpoints(files, n):
entries = []
for f in files:
m = pt_regexp_epoch_based.fullmatch(f)
if m is not None:
entries.append((int(m.group(1)), m.group(0)))
else:
m = pt_regexp_update_based.fullmatch(f)
if m is not None:
entries.append((int(m.group(1)), m.group(0)))
return [x[1] for x in sorted(entries)[n-1::n]]


def main():
parser = argparse.ArgumentParser(
description=(
'Recursively delete checkpoint files from `root_dir`, '
'but preserve checkpoint_best.pt and checkpoint_last.pt'
)
)
parser.add_argument('root_dirs', nargs='*')
parser.add_argument('--save-last', type=int, default=0, help='number of last checkpoints to save')
parser.add_argument('--save-every', type=int, default=0, help='interval of checkpoints to save')
parser.add_argument('--preserve-test', action='store_true',
help='preserve checkpoints in dirs that start with test_ prefix (default: delete them)')
parser.add_argument('--delete-best', action='store_true', help='delete checkpoint_best.pt')
parser.add_argument('--delete-last', action='store_true', help='delete checkpoint_last.pt')
parser.add_argument('--no-dereference', action='store_true', help='don\'t dereference symlinks')
args = parser.parse_args()

files_to_desymlink = []
files_to_preserve = []
files_to_delete = []
for root_dir in args.root_dirs:
for root, _subdirs, files in os.walk(root_dir):
if args.save_last > 0:
to_save = set(last_n_checkpoints(files, args.save_last))
else:
to_save = set()
if args.save_every > 0:
to_save = to_save | set(every_n_checkpoints(files, args.save_every))
for file in files:
if not pt_regexp.fullmatch(file):
continue
full_path = os.path.join(root, file)
if (
(
not os.path.basename(root).startswith('test_')
or args.preserve_test
)
and (
(file == 'checkpoint_last.pt' and not args.delete_last)
or (file == 'checkpoint_best.pt' and not args.delete_best)
or file in to_save
)
):
if os.path.islink(full_path) and not args.no_dereference:
files_to_desymlink.append(full_path)
else:
files_to_preserve.append(full_path)
else:
files_to_delete.append(full_path)

if len(files_to_desymlink) == 0 and len(files_to_delete) == 0:
print('Nothing to do.')
sys.exit(0)

print('Operations to perform (in order):')
if len(files_to_desymlink) > 0:
for file in files_to_desymlink:
print(' - preserve (and dereference symlink): ' + file)
if len(files_to_preserve) > 0:
for file in files_to_preserve:
print(' - preserve: ' + file)
if len(files_to_delete) > 0:
for file in files_to_delete:
print(' - delete: ' + file)
while True:
resp = input('Continue? (Y/N): ')
if resp.strip().lower() == 'y':
break
elif resp.strip().lower() == 'n':
sys.exit(0)

print('Executing...')
if len(files_to_desymlink) > 0:
for file in files_to_desymlink:
realpath = os.path.realpath(file)
print('rm ' + file)
os.remove(file)
print('cp {} {}'.format(realpath, file))
shutil.copyfile(realpath, file)
if len(files_to_delete) > 0:
for file in files_to_delete:
print('rm ' + file)
os.remove(file)


if __name__ == '__main__':
main()
Empty file modified scripts/sacrebleu_pregen.sh
100755 → 100644
Empty file.

0 comments on commit 9cee2b2

Please # to comment.