Skip to content

Commit 6b6156c

Browse files
hyanwongpetrelharp
andauthored
Raise error if pop passed to samples is not an integer (#3344)
Co-authored-by: Peter Ralph <[email protected]>
1 parent dbaae94 commit 6b6156c

File tree

3 files changed

+24
-0
lines changed

3 files changed

+24
-0
lines changed

python/CHANGELOG.rst

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,13 @@
1+
--------------------
2+
[1.0.x] - YYYY-MM-DD
3+
--------------------
4+
5+
**Bugfixes**
6+
7+
- ``ts.samples(population=...)`` now raises a ``ValueError`` if the population
8+
ID is e.g. a population name, rather than silently returning no samples.
9+
(:user:`hyanwong`, :pr:`3344`)
10+
111
--------------------
212
[1.0.0] - 2025-11-27
313
--------------------

python/tests/test_highlevel.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -992,6 +992,17 @@ def test_samples(self):
992992
]
993993
assert total == ts.num_samples
994994

995+
@pytest.mark.parametrize("pop", ["string", "", "0", np.arange(2), 0.0, 0.5, np.nan])
996+
def test_bad_samples(self, pop):
997+
ts = tskit.Tree.generate_balanced(4).tree_sequence
998+
with pytest.raises(ValueError, match="must be an integer ID"):
999+
ts.samples(population=pop)
1000+
1001+
@pytest.mark.parametrize("pop", [0, np.int32(0), np.int64(0), np.uint32(0)])
1002+
def test_good_samples(self, pop):
1003+
ts = msprime.sim_ancestry(2)
1004+
assert np.array_equiv(ts.samples(population=pop), ts.samples())
1005+
9951006
@pytest.mark.parametrize("time", [0, 0.1, 1 / 3, 1 / 4, 5 / 7])
9961007
def test_samples_time(self, time):
9971008
ts = self.get_tree_sequence(num_demes=2, n=20, times=[time, 0.2, 1, 15])

python/tskit/trees.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6533,6 +6533,9 @@ def samples(self, population=None, *, population_id=None, time=None):
65336533
samples = self._ll_tree_sequence.get_samples()
65346534
keep = np.full(shape=samples.shape, fill_value=True)
65356535
if population is not None:
6536+
if not isinstance(population, numbers.Integral):
6537+
raise ValueError("`population` must be an integer ID")
6538+
population = int(population)
65366539
sample_population = self.nodes_population[samples]
65376540
keep = np.logical_and(keep, sample_population == population)
65386541
if time is not None:

0 commit comments

Comments
 (0)