Skip to content

Commit 081e65f

Browse files
committed
Add a __class__ key to the demography .asdict methods
And convert list input to numpy array. The __class__ key is useful because e.g. for events, the dictionary keys aren't enough to distinguish between the event types. This also means that the classes are properly output when saving provenance, aiding reproducibility.
1 parent 0909034 commit 081e65f

File tree

3 files changed

+105
-3
lines changed

3 files changed

+105
-3
lines changed

CHANGELOG.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,13 @@
66

77
- Fix wheels on OSX-13 ({pr}`2355`, {user}`benjeffery`)
88

9+
**Breaking changes**:
10+
11+
- The `.asdict()` methods for Demography, Population, and Event classes in the
12+
demography submodule now return a `__class__` key. This is also stored in their
13+
provenance entries, to help recreate demography objects from provenance.
14+
({pr}`{2368}, {user}`hyanwong`)
15+
916
## [1.3.3] - 2024-08-07
1017

1118
Bugfix release for issues with Dirac and Beta coalescent models.

msprime/demography.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,9 @@ class Population:
223223
"""
224224

225225
def asdict(self):
226-
return dataclasses.asdict(self)
226+
d = dataclasses.asdict(self)
227+
d["__class__"] = f"{self.__module__}.{self.__class__.__name__}"
228+
return d
227229

228230
def validate(self):
229231
if self.initial_size < 0:
@@ -269,6 +271,19 @@ def __post_init__(self):
269271
if self.migration_matrix is None:
270272
N = self.num_populations
271273
self.migration_matrix = np.zeros((N, N))
274+
else:
275+
# convert to a numpy array if it's not already
276+
if not isinstance(self.migration_matrix, np.ndarray):
277+
self.migration_matrix = np.array(self.migration_matrix)
278+
if self.migration_matrix.shape != (self.num_populations, self.num_populations):
279+
raise ValueError(
280+
"Migration matrix must be square and match the number of populations"
281+
)
282+
if self.events is not None:
283+
for event in self.events:
284+
if not isinstance(event, DemographicEvent):
285+
raise TypeError("Events must be instances of DemographicEvent")
286+
event.demography = self
272287

273288
# People might get cryptic errors from passing in copies of the same
274289
# population, so check for it.
@@ -1196,6 +1211,7 @@ def asdict(self):
11961211
"populations": [pop.asdict() for pop in self.populations],
11971212
"events": [event.asdict() for event in self.events],
11981213
"migration_matrix": self.migration_matrix.tolist(),
1214+
"__class__": f"{self.__module__}.{self.__class__.__name__}",
11991215
}
12001216

12011217
def debug(self):
@@ -2882,6 +2898,7 @@ def asdict(self):
28822898
initial_size=self.initial_size,
28832899
growth_rate=self.growth_rate,
28842900
metadata=self.metadata,
2901+
__class__=f"{self.__module__}.{self.__class__.__name__}",
28852902
)
28862903

28872904

@@ -2914,11 +2931,14 @@ def _effect(self):
29142931
raise NotImplementedError()
29152932

29162933
def asdict(self):
2917-
return {
2934+
deprecated = {"population_id", "matrix_index", "destination"}
2935+
d = {
29182936
key: getattr(self, key)
29192937
for key in inspect.signature(self.__init__).parameters.keys()
2920-
if hasattr(self, key)
2938+
if hasattr(self, key) and key not in deprecated
29212939
}
2940+
d["__class__"] = f"{self.__module__}.{self.__class__.__name__}"
2941+
return d
29222942

29232943
def _convert_id(self, population_ref):
29242944
"""

tests/test_demography.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
"""
2020
Test cases for demographic events in msprime.
2121
"""
22+
import importlib
2223
import io
2324
import itertools
2425
import json
@@ -4292,6 +4293,41 @@ def test_duplicate_population_name(self):
42924293
[msprime.Population(10, name="A"), msprime.Population(11, name="A")]
42934294
)
42944295

4296+
def test_init_with_pops(self):
4297+
pop1 = msprime.Population(10)
4298+
pop2 = msprime.Population(10)
4299+
demography = msprime.Demography([pop1, pop2])
4300+
assert len(demography.populations) == 2
4301+
assert demography.populations[0].id == 0
4302+
assert demography.populations[1].id == 1
4303+
4304+
def test_init_with_matrix(self):
4305+
pop1 = msprime.Population(10)
4306+
pop2 = msprime.Population(10)
4307+
# Test passing in as a list of lists, not a numpy array
4308+
demography = msprime.Demography([pop1, pop2], migration_matrix=[[1, 1], [1, 1]])
4309+
assert demography.migration_matrix.shape == (2, 2)
4310+
assert np.all(demography.migration_matrix == 1)
4311+
4312+
def test_init_with_bad_matrix(self):
4313+
pop1 = msprime.Population(10)
4314+
pop2 = msprime.Population(10)
4315+
with pytest.raises(ValueError, match="must be square"):
4316+
msprime.Demography([pop1, pop2], migration_matrix=[[1, 1]])
4317+
4318+
def test_init_with_events(self):
4319+
pop = msprime.Population(10)
4320+
event = msprime.PopulationParametersChange(1, initial_size=1)
4321+
demography = msprime.Demography([pop], events=[event])
4322+
assert len(demography.events) == 1
4323+
assert np.all(demography.migration_matrix == 0)
4324+
assert demography.events[0].demography == demography
4325+
4326+
def test_bad_init_with_events(self):
4327+
pop = msprime.Population(10)
4328+
with pytest.raises(TypeError, match="instances of DemographicEvent"):
4329+
msprime.Demography([pop], events=[None])
4330+
42954331
def test_duplicate_populations(self):
42964332
pop = msprime.Population(10)
42974333
with pytest.raises(ValueError, match="must be distinct"):
@@ -4562,6 +4598,45 @@ def test_validate_resolves_defaults(self):
45624598
assert validated["B"].initially_active
45634599
assert not validated["C"].initially_active
45644600

4601+
def test_population_asdict(self):
4602+
# Test that we can instantiate the components of a demography object
4603+
demography = msprime.Demography()
4604+
demography.add_population(
4605+
name="A",
4606+
initial_size=1234,
4607+
growth_rate=0.234,
4608+
description="ASDF",
4609+
extra_metadata={"a": "B", "c": 1234},
4610+
default_sampling_time=0.2,
4611+
initially_active=True,
4612+
)
4613+
popdict = demography.populations[0].asdict()
4614+
module, classname = popdict.pop("__class__").rsplit(".", 1)
4615+
popdict.pop("id", None) # can't create a population with ID set
4616+
cls = getattr(importlib.import_module(module), classname)
4617+
obj = cls(**popdict)
4618+
assert isinstance(obj, msprime.demography.Population)
4619+
assert obj.name == "A"
4620+
assert obj.initial_size == 1234
4621+
4622+
def test_event_asdict(self):
4623+
demography = msprime.Demography(
4624+
[msprime.Population(1234), msprime.Population(4321)],
4625+
events=[
4626+
msprime.PopulationParametersChange(2, initial_size=5, population_id=1)
4627+
],
4628+
)
4629+
assert len(demography.events) == 1
4630+
eventdict = demography.events[0].asdict()
4631+
assert "population_id" not in eventdict # deprecated param
4632+
assert eventdict["population"] == 1
4633+
module, classname = eventdict.pop("__class__").rsplit(".", 1)
4634+
cls = getattr(importlib.import_module(module), classname)
4635+
obj = cls(**eventdict)
4636+
assert isinstance(obj, msprime.demography.PopulationParametersChange)
4637+
assert obj.time == 2
4638+
assert obj.initial_size == 5
4639+
45654640

45664641
class TestDemographyCopy:
45674642
def test_empty(self):

0 commit comments

Comments
 (0)