-
Notifications
You must be signed in to change notification settings - Fork 5
Expand file tree
/
Copy pathutils.py
More file actions
51 lines (42 loc) · 1.39 KB
/
utils.py
File metadata and controls
51 lines (42 loc) · 1.39 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
''' Imports '''
import pickle
import torch
import os
# Retrieves all the model checkpoints
def get_checkpoints (path):
checkpoints = []
with open(path, 'rb') as f:
while True:
try:
checkpoints.append(pickle.load(f))
except BaseException:
return checkpoints
# Retrieves the last valid model checkpoint
def get_last_params (path, truncate=False):
checkpoints = get_checkpoints(path)
for i in range(len(checkpoints) - 1, -1, -1):
if 'g_params' in checkpoints[i]:
contains_nan = False
for key, value in checkpoints[i]['g_params'].items():
if torch.isnan(value).any():
contains_nan = True
break
if not contains_nan:
checkpoint = checkpoints[i]
break
checkpoints.pop(i)
if truncate:
os.remove(path)
with open(path, 'ab') as f:
for checkpoint in checkpoints:
pickle.dump(checkpoint, f)
return checkpoint
# Retrieves a specific checkpoint
def get_checkpoint (path, epoch, truncate=False):
if epoch == 'latest':
return get_last_params(path, truncate)
else:
checkpoints = get_checkpoints(path)
for checkpoint in checkpoints:
if checkpoint['epoch'] == int(epoch):
return checkpoint