@@ -1786,17 +1786,21 @@ def test_SwitchLayer_masking():
17861786
17871787def test_SwitchLayer_template_const_from ():
17881788 net = TFNetwork (extern_data = ExternData ())
1789+ batch_dim = DimensionTag (kind = DimensionTag .Types .Batch , description = "batch" )
1790+ time_dim = DimensionTag (kind = DimensionTag .Types .Spatial , description = "time" )
1791+ feat_dim = DimensionTag (kind = DimensionTag .Types .Feature , description = "feature" , dimension = 2 )
17891792 # [T]
17901793 condition = InternalLayer (network = net , name = "condition" , output = Data (
1791- "condition_output" , batch_dim_axis = None , time_dim_axis = 0 , feature_dim_axis = None , shape = ( None ,) ))
1794+ "condition_output" , time_dim_axis = 0 , feature_dim_axis = None , dim_tags = [ time_dim ] ))
17921795 true_from = 42
17931796 # [B,F|2,T]
17941797 false_from = InternalLayer (network = net , name = "false_from" , output = Data (
1795- "false_from_output" , batch_dim_axis = 0 , time_dim_axis = 2 , feature_dim_axis = 1 , shape = (2 , None ), dim = 2 ))
1798+ "false_from_output" , batch_dim_axis = 0 , time_dim_axis = 2 , feature_dim_axis = 1 ,
1799+ dim_tags = [batch_dim , feat_dim , time_dim ]))
17961800
17971801 # should be [B,F|2,T]
1798- switch = SwitchLayer .get_out_data_from_opts ('switch' , condition = condition , true_from = true_from ,
1799- false_from = false_from )
1802+ switch = SwitchLayer .get_out_data_from_opts (
1803+ 'switch' , condition = condition , true_from = true_from , false_from = false_from )
18001804 assert switch .batch_ndim == 3
18011805 assert switch .batch_dim_axis == 0 and switch .time_dim_axis == 2 and switch .feature_dim_axis == 1
18021806 assert switch .dim == 2
0 commit comments