2929import random
3030import json
3131import sys
32+ import math
3233
3334import numpy as np
3435import msprime
@@ -120,11 +121,11 @@ def generate_segments(n, sequence_length=100, seed=None):
120121 return segs
121122
122123
123- def kc_distance_tree (tree1 , tree2 , topo_v_age = 0 ):
124+ def kc_distance_tree (tree1 , tree2 , lambda_param = 0 ):
124125 """
125126 Returns the Kendall-Colijn distance between the specified pair of trees.
126- topo_v_age determines weight of topology vs branch lengths in calculating
127- the distance. Set topo_v_age at 0 to only consider topology, set at 1 to
127+ lambda_param determines weight of topology vs branch lengths in calculating
128+ the distance. Set lambda_param at 0 to only consider topology, set at 1 to
128129 only consider branch lengths. See Kendall & Colijn (2016):
129130 https://academic.oup.com/mbe/article/33/10/2735/2925548
130131 """
@@ -155,7 +156,57 @@ def kc_distance_tree(tree1, tree2, topo_v_age=0):
155156 M [tree_index ][pair_index ] = tree .time (tree .root ) - time
156157 if len (tree .children (u )) == 0 :
157158 M [tree_index ][u + n ] = tree .branch_length (u )
158- return np .linalg .norm ((1 - topo_v_age ) * (m [0 ] - m [1 ]) + topo_v_age * (M [0 ] - M [1 ]))
159+ return np .linalg .norm ((1 - lambda_param ) *
160+ (m [0 ] - m [1 ]) + lambda_param * (M [0 ] - M [1 ]))
161+
162+
163+ def kc_distance_tree_simple (tree1 , tree2 , lambda_param = 0 ):
164+ """
165+ Simplified version of the kc_distance_tree() function above.
166+ Written without Python features to aid writing C implementation.
167+ """
168+ samples = tree1 .tree_sequence .samples ()
169+ for sample1 , sample2 in zip (samples , tree2 .tree_sequence .samples ()):
170+ if sample1 != sample2 :
171+ raise ValueError ("Trees must have the same samples" )
172+ if not len (tree1 .roots ) == len (tree2 .roots ) == 1 :
173+ raise ValueError ("Trees must have one root" )
174+ k = samples .shape [0 ]
175+ n = (k * (k - 1 )) // 2
176+ m = [np .ones (n + k ), np .ones (n + k )]
177+ M = [np .zeros (n + k ), np .zeros (n + k )]
178+ path_distance = [np .zeros (tree1 .num_nodes ), np .zeros (tree2 .num_nodes )]
179+ time_distance = [np .zeros (tree1 .num_nodes ), np .zeros (tree2 .num_nodes )]
180+ for tree_index , tree in enumerate ([tree1 , tree2 ]):
181+ stack = [(tree .root , 0 , tree .time (tree .root ))]
182+ while len (stack ) > 0 :
183+ u , depth , time = stack .pop ()
184+ children = tree .children (u )
185+ for v in children :
186+ stack .append ((v , depth + 1 , tree .time (v )))
187+ path_distance [tree_index ][u ] = depth
188+ time_distance [tree_index ][u ] = tree .time (tree .root ) - time
189+ if len (tree .children (u )) == 0 :
190+ M [tree_index ][u + n ] = tree .branch_length (u )
191+
192+ for index , n1 in enumerate (samples ):
193+ for n2 in samples [index + 1 :]:
194+ mrca = tree .mrca (n1 , n2 )
195+ pair_index = n1 * (n1 - 2 * k + 1 ) // - 2 + n2 - n1 - 1
196+ assert m [tree_index ][pair_index ] == 1
197+ m [tree_index ][pair_index ] = path_distance [tree_index ][mrca ]
198+ M [tree_index ][pair_index ] = time_distance [tree_index ][mrca ]
199+
200+ vT1 = 0
201+ vT2 = 0
202+ distance_sum = 0
203+
204+ for i in range (n + k ):
205+ vT1 = (m [0 ][i ] * (1 - lambda_param )) + (lambda_param * M [0 ][i ])
206+ vT2 = (m [1 ][i ] * (1 - lambda_param )) + (lambda_param * M [1 ][i ])
207+ distance_sum += (vT1 - vT2 ) ** 2
208+
209+ return math .sqrt (distance_sum )
159210
160211
161212class TestKCMetric (unittest .TestCase ):
@@ -168,28 +219,35 @@ def test_same_tree_zero_distance(self):
168219 ts = msprime .simulate (n , random_seed = seed )
169220 tree = ts .first ()
170221 self .assertEqual (kc_distance_tree (tree , tree ), 0 )
222+ self .assertEqual (kc_distance_tree_simple (tree , tree ), 0 )
171223 ts = msprime .simulate (n , random_seed = seed )
172224 tree2 = ts .first ()
173225 self .assertEqual (kc_distance_tree (tree , tree2 ), 0 )
226+ self .assertEqual (kc_distance_tree_simple (tree , tree2 ), 0 )
174227
175228 def test_sample_2_zero_distance (self ):
176229 # All trees with 2 leaves must be equal distance from each other.
177230 for seed in range (1 , 10 ):
178231 tree1 = msprime .simulate (2 , random_seed = seed ).first ()
179232 tree2 = msprime .simulate (2 , random_seed = seed + 1 ).first ()
180233 self .assertEqual (kc_distance_tree (tree1 , tree2 ), 0 )
234+ self .assertEqual (kc_distance_tree_simple (tree1 , tree2 ), 0 )
181235
182236 def test_different_samples_error (self ):
183237 tree1 = msprime .simulate (10 , random_seed = 1 ).first ()
184238 tree2 = msprime .simulate (2 , random_seed = 1 ).first ()
185239 self .assertRaises (ValueError , kc_distance_tree , tree1 , tree2 )
240+ self .assertRaises (ValueError , kc_distance_tree_simple , tree1 , tree2 )
186241
187242 def validate_trees (self , n ):
188243 for seed in range (1 , 10 ):
189244 tree1 = msprime .simulate (n , random_seed = seed ).first ()
190245 tree2 = msprime .simulate (n , random_seed = seed + 1 ).first ()
191246 self .assertAlmostEqual (
192247 kc_distance_tree (tree1 , tree2 ), kc_distance_tree (tree1 , tree2 ))
248+ self .assertAlmostEqual (
249+ kc_distance_tree_simple (tree1 , tree2 ),
250+ kc_distance_tree_simple (tree1 , tree2 ))
193251
194252 def test_sample_3 (self ):
195253 self .validate_trees (3 )
@@ -227,12 +285,24 @@ def validate_nonbinary_trees(self, n):
227285 kc_distance_tree (tree1 , tree2 ), kc_distance_tree (tree1 , tree2 ))
228286 self .assertAlmostEqual (
229287 kc_distance_tree (tree2 , tree1 ), kc_distance_tree (tree2 , tree1 ))
288+ self .assertAlmostEqual (
289+ kc_distance_tree_simple (tree1 , tree2 ),
290+ kc_distance_tree_simple (tree1 , tree2 ))
291+ self .assertAlmostEqual (
292+ kc_distance_tree_simple (tree2 , tree1 ),
293+ kc_distance_tree_simple (tree2 , tree1 ))
230294 # compare to a binary tree also
231295 tree2 = msprime .simulate (n , random_seed = seed + 1 ).first ()
232296 self .assertAlmostEqual (
233297 kc_distance_tree (tree1 , tree2 ), kc_distance_tree (tree1 , tree2 ))
234298 self .assertAlmostEqual (
235299 kc_distance_tree (tree2 , tree1 ), kc_distance_tree (tree2 , tree1 ))
300+ self .assertAlmostEqual (
301+ kc_distance_tree_simple (tree1 , tree2 ),
302+ kc_distance_tree_simple (tree1 , tree2 ))
303+ self .assertAlmostEqual (
304+ kc_distance_tree_simple (tree2 , tree1 ),
305+ kc_distance_tree_simple (tree2 , tree1 ))
236306
237307 def test_non_binary_sample_10 (self ):
238308 self .validate_nonbinary_trees (10 )
@@ -277,6 +347,10 @@ def test_known_kc_sample_3(self):
277347 kc_distance_tree (tree_1 , tree_2 , 0 ), 0 )
278348 self .assertAlmostEqual (
279349 kc_distance_tree (tree_1 , tree_2 , 1 ), 4.243 , places = 3 )
350+ self .assertAlmostEqual (
351+ kc_distance_tree_simple (tree_1 , tree_2 , 0 ), 0 )
352+ self .assertAlmostEqual (
353+ kc_distance_tree_simple (tree_1 , tree_2 , 1 ), 4.243 , places = 3 )
280354
281355 def test_10_samples (self ):
282356 nodes_1 = io .StringIO ("""\
@@ -377,6 +451,10 @@ def test_10_samples(self):
377451 kc_distance_tree (tree_1 , tree_2 , 0 ), 12.85 , places = 2 )
378452 self .assertAlmostEqual (
379453 kc_distance_tree (tree_1 , tree_2 , 1 ), 10.64 , places = 2 )
454+ self .assertAlmostEqual (
455+ kc_distance_tree_simple (tree_1 , tree_2 , 0 ), 12.85 , places = 2 )
456+ self .assertAlmostEqual (
457+ kc_distance_tree_simple (tree_1 , tree_2 , 1 ), 10.64 , places = 2 )
380458
381459 def test_15_samples (self ):
382460 nodes_1 = io .StringIO ("""\
@@ -518,6 +596,10 @@ def test_15_samples(self):
518596 kc_distance_tree (tree_1 , tree_2 , 0 ), 19.95 , places = 2 )
519597 self .assertAlmostEqual (
520598 kc_distance_tree (tree_1 , tree_2 , 1 ), 17.74 , places = 2 )
599+ self .assertAlmostEqual (
600+ kc_distance_tree_simple (tree_1 , tree_2 , 0 ), 19.95 , places = 2 )
601+ self .assertAlmostEqual (
602+ kc_distance_tree_simple (tree_1 , tree_2 , 1 ), 17.74 , places = 2 )
521603
522604 def test_nobinary_trees (self ):
523605 nodes_1 = io .StringIO ("""\
@@ -627,6 +709,10 @@ def test_nobinary_trees(self):
627709 kc_distance_tree (tree_1 , tree_2 , 0 ), 9.434 , places = 3 )
628710 self .assertAlmostEqual (
629711 kc_distance_tree (tree_1 , tree_2 , 1 ), 44 , places = 1 )
712+ self .assertAlmostEqual (
713+ kc_distance_tree_simple (tree_1 , tree_2 , 0 ), 9.434 , places = 3 )
714+ self .assertAlmostEqual (
715+ kc_distance_tree_simple (tree_1 , tree_2 , 1 ), 44 , places = 1 )
630716
631717 def test_multiple_roots (self ):
632718 tables = tskit .TableCollection (sequence_length = 1.0 )
@@ -643,6 +729,8 @@ def test_multiple_roots(self):
643729
644730 with self .assertRaises (ValueError ):
645731 kc_distance_tree (ts .first (), ts .first (), 0 )
732+ with self .assertRaises (ValueError ):
733+ kc_distance_tree_simple (ts .first (), ts .first (), 0 )
646734
647735
648736class TestOverlappingSegments (unittest .TestCase ):
0 commit comments