File tree Expand file tree Collapse file tree 2 files changed +25
-4
lines changed
Expand file tree Collapse file tree 2 files changed +25
-4
lines changed Original file line number Diff line number Diff line change @@ -433,9 +433,17 @@ def encode_cache(self, accumulator):
433433 return tf .compat .as_bytes (json .dumps (primitive_accumulator ))
434434
435435 def decode_cache (self , encoded_accumulator ):
436- return np .array (
437- json .loads (tf .compat .as_text (encoded_accumulator )), dtype = self ._dtype
438- )
436+ # TODO(b/268341036): Set dtype correctly for combiners for numpy 1.24.
437+ try :
438+ return np .array (
439+ json .loads (tf .compat .as_text (encoded_accumulator )), dtype = self ._dtype
440+ )
441+ except ValueError :
442+ if self ._dtype != object :
443+ return np .array (
444+ json .loads (tf .compat .as_text (encoded_accumulator )), dtype = object
445+ )
446+ raise
439447
440448
441449class AnalyzerDef (nodes .OperationDef , metaclass = abc .ABCMeta ):
Original file line number Diff line number Diff line change @@ -81,7 +81,7 @@ def test_validate_dataset_keys(self):
8181 coder = analyzer_nodes ._VocabularyAccumulatorCoder (),
8282 value = [b'\x8a ' , 29 ]),
8383 dict (
84- testcase_name = '_VocabularyAccumulatorCoderClassAccumulator ' ,
84+ testcase_name = '_WeightedMeanAndVarAccumulatorPerKey ' ,
8585 coder = analyzer_nodes ._VocabularyAccumulatorCoder (),
8686 value = [
8787 b'A' ,
@@ -92,6 +92,19 @@ def test_validate_dataset_keys(self):
9292 weight = np .array (0. ),
9393 )
9494 ]),
95+ dict (
96+ testcase_name = '_WeightedMeanAndVarAccumulatorKeepDims' ,
97+ coder = analyzer_nodes .JsonNumpyCacheCoder (),
98+ # TODO(b/268341036): Remove this complication once np 1.24 issue is
99+ # fixed.
100+ value = analyzer_nodes .JsonNumpyCacheCoder (object ).decode_cache (
101+ analyzer_nodes .JsonNumpyCacheCoder ().encode_cache (
102+ analyzers ._WeightedMeanAndVarAccumulator (
103+ count = np .array (0 ),
104+ mean = np .array ([], dtype = np .float64 ),
105+ variance = np .array ([], dtype = np .float64 ),
106+ weight = np .array (0.0 ))))
107+ ),
95108 dict (
96109 testcase_name = '_QuantilesAccumulatorCoderClassAccumulator' ,
97110 coder = analyzers ._QuantilesSketchCacheCoder (),
You can’t perform that action at this time.
0 commit comments