Skip to content

Commit 7079672

Browse files
hyanwongmergify[bot]
authored andcommitted
Allow sequence length to be specified in sd.subset
1 parent 4f76b6a commit 7079672

File tree

3 files changed

+21
-2
lines changed

3 files changed

+21
-2
lines changed

CHANGELOG.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,10 @@
22
[0.2.4] - 2022-06-xx
33
********************
44

5+
**Fixes**
6+
7+
- sample_data.subset() now accepts a sequence_length (:pr:`681`, :user:`hyanwong`)
8+
59
**Breaking changes**:
610

711
- Inference now sets time_units on both ancestor and final tree sequences to

tests/test_formats.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1515,6 +1515,19 @@ def test_file_kwargs(self):
15151515
sd2 = formats.SampleData.load(path)
15161516
assert sd1.data_equal(sd2)
15171517

1518+
def test_sequence_length_change(self):
1519+
ts = tsutil.get_example_ts(10)
1520+
sd1 = formats.SampleData.from_tree_sequence(ts)
1521+
max_site_to_use = sd1.num_sites // 2
1522+
new_seq_length = sd1.sites_position[max_site_to_use + 1]
1523+
assert new_seq_length < sd1.sequence_length
1524+
assert max_site_to_use > 0
1525+
sd2 = sd1.subset(
1526+
sites=np.arange(max_site_to_use),
1527+
sequence_length=new_seq_length,
1528+
)
1529+
assert sd2.sequence_length == new_seq_length
1530+
15181531

15191532
class TestSampleDataMerge:
15201533
"""

tsinfer/formats.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1314,7 +1314,7 @@ def assert_data_equal(self, other):
13141314
assert self.samples_equal(other)
13151315
assert self.sites_equal(other)
13161316

1317-
def subset(self, individuals=None, sites=None, **kwargs):
1317+
def subset(self, individuals=None, sites=None, *, sequence_length=None, **kwargs):
13181318
"""
13191319
Returns a subset of this sample data file consisting of the specified
13201320
individuals and sites. It is important to note that these are
@@ -1356,7 +1356,9 @@ def subset(self, individuals=None, sites=None, **kwargs):
13561356
sites = set(sites)
13571357
if len(sites) != num_sites:
13581358
raise ValueError("Duplicate site IDS")
1359-
with SampleData(sequence_length=self.sequence_length, **kwargs) as subset:
1359+
if sequence_length is None:
1360+
sequence_length = self.sequence_length
1361+
with SampleData(sequence_length=sequence_length, **kwargs) as subset:
13601362
# NOTE We don't bother filtering the populations, but we could.
13611363
for population in self.populations():
13621364
subset.add_population(population.metadata)

0 commit comments

Comments
 (0)