Skip to content

Commit e51c1bd

Browse files
committed
update deconv2d fn for auto batchsize
1 parent 3df125c commit e51c1bd

File tree

1 file changed

+8
-2
lines changed

1 file changed

+8
-2
lines changed

tensorlayer/layers.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1959,8 +1959,14 @@ def DeConv2d(net, n_out_channel = 32, filter_size=(3, 3),
19591959
assert len(strides) == 2, "len(strides) should be 2, DeConv2d and DeConv2dLayer are different."
19601960
if act is None:
19611961
act = tf.identity
1962-
if batch_size is None:
1963-
batch_size = tf.shape(net.outputs)[0]
1962+
# if batch_size is None:
1963+
# batch_size = tf.shape(net.outputs)[0]
1964+
fixed_batch_size = net.outputs.get_shape().with_rank_at_least(1)[0]
1965+
if fixed_batch_size.value:
1966+
batch_size = fixed_batch_size.value
1967+
else:
1968+
from tensorflow.python.ops import array_ops
1969+
batch_size = array_ops.shape(net.outputs)[0]
19641970
net = DeConv2dLayer(layer = net,
19651971
act = act,
19661972
shape = [filter_size[0], filter_size[1], n_out_channel, int(net.outputs.get_shape()[-1])],

0 commit comments

Comments
 (0)