Skip to content

Commit 347a7e4

Browse files
committed
test_CombineLayer_match_unknown
1 parent 42cd1bf commit 347a7e4

File tree

1 file changed

+12
-0
lines changed

1 file changed

+12
-0
lines changed

tests/test_TFNetworkLayer.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -705,6 +705,18 @@ def test_CombineLayer_broadcast_multiple():
705705
assert out_v.shape == out.output.batch_shape
706706

707707

708+
def test_CombineLayer_match_unknown():
709+
with make_scope() as session:
710+
dat1 = Data(name="undefined", shape=(None, 3))
711+
# Create placeholders to have this dyn size clearly defined.
712+
dat2 = Data(name="defined", shape=(None, 3), auto_create_placeholders=True)
713+
net = TFNetwork(extern_data=ExternData())
714+
layer1 = InternalLayer(name="layer1_undefined", network=net, output=dat1)
715+
layer2 = InternalLayer(name="layer2_defined", network=net, output=dat2)
716+
out = CombineLayer.get_out_data_from_opts(name="combine", network=net, sources=[layer1, layer2])
717+
assert out.dim_tags[:2] == dat2.dim_tags[:2] and out.batch_shape == dat2.batch_shape
718+
719+
708720
def test_CombineLayer_different_batch_axis():
709721
# ["base:enc_ctx", "weight_feedback", "s_transformed"]
710722
# base:enc_ctx: Data(name='enc_ctx_output', shape=(None, 14), batch_dim_axis=1)

0 commit comments

Comments
 (0)