|
39 | 39 | from . import util
|
40 | 40 |
|
41 | 41 |
|
42 |
| -PriorParams_base = namedtuple("PriorParams", "alpha, beta, mean, var") |
43 |
| - |
44 |
| - |
45 |
| -class PriorParams(PriorParams_base): |
| 42 | +class PriorParams(namedtuple("PriorParamsBase", "alpha, beta, mean, var")): |
46 | 43 | @classmethod
|
47 | 44 | def field_index(cls, fieldname):
|
48 | 45 | return np.where([f == fieldname for f in cls._fields])[0][0]
|
@@ -652,24 +649,13 @@ def save_to_spans(prev_tree, node, num_fixed_at_0_treenodes):
|
652 | 649 | unary_descendants = set()
|
653 | 650 | for node in changed_nodes:
|
654 | 651 | children = prev_tree.children(node)
|
655 |
| - if children is not None: |
656 |
| - if len(children) == 1: |
657 |
| - # Keep descending |
658 |
| - while True: |
659 |
| - children = prev_tree.children(node) |
660 |
| - if len(children) != 1: |
661 |
| - break |
662 |
| - unary_descendants.add(node) |
663 |
| - node = children[0] |
664 |
| - else: |
665 |
| - # Descend all branches, looking for unary nodes |
666 |
| - for node in prev_tree.children(node): |
667 |
| - while True: |
668 |
| - children = prev_tree.children(node) |
669 |
| - if len(children) != 1: |
670 |
| - break |
671 |
| - unary_descendants.add(node) |
672 |
| - node = children[0] |
| 652 | + for child in children: |
| 653 | + while True: |
| 654 | + children = prev_tree.children(child) |
| 655 | + if len(children) != 1: |
| 656 | + break |
| 657 | + unary_descendants.add(child) |
| 658 | + child = children[0] |
673 | 659 |
|
674 | 660 | # find all the nodes in the tree that might have changed their number
|
675 | 661 | # of descendants, and reset. This might include nodes that are not in
|
|
0 commit comments