Skip to content

Commit 1a86a02

Browse files
committed
Add a __class__ key to the demography .asdict methods
And convert list input to numpy array. The __class__ key is useful because e.g. for events, the dictionary keys aren't enough to distinguish between the event types. This also means that the classes are properly output when saving provenance, aiding reproducibility.
1 parent 0909034 commit 1a86a02

File tree

1 file changed

+19
-3
lines changed

1 file changed

+19
-3
lines changed

msprime/demography.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,9 @@ class Population:
223223
"""
224224

225225
def asdict(self):
226-
return dataclasses.asdict(self)
226+
d = dataclasses.asdict(self)
227+
d["__class__"] = f"{self.__module__}.{self.__class__.__name__}"
228+
return d
227229

228230
def validate(self):
229231
if self.initial_size < 0:
@@ -269,6 +271,15 @@ def __post_init__(self):
269271
if self.migration_matrix is None:
270272
N = self.num_populations
271273
self.migration_matrix = np.zeros((N, N))
274+
else:
275+
# convert to a numpy array if it's not already
276+
if not isinstance(self.migration_matrix, np.ndarray):
277+
self.migration_matrix = np.array(self.migration_matrix)
278+
if self.events is not None:
279+
for event in self.events:
280+
if not isinstance(event, DemographicEvent):
281+
raise TypeError("Events must be instances of DemographicEvent")
282+
event.demography = self
272283

273284
# People might get cryptic errors from passing in copies of the same
274285
# population, so check for it.
@@ -1196,6 +1207,7 @@ def asdict(self):
11961207
"populations": [pop.asdict() for pop in self.populations],
11971208
"events": [event.asdict() for event in self.events],
11981209
"migration_matrix": self.migration_matrix.tolist(),
1210+
"__class__": f"{self.__module__}.{self.__class__.__name__}",
11991211
}
12001212

12011213
def debug(self):
@@ -2882,6 +2894,7 @@ def asdict(self):
28822894
initial_size=self.initial_size,
28832895
growth_rate=self.growth_rate,
28842896
metadata=self.metadata,
2897+
__class__=f"{self.__module__}.{self.__class__.__name__}",
28852898
)
28862899

28872900

@@ -2914,11 +2927,14 @@ def _effect(self):
29142927
raise NotImplementedError()
29152928

29162929
def asdict(self):
2917-
return {
2930+
deprecated = {"population_id", "matrix_index", "destination"}
2931+
d = {
29182932
key: getattr(self, key)
29192933
for key in inspect.signature(self.__init__).parameters.keys()
2920-
if hasattr(self, key)
2934+
if hasattr(self, key) and key not in deprecated
29212935
}
2936+
d["__class__"] = f"{self.__module__}.{self.__class__.__name__}"
2937+
return d
29222938

29232939
def _convert_id(self, population_ref):
29242940
"""

0 commit comments

Comments
 (0)