Skip to content

Commit fffa7ae

Browse files
awohnsjeromekelleher
authored andcommitted
added simple kc_distance_tree function
1 parent e0d55ed commit fffa7ae

File tree

1 file changed

+92
-4
lines changed

1 file changed

+92
-4
lines changed

python/tests/test_topology.py

Lines changed: 92 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
import random
3030
import json
3131
import sys
32+
import math
3233

3334
import numpy as np
3435
import msprime
@@ -120,11 +121,11 @@ def generate_segments(n, sequence_length=100, seed=None):
120121
return segs
121122

122123

123-
def kc_distance_tree(tree1, tree2, topo_v_age=0):
124+
def kc_distance_tree(tree1, tree2, lambda_param=0):
124125
"""
125126
Returns the Kendall-Colijn distance between the specified pair of trees.
126-
topo_v_age determines weight of topology vs branch lengths in calculating
127-
the distance. Set topo_v_age at 0 to only consider topology, set at 1 to
127+
lambda_param determines weight of topology vs branch lengths in calculating
128+
the distance. Set lambda_param at 0 to only consider topology, set at 1 to
128129
only consider branch lengths. See Kendall & Colijn (2016):
129130
https://academic.oup.com/mbe/article/33/10/2735/2925548
130131
"""
@@ -155,7 +156,57 @@ def kc_distance_tree(tree1, tree2, topo_v_age=0):
155156
M[tree_index][pair_index] = tree.time(tree.root) - time
156157
if len(tree.children(u)) == 0:
157158
M[tree_index][u + n] = tree.branch_length(u)
158-
return np.linalg.norm((1 - topo_v_age) * (m[0] - m[1]) + topo_v_age * (M[0] - M[1]))
159+
return np.linalg.norm((1 - lambda_param) *
160+
(m[0] - m[1]) + lambda_param * (M[0] - M[1]))
161+
162+
163+
def kc_distance_tree_simple(tree1, tree2, lambda_param=0):
164+
"""
165+
Simplified version of the kc_distance_tree() function above.
166+
Written without Python features to aid writing C implementation.
167+
"""
168+
samples = tree1.tree_sequence.samples()
169+
for sample1, sample2 in zip(samples, tree2.tree_sequence.samples()):
170+
if sample1 != sample2:
171+
raise ValueError("Trees must have the same samples")
172+
if not len(tree1.roots) == len(tree2.roots) == 1:
173+
raise ValueError("Trees must have one root")
174+
k = samples.shape[0]
175+
n = (k * (k - 1)) // 2
176+
m = [np.ones(n + k), np.ones(n + k)]
177+
M = [np.zeros(n + k), np.zeros(n + k)]
178+
path_distance = [np.zeros(tree1.num_nodes), np.zeros(tree2.num_nodes)]
179+
time_distance = [np.zeros(tree1.num_nodes), np.zeros(tree2.num_nodes)]
180+
for tree_index, tree in enumerate([tree1, tree2]):
181+
stack = [(tree.root, 0, tree.time(tree.root))]
182+
while len(stack) > 0:
183+
u, depth, time = stack.pop()
184+
children = tree.children(u)
185+
for v in children:
186+
stack.append((v, depth + 1, tree.time(v)))
187+
path_distance[tree_index][u] = depth
188+
time_distance[tree_index][u] = tree.time(tree.root) - time
189+
if len(tree.children(u)) == 0:
190+
M[tree_index][u + n] = tree.branch_length(u)
191+
192+
for index, n1 in enumerate(samples):
193+
for n2 in samples[index + 1:]:
194+
mrca = tree.mrca(n1, n2)
195+
pair_index = n1 * (n1 - 2 * k + 1) // -2 + n2 - n1 - 1
196+
assert m[tree_index][pair_index] == 1
197+
m[tree_index][pair_index] = path_distance[tree_index][mrca]
198+
M[tree_index][pair_index] = time_distance[tree_index][mrca]
199+
200+
vT1 = 0
201+
vT2 = 0
202+
distance_sum = 0
203+
204+
for i in range(n + k):
205+
vT1 = (m[0][i] * (1 - lambda_param)) + (lambda_param * M[0][i])
206+
vT2 = (m[1][i] * (1 - lambda_param)) + (lambda_param * M[1][i])
207+
distance_sum += (vT1 - vT2) ** 2
208+
209+
return math.sqrt(distance_sum)
159210

160211

161212
class TestKCMetric(unittest.TestCase):
@@ -168,28 +219,35 @@ def test_same_tree_zero_distance(self):
168219
ts = msprime.simulate(n, random_seed=seed)
169220
tree = ts.first()
170221
self.assertEqual(kc_distance_tree(tree, tree), 0)
222+
self.assertEqual(kc_distance_tree_simple(tree, tree), 0)
171223
ts = msprime.simulate(n, random_seed=seed)
172224
tree2 = ts.first()
173225
self.assertEqual(kc_distance_tree(tree, tree2), 0)
226+
self.assertEqual(kc_distance_tree_simple(tree, tree2), 0)
174227

175228
def test_sample_2_zero_distance(self):
176229
# All trees with 2 leaves must be equal distance from each other.
177230
for seed in range(1, 10):
178231
tree1 = msprime.simulate(2, random_seed=seed).first()
179232
tree2 = msprime.simulate(2, random_seed=seed + 1).first()
180233
self.assertEqual(kc_distance_tree(tree1, tree2), 0)
234+
self.assertEqual(kc_distance_tree_simple(tree1, tree2), 0)
181235

