Skip to content

Commit 80f0cf7

Browse files
committed
.
1 parent 8719775 commit 80f0cf7

File tree

3 files changed

+69
-92
lines changed

3 files changed

+69
-92
lines changed

c/tskit/trees.c

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3551,6 +3551,7 @@ tsk_treeseq_update_branch_afs(const tsk_treeseq_t *self, tsk_id_t u, double righ
35513551
tsk_size_t k;
35523552
tsk_size_t time_window_index;
35533553
double *afs;
3554+
// note: moving this malloc outside this function doesn't speed things up
35543555
tsk_size_t *coordinate = tsk_malloc(num_sample_sets * sizeof(*coordinate));
35553556
bool polarised = !!(options & TSK_STAT_POLARISED);
35563557
const double *count_row = GET_2D_ROW(counts, num_sample_sets + 1, u);

python/tests/test_tree_stats.py

Lines changed: 41 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,24 @@
4545

4646
np.random.seed(5)
4747

48+
# Notes for refactoring:
49+
#
50+
# Things we need to test here are:
51+
# 1. general_stat: correctly uses summary functions
52+
# 2. general_stat: branch mode, correctness
53+
# 3. general_stat: site mode, correctness
54+
# 4. general_stat: node mode, correctness
55+
# 5. sample sets: correctness
56+
# 6. indexes: correctness
57+
# 7. genome windowing: correctness
58+
# 8. time windowing: correctness
59+
# 9. dropping dimensions, output
60+
# 10. span normalise
61+
# 11. sample_count_stat: correctly uses summary functions
62+
# 12. each statistic: a single tree sufficies, with edge cases
63+
# a. agrees with naive version, polarised and not;
64+
# b. agrees with python version, polarised and not;
65+
# c. stat-specific options (eg centre)
4866

4967
def cached_np(func):
5068
"""
@@ -724,7 +742,9 @@ def ts_10_recomb_fixture():
724742
@pytest.fixture(scope="session")
725743
def ts_10_mut_fixture():
726744
"""10-sample tree sequence with mutations (used 10 times)."""
727-
return msprime.simulate(10, mutation_rate=1, random_seed=1)
745+
ts = msprime.simulate(10, mutation_rate=1, random_seed=1)
746+
assert ts.num_mutations > 0
747+
return ts
728748

729749

730750
@pytest.fixture(scope="session")
@@ -6810,36 +6830,31 @@ class TestOutputDimensions(StatsTestCase):
68106830
Tests for the dimension stripping behaviour of the stats functions.
68116831
"""
68126832

6813-
def get_example_ts(self, ts_10_mut_fixture):
6814-
ts = ts_10_mut_fixture
6815-
assert ts.num_sites > 1
6816-
return ts
6817-
68186833
def test_one_way_no_window_scalar_stat(self, ts_10_mut_fixture):
6819-
ts = self.get_example_ts(ts_10_mut_fixture)
6834+
ts = ts_10_mut_fixture
68206835
x = ts.diversity()
68216836
assert isinstance(x, np.floating)
68226837

68236838
def test_one_way_one_list_scalar_stat(self, ts_10_mut_fixture):
6824-
ts = self.get_example_ts(ts_10_mut_fixture)
6839+
ts = ts_10_mut_fixture
68256840
x = ts.diversity(sample_sets=list(ts.samples()))
68266841
assert isinstance(x, np.floating)
68276842

68286843
def test_one_way_nested_list_not_scalar_stat(self, ts_10_mut_fixture):
6829-
ts = self.get_example_ts(ts_10_mut_fixture)
6844+
ts = ts_10_mut_fixture
68306845
x = ts.diversity(sample_sets=[list(ts.samples())])
68316846
assert x.shape == (1,)
68326847

68336848
def test_one_way_one_window_scalar_stat(self, ts_10_mut_fixture):
6834-
ts = self.get_example_ts(ts_10_mut_fixture)
6849+
ts = ts_10_mut_fixture
68356850
x = ts.diversity(windows=[0, ts.sequence_length])
68366851
assert x.shape == (1,)
68376852
for samples in (None, list(ts.samples())):
68386853
x = ts.diversity(sample_sets=samples, windows=[0, ts.sequence_length])
68396854
assert x.shape == (1,)
68406855

68416856
def test_multi_way_no_window_scalar_stat(self, ts_10_mut_fixture):
6842-
ts = self.get_example_ts(ts_10_mut_fixture)
6857+
ts = ts_10_mut_fixture
68436858
n = ts.num_samples
68446859
x = ts.f2(
68456860
sample_sets=[
@@ -6850,7 +6865,7 @@ def test_multi_way_no_window_scalar_stat(self, ts_10_mut_fixture):
68506865
assert isinstance(x, np.floating)
68516866

68526867
def test_multi_way_one_window_not_scalar_stat(self, ts_10_mut_fixture):
6853-
ts = self.get_example_ts(ts_10_mut_fixture)
6868+
ts = ts_10_mut_fixture
68546869
n = ts.num_samples
68556870
x = ts.f2(
68566871
sample_sets=[
@@ -6862,7 +6877,7 @@ def test_multi_way_one_window_not_scalar_stat(self, ts_10_mut_fixture):
68626877
assert x.shape == (1,)
68636878

68646879
def test_multi_way_no_indexes_scalar_stat(self, ts_10_mut_fixture):
6865-
ts = self.get_example_ts(ts_10_mut_fixture)
6880+
ts = ts_10_mut_fixture
68666881
n = ts.num_samples
68676882
x = ts.f2(
68686883
sample_sets=[
@@ -6873,7 +6888,7 @@ def test_multi_way_no_indexes_scalar_stat(self, ts_10_mut_fixture):
68736888
assert isinstance(x, np.floating)
68746889

68756890
def test_multi_way_indexes_not_scalar_stat(self, ts_10_mut_fixture):
6876-
ts = self.get_example_ts(ts_10_mut_fixture)
6891+
ts = ts_10_mut_fixture
68776892
n = ts.num_samples
68786893
x = ts.f2(
68796894
sample_sets=[
@@ -6885,7 +6900,7 @@ def test_multi_way_indexes_not_scalar_stat(self, ts_10_mut_fixture):
68856900
assert x.shape == (1,)
68866901

68876902
def test_afs_default_windows(self, ts_10_mut_fixture):
6888-
ts = self.get_example_ts(ts_10_mut_fixture)
6903+
ts = ts_10_mut_fixture
68896904
n = ts.num_samples
68906905
A = ts.samples()[:4]
68916906
B = ts.samples()[6:]
@@ -6900,7 +6915,7 @@ def test_afs_default_windows(self, ts_10_mut_fixture):
69006915
assert x.shape == (len(A) + 1, len(B) + 1)
69016916

69026917
def test_afs_windows(self, ts_10_mut_fixture):
6903-
ts = self.get_example_ts(ts_10_mut_fixture)
6918+
ts = ts_10_mut_fixture
69046919
L = ts.sequence_length
69056920

69066921
windows = [0, L / 4, L / 2, L]
@@ -6920,7 +6935,7 @@ def test_afs_windows(self, ts_10_mut_fixture):
69206935
self.assertArrayEqual(x, y)
69216936

69226937
def test_one_way_stat_default_windows(self, ts_10_mut_fixture):
6923-
ts = self.get_example_ts(ts_10_mut_fixture)
6938+
ts = ts_10_mut_fixture
69246939
# Use diversity as the example one-way stat.
69256940
for mode in ["site", "branch"]:
69266941
x = ts.diversity(mode=mode)
@@ -6989,19 +7004,19 @@ def verify_one_way_stat_windows(self, ts, method):
69897004
self.assertArrayEqual(x[0], x[2])
69907005

69917006
def test_diversity_windows(self, ts_10_mut_fixture):
6992-
ts = self.get_example_ts(ts_10_mut_fixture)
7007+
ts = ts_10_mut_fixture
69937008
self.verify_one_way_stat_windows(ts, ts.diversity)
69947009

69957010
def test_Tajimas_D_windows(self, ts_10_mut_fixture):
6996-
ts = self.get_example_ts(ts_10_mut_fixture)
7011+
ts = ts_10_mut_fixture
69977012
self.verify_one_way_stat_windows(ts, ts.Tajimas_D)
69987013

69997014
def test_segregating_sites_windows(self, ts_10_mut_fixture):
7000-
ts = self.get_example_ts(ts_10_mut_fixture)
7015+
ts = ts_10_mut_fixture
70017016
self.verify_one_way_stat_windows(ts, ts.segregating_sites)
70027017

70037018
def test_two_way_stat_default_windows(self, ts_10_mut_fixture):
7004-
ts = self.get_example_ts(ts_10_mut_fixture)
7019+
ts = ts_10_mut_fixture
70057020
# Use divergence as the example one-way stat.
70067021
A = ts.samples()[:6]
70077022
B = ts.samples()[6:]
@@ -7072,15 +7087,15 @@ def verify_two_way_stat_windows(self, ts, method):
70727087
self.assertArrayEqual(x[0], x[2])
70737088

70747089
def test_divergence_windows(self, ts_10_mut_fixture):
7075-
ts = self.get_example_ts(ts_10_mut_fixture)
7090+
ts = ts_10_mut_fixture
70767091
self.verify_two_way_stat_windows(ts, ts.divergence)
70777092

70787093
def test_Fst_windows(self, ts_10_mut_fixture):
7079-
ts = self.get_example_ts(ts_10_mut_fixture)
7094+
ts = ts_10_mut_fixture
70807095
self.verify_two_way_stat_windows(ts, ts.Fst)
70817096

70827097
def test_f2_windows(self, ts_10_mut_fixture):
7083-
ts = self.get_example_ts(ts_10_mut_fixture)
7098+
ts = ts_10_mut_fixture
70847099
self.verify_two_way_stat_windows(ts, ts.f2)
70857100

70867101
def verify_three_way_stat_windows(self, ts, method):
@@ -7136,11 +7151,11 @@ def verify_three_way_stat_windows(self, ts, method):
71367151
self.assertArrayEqual(x[0], x[2])
71377152

71387153
def test_Y3_windows(self, ts_10_mut_fixture):
7139-
ts = self.get_example_ts(ts_10_mut_fixture)
7154+
ts = ts_10_mut_fixture
71407155
self.verify_three_way_stat_windows(ts, ts.Y3)
71417156

71427157
def test_f3_windows(self, ts_10_mut_fixture):
7143-
ts = self.get_example_ts(ts_10_mut_fixture)
7158+
ts = ts_10_mut_fixture
71447159
self.verify_three_way_stat_windows(ts, ts.f3)
71457160

71467161

python/tskit/trees.py

Lines changed: 27 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -7994,13 +7994,13 @@ def __one_way_sample_set_stat(
79947994
ll_method,
79957995
sample_sets,
79967996
windows=None,
7997+
time_windows=None,
79977998
mode=None,
79987999
span_normalise=True,
79998000
polarised=False,
80008001
):
80018002
if sample_sets is None:
80028003
sample_sets = self.samples()
8003-
80048004
# First try to convert to a 1D numpy array. If it is, then we strip off
80058005
# the corresponding dimension from the output.
80068006
drop_dimension = False
@@ -8013,82 +8013,43 @@ def __one_way_sample_set_stat(
80138013
# of integers then drop the dimension
80148014
if len(sample_sets.shape) == 1:
80158015
sample_sets = [sample_sets]
8016-
drop_dimension = True
8017-
8016+
if ll_method.__name__ != "allele_frequency_spectrum":
8017+
drop_dimension = True
80188018
sample_set_sizes = np.array(
80198019
[len(sample_set) for sample_set in sample_sets], dtype=np.uint32
80208020
)
80218021
if np.any(sample_set_sizes == 0):
80228022
raise ValueError("Sample sets must contain at least one element")
80238023

80248024
flattened = util.safe_np_int_cast(np.hstack(sample_sets), np.int32)
8025-
stat = self.__run_windowed_stat(
8026-
windows,
8027-
ll_method,
8028-
sample_set_sizes,
8029-
flattened,
8030-
mode=mode,
8031-
span_normalise=span_normalise,
8032-
polarised=polarised,
8033-
)
8025+
use_tw = (ll_method.__name__ == "allele_frequency_spectrum")
8026+
if use_tw:
8027+
stat = self.__run_windowed_stat_tw(
8028+
windows,
8029+
time_windows,
8030+
ll_method,
8031+
sample_set_sizes,
8032+
flattened,
8033+
mode=mode,
8034+
span_normalise=span_normalise,
8035+
polarised=polarised,
8036+
)
8037+
else:
8038+
stat = self.__run_windowed_stat(
8039+
windows,
8040+
ll_method,
8041+
sample_set_sizes,
8042+
flattened,
8043+
mode=mode,
8044+
span_normalise=span_normalise,
8045+
polarised=polarised,
8046+
)
80348047
if drop_dimension:
80358048
stat = stat.reshape(stat.shape[:-1])
8036-
if stat.shape == () and windows is None:
8049+
if stat.shape == () and windows is None and time_windows is None:
80378050
stat = stat[()]
80388051
return stat
80398052

8040-
# only for temporary tw version
8041-
def __one_way_sample_set_stat_tw(
8042-
self,
8043-
ll_method,
8044-
sample_sets,
8045-
windows=None,
8046-
time_windows=None,
8047-
mode=None,
8048-
span_normalise=True,
8049-
polarised=False,
8050-
):
8051-
if sample_sets is None:
8052-
sample_sets = self.samples()
8053-
# First try to convert to a 1D numpy array. If it is, then we strip off
8054-
# the corresponding dimension from the output.
8055-
drop_dimension = False
8056-
try:
8057-
sample_sets = np.array(sample_sets, dtype=np.uint64)
8058-
except ValueError:
8059-
pass
8060-
else:
8061-
# If we've successfully converted sample_sets to a 1D numpy array
8062-
# of integers then drop the dimension
8063-
if len(sample_sets.shape) == 1:
8064-
sample_sets = [sample_sets]
8065-
drop_dimension = True
8066-
sample_set_sizes = np.array(
8067-
[len(sample_set) for sample_set in sample_sets], dtype=np.uint32
8068-
)
8069-
if np.any(sample_set_sizes == 0):
8070-
raise ValueError("Sample sets must contain at least one element")
8071-
8072-
flattened = util.safe_np_int_cast(np.hstack(sample_sets), np.int32)
8073-
stat = self.__run_windowed_stat_tw(
8074-
windows,
8075-
time_windows,
8076-
ll_method,
8077-
sample_set_sizes,
8078-
flattened,
8079-
mode=mode,
8080-
span_normalise=span_normalise,
8081-
polarised=polarised,
8082-
)
8083-
if drop_dimension:
8084-
# not applicable for AFS
8085-
if not ll_method.__name__ == "allele_frequency_spectrum":
8086-
stat = stat.reshape(stat.shape[:-1])
8087-
# We'll need this for non-AFS functions; but can't test it with AFS:
8088-
# if stat.shape == () and windows is None and time_windows is None:
8089-
# stat = stat[()]
8090-
return stat
8091-
80928053
def parse_sites(self, sites):
80938054
row_sites, col_sites = None, None
80948055
if sites is not None:
@@ -9781,7 +9742,7 @@ def allele_frequency_spectrum(
97819742
"""
97829743
if sample_sets is None:
97839744
sample_sets = [self.samples()]
9784-
return self.__one_way_sample_set_stat_tw(
9745+
return self.__one_way_sample_set_stat(
97859746
self._ll_tree_sequence.allele_frequency_spectrum,
97869747
sample_sets,
97879748
windows=windows,

0 commit comments

Comments
 (0)