Skip to content

Commit feba064

Browse files
committed
test_SwitchLayer_template_const_from fix
1 parent c905ca1 commit feba064

File tree

1 file changed

+8
-4
lines changed

1 file changed

+8
-4
lines changed

tests/test_TFNetworkLayer.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1786,17 +1786,21 @@ def test_SwitchLayer_masking():
17861786

17871787
def test_SwitchLayer_template_const_from():
17881788
net = TFNetwork(extern_data=ExternData())
1789+
batch_dim = DimensionTag(kind=DimensionTag.Types.Batch, description="batch")
1790+
time_dim = DimensionTag(kind=DimensionTag.Types.Spatial, description="time")
1791+
feat_dim = DimensionTag(kind=DimensionTag.Types.Feature, description="feature", dimension=2)
17891792
# [T]
17901793
condition = InternalLayer(network=net, name="condition", output=Data(
1791-
"condition_output", batch_dim_axis=None, time_dim_axis=0, feature_dim_axis=None, shape=(None,)))
1794+
"condition_output", time_dim_axis=0, feature_dim_axis=None, dim_tags=[time_dim]))
17921795
true_from = 42
17931796
# [B,F|2,T]
17941797
false_from = InternalLayer(network=net, name="false_from", output=Data(
1795-
"false_from_output", batch_dim_axis=0, time_dim_axis=2, feature_dim_axis=1, shape=(2, None), dim=2))
1798+
"false_from_output", batch_dim_axis=0, time_dim_axis=2, feature_dim_axis=1,
1799+
dim_tags=[batch_dim, feat_dim, time_dim]))
17961800

17971801
# should be [B,F|2,T]
1798-
switch = SwitchLayer.get_out_data_from_opts('switch', condition=condition, true_from=true_from,
1799-
false_from=false_from)
1802+
switch = SwitchLayer.get_out_data_from_opts(
1803+
'switch', condition=condition, true_from=true_from, false_from=false_from)
18001804
assert switch.batch_ndim == 3
18011805
assert switch.batch_dim_axis == 0 and switch.time_dim_axis == 2 and switch.feature_dim_axis == 1
18021806
assert switch.dim == 2

0 commit comments

Comments
 (0)