Skip to content

Commit d10ba5c

Browse files
brillstfx-copybara
authored andcommitted
Automated rollback of commit 8a28387
PiperOrigin-RevId: 380656197
1 parent 8a28387 commit d10ba5c

File tree

3 files changed

+140
-0
lines changed

3 files changed

+140
-0
lines changed

RELEASE.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66

77
## Bug Fixes and Other Changes
88

9+
* Optimized certain stats generators that needs to materialize the input
10+
RecordBatches.
911
* Depends on `protobuf>=3.13,<4`.
1012

1113
## Known Issues

tensorflow_data_validation/types.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,3 +136,40 @@ def __len__(self) -> int:
136136

137137
def __bool__(self) -> bool:
138138
return bool(self._steps)
139+
140+
141+
# Do not use multiple threads to encode record batches, as parallelism
142+
# should be managed by beam.
143+
_ARROW_CODER_IPC_OPTIONS = pa.ipc.IpcWriteOptions(use_threads=False)
144+
145+
146+
# TODO(b/190756453): Make this into the upstream
147+
# (preference: Arrow, Beam, tfx_bsl).
148+
class _ArrowRecordBatchCoder(beam.coders.Coder):
149+
"""Custom coder for Arrow record batches."""
150+
151+
def encode(self, value: pa.RecordBatch) -> bytes:
152+
sink = pa.BufferOutputStream()
153+
writer = pa.ipc.new_stream(
154+
sink, value.schema, options=_ARROW_CODER_IPC_OPTIONS)
155+
writer.write_batch(value)
156+
writer.close()
157+
return sink.getvalue().to_pybytes()
158+
159+
def decode(self, encoded: bytes) -> pa.RecordBatch:
160+
reader = pa.ipc.open_stream(encoded)
161+
result = reader.read_next_batch()
162+
try:
163+
reader.read_next_batch()
164+
except StopIteration:
165+
pass
166+
else:
167+
raise ValueError("Expected only one RecordBatch in the stream.")
168+
return result
169+
170+
def to_type_hint(self):
171+
return pa.RecordBatch
172+
173+
174+
beam.coders.typecoders.registry.register_coder(pa.RecordBatch,
175+
_ArrowRecordBatchCoder)
Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
# Copyright 2021 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Tests for types."""
16+
17+
from absl.testing import absltest
18+
import apache_beam as beam
19+
from apache_beam.testing import util
20+
import pyarrow as pa
21+
from tensorflow_data_validation import types # pylint: disable=unused-import
22+
23+
24+
def _make_record_batch(num_cols, num_rows):
25+
columns = [
26+
pa.array([[b"kk"]] * num_rows, type=pa.large_list(pa.large_binary()))
27+
for _ in range(num_cols)
28+
]
29+
column_names = ["col%d" % c for c in range(num_cols)]
30+
return pa.record_batch(columns, column_names)
31+
32+
33+
class _Tracker(object):
34+
"""A singleton to track whether _TrackedCoder.encode/decode is called."""
35+
36+
_instance = None
37+
38+
def reset(self):
39+
self.encode_called = False
40+
self.decode_called = False
41+
42+
def __new__(cls):
43+
if cls._instance is None:
44+
cls._instance = object.__new__(cls)
45+
cls._instance.reset()
46+
return cls._instance
47+
48+
49+
class _TrackedCoder(types._ArrowRecordBatchCoder):
50+
51+
def encode(self, value):
52+
_Tracker().encode_called = True
53+
return super().encode(value)
54+
55+
def decode(self, encoded):
56+
_Tracker().decode_called = True
57+
return super().decode(encoded)
58+
59+
60+
class TypesTest(absltest.TestCase):
61+
62+
def test_coder(self):
63+
rb = _make_record_batch(10, 10)
64+
coder = types._ArrowRecordBatchCoder()
65+
self.assertTrue(coder.decode(coder.encode(rb)).equals(rb))
66+
67+
def test_coder_end_to_end(self):
68+
# First check that the registration is done.
69+
self.assertIsInstance(
70+
beam.coders.typecoders.registry.get_coder(pa.RecordBatch),
71+
types._ArrowRecordBatchCoder)
72+
# Then replace the registered coder with our patched one to track whether
73+
# encode() / decode() are called.
74+
beam.coders.typecoders.registry.register_coder(pa.RecordBatch,
75+
_TrackedCoder)
76+
rb = _make_record_batch(1000, 1)
77+
def pipeline(root):
78+
sample = (
79+
root
80+
| beam.Create([rb] * 20)
81+
| beam.combiners.Sample.FixedSizeGlobally(5))
82+
83+
def matcher(actual):
84+
self.assertLen(actual, 1)
85+
actual = actual[0]
86+
self.assertLen(actual, 5)
87+
for actual_rb in actual:
88+
self.assertTrue(actual_rb.equals(rb))
89+
90+
util.assert_that(sample, matcher)
91+
92+
_Tracker().reset()
93+
beam.runners.DirectRunner().run(pipeline)
94+
self.assertTrue(_Tracker().encode_called)
95+
self.assertTrue(_Tracker().decode_called)
96+
beam.coders.typecoders.registry.register_coder(pa.RecordBatch,
97+
types._ArrowRecordBatchCoder)
98+
99+
100+
if __name__ == "__main__":
101+
absltest.main()

0 commit comments

Comments
 (0)