Skip to content

Commit 4c53dd6

Browse files
committed
Adds a helper to safely fetch a column from a RecordBatch and updates a few uses.
PiperOrigin-RevId: 397365531
1 parent 18a0d90 commit 4c53dd6

File tree

5 files changed

+52
-20
lines changed

5 files changed

+52
-20
lines changed

tensorflow_data_validation/arrow/arrow_util.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -342,3 +342,29 @@ def get_nest_level(array_type: pa.DataType) -> int:
342342
if pa.types.is_null(array_type):
343343
result += 1
344344
return result
345+
346+
347+
def get_column(record_batch: pa.RecordBatch,
348+
feature_name: types.FeatureName,
349+
missing_ok: bool = False) -> Optional[pa.Array]:
350+
"""Get a column by feature name.
351+
352+
Args:
353+
record_batch: A pa.RecordBatch.
354+
feature_name: The name of a feature (column) within record_batch.
355+
missing_ok: If True, returns None for missing feature names.
356+
357+
Returns:
358+
The column with the specified name, or None if missing_ok is true and
359+
a column with the specified name is missing, or more than one exist.
360+
361+
Raises:
362+
KeyError: If a column with the specified name is missing, or more than
363+
one exist, and missing_ok is False.
364+
"""
365+
idx = record_batch.schema.get_field_index(feature_name)
366+
if idx < 0:
367+
if missing_ok:
368+
return None
369+
raise KeyError('missing column %s' % feature_name)
370+
return record_batch.column(idx)

tensorflow_data_validation/arrow/arrow_util_test.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -340,6 +340,14 @@ def testFlattenNestedNonList(self):
340340
self.assertTrue(flattened.equals(pa.array([1, 2])))
341341
np.testing.assert_array_equal(parent_indices, [0, 1])
342342

343+
def testGetColumn(self):
344+
self.assertTrue(
345+
arrow_util.get_column(_INPUT_RECORD_BATCH,
346+
"f1").equals(pa.array([[1], [2, 3]])))
347+
self.assertIsNone(
348+
arrow_util.get_column(_INPUT_RECORD_BATCH, "xyz", missing_ok=True))
349+
with self.assertRaises(KeyError):
350+
arrow_util.get_column(_INPUT_RECORD_BATCH, "xyz")
343351

344352
if __name__ == "__main__":
345353
absltest.main()

tensorflow_data_validation/arrow/decoded_examples_to_arrow_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import numpy as np
2424
import pyarrow as pa
2525
import six
26+
from tensorflow_data_validation.arrow import arrow_util
2627
from tensorflow_data_validation.arrow import decoded_examples_to_arrow
2728

2829

@@ -189,8 +190,7 @@ def test_conversion(self, input_examples, expected_output):
189190
input_examples)
190191
self.assertLen(expected_output, record_batch.num_columns)
191192
for feature_name, expected_arrow_array in six.iteritems(expected_output):
192-
actual = record_batch.column(
193-
record_batch.schema.get_field_index(feature_name))
193+
actual = arrow_util.get_column(record_batch, feature_name)
194194
self.assertTrue(
195195
expected_arrow_array.equals(actual),
196196
"{} vs {}".format(expected_arrow_array, actual))

tensorflow_data_validation/statistics/stats_impl.py

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -339,15 +339,14 @@ def _filter_features(
339339
Returns:
340340
An Arrow RecordBatch containing only features on the allowlist.
341341
"""
342-
schema = record_batch.schema
343-
column_names = set(schema.names)
344342
columns_to_select = []
345343
column_names_to_select = []
346344
for feature_name in feature_allowlist:
347-
if feature_name in column_names:
348-
columns_to_select.append(
349-
record_batch.column(schema.get_field_index(feature_name)))
350-
column_names_to_select.append(feature_name)
345+
col = arrow_util.get_column(record_batch, feature_name, missing_ok=True)
346+
if col is None:
347+
continue
348+
columns_to_select.append(col)
349+
column_names_to_select.append(feature_name)
351350
return pa.RecordBatch.from_arrays(columns_to_select, column_names_to_select)
352351

353352

@@ -523,8 +522,7 @@ def add_input(self, accumulator: List[float],
523522
examples: pa.RecordBatch) -> List[float]:
524523
accumulator[0] += examples.num_rows
525524
if self._weight_feature:
526-
weights_column = examples.column(
527-
examples.schema.get_field_index(self._weight_feature))
525+
weights_column = arrow_util.get_column(examples, self._weight_feature)
528526
accumulator[1] += np.sum(np.asarray(weights_column.flatten()))
529527
return accumulator
530528

@@ -787,13 +785,13 @@ def generate_partial_statistics_in_memory(
787785
"""
788786
result = []
789787
if options.feature_allowlist:
790-
schema = record_batch.schema
791-
columns = [
792-
record_batch.column(schema.get_field_index(f))
793-
for f in options.feature_allowlist
794-
]
795-
record_batch = pa.RecordBatch.from_arrays(columns,
796-
list(options.feature_allowlist))
788+
columns, features = [], []
789+
for feature_name in options.feature_allowlist:
790+
c = arrow_util.get_column(record_batch, feature_name, missing_ok=True)
791+
if c is not None:
792+
columns.append(c)
793+
features.append(feature_name)
794+
record_batch = pa.RecordBatch.from_arrays(columns, features)
797795
for generator in stats_generators:
798796
result.append(
799797
generator.add_input(generator.create_accumulator(), record_batch))

tensorflow_data_validation/utils/slicing_util.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -123,12 +123,12 @@ def feature_value_slicer(record_batch: pa.RecordBatch) -> Iterable[
123123
"""
124124
per_feature_parent_indices = []
125125
for feature_name, values in six.iteritems(features):
126-
idx = record_batch.schema.get_field_index(feature_name)
126+
feature_array = arrow_util.get_column(
127+
record_batch, feature_name, missing_ok=True)
127128
# If the feature name does not appear in the schema for this record batch,
128129
# drop it from the set of sliced features.
129-
if idx < 0:
130+
if feature_array is None:
130131
continue
131-
feature_array = record_batch.column(idx)
132132
flattened, value_parent_indices = arrow_util.flatten_nested(
133133
feature_array, True)
134134
non_missing_values = np.asarray(flattened)

0 commit comments

Comments
 (0)