Skip to content

Commit bdaf19c

Browse files
committed
test_ConvLayer_time_dim_out
1 parent 347a7e4 commit bdaf19c

File tree

1 file changed

+29
-0
lines changed

1 file changed

+29
-0
lines changed

tests/test_TFNetworkLayer.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2851,6 +2851,35 @@ def test_ConvLayer_feature_dim_unspecified():
28512851
assert out.output.feature_dim_axis_or_unspecified is NotSpecified
28522852

28532853

2854+
def test_ConvLayer_time_dim_out():
2855+
config = Config({"extern_data": {"data": {"dim": 7}}})
2856+
with make_scope() as session:
2857+
net = TFNetwork(config=config)
2858+
in_layer = SourceLayer(name="input", network=net, data_key="data", output=net.extern_data.get_default_input_data())
2859+
layer_desc = {
2860+
'name': "conv",
2861+
"network": net,
2862+
"sources": [in_layer],
2863+
'filter_size': (4,),
2864+
'strides': 10,
2865+
'padding': 'valid',
2866+
'n_out': 64,
2867+
'activation': 'abs'}
2868+
conv_out = ConvLayer.get_out_data_from_opts(**layer_desc)
2869+
print("conv out:", conv_out)
2870+
out_time = conv_out.get_time_dim_tag()
2871+
assert in_layer.output.dim_tags[1].is_spatial_dim()
2872+
assert out_time != in_layer.output.dim_tags[1]
2873+
layer_desc["output"] = conv_out
2874+
with ConvLayer.cls_layer_scope("conv"):
2875+
conv_layer = ConvLayer(**layer_desc)
2876+
net.layers["conv"] = conv_layer
2877+
net.initialize_params(session)
2878+
session.run(
2879+
(conv_layer.output.placeholder, conv_layer.output.get_sequence_lengths()),
2880+
feed_dict=make_feed_dict(net.extern_data))
2881+
2882+
28542883
def test_conv_layer_NCHW():
28552884
with make_scope() as session:
28562885
import numpy as np

0 commit comments

Comments
 (0)