-
Notifications
You must be signed in to change notification settings - Fork 41
Expand file tree
/
Copy pathtest.py
More file actions
executable file
·30 lines (26 loc) · 801 Bytes
/
test.py
File metadata and controls
executable file
·30 lines (26 loc) · 801 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import numpy as np
import cv2
import torch
import data
from options.test_options import TestOptions
import models
opt = TestOptions().parse()
dataloader = data.create_dataloader(opt)
model = models.create_model(opt)
model.eval()
# test
num = 0
psnr_total = 0
for i, data_i in enumerate(dataloader):
if i * opt.batchSize >= opt.how_many:
break
with torch.no_grad():
generated,_ = model(data_i, mode='inference')
generated = torch.clamp(generated, -1, 1)
generated = (generated+1)/2*255
generated = generated.cpu().numpy().astype(np.uint8)
img_path = data_i['path']
for b in range(generated.shape[0]):
pred_im = generated[b].transpose((1,2,0))
print('process image... %s' % img_path[b])
cv2.imwrite(img_path[b], pred_im[:,:,::-1])