Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,15 @@
# Changelog


## [1.3.5] - 2025-XX-XX

**Breaking changes**:

- The `.asdict()` methods for Demography, Population, and Event classes in the
demography submodule now return a `__class__` key. This is also stored in their
provenance entries, to help recreate demography objects from provenance.
({pr}`{2368}, {user}`hyanwong`)

## [1.3.4] - 2025-05-01

**Bug fixes**:
Expand Down
26 changes: 23 additions & 3 deletions msprime/demography.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,9 @@ class Population:
"""

def asdict(self):
return dataclasses.asdict(self)
d = dataclasses.asdict(self)
d["__class__"] = f"{self.__module__}.{self.__class__.__name__}"
return d

def validate(self):
if self.initial_size < 0:
Expand Down Expand Up @@ -269,6 +271,19 @@ def __post_init__(self):
if self.migration_matrix is None:
N = self.num_populations
self.migration_matrix = np.zeros((N, N))
else:
# convert to a numpy array if it's not already
if not isinstance(self.migration_matrix, np.ndarray):
self.migration_matrix = np.array(self.migration_matrix)
if self.migration_matrix.shape != (self.num_populations, self.num_populations):
raise ValueError(
"Migration matrix must be square and match the number of populations"
)
if self.events is not None:
for event in self.events:
if not isinstance(event, DemographicEvent):
raise TypeError("Events must be instances of DemographicEvent")
event.demography = self

# People might get cryptic errors from passing in copies of the same
# population, so check for it.
Expand Down Expand Up @@ -1196,6 +1211,7 @@ def asdict(self):
"populations": [pop.asdict() for pop in self.populations],
"events": [event.asdict() for event in self.events],
"migration_matrix": self.migration_matrix.tolist(),
"__class__": f"{self.__module__}.{self.__class__.__name__}",
}

def debug(self):
Expand Down Expand Up @@ -2882,6 +2898,7 @@ def asdict(self):
initial_size=self.initial_size,
growth_rate=self.growth_rate,
metadata=self.metadata,
__class__=f"{self.__module__}.{self.__class__.__name__}",
)


Expand Down Expand Up @@ -2914,11 +2931,14 @@ def _effect(self):
raise NotImplementedError()

def asdict(self):
return {
deprecated = {"population_id", "matrix_index", "destination"}
d = {
key: getattr(self, key)
for key in inspect.signature(self.__init__).parameters.keys()
if hasattr(self, key)
if hasattr(self, key) and key not in deprecated
}
d["__class__"] = f"{self.__module__}.{self.__class__.__name__}"
return d

def _convert_id(self, population_ref):
"""
Expand Down
75 changes: 75 additions & 0 deletions tests/test_demography.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
"""
Test cases for demographic events in msprime.
"""
import importlib
import io
import itertools
import json
Expand Down Expand Up @@ -4292,6 +4293,41 @@ def test_duplicate_population_name(self):
[msprime.Population(10, name="A"), msprime.Population(11, name="A")]
)

def test_init_with_pops(self):
pop1 = msprime.Population(10)
pop2 = msprime.Population(10)
demography = msprime.Demography([pop1, pop2])
assert len(demography.populations) == 2
assert demography.populations[0].id == 0
assert demography.populations[1].id == 1

def test_init_with_matrix(self):
pop1 = msprime.Population(10)
pop2 = msprime.Population(10)
# Test passing in as a list of lists, not a numpy array
demography = msprime.Demography([pop1, pop2], migration_matrix=[[1, 1], [1, 1]])
assert demography.migration_matrix.shape == (2, 2)
assert np.all(demography.migration_matrix == 1)

def test_init_with_bad_matrix(self):
pop1 = msprime.Population(10)
pop2 = msprime.Population(10)
with pytest.raises(ValueError, match="must be square"):
msprime.Demography([pop1, pop2], migration_matrix=[[1, 1]])

def test_init_with_events(self):
pop = msprime.Population(10)
event = msprime.PopulationParametersChange(1, initial_size=1)
demography = msprime.Demography([pop], events=[event])
assert len(demography.events) == 1
assert np.all(demography.migration_matrix == 0)
assert demography.events[0].demography == demography

def test_bad_init_with_events(self):
pop = msprime.Population(10)
with pytest.raises(TypeError, match="instances of DemographicEvent"):
msprime.Demography([pop], events=[None])

def test_duplicate_populations(self):
pop = msprime.Population(10)
with pytest.raises(ValueError, match="must be distinct"):
Expand Down Expand Up @@ -4562,6 +4598,45 @@ def test_validate_resolves_defaults(self):
assert validated["B"].initially_active
assert not validated["C"].initially_active

def test_population_asdict(self):
# Test that we can instantiate the components of a demography object
demography = msprime.Demography()
demography.add_population(
name="A",
initial_size=1234,
growth_rate=0.234,
description="ASDF",
extra_metadata={"a": "B", "c": 1234},
default_sampling_time=0.2,
initially_active=True,
)
popdict = demography.populations[0].asdict()
module, classname = popdict.pop("__class__").rsplit(".", 1)
popdict.pop("id", None) # can't create a population with ID set
cls = getattr(importlib.import_module(module), classname)
obj = cls(**popdict)
assert isinstance(obj, msprime.demography.Population)
assert obj.name == "A"
assert obj.initial_size == 1234

def test_event_asdict(self):
demography = msprime.Demography(
[msprime.Population(1234), msprime.Population(4321)],
events=[
msprime.PopulationParametersChange(2, initial_size=5, population_id=1)
],
)
assert len(demography.events) == 1
eventdict = demography.events[0].asdict()
assert "population_id" not in eventdict # deprecated param
assert eventdict["population"] == 1
module, classname = eventdict.pop("__class__").rsplit(".", 1)
cls = getattr(importlib.import_module(module), classname)
obj = cls(**eventdict)
assert isinstance(obj, msprime.demography.PopulationParametersChange)
assert obj.time == 2
assert obj.initial_size == 5


class TestDemographyCopy:
def test_empty(self):
Expand Down
Loading