Skip to content

Commit 813e361

Browse files
lkirkmergify[bot]
authored andcommitted
Dividing by zero should be NaN
Allow C and Python portions of the code to divide by zero natively, relying on their respective runtime implementations. Remove code that forces the two-locus statistics to return zero when we encounter 0/0 situations. In testing, it was ensured that some test cases return NaN values, which we diff appropriately. The mingw implementation seems to want to cast my NAN values as floats, so we preemptively cast them as floats in the tests.
1 parent 7a0b863 commit 813e361

File tree

4 files changed

+40
-29
lines changed

4 files changed

+40
-29
lines changed

c/tests/test_stats.c

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2045,10 +2045,10 @@ test_paper_ex_two_site(void)
20452045
double truth_two_sets[18] = { 1, 1, 0.1111111111111111, 0.1111111111111111,
20462046
0.1111111111111111, 0.1111111111111111, 0.1111111111111111, 0.1111111111111111,
20472047
1, 1, 1, 1, 0.1111111111111111, 0.1111111111111111, 1, 1, 1, 1 };
2048-
double truth_three_sets[27]
2049-
= { 1, 1, 0, 0.1111111111111111, 0.1111111111111111, 0, 0.1111111111111111,
2050-
0.1111111111111111, 0, 0.1111111111111111, 0.1111111111111111, 0, 1, 1, 1,
2051-
1, 1, 1, 0.1111111111111111, 0.1111111111111111, 0, 1, 1, 1, 1, 1, 1 };
2048+
double truth_three_sets[27] = { 1, 1, NAN, 0.1111111111111111, 0.1111111111111111,
2049+
NAN, 0.1111111111111111, 0.1111111111111111, NAN, 0.1111111111111111,
2050+
0.1111111111111111, NAN, 1, 1, 1, 1, 1, 1, 0.1111111111111111,
2051+
0.1111111111111111, NAN, 1, 1, 1, 1, 1, 1 };
20522052

20532053
tsk_treeseq_from_text(&ts, 10, paper_ex_nodes, paper_ex_edges, NULL, paper_ex_sites,
20542054
paper_ex_mutations, paper_ex_individuals, NULL, 0);
@@ -2104,7 +2104,8 @@ test_paper_ex_two_site(void)
21042104
row_sites, num_sites, col_sites, 0, result);
21052105

21062106
CU_ASSERT_EQUAL_FATAL(ret, 0);
2107-
assert_arrays_almost_equal(result_size * num_sample_sets, result, truth_three_sets);
2107+
assert_arrays_almost_equal_nan(
2108+
result_size * num_sample_sets, result, truth_three_sets);
21082109

21092110
tsk_treeseq_free(&ts);
21102111
tsk_safe_free(row_sites);

c/tests/testlib.h

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,22 @@ void unsort_edges(tsk_edge_table_t *edges, size_t start);
6565
} while (0); \
6666
}
6767

68+
/* Array equality if the arrays contain NaN values
69+
NB: the float cast for NaNs is for mingw, which complains without */
70+
#define assert_arrays_almost_equal_nan(len, a, b) \
71+
{ \
72+
do { \
73+
tsk_size_t _j; \
74+
for (_j = 0; _j < len; _j++) { \
75+
if (isnan((float) a[_j]) || isnan((float) b[_j])) { \
76+
CU_ASSERT_EQUAL_FATAL(isnan((float) a[_j]), isnan((float) b[_j])); \
77+
} else { \
78+
CU_ASSERT_DOUBLE_EQUAL(a[_j], b[_j], 1e-9); \
79+
} \
80+
} \
81+
} while (0); \
82+
}
83+
6884
extern const char *single_tree_ex_nodes;
6985
extern const char *single_tree_ex_edges;
7086
extern const char *single_tree_ex_sites;

