Skip to content

Commit 23bcaff

Browse files
committed
Test skeletons
1 parent b6376c6 commit 23bcaff

File tree

4 files changed

+64
-17
lines changed

4 files changed

+64
-17
lines changed

tests/test_inference.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -372,3 +372,41 @@ def test_simple_sim_multi_tree(self):
372372
assert all(
373373
[a == b for a, b in zip(ts.haplotypes(), io_dated_ts.haplotypes())]
374374
)
375+
376+
377+
class TestVariational:
378+
"""
379+
Tests for tsdate with variational algorithm
380+
"""
381+
382+
def test_simple_sim_1_tree(self):
383+
ts = msprime.simulate(8, mutation_rate=5, random_seed=2)
384+
tsdate.date(ts, mutation_rate=5, population_size=1, method="variational_gamma")
385+
386+
def test_simple_sim_multi_tree(self):
387+
ts = msprime.simulate(8, mutation_rate=5, recombination_rate=5, random_seed=2)
388+
tsdate.date(ts, mutation_rate=5, population_size=1, method="variational_gamma")
389+
390+
def test_nonglobal_priors(self):
391+
ts = msprime.simulate(8, mutation_rate=5, recombination_rate=5, random_seed=2)
392+
priors = tsdate.prior.MixturePrior(ts, prior_distribution="gamma")
393+
grid = priors.make_parameter_grid(population_size=1)
394+
grid.grid_data[:] = [1.0, 0.0] # noninformative prior
395+
tsdate.date(
396+
ts,
397+
mutation_rate=5,
398+
method="variational_gamma",
399+
priors=grid,
400+
global_prior=False,
401+
)
402+
403+
def test_bad_arguments(self):
404+
ts = utility_functions.two_tree_mutation_ts()
405+
with pytest.raises(ValueError, match="Maximum number of iterations"):
406+
tsdate.date(
407+
ts,
408+
mutation_rate=5,
409+
population_size=1,
410+
method="variational_gamma",
411+
max_iterations=-1,
412+
)

tsdate/approx.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -111,8 +111,8 @@ def sufficient_statistics(a_i, b_i, a_j, b_j, y_ij, mu_ij):
111111
112112
:return: normalizing constant, E[t_i], E[log t_i], E[t_j], E[log t_j]
113113
"""
114-
assert a_i > 0 and b_i > 0, "Invalid parent parameters"
115-
assert a_j > 0 and b_j > 0, "Invalid child parameters"
114+
assert a_i > 0 and b_i >= 0, "Invalid parent parameters"
115+
assert a_j > 0 and b_j >= 0, "Invalid child parameters"
116116
assert y_ij >= 0 and mu_ij > 0, "Invalid edge parameters"
117117

118118
a = a_i + a_j + y_ij

tsdate/core.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -630,17 +630,20 @@ def __init__(self, priors, lik, *, progress=False):
630630

631631
# === Grouped edge iterators ===
632632

633-
def edges_by_parent_asc(self):
633+
def edges_by_parent_asc(self, grouped=True):
634634
"""
635635
Return an itertools.groupby object of edges grouped by parent in ascending order
636636
of the time of the parent. Since tree sequence properties guarantee that edges
637637
are listed in nondecreasing order of parent time
638638
(https://tskit.readthedocs.io/en/latest/data-model.html#edge-requirements)
639639
we can simply use the standard edge order
640640
"""
641-
return itertools.groupby(self.ts.edges(), operator.attrgetter("parent"))
641+
if grouped:
642+
return itertools.groupby(self.ts.edges(), operator.attrgetter("parent"))
643+
else:
644+
return self.ts.edges()
642645

643-
def edges_by_child_desc(self):
646+
def edges_by_child_desc(self, grouped=True):
644647
"""
645648
Return an itertools.groupby object of edges grouped by child in descending order
646649
of the time of the child.
@@ -651,9 +654,12 @@ def edges_by_child_desc(self):
651654
(self.ts.edges_child, -self.ts.nodes_time[self.ts.edges_child])
652655
)
653656
)
654-
return itertools.groupby(it, operator.attrgetter("child"))
657+
if grouped:
658+
return itertools.groupby(it, operator.attrgetter("child"))
659+
else:
660+
return it
655661

