Skip to content

Commit 6180b4c

Browse files
Merge pull request #2303 from jeromekelleher/refactor-smck-distributed
Factor out distributed SMC logic
2 parents 8d65577 + ee2b6c0 commit 6180b4c

File tree

4 files changed

+447
-563
lines changed

4 files changed

+447
-563
lines changed

algorithms.py

Lines changed: 22 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -1021,57 +1021,12 @@ def initialise(self, ts):
10211021
lineage = root_lineages[node]
10221022
if lineage is not None and ts.nodes_time[node] == start_time:
10231023
seg = lineage.head
1024-
left_end = seg.left
10251024
while seg is not None:
10261025
self.set_segment_mass(seg)
10271026
lineage.tail = seg
10281027
seg = seg.next
10291028
self.add_lineage(lineage)
10301029

1031-
if self.model == "smc_k":
1032-
for node in range(ts.num_nodes):
1033-
lineage = root_lineages[node]
1034-
if lineage is not None:
1035-
seg = lineage.head
1036-
left_end = seg.left
1037-
pop = lineage.population
1038-
label = lineage.label
1039-
right_end = root_segments_tail[node].right
1040-
new_hull = self.alloc_hull(left_end, right_end, lineage)
1041-
# insert Hull
1042-
floor = self.P[pop].hulls_left[label].floor_key(new_hull)
1043-
insertion_order = 0
1044-
if floor is not None:
1045-
if floor.left == new_hull.left:
1046-
insertion_order = floor.insertion_order + 1
1047-
new_hull.insertion_order = insertion_order
1048-
self.P[pop].hulls_left[label][new_hull] = -1
1049-
1050-
# initialise the correct coalesceable pairs count
1051-
for pop in self.P:
1052-
for label, ost_left in enumerate(pop.hulls_left):
1053-
avl = ost_left.avl
1054-
ost_right = pop.hulls_right[label]
1055-
count = 0
1056-
for hull in avl.keys():
1057-
floor = ost_right.floor_key(HullEnd(hull.left))
1058-
num_ending_before_hull = 0
1059-
if floor is not None:
1060-
num_ending_before_hull = ost_right.rank[floor] + 1
1061-
num_pairs = count - num_ending_before_hull
1062-
avl[hull] = num_pairs
1063-
pop.coal_mass_index[label].set_value(hull.index, num_pairs)
1064-
# insert HullEnd
1065-
hull_end = HullEnd(hull.right)
1066-
floor = ost_right.floor_key(hull_end)
1067-
insertion_order = 0
1068-
if floor is not None:
1069-
if floor.x == hull.right:
1070-
insertion_order = floor.insertion_order + 1
1071-
hull_end.insertion_order = insertion_order
1072-
ost_right[hull_end] = -1
1073-
count += 1
1074-
10751030
def ancestors_remain(self):
10761031
"""
10771032
Returns True if the simulation is not finished, i.e., there is some ancestral
@@ -1203,6 +1158,15 @@ def store_edge(self, left, right, parent, child):
12031158
tskit.Edge(left=left, right=right, parent=parent, child=child)
12041159
)
12051160

1161+
def update_lineage_right(self, lineage):
1162+
if self.model == "smc_k":
1163+
# modify original hull
1164+
pop = lineage.population
1165+
hull = lineage.hull
1166+
old_right = hull.right
1167+
hull.right = min(lineage.tail.right + self.hull_offset, self.L)
1168+
self.P[pop].reset_hull_right(lineage.label, hull, old_right, hull.right)
1169+
12061170
def add_lineage(self, lineage):
12071171
pop = lineage.population
12081172
self.P[pop].add(lineage, lineage.label)
@@ -1213,6 +1177,15 @@ def add_lineage(self, lineage):
12131177
assert x.lineage == lineage
12141178
x = x.next
12151179

1180+
if self.model == "smc_k":
1181+
head = lineage.head
1182+
assert head.prev is None
1183+
hull = self.alloc_hull(head.left, head.right, lineage)
1184+
right = lineage.tail.right
1185+
hull.right = min(right + self.hull_offset, self.L)
1186+
pop = self.P[lineage.population]
1187+
pop.add_hull(lineage.label, hull)
1188+
12161189
def finalise(self):
12171190
"""
12181191
Finalises the simulation returns an msprime tree sequence object.
@@ -1838,22 +1811,11 @@ def hudson_recombination_event(self, label):
18381811
left_lineage.tail = x
18391812
lhs_tail = x
18401813

1814+
self.update_lineage_right(left_lineage)
18411815
right_lineage = self.alloc_lineage(alpha, left_lineage.population, label=label)
18421816
self.set_segment_mass(alpha)
18431817
self.add_lineage(right_lineage)
18441818

1845-
if self.model == "smc_k":
1846-
# modify original hull
1847-
pop = left_lineage.population
1848-
lhs_hull = lhs_tail.get_hull()
1849-
rhs_right = lhs_hull.right
1850-
lhs_hull.right = min(lhs_tail.right + self.hull_offset, self.L)
1851-
self.P[pop].reset_hull_right(label, lhs_hull, rhs_right, lhs_hull.right)
1852-
1853-
# create hull for alpha
1854-
alpha_hull = self.alloc_hull(alpha.left, rhs_right, right_lineage)
1855-
self.P[pop].add_hull(label, alpha_hull)
1856-
18571819
if self.additional_nodes.value & msprime.NODE_IS_RE_EVENT > 0:
18581820
self.store_node(left_lineage.population, flags=msprime.NODE_IS_RE_EVENT)
18591821
self.store_arg_edges(lhs_tail)
@@ -1892,11 +1854,8 @@ def wiuf_gene_conversion_within_event(self, label):
18921854
# lbp rbp
18931855
return None
18941856
self.num_gc_events += 1
1895-
hull = y.get_hull()
1896-
assert (self.model == "smc_k") == (hull is not None)
18971857
lineage = y.lineage
18981858
pop = lineage.population
1899-
reset_right = -1
19001859