c/tskit/trees.c

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3597,11 +3597,7 @@ r2_summary_func(tsk_size_t state_dim, const double *state,
35973597
double D = p_AB - (p_A * p_B);
35983598
double denom = p_A * p_B * (1 - p_A) * (1 - p_B);
35993599

3600-
if (denom == 0 && D == 0) {
3601-
result[j] = 0;
3602-
} else {
3603-
result[j] = (D * D) / denom;
3604-
}
3600+
result[j] = (D * D) / denom;
36053601
}
36063602
return 0;
36073603
}
@@ -3637,8 +3633,8 @@ D_prime_summary_func(tsk_size_t state_dim, const double *state,
36373633
double p_B = p_AB + p_aB;
36383634

36393635
double D = p_AB - (p_A * p_B);
3640-
result[j] = 0;
3641-
if (D > 0) {
3636+
3637+
if (D >= 0) {
36423638
result[j] = D / TSK_MIN(p_A * (1 - p_B), (1 - p_A) * p_B);
36433639
} else if (D < 0) {
36443640
result[j] = D / TSK_MIN(p_A * p_B, (1 - p_A) * (1 - p_B));
@@ -3681,11 +3677,7 @@ r_summary_func(tsk_size_t state_dim, const double *state,
36813677
double D = p_AB - (p_A * p_B);
36823678
double denom = p_A * p_B * (1 - p_A) * (1 - p_B);
36833679

3684-
if (denom == 0 && D == 0) {
3685-
result[j] = 0;
3686-
} else {
3687-
result[j] = D / sqrt(denom);
3688-
}
3680+
result[j] = D / sqrt(denom);
36893681
}
36903682
return 0;
36913683
}

python/tests/test_ld_matrix.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
"""
2323
Test cases for two-locus statistics
2424
"""
25+
import contextlib
2526
import io
2627
from itertools import combinations_with_replacement
2728
from itertools import permutations
@@ -40,6 +41,12 @@
4041
from tests.test_highlevel import get_example_tree_sequences
4142

4243

44+
@contextlib.contextmanager
45+
def suppress_division_by_zero_warning():
46+
with np.errstate(invalid="ignore", divide="ignore"):
47+
yield
48+
49+
4350
class BitSet:
4451
"""BitSet object, which stores values in arrays of unsigned integers.
4552
The rows represent all possible values a bit can take, and the rows
@@ -729,9 +736,7 @@ def r2_summary_func(
729736
D = p_AB - (p_A * p_B)
730737
denom = p_A * p_B * (1 - p_A) * (1 - p_B)
731738

732-
if denom == 0 and D == 0:
733-
result[k] = 0
734-
else:
739+
with suppress_division_by_zero_warning():
735740
result[k] = (D * D) / denom
736741

737742

@@ -782,12 +787,11 @@ def D_prime_summary_func(
782787
p_B = p_AB + p_aB
783788

784789
D = p_AB - (p_A * p_B)
785-
if D == 0:
786-
result[k] = 0
787-
elif D > 0:
788-
result[k] = D / min(p_A * (1 - p_B), (1 - p_A) * p_B)
789-
else:
790-
result[k] = D / min(p_A * p_B, (1 - p_A) * (1 - p_B))
790+
with suppress_division_by_zero_warning():
791+
if D >= 0:
792+
result[k] = D / min(p_A * (1 - p_B), (1 - p_A) * p_B)
793+
else:
794+
result[k] = D / min(p_A * p_B, (1 - p_A) * (1 - p_B))
791795

792796

793797
def r_summary_func(
@@ -806,9 +810,7 @@ def r_summary_func(
806810
D = p_AB - (p_A * p_B)
807811
denom = p_A * p_B * (1 - p_A) * (1 - p_B)
808812

809-
if denom == 0 and D == 0:
810-
result[k] = 0
811-
else:
813+
with suppress_division_by_zero_warning():
812814
result[k] = D / np.sqrt(denom)
813815

814816

0 commit comments

Comments
 (0)