Skip to content

Commit ed64181

Browse files
authored
Merge pull request #965 from benjeffery/fix-json
Encode numpy arrays in sample batch json
2 parents f9de549 + fe04950 commit ed64181

File tree

1 file changed

+17
-3
lines changed

1 file changed

+17
-3
lines changed

tsinfer/inference.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -945,13 +945,27 @@ def common_params(self) -> dict:
945945
}
946946

947947
def save(self, path):
948+
def numpy_encoder(obj):
949+
if isinstance(obj, np.ndarray):
950+
return {
951+
"__numpy__": True,
952+
"data": obj.tolist(),
953+
"dtype": str(obj.dtype),
954+
}
955+
return obj
956+
948957
with open(path, "w") as f:
949-
json.dump(dataclasses.asdict(self), f, indent=2)
958+
json.dump(dataclasses.asdict(self), f, indent=2, default=numpy_encoder)
950959

951960
@classmethod
952961
def load(cls, path):
962+
def numpy_decoder(dct):
963+
if "__numpy__" in dct:
964+
return np.array(dct["data"], dtype=dct["dtype"])
965+
return dct
966+
953967
with open(path) as f:
954-
wd_dict = json.load(f)
968+
wd_dict = json.load(f, object_hook=numpy_decoder)
955969
return cls(**wd_dict)
956970

957971

@@ -1048,7 +1062,7 @@ def match_samples_batch_init(
10481062
sample_times = sample_times.tolist()
10491063
wd.sample_indexes = sample_indexes
10501064
wd.sample_times = sample_times
1051-
num_samples_per_partition = min_work_per_job // variant_data.num_sites
1065+
num_samples_per_partition = int(min_work_per_job // variant_data.num_sites)
10521066
if num_samples_per_partition == 0:
10531067
num_samples_per_partition = 1
10541068
wd.num_samples_per_partition = num_samples_per_partition

0 commit comments

Comments
 (0)