Skip to content

Commit 9d95661

Browse files
committed
update cov2d if previous layer channel is ?
1 parent 6ebf7cd commit 9d95661

File tree

1 file changed

+6
-1
lines changed

1 file changed

+6
-1
lines changed

tensorlayer/layers.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1917,9 +1917,14 @@ def Conv2d(net, n_filter=32, filter_size=(3, 3), strides=(1, 1), act = None,
19171917
assert len(strides) == 2, "len(strides) should be 2, Conv2d and Conv2dLayer are different."
19181918
if act is None:
19191919
act = tf.identity
1920+
1921+
try:
1922+
pre_channel = int(net.outputs.get_shape()[-1])
1923+
except: # if pre_channel is ?, it happens when using Spatial Transformer Net
1924+
pre_channel = 1
19201925
net = Conv2dLayer(net,
19211926
act = act,
1922-
shape = [filter_size[0], filter_size[1], int(net.outputs.get_shape()[-1]), n_filter], # 32 features for each 5x5 patch
1927+
shape = [filter_size[0], filter_size[1], pre_channel, n_filter], # 32 features for each 5x5 patch
19231928
strides = [1, strides[0], strides[1], 1],
19241929
padding = padding,
19251930
W_init = W_init,

0 commit comments

Comments
 (0)