@@ -31,6 +31,7 @@ class SlicingUtilTest(absltest.TestCase):
3131 def _check_results (self , got , expected ):
3232 got_dict = {g [0 ]: g [1 ] for g in got }
3333 expected_dict = {e [0 ]: e [1 ] for e in expected }
34+
3435 self .assertCountEqual (got_dict .keys (), expected_dict .keys ())
3536 for k , got_record_batch in got_dict .items ():
3637 expected_record_batch = expected_dict [k ]
@@ -80,6 +81,25 @@ def test_get_feature_value_slicer(self):
8081 slicing_util .get_feature_value_slicer (features )(input_record_batch ),
8182 expected_result )
8283
84+ def test_get_feature_value_slicer_one_feature_not_in_batch (self ):
85+ features = {'not_an_actual_feature' : None , 'a' : None }
86+ input_record_batch = pa .RecordBatch .from_arrays ([
87+ pa .array ([[1 ], [2 , 1 ]]),
88+ pa .array ([['dog' ], ['cat' ]]),
89+ ], ['a' , 'b' ])
90+ expected_result = [
91+ (u'a_1' ,
92+ pa .RecordBatch .from_arrays (
93+ [pa .array ([[1 ], [2 , 1 ]]),
94+ pa .array ([['dog' ], ['cat' ]])], ['a' , 'b' ])),
95+ (u'a_2' ,
96+ pa .RecordBatch .from_arrays (
97+ [pa .array ([[2 , 1 ]]), pa .array ([['cat' ]])], ['a' , 'b' ])),
98+ ]
99+ self ._check_results (
100+ slicing_util .get_feature_value_slicer (features )(input_record_batch ),
101+ expected_result )
102+
83103 def test_get_feature_value_slicer_single_feature (self ):
84104 features = {'a' : [2 ]}
85105 input_record_batch = pa .RecordBatch .from_arrays ([
@@ -118,6 +138,18 @@ def test_get_feature_value_slicer_feature_not_in_record_batch(self):
118138 slicing_util .get_feature_value_slicer (features )(input_record_batch ),
119139 expected_result )
120140
141+ def test_get_feature_value_slicer_feature_not_in_record_batch_all_values (
142+ self ):
143+ features = {'c' : None }
144+ input_record_batch = pa .RecordBatch .from_arrays ([
145+ pa .array ([[1 ], [2 , 1 ]]),
146+ pa .array ([['dog' ], ['cat' ]]),
147+ ], ['a' , 'b' ])
148+ expected_result = []
149+ self ._check_results (
150+ slicing_util .get_feature_value_slicer (features )(input_record_batch ),
151+ expected_result )
152+
121153 def test_get_feature_value_slicer_bytes_feature_valid_utf8 (self ):
122154 features = {'b' : None }
123155 input_record_batch = pa .RecordBatch .from_arrays ([
0 commit comments