@@ -388,13 +388,29 @@ def __init__(
388388 raise ValueError ("u must be between 0 and 1" )
389389 alleles = [str (int (x )) for x in range (lo , hi + 1 )]
390390 transition_matrix = general_microsat_rate_matrix (s , u , v , p , m , lo , hi )
391+
392+ # Print for debugging (keep existing print statements)
393+ print ("Transition matrix:" )
394+ print (transition_matrix , np .sum (transition_matrix ))
395+ print ("Transition matrix T:" )
396+ print (transition_matrix .T , np .sum (transition_matrix .T ))
397+
391398 if root_distribution is None :
392- # solve for stationary distribution
393- S , U = np .linalg .eig (transition_matrix .T )
394- U = np .real_if_close (U , tol = 1 )
395- stationary = np .array (U [:, np .where (np .abs (S - 1.0 ) < 1e-8 )[0 ][0 ]])
396- stationary = stationary / np .sum (stationary )
397- root_distribution = stationary
399+ # Compute the stationary distribution using power iteration
400+ n = transition_matrix .shape [0 ]
401+ pi = np .ones (n ) / n
402+ max_iter = 1000
403+ tol = 1e-10
404+ for _ in range (max_iter ):
405+ pi_next = pi @ transition_matrix
406+ if np .max (np .abs (pi_next - pi )) < tol :
407+ root_distribution = pi_next
408+ break
409+ pi = pi_next
410+ else :
411+ # If the loop completes without breaking, use the last approximation
412+ root_distribution = pi
413+
398414 super ().__init__ (alleles , root_distribution , transition_matrix )
399415
400416
0 commit comments