Skip to content

Commit 18a0d90

Browse files
committed
Fixes a tfdv bug caused by slicing on a feature missing from a RecordBatch.
PiperOrigin-RevId: 397144053
1 parent 0e3c11b commit 18a0d90

File tree

3 files changed

+44
-3
lines changed

3 files changed

+44
-3
lines changed

RELEASE.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
large numbers of examples.
1313
* Depends on
1414
`tensorflow>=1.15.2,!=2.0.*,!=2.1.*,!=2.2.*,!=2.3.*,!=2.4.*,!=2.5.*,<3`.
15+
* Fixed a bug wherein slicing on a feature missing from some batches could
16+
produce slice keys derived from a different feature.
1517

1618
## Known Issues
1719

tensorflow_data_validation/utils/slicing_util.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -123,8 +123,12 @@ def feature_value_slicer(record_batch: pa.RecordBatch) -> Iterable[
123123
"""
124124
per_feature_parent_indices = []
125125
for feature_name, values in six.iteritems(features):
126-
feature_array = record_batch.column(
127-
record_batch.schema.get_field_index(feature_name))
126+
idx = record_batch.schema.get_field_index(feature_name)
127+
# If the feature name does not appear in the schema for this record batch,
128+
# drop it from the set of sliced features.
129+
if idx < 0:
130+
continue
131+
feature_array = record_batch.column(idx)
128132
flattened, value_parent_indices = arrow_util.flatten_nested(
129133
feature_array, True)
130134
non_missing_values = np.asarray(flattened)
@@ -138,7 +142,10 @@ def feature_value_slicer(record_batch: pa.RecordBatch) -> Iterable[
138142
if values is not None:
139143
df = df.loc[df[feature_name].isin(values)]
140144
per_feature_parent_indices.append(df)
141-
145+
# If there are no features to slice on, yield no output.
146+
# TODO(b/200081813): Produce output with an appropriate placeholder key.
147+
if not per_feature_parent_indices:
148+
return
142149
# Join dataframes based on parent indices.
143150
# Note that we want the parent indices per slice key to be sorted in the
144151
# merged dataframe. The individual dataframes have the parent indices in

tensorflow_data_validation/utils/slicing_util_test.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ class SlicingUtilTest(absltest.TestCase):
3131
def _check_results(self, got, expected):
3232
got_dict = {g[0]: g[1] for g in got}
3333
expected_dict = {e[0]: e[1] for e in expected}
34+
3435
self.assertCountEqual(got_dict.keys(), expected_dict.keys())
3536
for k, got_record_batch in got_dict.items():
3637
expected_record_batch = expected_dict[k]
@@ -80,6 +81,25 @@ def test_get_feature_value_slicer(self):
8081
slicing_util.get_feature_value_slicer(features)(input_record_batch),
8182
expected_result)
8283

84+
def test_get_feature_value_slicer_one_feature_not_in_batch(self):
85+
features = {'not_an_actual_feature': None, 'a': None}
86+
input_record_batch = pa.RecordBatch.from_arrays([
87+
pa.array([[1], [2, 1]]),
88+
pa.array([['dog'], ['cat']]),
89+
], ['a', 'b'])
90+
expected_result = [
91+
(u'a_1',
92+
pa.RecordBatch.from_arrays(
93+
[pa.array([[1], [2, 1]]),
94+
pa.array([['dog'], ['cat']])], ['a', 'b'])),
95+
(u'a_2',
96+
pa.RecordBatch.from_arrays(
97+
[pa.array([[2, 1]]), pa.array([['cat']])], ['a', 'b'])),
98+
]
99+
self._check_results(
100+
slicing_util.get_feature_value_slicer(features)(input_record_batch),
101+
expected_result)
102+
83103
def test_get_feature_value_slicer_single_feature(self):
84104
features = {'a': [2]}
85105
input_record_batch = pa.RecordBatch.from_arrays([
@@ -118,6 +138,18 @@ def test_get_feature_value_slicer_feature_not_in_record_batch(self):
118138
slicing_util.get_feature_value_slicer(features)(input_record_batch),
119139
expected_result)
120140

141+
def test_get_feature_value_slicer_feature_not_in_record_batch_all_values(
142+
self):
143+
features = {'c': None}
144+
input_record_batch = pa.RecordBatch.from_arrays([
145+
pa.array([[1], [2, 1]]),
146+
pa.array([['dog'], ['cat']]),
147+
], ['a', 'b'])
148+
expected_result = []
149+
self._check_results(
150+
slicing_util.get_feature_value_slicer(features)(input_record_batch),
151+
expected_result)
152+
121153
def test_get_feature_value_slicer_bytes_feature_valid_utf8(self):
122154
features = {'b': None}
123155
input_record_batch = pa.RecordBatch.from_arrays([

0 commit comments

Comments
 (0)