@@ -630,17 +630,20 @@ def __init__(self, priors, lik, *, progress=False):
630
630
631
631
# === Grouped edge iterators ===
632
632
633
- def edges_by_parent_asc (self ):
633
+ def edges_by_parent_asc (self , grouped = True ):
634
634
"""
635
635
Return an itertools.groupby object of edges grouped by parent in ascending order
636
636
of the time of the parent. Since tree sequence properties guarantee that edges
637
637
are listed in nondecreasing order of parent time
638
638
(https://tskit.readthedocs.io/en/latest/data-model.html#edge-requirements)
639
639
we can simply use the standard edge order
640
640
"""
641
- return itertools .groupby (self .ts .edges (), operator .attrgetter ("parent" ))
641
+ if grouped :
642
+ return itertools .groupby (self .ts .edges (), operator .attrgetter ("parent" ))
643
+ else :
644
+ return self .ts .edges ()
642
645
643
- def edges_by_child_desc (self ):
646
+ def edges_by_child_desc (self , grouped = True ):
644
647
"""
645
648
Return an itertools.groupby object of edges grouped by child in descending order
646
649
of the time of the child.
@@ -651,9 +654,12 @@ def edges_by_child_desc(self):
651
654
(self .ts .edges_child , - self .ts .nodes_time [self .ts .edges_child ])
652
655
)
653
656
)
654
- return itertools .groupby (it , operator .attrgetter ("child" ))
657
+ if grouped :
658
+ return itertools .groupby (it , operator .attrgetter ("child" ))
659
+ else :
660
+ return it
655
661
656
- def edges_by_child_then_parent_desc (self ):
662
+ def edges_by_child_then_parent_desc (self , grouped = True ):
657
663
"""
658
664
Return an itertools.groupby object of edges grouped by child in descending order
659
665
of the time of the child, then by descending order of age of child
@@ -675,7 +681,10 @@ def edges_by_child_then_parent_desc(self):
675
681
np .argsort (w , order = ("child_age" , "child_node" , "parent_age" ))
676
682
)
677
683
)
678
- return itertools .groupby (sorted_child_parent , operator .attrgetter ("child" ))
684
+ if grouped :
685
+ return itertools .groupby (sorted_child_parent , operator .attrgetter ("child" ))
686
+ else :
687
+ return sorted_child_parent
679
688
680
689
# === MAIN ALGORITHMS ===
681
690
@@ -1038,8 +1047,8 @@ def iterate(self, *, progress=None, **kwargs):
1038
1047
Update edge factors from leaves to root then from root to leaves,
1039
1048
and return approximate log marginal likelihood
1040
1049
"""
1041
- self .propagate (edges = self .edges_by_parent_asc (), progress = progress )
1042
- self .propagate (edges = self .edges_by_child_desc (), progress = progress )
1050
+ self .propagate (edges = self .edges_by_parent_asc (grouped = False ), progress = progress )
1051
+ self .propagate (edges = self .edges_by_child_desc (grouped = False ), progress = progress )
1043
1052
# TODO
1044
1053
# marginal_lik = np.sum(self.factor_norm)
1045
1054
# return marginal_lik
0 commit comments