Skip to content

Commit ce07620

Browse files
author
Youngmin Baek
committed
fix error in loading the model in cpu mode
1 parent 25c2055 commit ce07620

File tree

1 file changed

+15
-2
lines changed

1 file changed

+15
-2
lines changed

test.py

100755100644
Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,18 @@
2727

2828
from craft import CRAFT
2929

30+
from collections import OrderedDict
31+
def copyStateDict(state_dict):
32+
if list(state_dict.keys())[0].startswith("module"):
33+
start_idx = 1
34+
else:
35+
start_idx = 0
36+
new_state_dict = OrderedDict()
37+
for k, v in state_dict.items():
38+
name = ".".join(k.split(".")[start_idx:])
39+
new_state_dict[name] = v
40+
return new_state_dict
41+
3042
def str2bool(v):
3143
return v.lower() in ("yes", "y", "true", "t", "1")
3244

@@ -96,13 +108,14 @@ def test_net(net, image, text_threshold, link_threshold, low_text, cuda):
96108
# load net
97109
net = CRAFT() # initialize
98110

111+
print('Loading weights from checkpoint (' + args.trained_model + ')')
112+
net.load_state_dict(copyStateDict(torch.load(args.trained_model)))
113+
99114
if args.cuda:
100115
net = net.cuda()
101116
net = torch.nn.DataParallel(net)
102117
cudnn.benchmark = False
103118

104-
print('Loading weights from checkpoint (' + args.trained_model + ')')
105-
net.load_state_dict(torch.load(args.trained_model))
106119
net.eval()
107120

108121
t = time.time()

0 commit comments

Comments
 (0)