@@ -402,9 +402,11 @@ def __init__(
402
402
mutations_parent ,
403
403
mutations_time ,
404
404
mutations_derived_state ,
405
+ mutations_inherited_state ,
405
406
breakpoints ,
406
407
max_ancestral_length ,
407
408
max_derived_length ,
409
+ max_inherited_length ,
408
410
):
409
411
self .num_trees = num_trees
410
412
self .num_nodes = num_nodes
@@ -431,9 +433,11 @@ def __init__(
431
433
self .mutations_parent = mutations_parent
432
434
self .mutations_time = mutations_time
433
435
self .mutations_derived_state = mutations_derived_state
436
+ self .mutations_inherited_state = mutations_inherited_state
434
437
self .breakpoints = breakpoints
435
438
self .max_ancestral_length = max_ancestral_length
436
439
self .max_derived_length = max_derived_length
440
+ self .max_inherited_length = max_inherited_length
437
441
438
442
def tree_index (self ):
439
443
"""
@@ -526,7 +530,7 @@ def parent_index(self):
526
530
527
531
# We cache these classes to avoid repeated JIT compilation
528
532
@functools .lru_cache (None )
529
- def _jitwrap (max_ancestral_length , max_derived_length ):
533
+ def _jitwrap (max_ancestral_length , max_derived_length , max_inherited_length ):
530
534
# We have a circular dependency in JIT compilation between NumbaTreeSequence
531
535
# and NumbaTreeIndex so we used a deferred type to break it
532
536
tree_sequence_type = numba .deferred_type ()
@@ -576,9 +580,14 @@ def _jitwrap(max_ancestral_length, max_derived_length):
576
580
("mutations_parent" , numba .int32 [:]),
577
581
("mutations_time" , numba .float64 [:]),
578
582
("mutations_derived_state" , numba .types .UnicodeCharSeq (max_derived_length )[:]),
583
+ (
584
+ "mutations_inherited_state" ,
585
+ numba .types .UnicodeCharSeq (max_inherited_length )[:],
586
+ ),
579
587
("breakpoints" , numba .float64 [:]),
580
588
("max_ancestral_length" , numba .int32 ),
581
589
("max_derived_length" , numba .int32 ),
590
+ ("max_inherited_length" , numba .int32 ),
582
591
]
583
592
584
593
# The `tree_index` method on NumbaTreeSequence uses NumbaTreeIndex
@@ -614,8 +623,13 @@ def jitwrap(ts):
614
623
"""
615
624
max_ancestral_length = max (1 , max (map (len , ts .sites_ancestral_state ), default = 1 ))
616
625
max_derived_length = max (1 , max (map (len , ts .mutations_derived_state ), default = 1 ))
626
+ max_inherited_length = max (
627
+ 1 , max (map (len , ts .mutations_inherited_state ), default = 1 )
628
+ )
617
629
618
- JittedTreeSequence = _jitwrap (max_ancestral_length , max_derived_length )
630
+ JittedTreeSequence = _jitwrap (
631
+ max_ancestral_length , max_derived_length , max_inherited_length
632
+ )
619
633
620
634
# Create the tree sequence instance
621
635
numba_ts = JittedTreeSequence (
@@ -648,9 +662,13 @@ def jitwrap(ts):
648
662
mutations_derived_state = ts .mutations_derived_state .astype (
649
663
f"U{ max_derived_length } "
650
664
),
665
+ mutations_inherited_state = ts .mutations_inherited_state .astype (
666
+ f"U{ max_inherited_length } "
667
+ ),
651
668
breakpoints = ts .breakpoints (as_array = True ),
652
669
max_ancestral_length = max_ancestral_length ,
653
670
max_derived_length = max_derived_length ,
671
+ max_inherited_length = max_inherited_length ,
654
672
)
655
673
656
674
return numba_ts
0 commit comments