Skip to content

Commit d69e02f

Browse files
zoyahavtfx-copybara
authored andcommitted
Temporary fix for TFT analyzers which do not reduce instance dims to work with numpy 1.24
PiperOrigin-RevId: 508126642
1 parent 004e3d0 commit d69e02f

File tree

2 files changed

+25
-4
lines changed

2 files changed

+25
-4
lines changed

tensorflow_transform/analyzer_nodes.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -433,9 +433,17 @@ def encode_cache(self, accumulator):
433433
return tf.compat.as_bytes(json.dumps(primitive_accumulator))
434434

435435
def decode_cache(self, encoded_accumulator):
436-
return np.array(
437-
json.loads(tf.compat.as_text(encoded_accumulator)), dtype=self._dtype
438-
)
436+
# TODO(b/268341036): Set dtype correctly for combiners for numpy 1.24.
437+
try:
438+
return np.array(
439+
json.loads(tf.compat.as_text(encoded_accumulator)), dtype=self._dtype
440+
)
441+
except ValueError:
442+
if self._dtype != object:
443+
return np.array(
444+
json.loads(tf.compat.as_text(encoded_accumulator)), dtype=object
445+
)
446+
raise
439447

440448

441449
class AnalyzerDef(nodes.OperationDef, metaclass=abc.ABCMeta):

tensorflow_transform/beam/analyzer_cache_test.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ def test_validate_dataset_keys(self):
8181
coder=analyzer_nodes._VocabularyAccumulatorCoder(),
8282
value=[b'\x8a', 29]),
8383
dict(
84-
testcase_name='_VocabularyAccumulatorCoderClassAccumulator',
84+
testcase_name='_WeightedMeanAndVarAccumulatorPerKey',
8585
coder=analyzer_nodes._VocabularyAccumulatorCoder(),
8686
value=[
8787
b'A',
@@ -92,6 +92,19 @@ def test_validate_dataset_keys(self):
9292
weight=np.array(0.),
9393
)
9494
]),
95+
dict(
96+
testcase_name='_WeightedMeanAndVarAccumulatorKeepDims',
97+
coder=analyzer_nodes.JsonNumpyCacheCoder(),
98+
# TODO(b/268341036): Remove this complication once np 1.24 issue is
99+
# fixed.
100+
value=analyzer_nodes.JsonNumpyCacheCoder(object).decode_cache(
101+
analyzer_nodes.JsonNumpyCacheCoder().encode_cache(
102+
analyzers._WeightedMeanAndVarAccumulator(
103+
count=np.array(0),
104+
mean=np.array([], dtype=np.float64),
105+
variance=np.array([], dtype=np.float64),
106+
weight=np.array(0.0))))
107+
),
95108
dict(
96109
testcase_name='_QuantilesAccumulatorCoderClassAccumulator',
97110
coder=analyzers._QuantilesSketchCacheCoder(),

0 commit comments

Comments
 (0)