Skip to content

Commit f2db115

Browse files
zsdonghaowagamamaz
authored andcommitted
update sampling layers (#901)
1 parent 6706764 commit f2db115

File tree

1 file changed

+36
-10
lines changed

1 file changed

+36
-10
lines changed

tensorlayer/layers/image_resampling.py

Lines changed: 36 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -59,19 +59,31 @@ def __init__(
5959

6060
if not isinstance(size, (list, tuple)) and len(size) == 2:
6161
raise AssertionError()
62-
6362
if len(self.inputs.get_shape()) == 3:
6463
if is_scale:
65-
size_h = size[0] * tf.shape(self.inputs)[0]
66-
size_w = size[1] * tf.shape(self.inputs)[1]
64+
input_shape = self.inputs.shape.as_list()
65+
if input_shape[0] is not None:
66+
size_h = size[0] * input_shape[0]
67+
else:
68+
size_h = size[0] * tf.shape(self.inputs)[0]
69+
if input_shape[1] is not None:
70+
size_w = size[1] * input_shape[1]
71+
else:
72+
size_w = size[1] * tf.shape(self.inputs)[1]
6773
size = [size_h, size_w]
6874

6975
elif len(self.inputs.get_shape()) == 4:
7076
if is_scale:
71-
size_h = size[0] * tf.shape(self.inputs)[1]
72-
size_w = size[1] * tf.shape(self.inputs)[2]
77+
input_shape = self.inputs.shape.as_list()
78+
if input_shape[1] is not None:
79+
size_h = size[0] * input_shape[1]
80+
else:
81+
size_h = size[0] * tf.shape(self.inputs)[1]
82+
if input_shape[2] is not None:
83+
size_w = size[1] * input_shape[2]
84+
else:
85+
size_w = size[1] * tf.shape(self.inputs)[2]
7386
size = [size_h, size_w]
74-
7587
else:
7688
raise Exception("Donot support shape %s" % tf.shape(self.inputs))
7789

@@ -135,14 +147,28 @@ def __init__(
135147

136148
if len(self.inputs.get_shape()) == 3:
137149
if is_scale:
138-
size_h = size[0] * tf.shape(self.inputs)[0]
139-
size_w = size[1] * tf.shape(self.inputs)[1]
150+
input_shape = self.inputs.shape.as_list()
151+
if input_shape[0] is not None:
152+
size_h = size[0] * input_shape[0]
153+
else:
154+
size_h = size[0] * tf.shape(self.inputs)[0]
155+
if input_shape[1] is not None:
156+
size_w = size[1] * input_shape[1]
157+
else:
158+
size_w = size[1] * tf.shape(self.inputs)[1]
140159
size = [size_h, size_w]
141160

142161
elif len(self.inputs.get_shape()) == 4:
143162
if is_scale:
144-
size_h = size[0] * tf.shape(self.inputs)[1]
145-
size_w = size[1] * tf.shape(self.inputs)[2]
163+
input_shape = self.inputs.shape.as_list()
164+
if input_shape[1] is not None:
165+
size_h = size[0] * input_shape[1]
166+
else:
167+
size_h = size[0] * tf.shape(self.inputs)[1]
168+
if input_shape[2] is not None:
169+
size_w = size[1] * input_shape[2]
170+
else:
171+
size_w = size[1] * tf.shape(self.inputs)[2]
146172
size = [size_h, size_w]
147173

148174
else:

0 commit comments

Comments
 (0)