-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmisc.py
36 lines (24 loc) · 961 Bytes
/
misc.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import torch
import numpy as np
import os
from options import opt
from torch.autograd import Variable
def getLatestCheckpointName():
if os.path.exists(opt.checkpoints_dir):
file_names = os.listdir(opt.checkpoints_dir)
names_ext = [os.path.splitext(x) for x in file_names]
checkpoint_names_G = []
l = []
for i in range(len(names_ext)):
module = (names_ext[i][1] == '.pt' or names_ext[i][1] == '.pth') and str(names_ext[i][0]).split('_')
if module[0] == 'netG':
checkpoint_names_G.append(int(module[1]))
if len(checkpoint_names_G) == 0:
return None
g_index = max(checkpoint_names_G)
ckp_g = None
for i in file_names:
if int(str(i).split('_')[1].split('.')[0]) == g_index and str(i).split('_')[0] == 'netG':
ckp_g = i
break
return ckp_g