File tree Expand file tree Collapse file tree 1 file changed +17
-3
lines changed
Expand file tree Collapse file tree 1 file changed +17
-3
lines changed Original file line number Diff line number Diff line change @@ -945,13 +945,27 @@ def common_params(self) -> dict:
945945 }
946946
947947 def save (self , path ):
948+ def numpy_encoder (obj ):
949+ if isinstance (obj , np .ndarray ):
950+ return {
951+ "__numpy__" : True ,
952+ "data" : obj .tolist (),
953+ "dtype" : str (obj .dtype ),
954+ }
955+ return obj
956+
948957 with open (path , "w" ) as f :
949- json .dump (dataclasses .asdict (self ), f , indent = 2 )
958+ json .dump (dataclasses .asdict (self ), f , indent = 2 , default = numpy_encoder )
950959
951960 @classmethod
952961 def load (cls , path ):
962+ def numpy_decoder (dct ):
963+ if "__numpy__" in dct :
964+ return np .array (dct ["data" ], dtype = dct ["dtype" ])
965+ return dct
966+
953967 with open (path ) as f :
954- wd_dict = json .load (f )
968+ wd_dict = json .load (f , object_hook = numpy_decoder )
955969 return cls (** wd_dict )
956970
957971
@@ -1048,7 +1062,7 @@ def match_samples_batch_init(
10481062 sample_times = sample_times .tolist ()
10491063 wd .sample_indexes = sample_indexes
10501064 wd .sample_times = sample_times
1051- num_samples_per_partition = min_work_per_job // variant_data .num_sites
1065+ num_samples_per_partition = int ( min_work_per_job // variant_data .num_sites )
10521066 if num_samples_per_partition == 0 :
10531067 num_samples_per_partition = 1
10541068 wd .num_samples_per_partition = num_samples_per_partition
You can’t perform that action at this time.
0 commit comments