Skip to content

Commit 53d7843

Browse files
committed
test_CombineLayer_match_unknown_derived
1 parent 00c5046 commit 53d7843

File tree

1 file changed

+18
-0
lines changed

1 file changed

+18
-0
lines changed

tests/test_TFNetworkLayer.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -718,6 +718,24 @@ def test_CombineLayer_match_unknown():
718718
assert out.dim_tags[:2] == dat2.dim_tags[:2] and out.batch_shape == dat2.batch_shape
719719

720720

721+
def test_CombineLayer_match_unknown_derived():
722+
with make_scope() as session:
723+
dat1 = Data(name="undefined", shape=(None, 3))
724+
assert dat1.dim_tags[1].undefined
725+
dat1_derived_dim_tags = list(dat1.dim_tags)
726+
dat1_derived_dim_tags[1] = DimensionTag(
727+
kind=DimensionTag.Types.Spatial, description="undefined_derived_dim", derived_from_tag=dat1.dim_tags[1])
728+
dat1_derived = Data(name="undefined_derived", dim_tags=dat1_derived_dim_tags)
729+
assert dat1_derived.dim_tags[1].undefined
730+
# Create placeholders to have this dyn size clearly defined.
731+
dat2 = Data(name="defined", shape=(None, 3), auto_create_placeholders=True)
732+
net = TFNetwork(extern_data=ExternData())
733+
layer1 = InternalLayer(name="layer1_undefined_derived", network=net, output=dat1_derived)
734+
layer2 = InternalLayer(name="layer2_defined", network=net, output=dat2)
735+
out = CombineLayer.get_out_data_from_opts(name="combine", network=net, sources=[layer1, layer2])
736+
assert out.dim_tags[:2] == dat2.dim_tags[:2] and out.batch_shape == dat2.batch_shape
737+
738+
721739
def test_CombineLayer_different_batch_axis():
722740
# ["base:enc_ctx", "weight_feedback", "s_transformed"]
723741
# base:enc_ctx: Data(name='enc_ctx_output', shape=(None, 14), batch_dim_axis=1)

0 commit comments

Comments
 (0)