Skip to content

Commit 4f76b6a

Browse files
hyanwongmergify[bot]
authored andcommitted
Set time_units
1 parent 34bd064 commit 4f76b6a

File tree

3 files changed

+61
-8
lines changed

3 files changed

+61
-8
lines changed

CHANGELOG.rst

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,13 @@
1+
********************
2+
[0.2.4] - 2022-06-xx
3+
********************
4+
5+
**Breaking changes**:
6+
7+
- Inference now sets time_units on both ancestor and final tree sequences to
8+
tskit.TIME_UNITS_UNCALIBRATED, stopping accidental use of branch length
9+
calculations on the ts. (:pr:`680`, :user:`hyanwong`)
10+
111
********************
212
[0.2.3] - 2022-04-08
313
********************

tests/test_inference.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1644,6 +1644,17 @@ def test_multi_char_alleles(self):
16441644
self.verify(sample_data, mismatch_ratio=100, recombination_rate=1e-9)
16451645
self.verify(sample_data, mismatch_ratio=0.01, recombination_rate=1e-3)
16461646

1647+
def test_time_units(self):
1648+
with tsinfer.SampleData(1.0) as sample_data:
1649+
sample_data.add_site(0.5, [0, 1, 1])
1650+
ancestor_data = tsinfer.generate_ancestors(sample_data)
1651+
ancestors_ts = tsinfer.match_ancestors(sample_data, ancestor_data)
1652+
assert ancestors_ts.time_units == tskit.TIME_UNITS_UNCALIBRATED
1653+
ancestors_ts = tsinfer.match_ancestors(
1654+
sample_data, ancestor_data, time_units="generations"
1655+
)
1656+
assert ancestors_ts.time_units == "generations"
1657+
16471658

16481659
class TestAncestorsTreeSequenceFlags:
16491660
"""
@@ -1917,6 +1928,30 @@ def test_partial_bad_indexes(self):
19171928
with pytest.raises(ValueError):
19181929
tsinfer.match_samples(sd, a_ts, indexes=bad_samples)
19191930

1931+
def test_time_units_default_uncalibrated(self):
1932+
with tsinfer.SampleData(1.0) as sample_data:
1933+
sample_data.add_site(0.5, [0, 1, 1])
1934+
ts = tsinfer.infer(sample_data)
1935+
assert ts.time_units == tskit.TIME_UNITS_UNCALIBRATED
1936+
1937+
def test_time_units_passed_through(self):
1938+
with tsinfer.SampleData(1.0) as sample_data:
1939+
sample_data.add_site(0.5, [0, 1, 1])
1940+
ts = tsinfer.infer(sample_data)
1941+
assert ts.time_units == tskit.TIME_UNITS_UNCALIBRATED
1942+
ancestor_data = tsinfer.generate_ancestors(sample_data)
1943+
ancestors_ts = tsinfer.match_ancestors(
1944+
sample_data, ancestor_data, time_units="generations"
1945+
)
1946+
ts = tsinfer.match_samples(sample_data, ancestors_ts)
1947+
assert ts.time_units == "generations"
1948+
1949+
def test_time_units_in_infer(self):
1950+
with tsinfer.SampleData(1.0) as sample_data:
1951+
sample_data.add_site(0.5, [1, 1])
1952+
ts = tsinfer.infer(sample_data, time_units="generations")
1953+
assert ts.time_units == "generations"
1954+
19201955

19211956
class AlgorithmsExactlyEqualMixin:
19221957
"""

tsinfer/inference.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,7 @@ def infer(
223223
precision=None,
224224
engine=constants.C_ENGINE,
225225
progress_monitor=None,
226+
time_units=None,
226227
):
227228
"""
228229
infer(sample_data, *, recombination_rate=None, mismatch_ratio=None,\
@@ -298,6 +299,7 @@ def infer(
298299
precision=precision,
299300
path_compression=path_compression,
300301
progress_monitor=progress_monitor,
302+
time_units=time_units,
301303
)
302304
inferred_ts = match_samples(
303305
sample_data,
@@ -393,6 +395,7 @@ def match_ancestors(
393395
engine=constants.C_ENGINE,
394396
progress_monitor=None,
395397
extended_checks=False,
398+
time_units=None,
396399
):
397400
"""
398401
match_ancestors(sample_data, ancestor_data, *, recombination_rate=None,\
@@ -432,9 +435,11 @@ def match_ancestors(
432435
progress_monitor = _get_progress_monitor(progress_monitor, match_ancestors=True)
433436
sample_data._check_finalised()
434437
ancestor_data._check_finalised()
438+
435439
matcher = AncestorMatcher(
436440
sample_data,
437441
ancestor_data,
442+
time_units=time_units,
438443
recombination_rate=recombination_rate,
439444
recombination=recombination,
440445
mismatch_ratio=mismatch_ratio,
@@ -1240,9 +1245,12 @@ def convert_inference_mutations(self, tables):
12401245

12411246

12421247
class AncestorMatcher(Matcher):
1243-
def __init__(self, sample_data, ancestor_data, **kwargs):
1248+
def __init__(self, sample_data, ancestor_data, time_units=None, **kwargs):
12441249
super().__init__(sample_data, ancestor_data.sites_position[:], **kwargs)
12451250
self.ancestor_data = ancestor_data
1251+
if time_units is None:
1252+
time_units = tskit.TIME_UNITS_UNCALIBRATED
1253+
self.time_units = time_units
12461254
self.num_ancestors = self.ancestor_data.num_ancestors
12471255
self.epoch = self.ancestor_data.ancestors_time[:]
12481256

@@ -1394,10 +1402,10 @@ def match_ancestors(self):
13941402
logger.info("Finished ancestor matching")
13951403
return ts
13961404

1397-
def get_ancestors_tree_sequence(self):
1405+
def get_ancestors_tables(self):
13981406
"""
1399-
Return the ancestors tree sequence. Only inference sites are included in this
1400-
tree sequence. All nodes have the sample flag bit set, and if a node
1407+
Return the ancestors tree sequence tables. Only inference sites are included in
1408+
this tree sequence. All nodes have the sample flag bit set, and if a node
14011409
corresponds to an ancestor in the ancestors file, it is indicated via metadata.
14021410
"""
14031411
logger.debug("Building ancestors tree sequence")
@@ -1461,18 +1469,18 @@ def get_ancestors_tree_sequence(self):
14611469
len(tables.sites),
14621470
)
14631471
)
1464-
return tables.tree_sequence()
1472+
return tables
14651473

14661474
def store_output(self):
14671475
if self.num_ancestors > 0:
1468-
ts = self.get_ancestors_tree_sequence()
1476+
tables = self.get_ancestors_tables()
14691477
else:
14701478
# Allocate an empty tree sequence.
14711479
tables = tskit.TableCollection(
14721480
sequence_length=self.ancestor_data.sequence_length
14731481
)
1474-
ts = tables.tree_sequence()
1475-
return ts
1482+
tables.time_units = self.time_units
1483+
return tables.tree_sequence()
14761484

14771485

14781486
class SampleMatcher(Matcher):

0 commit comments

Comments
 (0)