Skip to content

Commit 7f3e1e4

Browse files
committed
Compute stationary distribution using power iteration to avoid numerical issue on Windows
1 parent f7cd407 commit 7f3e1e4

File tree

3 files changed

+19
-8
lines changed

3 files changed

+19
-8
lines changed

.github/workflows/tests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,4 +95,4 @@ jobs:
9595
run: |
9696
source ~/.profile
9797
conda activate anaconda-client-env
98-
pytest -xvs -n0 tests/test_mutations.py::TestMatrixMutationModel::test_TPM[0.5-0.25-50-500]
98+
pytest -xvs -n0

msprime/mutations.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -362,6 +362,7 @@ def __init__(
362362
hi=None,
363363
root_distribution=None,
364364
):
365+
print(np.show_config())
365366
s = 0.0 if s is None else s
366367
u = 0.5 if u is None else u
367368
v = 0.0 if v is None else v
@@ -388,13 +389,23 @@ def __init__(
388389
raise ValueError("u must be between 0 and 1")
389390
alleles = [str(int(x)) for x in range(lo, hi + 1)]
390391
transition_matrix = general_microsat_rate_matrix(s, u, v, p, m, lo, hi)
392+
391393
if root_distribution is None:
392-
# solve for stationary distribution
393-
S, U = np.linalg.eig(transition_matrix.T)
394-
U = np.real_if_close(U, tol=1)
395-
stationary = np.array(U[:, np.where(np.abs(S - 1.0) < 1e-8)[0][0]])
396-
stationary = stationary / np.sum(stationary)
397-
root_distribution = stationary
394+
# Compute the stationary distribution using power iteration
395+
n = transition_matrix.shape[0]
396+
pi = np.ones(n) / n
397+
max_iter = 1000
398+
tol = 1e-10
399+
for _ in range(max_iter):
400+
pi_next = pi @ transition_matrix
401+
if np.max(np.abs(pi_next - pi)) < tol:
402+
root_distribution = pi_next
403+
break
404+
pi = pi_next
405+
else:
406+
# If we haven't converged use the last approximation
407+
root_distribution = pi
408+
398409
super().__init__(alleles, root_distribution, transition_matrix)
399410

400411

requirements/CI-complete/requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ pytest==8.3.5
77
pytest-cov==6.0.0
88
pytest-xdist==3.6.1
99
python_jsonschema_objects==0.5.7
10-
scipy==1.13.1
10+
scipy==1.14.0
1111
stdpopsim==0.1.2 #Pinned for OOA model
1212
tskit==0.6.0
1313
kastore==0.3.3

0 commit comments

Comments
 (0)