182236
def test_different_samples_error(self):
183237
tree1 = msprime.simulate(10, random_seed=1).first()
184238
tree2 = msprime.simulate(2, random_seed=1).first()
185239
self.assertRaises(ValueError, kc_distance_tree, tree1, tree2)
240+
self.assertRaises(ValueError, kc_distance_tree_simple, tree1, tree2)
186241

187242
def validate_trees(self, n):
188243
for seed in range(1, 10):
189244
tree1 = msprime.simulate(n, random_seed=seed).first()
190245
tree2 = msprime.simulate(n, random_seed=seed + 1).first()
191246
self.assertAlmostEqual(
192247
kc_distance_tree(tree1, tree2), kc_distance_tree(tree1, tree2))
248+
self.assertAlmostEqual(
249+
kc_distance_tree_simple(tree1, tree2),
250+
kc_distance_tree_simple(tree1, tree2))
193251

194252
def test_sample_3(self):
195253
self.validate_trees(3)
@@ -227,12 +285,24 @@ def validate_nonbinary_trees(self, n):
227285
kc_distance_tree(tree1, tree2), kc_distance_tree(tree1, tree2))
228286
self.assertAlmostEqual(
229287
kc_distance_tree(tree2, tree1), kc_distance_tree(tree2, tree1))
288+
self.assertAlmostEqual(
289+
kc_distance_tree_simple(tree1, tree2),
290+
kc_distance_tree_simple(tree1, tree2))
291+
self.assertAlmostEqual(
292+
kc_distance_tree_simple(tree2, tree1),
293+
kc_distance_tree_simple(tree2, tree1))
230294
# compare to a binary tree also
231295
tree2 = msprime.simulate(n, random_seed=seed + 1).first()
232296
self.assertAlmostEqual(
233297
kc_distance_tree(tree1, tree2), kc_distance_tree(tree1, tree2))
234298
self.assertAlmostEqual(
235299
kc_distance_tree(tree2, tree1), kc_distance_tree(tree2, tree1))
300+
self.assertAlmostEqual(
301+
kc_distance_tree_simple(tree1, tree2),
302+
kc_distance_tree_simple(tree1, tree2))
303+
self.assertAlmostEqual(
304+
kc_distance_tree_simple(tree2, tree1),
305+
kc_distance_tree_simple(tree2, tree1))
236306

237307
def test_non_binary_sample_10(self):
238308
self.validate_nonbinary_trees(10)
@@ -277,6 +347,10 @@ def test_known_kc_sample_3(self):
277347
kc_distance_tree(tree_1, tree_2, 0), 0)
278348
self.assertAlmostEqual(
279349
kc_distance_tree(tree_1, tree_2, 1), 4.243, places=3)
350+
self.assertAlmostEqual(
351+
kc_distance_tree_simple(tree_1, tree_2, 0), 0)
352+
self.assertAlmostEqual(
353+
kc_distance_tree_simple(tree_1, tree_2, 1), 4.243, places=3)
280354

281355
def test_10_samples(self):
282356
nodes_1 = io.StringIO("""\
@@ -377,6 +451,10 @@ def test_10_samples(self):
377451
kc_distance_tree(tree_1, tree_2, 0), 12.85, places=2)
378452
self.assertAlmostEqual(
379453
kc_distance_tree(tree_1, tree_2, 1), 10.64, places=2)
454+
self.assertAlmostEqual(
455+
kc_distance_tree_simple(tree_1, tree_2, 0), 12.85, places=2)
456+
self.assertAlmostEqual(
457+
kc_distance_tree_simple(tree_1, tree_2, 1), 10.64, places=2)
380458

381459
def test_15_samples(self):
382460
nodes_1 = io.StringIO("""\
@@ -518,6 +596,10 @@ def test_15_samples(self):
518596
kc_distance_tree(tree_1, tree_2, 0), 19.95, places=2)
519597
self.assertAlmostEqual(
520598
kc_distance_tree(tree_1, tree_2, 1), 17.74, places=2)
599+
self.assertAlmostEqual(
600+
kc_distance_tree_simple(tree_1, tree_2, 0), 19.95, places=2)
601+
self.assertAlmostEqual(
602+
kc_distance_tree_simple(tree_1, tree_2, 1), 17.74, places=2)
521603

522604
def test_nobinary_trees(self):
523605
nodes_1 = io.StringIO("""\
@@ -627,6 +709,10 @@ def test_nobinary_trees(self):
627709
kc_distance_tree(tree_1, tree_2, 0), 9.434, places=3)
628710
self.assertAlmostEqual(
629711
kc_distance_tree(tree_1, tree_2, 1), 44, places=1)
712+
self.assertAlmostEqual(
713+
kc_distance_tree_simple(tree_1, tree_2, 0), 9.434, places=3)
714+
self.assertAlmostEqual(
715+
kc_distance_tree_simple(tree_1, tree_2, 1), 44, places=1)
630716

631717
def test_multiple_roots(self):
632718
tables = tskit.TableCollection(sequence_length=1.0)
@@ -643,6 +729,8 @@ def test_multiple_roots(self):
643729

644730
with self.assertRaises(ValueError):
645731
kc_distance_tree(ts.first(), ts.first(), 0)
732+
with self.assertRaises(ValueError):
733+
kc_distance_tree_simple(ts.first(), ts.first(), 0)
646734

647735

648736
class TestOverlappingSegments(unittest.TestCase):

0 commit comments

Comments
 (0)