19011860
# Process left break
19021861
insert_alpha = True
@@ -1915,7 +1874,6 @@ def wiuf_gene_conversion_within_event(self, label):
19151874
insert_alpha = False
19161875
else:
19171876
x.next = None
1918-
reset_right = x.right
19191877
y.prev = None
19201878
alpha = y
19211879
tail = x
@@ -1937,15 +1895,11 @@ def wiuf_gene_conversion_within_event(self, label):
19371895
y.right = left_breakpoint
19381896
self.set_segment_mass(y)
19391897
tail = y
1940-
reset_right = left_breakpoint
19411898
self.set_segment_mass(alpha)
19421899

19431900
# Find the segment z that the right breakpoint falls in
19441901
z = alpha
1945-
hull_left = z.left
1946-
hull_right = -1
19471902
while z is not None and right_breakpoint >= z.right:
1948-
hull_right = z.right
19491903
z = z.next
19501904

19511905
head = None
@@ -1969,7 +1923,6 @@ def wiuf_gene_conversion_within_event(self, label):
19691923
z.right = right_breakpoint
19701924
z.next = None
19711925
self.set_segment_mass(z)
1972-
hull_right = right_breakpoint
19731926
else:
19741927
# tail z
19751928
# ======
@@ -1987,12 +1940,6 @@ def wiuf_gene_conversion_within_event(self, label):
19871940
tail.next = head
19881941
head.prev = tail
19891942
self.set_segment_mass(head)
1990-
else:
1991-
# rbp lies beyond segment chain, regular recombination logic applies
1992-
if insert_alpha and self.model == "smc_k":
1993-
assert reset_right > 0
1994-
reset_right = min(reset_right + self.hull_offset, self.L)
1995-
self.P[pop].reset_hull_right(label, hull, hull.right, reset_right)
19961943

19971944
# y z
19981945
# | ========== ... ===== |
@@ -2007,12 +1954,8 @@ def wiuf_gene_conversion_within_event(self, label):
20071954
if new_individual_head is not None:
20081955
# FIXME when doing the smc_k update
20091956
lineage.reset_segments()
1957+
self.update_lineage_right(lineage)
20101958
new_lineage = self.alloc_lineage(new_individual_head, pop)
2011-
if self.model == "smc_k":
2012-
assert hull_left < hull_right
2013-
hull_right = min(self.L, hull_right + self.hull_offset)
2014-
hull = self.alloc_hull(hull_left, hull_right, new_lineage)
2015-
self.P[new_lineage.population].add_hull(new_lineage.label, hull)
20161959
self.add_lineage(new_lineage)
20171960

20181961
def wiuf_gene_conversion_left_event(self, label):
@@ -2044,8 +1987,6 @@ def wiuf_gene_conversion_left_event(self, label):
20441987
x = y.prev
20451988
lineage = y.lineage
20461989
pop = lineage.population
2047-
lhs_hull = y.get_hull()
2048-
assert (self.model == "smc_k") == (lhs_hull is not None)
20491990
if y.left < bp:
20501991
# x y
20511992
# ===== =====|====
@@ -2063,7 +2004,6 @@ def wiuf_gene_conversion_left_event(self, label):
20632004
y.next = None
20642005
y.right = bp
20652006
self.set_segment_mass(y)
2066-
right = y.right
20672007
else:
20682008
# x y
20692009
# ===== | =========
@@ -2077,19 +2017,10 @@ def wiuf_gene_conversion_left_event(self, label):
20772017
x.next = None
20782018
y.prev = None
20792019
alpha = y
2080-
right = x.right
2081-
2082-
if self.model == "smc_k":
2083-
# lhs logic is identical to the lhs recombination event
2084-
lhs_old_right = lhs_hull.right
2085-
lhs_new_right = min(self.L, right + self.hull_offset)
2086-
self.P[pop].reset_hull_right(label, lhs_hull, lhs_old_right, lhs_new_right)
2087-
2088-
# rhs
2089-
hull = self.alloc_hull(alpha.left, lhs_old_right, alpha)
2090-
self.P[pop].add_hull(label, hull)
20912020

2021+
# FIXME
20922022
lineage.reset_segments()
2023+
self.update_lineage_right(lineage)
20932024

20942025
self.set_segment_mass(alpha)
20952026
assert alpha.prev is None
@@ -2576,16 +2507,6 @@ def insert_merged_lineage(
25762507
# assert tail == new_lineage.tail
25772508
self.add_lineage(new_lineage)
25782509

2579-
if self.model == "smc_k":
2580-
merged_head = new_lineage.head
2581-
assert merged_head.prev is None
2582-
hull = self.alloc_hull(merged_head.left, merged_head.right, new_lineage)
2583-
while merged_head is not None:
2584-
right = merged_head.right
2585-
merged_head = merged_head.next
2586-
hull.right = min(right + self.hull_offset, self.L)
2587-
pop = self.P[new_lineage.population]
2588-
pop.add_hull(new_lineage.label, hull)
25892510
return new_lineage
25902511

25912512
def print_state(self, verify=False):

0 commit comments

Comments
 (0)