Skip to content

Commit 7ee9107

Browse files
committed
Add a __class__ key to the demography .asdict methods
And convert list input to numpy array. The __class__ key is useful because e.g. for events, the dictionary keys aren't enough to distinguish between the event types. This also means that the classes are properly output when saving provenance, aiding reproducibility.
1 parent 0909034 commit 7ee9107

File tree

2 files changed

+88
-3
lines changed

2 files changed

+88
-3
lines changed

msprime/demography.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,9 @@ class Population:
223223
"""
224224

225225
def asdict(self):
226-
return dataclasses.asdict(self)
226+
d = dataclasses.asdict(self)
227+
d["__class__"] = f"{self.__module__}.{self.__class__.__name__}"
228+
return d
227229

228230
def validate(self):
229231
if self.initial_size < 0:
@@ -269,6 +271,15 @@ def __post_init__(self):
269271
if self.migration_matrix is None:
270272
N = self.num_populations
271273
self.migration_matrix = np.zeros((N, N))
274+
else:
275+
# convert to a numpy array if it's not already
276+
if not isinstance(self.migration_matrix, np.ndarray):
277+
self.migration_matrix = np.array(self.migration_matrix)
278+
if self.events is not None:
279+
for event in self.events:
280+
if not isinstance(event, DemographicEvent):
281+
raise TypeError("Events must be instances of DemographicEvent")
282+
event.demography = self
272283

273284
# People might get cryptic errors from passing in copies of the same
274285
# population, so check for it.
@@ -1196,6 +1207,7 @@ def asdict(self):
11961207
"populations": [pop.asdict() for pop in self.populations],
11971208
"events": [event.asdict() for event in self.events],
11981209
"migration_matrix": self.migration_matrix.tolist(),
1210+
"__class__": f"{self.__module__}.{self.__class__.__name__}",
11991211
}
12001212

12011213
def debug(self):
@@ -2882,6 +2894,7 @@ def asdict(self):
28822894
initial_size=self.initial_size,
28832895
growth_rate=self.growth_rate,
28842896
metadata=self.metadata,
2897+
__class__=f"{self.__module__}.{self.__class__.__name__}",
28852898
)
28862899

28872900

@@ -2914,11 +2927,14 @@ def _effect(self):
29142927
raise NotImplementedError()
29152928

29162929
def asdict(self):
2917-
return {
2930+
deprecated = {"population_id", "matrix_index", "destination"}
2931+
d = {
29182932
key: getattr(self, key)
29192933
for key in inspect.signature(self.__init__).parameters.keys()
2920-
if hasattr(self, key)
2934+
if hasattr(self, key) and key not in deprecated
29212935
}
2936+
d["__class__"] = f"{self.__module__}.{self.__class__.__name__}"
2937+
return d
29222938

29232939
def _convert_id(self, population_ref):
29242940
"""

tests/test_demography.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
"""
2020
Test cases for demographic events in msprime.
2121
"""
22+
import importlib
2223
import io
2324
import itertools
2425
import json
@@ -4292,6 +4293,35 @@ def test_duplicate_population_name(self):
42924293
[msprime.Population(10, name="A"), msprime.Population(11, name="A")]
42934294
)
42944295

4296+
def test_init_with_pops(self):
4297+
pop1 = msprime.Population(10)
4298+
pop2 = msprime.Population(10)
4299+
demography = msprime.Demography([pop1, pop2])
4300+
assert len(demography.populations) == 2
4301+
assert demography.populations[0].id == 0
4302+
assert demography.populations[1].id == 1
4303+
4304+
def test_init_with_matrix(self):
4305+
pop1 = msprime.Population(10)
4306+
pop2 = msprime.Population(10)
4307+
# Test passing in as a list of lists, not a numpy array
4308+
demography = msprime.Demography([pop1, pop2], migration_matrix=[[1, 1], [1, 1]])
4309+
assert demography.migration_matrix.shape == (2, 2)
4310+
assert np.all(demography.migration_matrix == 1)
4311+
4312+
def test_init_with_events(self):
4313+
pop = msprime.Population(10)
4314+
event = msprime.PopulationParametersChange(1, initial_size=1)
4315+
demography = msprime.Demography([pop], events=[event])
4316+
assert len(demography.events) == 1
4317+
assert np.all(demography.migration_matrix == 0)
4318+
assert demography.events[0].demography == demography
4319+
4320+
def test_bad_init_with_events(self):
4321+
pop = msprime.Population(10)
4322+
with pytest.raises(TypeError, match="instances of DemographicEvent"):
4323+
msprime.Demography([pop], events=[None])
4324+
42954325
def test_duplicate_populations(self):
42964326
pop = msprime.Population(10)
42974327
with pytest.raises(ValueError, match="must be distinct"):
@@ -4562,6 +4592,45 @@ def test_validate_resolves_defaults(self):
45624592
assert validated["B"].initially_active
45634593
assert not validated["C"].initially_active
45644594

4595+
def test_population_asdict(self):
4596+
# Test that we can instantiate the components of a demography object
4597+
demography = msprime.Demography()
4598+
demography.add_population(
4599+
name="A",
4600+
initial_size=1234,
4601+
growth_rate=0.234,
4602+
description="ASDF",
4603+
extra_metadata={"a": "B", "c": 1234},
4604+
default_sampling_time=0.2,
4605+
initially_active=True,
4606+
)
4607+
popdict = demography.populations[0].asdict()
4608+
module, classname = popdict.pop("__class__").rsplit(".", 1)
4609+
popdict.pop("id", None) # can't create a population with ID set
4610+
cls = getattr(importlib.import_module(module), classname)
4611+
obj = cls(**popdict)
4612+
assert isinstance(obj, msprime.demography.Population)
4613+
assert obj.name == "A"
4614+
assert obj.initial_size == 1234
4615+
4616+
def test_event_asdict(self):
4617+
demography = msprime.Demography(
4618+
[msprime.Population(1234), msprime.Population(4321)],
4619+
events=[
4620+
msprime.PopulationParametersChange(2, initial_size=5, population_id=1)
4621+
],
4622+
)
4623+
assert len(demography.events) == 1
4624+
eventdict = demography.events[0].asdict()
4625+
assert "population_id" not in eventdict # deprecated param
4626+
assert eventdict["population"] == 1
4627+
module, classname = eventdict.pop("__class__").rsplit(".", 1)
4628+
cls = getattr(importlib.import_module(module), classname)
4629+
obj = cls(**eventdict)
4630+
assert isinstance(obj, msprime.demography.PopulationParametersChange)
4631+
assert obj.time == 2
4632+
assert obj.initial_size == 5
4633+
45654634

45664635
class TestDemographyCopy:
45674636
def test_empty(self):

0 commit comments

Comments
 (0)