Skip to content

Commit 952398a

Browse files
committed
Add a __class__ key to the demography .asdict methods
And convert list input to numpy array. The __class__ key is useful because e.g. for events, the dictionary keys aren't enough to distinguish between the event types. This also means that the classes are properly output when saving provenance, aiding reproducibility.
1 parent 0909034 commit 952398a

File tree

1 file changed

+21
-3
lines changed

1 file changed

+21
-3
lines changed

msprime/demography.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,9 @@ class Population:
223223
"""
224224

225225
def asdict(self):
226-
return dataclasses.asdict(self)
226+
d = dataclasses.asdict(self)
227+
d["__class__"] = f"{self.__module__}.{self.__class__.__name__}"
228+
return d
227229

228230
def validate(self):
229231
if self.initial_size < 0:
@@ -269,6 +271,15 @@ def __post_init__(self):
269271
if self.migration_matrix is None:
270272
N = self.num_populations
271273
self.migration_matrix = np.zeros((N, N))
274+
else:
275+
# convert to a numpy array if it's not already
276+
if not isinstance(self.migration_matrix, np.ndarray):
277+
self.migration_matrix = np.array(self.migration_matrix)
278+
if self.events is not None:
279+
for event in self.events:
280+
if not isinstance(event, DemographicEvent):
281+
raise TypeError("Events must be instances of DemographicEvent")
282+
event.demography = self
272283

273284
# People might get cryptic errors from passing in copies of the same
274285
# population, so check for it.
@@ -1196,6 +1207,8 @@ def asdict(self):
11961207
"populations": [pop.asdict() for pop in self.populations],
11971208
"events": [event.asdict() for event in self.events],
11981209
"migration_matrix": self.migration_matrix.tolist(),
1210+
"__class__": f"{self.__module__}.{self.__class__.__name__}",
1211+
11991212
}
12001213

12011214
def debug(self):
@@ -2882,6 +2895,8 @@ def asdict(self):
28822895
initial_size=self.initial_size,
28832896
growth_rate=self.growth_rate,
28842897
metadata=self.metadata,
2898+
__class__=f"{self.__module__}.{self.__class__.__name__}",
2899+
28852900
)
28862901

28872902

@@ -2914,11 +2929,14 @@ def _effect(self):
29142929
raise NotImplementedError()
29152930

29162931
def asdict(self):
2917-
return {
2932+
deprecated = {"population_id", "matrix_index", "destination"}
2933+
d = {
29182934
key: getattr(self, key)
29192935
for key in inspect.signature(self.__init__).parameters.keys()
2920-
if hasattr(self, key)
2936+
if hasattr(self, key) and key not in deprecated
29212937
}
2938+
d["__class__"] = f"{self.__module__}.{self.__class__.__name__}"
2939+
return d
29222940

29232941
def _convert_id(self, population_ref):
29242942
"""

0 commit comments

Comments
 (0)