@@ -402,9 +402,11 @@ def __init__(
402402 mutations_parent ,
403403 mutations_time ,
404404 mutations_derived_state ,
405+ mutations_inherited_state ,
405406 breakpoints ,
406407 max_ancestral_length ,
407408 max_derived_length ,
409+ max_inherited_length ,
408410 ):
409411 self .num_trees = num_trees
410412 self .num_nodes = num_nodes
@@ -431,9 +433,11 @@ def __init__(
431433 self .mutations_parent = mutations_parent
432434 self .mutations_time = mutations_time
433435 self .mutations_derived_state = mutations_derived_state
436+ self .mutations_inherited_state = mutations_inherited_state
434437 self .breakpoints = breakpoints
435438 self .max_ancestral_length = max_ancestral_length
436439 self .max_derived_length = max_derived_length
440+ self .max_inherited_length = max_inherited_length
437441
438442 def tree_index (self ):
439443 """
@@ -526,7 +530,7 @@ def parent_index(self):
526530
527531# We cache these classes to avoid repeated JIT compilation
528532@functools .lru_cache (None )
529- def _jitwrap (max_ancestral_length , max_derived_length ):
533+ def _jitwrap (max_ancestral_length , max_derived_length , max_inherited_length ):
530534 # We have a circular dependency in JIT compilation between NumbaTreeSequence
531535 # and NumbaTreeIndex so we used a deferred type to break it
532536 tree_sequence_type = numba .deferred_type ()
@@ -576,9 +580,14 @@ def _jitwrap(max_ancestral_length, max_derived_length):
576580 ("mutations_parent" , numba .int32 [:]),
577581 ("mutations_time" , numba .float64 [:]),
578582 ("mutations_derived_state" , numba .types .UnicodeCharSeq (max_derived_length )[:]),
583+ (
584+ "mutations_inherited_state" ,
585+ numba .types .UnicodeCharSeq (max_inherited_length )[:],
586+ ),
579587 ("breakpoints" , numba .float64 [:]),
580588 ("max_ancestral_length" , numba .int32 ),
581589 ("max_derived_length" , numba .int32 ),
590+ ("max_inherited_length" , numba .int32 ),
582591 ]
583592
584593 # The `tree_index` method on NumbaTreeSequence uses NumbaTreeIndex
@@ -614,8 +623,13 @@ def jitwrap(ts):
614623 """
615624 max_ancestral_length = max (1 , max (map (len , ts .sites_ancestral_state ), default = 1 ))
616625 max_derived_length = max (1 , max (map (len , ts .mutations_derived_state ), default = 1 ))
626+ max_inherited_length = max (
627+ 1 , max (map (len , ts .mutations_inherited_state ), default = 1 )
628+ )
617629
618- JittedTreeSequence = _jitwrap (max_ancestral_length , max_derived_length )
630+ JittedTreeSequence = _jitwrap (
631+ max_ancestral_length , max_derived_length , max_inherited_length
632+ )
619633
620634 # Create the tree sequence instance
621635 numba_ts = JittedTreeSequence (
@@ -648,9 +662,13 @@ def jitwrap(ts):
648662 mutations_derived_state = ts .mutations_derived_state .astype (
649663 f"U{ max_derived_length } "
650664 ),
665+ mutations_inherited_state = ts .mutations_inherited_state .astype (
666+ f"U{ max_inherited_length } "
667+ ),
651668 breakpoints = ts .breakpoints (as_array = True ),
652669 max_ancestral_length = max_ancestral_length ,
653670 max_derived_length = max_derived_length ,
671+ max_inherited_length = max_inherited_length ,
654672 )
655673
656674 return numba_ts
0 commit comments