@@ -718,6 +718,24 @@ def test_CombineLayer_match_unknown():
718718 assert out .dim_tags [:2 ] == dat2 .dim_tags [:2 ] and out .batch_shape == dat2 .batch_shape
719719
720720
721+ def test_CombineLayer_match_unknown_derived ():
722+ with make_scope () as session :
723+ dat1 = Data (name = "undefined" , shape = (None , 3 ))
724+ assert dat1 .dim_tags [1 ].undefined
725+ dat1_derived_dim_tags = list (dat1 .dim_tags )
726+ dat1_derived_dim_tags [1 ] = DimensionTag (
727+ kind = DimensionTag .Types .Spatial , description = "undefined_derived_dim" , derived_from_tag = dat1 .dim_tags [1 ])
728+ dat1_derived = Data (name = "undefined_derived" , dim_tags = dat1_derived_dim_tags )
729+ assert dat1_derived .dim_tags [1 ].undefined
730+ # Create placeholders to have this dyn size clearly defined.
731+ dat2 = Data (name = "defined" , shape = (None , 3 ), auto_create_placeholders = True )
732+ net = TFNetwork (extern_data = ExternData ())
733+ layer1 = InternalLayer (name = "layer1_undefined_derived" , network = net , output = dat1_derived )
734+ layer2 = InternalLayer (name = "layer2_defined" , network = net , output = dat2 )
735+ out = CombineLayer .get_out_data_from_opts (name = "combine" , network = net , sources = [layer1 , layer2 ])
736+ assert out .dim_tags [:2 ] == dat2 .dim_tags [:2 ] and out .batch_shape == dat2 .batch_shape
737+
738+
721739def test_CombineLayer_different_batch_axis ():
722740 # ["base:enc_ctx", "weight_feedback", "s_transformed"]
723741 # base:enc_ctx: Data(name='enc_ctx_output', shape=(None, 14), batch_dim_axis=1)
0 commit comments