Skip to content

Commit 38f70d9

Browse files
hyanwongbenjeffery
authored andcommitted
Save mismatch ratio in provenance
1 parent f8f2598 commit 38f70d9

File tree

2 files changed

+41
-14
lines changed

2 files changed

+41
-14
lines changed

tests/test_provenance.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -102,12 +102,16 @@ def test_no_provenance_match_samples(self, small_sd_fixture):
102102
ts = tsinfer.match_samples(small_sd_fixture, anc_ts, record_provenance=False)
103103
assert ts.num_provenances == small_sd_fixture.num_provenances
104104

105-
def test_provenance_infer(self, small_sd_fixture):
106-
ts = tsinfer.infer(small_sd_fixture)
105+
@pytest.mark.parametrize("mmr", [None, 0.1])
106+
def test_provenance_infer(self, small_sd_fixture, mmr):
107+
ts = tsinfer.infer(
108+
small_sd_fixture, mismatch_ratio=mmr, recombination_rate=1e-8
109+
)
107110
assert ts.num_provenances == small_sd_fixture.num_provenances + 1
108111
record = json.loads(ts.provenance(-1).record)
109112
params = record["parameters"]
110113
assert params["command"] == "infer"
114+
assert params["mismatch_ratio"] == mmr
111115

112116
def test_provenance_generate_ancestors(self, small_sd_fixture):
113117
ancestors = tsinfer.generate_ancestors(small_sd_fixture)
@@ -117,26 +121,34 @@ def test_provenance_generate_ancestors(self, small_sd_fixture):
117121
params = record["parameters"]
118122
assert params["command"] == "generate_ancestors"
119123

120-
def test_provenance_match_ancestors(self, small_sd_fixture):
124+
@pytest.mark.parametrize("mmr", [None, 0.1])
125+
def test_provenance_match_ancestors(self, small_sd_fixture, mmr):
121126
ancestors = tsinfer.generate_ancestors(small_sd_fixture)
122-
anc_ts = tsinfer.match_ancestors(small_sd_fixture, ancestors)
127+
anc_ts = tsinfer.match_ancestors(
128+
small_sd_fixture, ancestors, mismatch_ratio=mmr, recombination_rate=1e-8
129+
)
123130
assert anc_ts.num_provenances == small_sd_fixture.num_provenances + 2
124131
params = json.loads(anc_ts.provenance(-2).record)["parameters"]
125132
assert params["command"] == "generate_ancestors"
126133
params = json.loads(anc_ts.provenance(-1).record)["parameters"]
127134
assert params["command"] == "match_ancestors"
135+
assert params["mismatch_ratio"] == mmr
128136

129-
def test_provenance_match_samples(self, small_sd_fixture):
137+
@pytest.mark.parametrize("mmr", [None, 0.1])
138+
def test_provenance_match_samples(self, small_sd_fixture, mmr):
130139
ancestors = tsinfer.generate_ancestors(small_sd_fixture)
131140
anc_ts = tsinfer.match_ancestors(small_sd_fixture, ancestors)
132-
ts = tsinfer.match_samples(small_sd_fixture, anc_ts)
141+
ts = tsinfer.match_samples(
142+
small_sd_fixture, anc_ts, mismatch_ratio=mmr, recombination_rate=1e-8
143+
)
133144
assert ts.num_provenances == small_sd_fixture.num_provenances + 3
134145
params = json.loads(ts.provenance(-3).record)["parameters"]
135146
assert params["command"] == "generate_ancestors"
136147
params = json.loads(ts.provenance(-2).record)["parameters"]
137148
assert params["command"] == "match_ancestors"
138149
params = json.loads(ts.provenance(-1).record)["parameters"]
139150
assert params["command"] == "match_samples"
151+
assert params["mismatch_ratio"] == mmr
140152

141153

142154
class TestGetProvenance:

tsinfer/inference.py

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -271,7 +271,8 @@ def infer(
271271
:type recombination_rate: float, msprime.RateMap
272272
:param float mismatch_ratio: The probability of a mismatch relative to the median
273273
probability of recombination between adjacent sites: can only be used if a
274-
recombination rate has been set (default: 1)
274+
recombination rate has been set (default: ``None`` treated as 1 if
275+
``recombination_rate`` is set).
275276
:param bool path_compression: Whether to merge edges that share identical
276277
paths (essentially taking advantage of shared recombination breakpoints).
277278
:param bool post_process: Whether to run the :func:`post_process` method on the
@@ -336,7 +337,10 @@ def infer(
336337
)
337338
if record_provenance:
338339
tables = inferred_ts.dump_tables()
339-
record = provenance.get_provenance_dict(command="infer")
340+
record = provenance.get_provenance_dict(
341+
command="infer",
342+
mismatch_ratio=mismatch_ratio,
343+
)
340344
tables.provenances.add_row(record=json.dumps(record))
341345
inferred_ts = tables.tree_sequence()
342346
return inferred_ts
@@ -452,7 +456,8 @@ def match_ancestors(
452456
:type recombination_rate: float, msprime.RateMap
453457
:param float mismatch_ratio: The probability of a mismatch relative to the median
454458
probability of recombination between adjacent sites: can only be used if a
455-
recombination rate has been set (default: 1)
459+
recombination rate has been set (default: ``None`` treated as 1 if
460+
``recombination_rate`` is set).
456461
:param bool path_compression: Whether to merge edges that share identical
457462
paths (essentially taking advantage of shared recombination breakpoints).
458463
:param int num_threads: The number of match worker threads to use. If
@@ -486,7 +491,9 @@ def match_ancestors(
486491
tables.provenances.add_row(timestamp=timestamp, record=json.dumps(record))
487492
if record_provenance:
488493
record = provenance.get_provenance_dict(
489-
command="match_ancestors", source={"uuid": ancestor_data.uuid}
494+
command="match_ancestors",
495+
source={"uuid": ancestor_data.uuid},
496+
mismatch_ratio=mismatch_ratio,
490497
)
491498
tables.provenances.add_row(record=json.dumps(record))
492499
ts = tables.tree_sequence()
@@ -541,7 +548,8 @@ def augment_ancestors(
541548
:type recombination_rate: float, msprime.RateMap
542549
:param float mismatch_ratio: The probability of a mismatch relative to the median
543550
probability of recombination between adjacent sites: can only be used if a
544-
recombination rate has been set (default: 1)
551+
recombination rate has been set (default: ``None`` treated as 1 if
552+
``recombination_rate`` is set).
545553
:param bool path_compression: Whether to merge edges that share identical
546554
paths (essentially taking advantage of shared recombination breakpoints).
547555
:param int num_threads: The number of match worker threads to use. If
@@ -574,7 +582,10 @@ def augment_ancestors(
574582
ts = manager.get_augmented_ancestors_tree_sequence(sample_indexes)
575583
if record_provenance:
576584
tables = ts.dump_tables()
577-
record = provenance.get_provenance_dict(command="augment_ancestors")
585+
record = provenance.get_provenance_dict(
586+
command="augment_ancestors",
587+
mismatch_ratio=mismatch_ratio,
588+
)
578589
tables.provenances.add_row(record=json.dumps(record))
579590
ts = tables.tree_sequence()
580591
return ts
@@ -628,7 +639,8 @@ def match_samples(
628639
:type recombination_rate: float, msprime.RateMap
629640
:param float mismatch_ratio: The probability of a mismatch relative to the median
630641
probability of recombination between adjacent sites: can only be used if a
631-
recombination rate has been set (default: 1)
642+
recombination rate has been set (default: ``None`` treated as 1 if
643+
``recombination_rate`` is set).
632644
:param bool path_compression: Whether to merge edges that share identical
633645
paths (essentially taking advantage of shared recombination breakpoints).
634646
:param array_like indexes: An array of indexes into the sample_data file of
@@ -706,7 +718,10 @@ def match_samples(
706718
if record_provenance:
707719
tables = ts.dump_tables()
708720
# We don't have a source here because tree sequence files don't have a UUID yet.
709-
record = provenance.get_provenance_dict(command="match_samples")
721+
record = provenance.get_provenance_dict(
722+
command="match_samples",
723+
mismatch_ratio=mismatch_ratio,
724+
)
710725
tables.provenances.add_row(record=json.dumps(record))
711726
ts = tables.tree_sequence()
712727
return ts

0 commit comments

Comments
 (0)