@@ -1627,6 +1627,7 @@ def _reduce_batch_minus_min_and_max(x):
16271627 values = [3 , 2 , - 1 , 3 ],
16281628 dense_shape = [4 , 5 ]),
16291629 key = ['a' , 'a' , 'a' , 'b' ],
1630+ reduce_instance_dims = True ,
16301631 expected_key_vocab = [b'a' , b'b' ],
16311632 expected_x_minus_min = [1 , - 3 ],
16321633 expected_x_max = [3 , 3 ],
@@ -1638,25 +1639,52 @@ def _reduce_batch_minus_min_and_max(x):
16381639 testcase_name = 'float' ,
16391640 x = [[1 ], [5 ], [2 ], [3 ]],
16401641 key = ['a' , 'a' , 'a' , 'b' ],
1642+ reduce_instance_dims = True ,
16411643 expected_key_vocab = [b'a' , b'b' ],
16421644 expected_x_minus_min = [- 1 , - 3 ],
16431645 expected_x_max = [5 , 3 ],
16441646 input_signature = [
16451647 tf .TensorSpec ([None , None ], tf .float32 ),
16461648 tf .TensorSpec ([None ], tf .string )
16471649 ]),
1650+ dict (
1651+ testcase_name = 'float_elementwise' ,
1652+ x = [[1 ], [5 ], [2 ], [3 ]],
1653+ key = ['a' , 'a' , 'a' , 'b' ],
1654+ reduce_instance_dims = False ,
1655+ expected_key_vocab = [b'a' , b'b' ],
1656+ expected_x_minus_min = [[- 1 ], [- 3 ]],
1657+ expected_x_max = [[5 ], [3 ]],
1658+ input_signature = [
1659+ tf .TensorSpec ([None , None ], tf .float32 ),
1660+ tf .TensorSpec ([None ], tf .string )
1661+ ]),
16481662 dict (
16491663 testcase_name = 'float3dims' ,
16501664 x = [[[1 , 5 ], [1 , 1 ]], [[5 , 1 ], [5 , 5 ]], [[2 , 2 ], [2 , 5 ]],
16511665 [[3 , - 3 ], [3 , 3 ]]],
16521666 key = ['a' , 'a' , 'a' , 'b' ],
1667+ reduce_instance_dims = True ,
16531668 expected_key_vocab = [b'a' , b'b' ],
16541669 expected_x_minus_min = [- 1 , 3 ],
16551670 expected_x_max = [5 , 3 ],
16561671 input_signature = [
16571672 tf .TensorSpec ([None , None , None ], tf .float32 ),
16581673 tf .TensorSpec ([None ], tf .string )
16591674 ]),
1675+ dict (
1676+ testcase_name = 'float3dims_elementwise' ,
1677+ x = [[[1 , 5 ], [1 , 1 ]], [[5 , 1 ], [5 , 5 ]], [[2 , 2 ], [2 , 5 ]],
1678+ [[3 , - 3 ], [3 , 3 ]]],
1679+ key = ['a' , 'a' , 'a' , 'b' ],
1680+ reduce_instance_dims = False ,
1681+ expected_key_vocab = [b'a' , b'b' ],
1682+ expected_x_minus_min = [[[- 1 , - 1 ], [- 1 , - 1 ]], [[- 3 , 3 ], [- 3 , - 3 ]]],
1683+ expected_x_max = [[[5 , 5 ], [5 , 5 ]], [[3 , - 3 ], [3 , 3 ]]],
1684+ input_signature = [
1685+ tf .TensorSpec ([None , None , None ], tf .float32 ),
1686+ tf .TensorSpec ([None ], tf .string )
1687+ ]),
16601688 dict (
16611689 testcase_name = 'ragged' ,
16621690 x = tf .compat .v1 .ragged .RaggedTensorValue (
@@ -1673,6 +1701,7 @@ def _reduce_batch_minus_min_and_max(x):
16731701 row_splits = np .array ([0 , 2 , 3 , 4 , 5 ])),
16741702 row_splits = np .array ([0 , 2 , 3 , 4 ])),
16751703 row_splits = np .array ([0 , 2 , 3 ])),
1704+ reduce_instance_dims = True ,
16761705 expected_key_vocab = [b'a' , b'b' ],
16771706 expected_x_minus_min = [- 2. , - 3. ],
16781707 expected_x_max = [4. , 5. ],
@@ -1682,12 +1711,13 @@ def _reduce_batch_minus_min_and_max(x):
16821711 ]),
16831712 ]))
16841713 def test_reduce_batch_minus_min_and_max_per_key (
1685- self , x , key , expected_key_vocab , expected_x_minus_min , expected_x_max ,
1686- input_signature , function_handler ):
1714+ self , x , key , reduce_instance_dims , expected_key_vocab ,
1715+ expected_x_minus_min , expected_x_max , input_signature , function_handler ):
16871716
16881717 @function_handler (input_signature = input_signature )
16891718 def _reduce_batch_minus_min_and_max_per_key (x , key ):
1690- return tf_utils .reduce_batch_minus_min_and_max_per_key (x , key )
1719+ return tf_utils .reduce_batch_minus_min_and_max_per_key (
1720+ x , key , reduce_instance_dims = reduce_instance_dims )
16911721
16921722 key_vocab , x_minus_min , x_max = _reduce_batch_minus_min_and_max_per_key (
16931723 x , key )
0 commit comments