4
4
5
5
6
6
def make_grid (tensor , nrow = 8 , padding = 2 ,
7
- normalize = False , range = None , scale_each = False ):
7
+ normalize = False , range = None , scale_each = False , pad_value = 0 ):
8
8
"""
9
9
Given a 4D mini-batch Tensor of shape (B x C x H x W),
10
10
or a list of images all of the same size,
@@ -19,6 +19,8 @@ def make_grid(tensor, nrow=8, padding=2,
19
19
scale_each=True will scale each image in the batch of images separately rather than
20
20
computing the (min, max) over all images.
21
21
22
+ pad_value=<float> sets the value for the padded pixels.
23
+
22
24
[Example usage is given in this notebook](https://gist.github.com/anonymous/bf16430f7750c023141c562f3e9f2a91)
23
25
"""
24
26
# if list of tensors, convert to a 4D mini-batch Tensor
@@ -65,7 +67,7 @@ def norm_range(t, range):
65
67
xmaps = min (nrow , nmaps )
66
68
ymaps = int (math .ceil (float (nmaps ) / xmaps ))
67
69
height , width = int (tensor .size (2 ) + padding ), int (tensor .size (3 ) + padding )
68
- grid = tensor .new (3 , height * ymaps + 1 + padding // 2 , width * xmaps + 1 + padding // 2 ).fill_ (0 )
70
+ grid = tensor .new (3 , height * ymaps + 1 + padding // 2 , width * xmaps + 1 + padding // 2 ).fill_ (pad_value )
69
71
k = 0
70
72
for y in irange (ymaps ):
71
73
for x in irange (xmaps ):
@@ -79,7 +81,7 @@ def norm_range(t, range):
79
81
80
82
81
83
def save_image (tensor , filename , nrow = 8 , padding = 2 ,
82
- normalize = False , range = None , scale_each = False ):
84
+ normalize = False , range = None , scale_each = False , pad_value = 0 ):
83
85
"""
84
86
Saves a given Tensor into an image file.
85
87
If given a mini-batch tensor, will save the tensor as a grid of images by calling `make_grid`.
@@ -88,7 +90,7 @@ def save_image(tensor, filename, nrow=8, padding=2,
88
90
"""
89
91
from PIL import Image
90
92
tensor = tensor .cpu ()
91
- grid = make_grid (tensor , nrow = nrow , padding = padding ,
93
+ grid = make_grid (tensor , nrow = nrow , padding = padding , pad_value = pad_value ,
92
94
normalize = normalize , range = range , scale_each = scale_each )
93
95
ndarr = grid .mul (255 ).clamp (0 , 255 ).byte ().permute (1 , 2 , 0 ).numpy ()
94
96
im = Image .fromarray (ndarr )
0 commit comments