656-
def edges_by_child_then_parent_desc(self):
662+
def edges_by_child_then_parent_desc(self, grouped=True):
657663
"""
658664
Return an itertools.groupby object of edges grouped by child in descending order
659665
of the time of the child, then by descending order of age of child
@@ -675,7 +681,10 @@ def edges_by_child_then_parent_desc(self):
675681
np.argsort(w, order=("child_age", "child_node", "parent_age"))
676682
)
677683
)
678-
return itertools.groupby(sorted_child_parent, operator.attrgetter("child"))
684+
if grouped:
685+
return itertools.groupby(sorted_child_parent, operator.attrgetter("child"))
686+
else:
687+
return sorted_child_parent
679688

680689
# === MAIN ALGORITHMS ===
681690

@@ -1038,8 +1047,8 @@ def iterate(self, *, progress=None, **kwargs):
10381047
Update edge factors from leaves to root then from root to leaves,
10391048
and return approximate log marginal likelihood
10401049
"""
1041-
self.propagate(edges=self.edges_by_parent_asc(), progress=progress)
1042-
self.propagate(edges=self.edges_by_child_desc(), progress=progress)
1050+
self.propagate(edges=self.edges_by_parent_asc(grouped=False), progress=progress)
1051+
self.propagate(edges=self.edges_by_child_desc(grouped=False), progress=progress)
10431052
# TODO
10441053
# marginal_lik = np.sum(self.factor_norm)
10451054
# return marginal_lik

tsdate/hypergeo.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -341,8 +341,8 @@ def _hyp2f1_dlmf1583(a_i, b_i, a_j, b_j, y, mu):
341341
"""
342342
DLMF 15.8.3, sum of recurrence and series expansion
343343
"""
344-
assert b_i > 0
345-
assert 0 < mu <= b_j
344+
assert b_i >= 0
345+
assert 0 <= mu <= b_j
346346
assert y >= 0 and y % 1.0 == 0.0
347347

348348
f_1, s_1, da_i_1, db_i_1, da_j_1, db_j_1 = _hyp2f1_dlmf1583_first(
@@ -380,8 +380,8 @@ def _hyp2f1_dlmf1521(a_i, b_i, a_j, b_j, y, mu):
380380
"""
381381
DLMF 15.2.1, series expansion without transformation
382382
"""
383-
assert b_i > 0
384-
assert mu >= b_j > 0
383+
assert b_i >= 0
384+
assert mu >= b_j >= 0
385385
assert y >= 0 and y % 1 == 0.0
386386

387387
y = int(y)
@@ -409,8 +409,8 @@ def _hyp2f1_dlmf1581(a_i, b_i, a_j, b_j, y, mu):
409409
"""
410410
DLMF 15.8.1, series expansion with Pfaff transformation
411411
"""
412-
assert b_i > 0
413-
assert 0 < mu <= b_j
412+
assert b_i >= 0
413+
assert 0 <= mu <= b_j
414414
assert y >= 0 and y % 1 == 0.0
415415

416416
y = int(y)
@@ -456,7 +456,7 @@ def _hyp2f1(a_i, b_i, a_j, b_j, y, mu):
456456
and dividing the gradient by the function value.
457457
"""
458458
z = (mu - b_j) / (mu + b_i)
459-
assert z < 1.0 # TODO: allow z == 1.0 for improper prior
459+
assert z < 1.0, "Invalid hypergeometric function argument"
460460
if 0.0 <= z < 1.0:
461461
return _hyp2f1_dlmf1521(a_i, b_i, a_j, b_j, y, mu)
462462
elif -1.0 < z < 0.0:

0 commit comments

Comments
 (0)