Skip to content

Commit b5a02a9

Browse files
Hardened KC metric interface.
1 parent 1809003 commit b5a02a9

File tree

10 files changed

+430
-154
lines changed

10 files changed

+430
-154
lines changed

c/tests/test_core.c

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ test_strerror(void)
3232
{
3333
int j;
3434
const char *msg;
35-
int max_error_code = 1024; /* totally arbitrary */
35+
int max_error_code = 8192; /* totally arbitrary */
3636

3737
for (j = 0; j < max_error_code; j++) {
3838
msg = tsk_strerror(-j);

c/tests/test_trees.c

Lines changed: 110 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4449,6 +4449,9 @@ test_single_tree_kc(void)
44494449
ret = tsk_tree_kc_distance(&t, &other_t, 0, &result);
44504450
CU_ASSERT_EQUAL_FATAL(ret, 0);
44514451
CU_ASSERT_EQUAL_FATAL(result, 0);
4452+
ret = tsk_tree_kc_distance(&t, &other_t, 1, &result);
4453+
CU_ASSERT_EQUAL_FATAL(ret, 0);
4454+
CU_ASSERT_EQUAL_FATAL(result, 0);
44524455
tsk_treeseq_free(&ts);
44534456
tsk_tree_free(&t);
44544457
tsk_tree_free(&other_t);
@@ -4527,10 +4530,10 @@ test_empty_tree_kc(void)
45274530
CU_ASSERT_EQUAL_FATAL(t.left, 0);
45284531
CU_ASSERT_EQUAL_FATAL(t.right, 1);
45294532
CU_ASSERT_EQUAL_FATAL(tsk_tree_get_parent(&t, 0, &v), TSK_ERR_NODE_OUT_OF_BOUNDS);
4530-
4533+
45314534
ret = tsk_tree_kc_distance(&t, &t, 0, &result);
4532-
CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_BAD_PARAM_VALUE);
4533-
4535+
CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_MULTIPLE_ROOTS);
4536+
45344537
tsk_tree_free(&t);
45354538
tsk_treeseq_free(&ts);
45364539
tsk_table_collection_free(&tables);
@@ -4564,7 +4567,59 @@ test_nonbinary_tree_kc(void)
45644567
}
45654568

45664569
static void
4567-
test_unequal_samples_kc(void)
4570+
test_nonzero_samples_kc(void)
4571+
{
4572+
const char *nodes =
4573+
"0 0 0\n" /* unused node at the start */
4574+
"1 0 0\n"
4575+
"1 0 0\n"
4576+
"0 1 0";
4577+
const char *edges =
4578+
"0 1 3 1,2\n";
4579+
tsk_treeseq_t ts;
4580+
tsk_tree_t t;
4581+
int ret;
4582+
double result = 0;
4583+
4584+
tsk_treeseq_from_text(&ts, 1, nodes, edges, NULL, NULL, NULL, NULL, NULL);
4585+
ret = tsk_tree_init(&t, &ts, 0);
4586+
CU_ASSERT_EQUAL_FATAL(ret, 0);
4587+
ret = tsk_tree_first(&t);
4588+
CU_ASSERT_EQUAL_FATAL(ret, 1);
4589+
ret = tsk_tree_kc_distance(&t, &t, 0, &result);
4590+
CU_ASSERT_EQUAL_FATAL(ret, 0);
4591+
CU_ASSERT_EQUAL_FATAL(result, 0);
4592+
tsk_treeseq_free(&ts);
4593+
tsk_tree_free(&t);
4594+
}
4595+
4596+
static void
4597+
test_internal_samples_kc(void)
4598+
{
4599+
const char *nodes =
4600+
"1 0 0\n"
4601+
"1 0 0\n"
4602+
"1 1 0";
4603+
const char *edges =
4604+
"0 1 2 0,1\n";
4605+
tsk_treeseq_t ts;
4606+
tsk_tree_t t;
4607+
int ret;
4608+
double result = 0;
4609+
4610+
tsk_treeseq_from_text(&ts, 1, nodes, edges, NULL, NULL, NULL, NULL, NULL);
4611+
ret = tsk_tree_init(&t, &ts, 0);
4612+
CU_ASSERT_EQUAL_FATAL(ret, 0);
4613+
ret = tsk_tree_first(&t);
4614+
CU_ASSERT_EQUAL_FATAL(ret, 1);
4615+
ret = tsk_tree_kc_distance(&t, &t, 0, &result);
4616+
CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_INTERNAL_SAMPLES);
4617+
tsk_treeseq_free(&ts);
4618+
tsk_tree_free(&t);
4619+
}
4620+
4621+
static void
4622+
test_unequal_sample_size_kc(void)
45684623
{
45694624
const char *nodes =
45704625
"1 0 0\n"
@@ -4591,20 +4646,65 @@ test_unequal_samples_kc(void)
45914646
CU_ASSERT_EQUAL_FATAL(ret, 0);
45924647
ret = tsk_tree_first(&t);
45934648
CU_ASSERT_EQUAL_FATAL(ret, 1);
4594-
tsk_treeseq_from_text(&other_ts, 1, nodes_other, edges_other, NULL, NULL, NULL, NULL, NULL);
4649+
tsk_treeseq_from_text(&other_ts, 1, nodes_other, edges_other,
4650+
NULL, NULL, NULL, NULL, NULL);
45954651
ret = tsk_tree_init(&other_t, &other_ts, 0);
45964652
CU_ASSERT_EQUAL_FATAL(ret, 0);
45974653
ret = tsk_tree_first(&other_t);
45984654
CU_ASSERT_EQUAL_FATAL(ret, 1);
45994655
ret = tsk_tree_kc_distance(&t, &other_t, 0, &result);
4600-
CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_BAD_PARAM_VALUE);
4656+
CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_SAMPLE_SIZE_MISMATCH);
46014657
tsk_treeseq_free(&ts);
46024658
tsk_treeseq_free(&other_ts);
46034659
tsk_tree_free(&t);
46044660
tsk_tree_free(&other_t);
4605-
}
4661+
}
46064662

4663+
static void
4664+
test_unequal_samples_kc(void)
4665+
{
4666+
const char *nodes =
4667+
"1 0 0\n"
4668+
"1 0 0\n"
4669+
"1 0 0\n"
4670+
"0 2 0\n"
4671+
"0 3 0\n";
4672+
const char *nodes_other =
4673+
"0 0 0\n" /* Unused node at the start */
4674+
"1 0 0\n"
4675+
"1 0 0\n"
4676+
"1 0 0\n"
4677+
"0 2 0\n"
4678+
"0 3 0\n";
4679+
const char *edges =
4680+
"0 1 3 0,1\n"
4681+
"0 1 4 2,3\n";
4682+
const char *edges_other =
4683+
"0 1 4 1,2\n"
4684+
"0 1 5 3,4\n";
4685+
int ret;
4686+
tsk_treeseq_t ts, other_ts;
4687+
tsk_tree_t t, other_t;
4688+
double result = 0;
46074689

4690+
tsk_treeseq_from_text(&ts, 1, nodes, edges, NULL, NULL, NULL, NULL, NULL);
4691+
ret = tsk_tree_init(&t, &ts, 0);
4692+
CU_ASSERT_EQUAL_FATAL(ret, 0);
4693+
ret = tsk_tree_first(&t);
4694+
CU_ASSERT_EQUAL_FATAL(ret, 1);
4695+
tsk_treeseq_from_text(&other_ts, 1, nodes_other, edges_other,
4696+
NULL, NULL, NULL, NULL, NULL);
4697+
ret = tsk_tree_init(&other_t, &other_ts, 0);
4698+
CU_ASSERT_EQUAL_FATAL(ret, 0);
4699+
ret = tsk_tree_first(&other_t);
4700+
CU_ASSERT_EQUAL_FATAL(ret, 1);
4701+
ret = tsk_tree_kc_distance(&t, &other_t, 0, &result);
4702+
CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_SAMPLES_NOT_EQUAL);
4703+
tsk_treeseq_free(&ts);
4704+
tsk_treeseq_free(&other_ts);
4705+
tsk_tree_free(&t);
4706+
tsk_tree_free(&other_t);
4707+
}
46084708

46094709
/*=======================================================
46104710
* Miscellaneous tests.
@@ -5299,6 +5399,9 @@ main(int argc, char **argv)
52995399
{"test_two_trees_kc", test_two_trees_kc},
53005400
{"test_empty_tree_kc", test_empty_tree_kc},
53015401
{"test_nonbinary_tree_kc", test_nonbinary_tree_kc},
5402+
{"test_nonzero_samples_kc", test_nonzero_samples_kc},
5403+
{"test_internal_samples_kc", test_internal_samples_kc},
5404+
{"test_unequal_sample_size_kc", test_unequal_sample_size_kc},
53025405
{"test_unequal_samples_kc", test_unequal_samples_kc},
53035406

53045407
/* Misc */

c/tskit/core.c

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -373,6 +373,20 @@ tsk_strerror_internal(int err)
373373
"represented. Either generate variants (which support missing "
374374
"data) or use the impute missing data option.";
375375
break;
376+
377+
/* Distance metric errors */
378+
case TSK_ERR_SAMPLE_SIZE_MISMATCH:
379+
ret = "Cannot compare trees with different numbers of samples.";
380+
break;
381+
case TSK_ERR_SAMPLES_NOT_EQUAL:
382+
ret = "Samples must be identical in trees to compare.";
383+
break;
384+
case TSK_ERR_INTERNAL_SAMPLES:
385+
ret = "Internal samples are not supported.";
386+
break;
387+
case TSK_ERR_MULTIPLE_ROOTS:
388+
ret = "Trees with multiple roots not supported.";
389+
break;
376390
}
377391
return ret;
378392
}

c/tskit/core.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,12 @@ of tskit.
223223
#define TSK_ERR_MUST_IMPUTE_NON_SAMPLES -1100
224224
#define TSK_ERR_MUST_IMPUTE_HAPLOTYPES -1101
225225

226+
/* Distance metric errors */
227+
#define TSK_ERR_SAMPLE_SIZE_MISMATCH -1200
228+
#define TSK_ERR_SAMPLES_NOT_EQUAL -1201
229+
#define TSK_ERR_INTERNAL_SAMPLES -1202
230+
#define TSK_ERR_MULTIPLE_ROOTS -1203
231+
226232

227233
/* This bit is 0 for any errors originating from kastore */
228234
#define TSK_KAS_ERR_BIT 14

c/tskit/stats.h

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -48,14 +48,6 @@ int tsk_ld_calc_get_r2_array(tsk_ld_calc_t *self, tsk_id_t a, int direction,
4848
tsk_size_t max_sites, double max_distance,
4949
double *r2, tsk_size_t *num_r2_values);
5050

51-
typedef struct stack_elmt {
52-
tsk_id_t node;
53-
int path_depth;
54-
double time_depth;
55-
} stack_elmt;
56-
57-
58-
5951
#ifdef __cplusplus
6052
}
6153
#endif

c/tskit/trees.c

Lines changed: 54 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -4431,42 +4431,44 @@ tsk_diff_iter_next(tsk_diff_iter_t *self, double *ret_left, double *ret_right,
44314431
int
44324432
tsk_tree_kc_distance(tsk_tree_t *self, tsk_tree_t *other, double lambda, double *result)
44334433
{
4434-
tsk_size_t num_nodes_self, num_nodes_other;
4435-
int ret = 0, stack_top = 0;
4436-
int path_depth, tree_index, pair_index, i;
4437-
double vT1, vT2, distance_sum, time_depth, root_time;
4438-
int *m[2], *path_distance[2];
4439-
double *M[2], *times, *time_distance[2];
4440-
tsk_id_t N, u, v, mrca, n1, n2, num_samples;
4441-
tsk_tree_t *trees[2], *tree;
4442-
44434434
struct stack_elmt {
44444435
tsk_id_t node;
44454436
int path_depth;
44464437
double time_depth;
44474438
};
44484439

4449-
trees[0] = self;
4450-
trees[1] = other;
4451-
4440+
tsk_size_t num_nodes_self, num_nodes_other;
4441+
int ret = 0;
4442+
int stack_top = 0;
4443+
int path_depth, tree_index, pair_index, i;
4444+
double vT1, vT2, distance_sum, time_depth, root_time;
4445+
int *m[2], *path_distance[2];
4446+
double *M[2], *time_distance[2];
4447+
tsk_id_t N, u, v, mrca, n1, n2, num_samples, u_index;
4448+
tsk_tree_t *trees[2] = {self, other};
4449+
tsk_tree_t *tree;
4450+
const tsk_id_t *samples = self->tree_sequence->samples;
4451+
const tsk_id_t *other_samples = other->tree_sequence->samples;
4452+
const double *times;
4453+
const tsk_id_t *sample_index_map;
44524454
struct stack_elmt* stack = NULL;
4455+
44534456
memset(path_distance, 0, sizeof(path_distance));
44544457
memset(time_distance, 0, sizeof(time_distance));
44554458
memset(m, 0, sizeof(m));
44564459
memset(M, 0, sizeof(M));
44574460

44584461
if (tsk_tree_get_num_roots(self) != 1 || tsk_tree_get_num_roots(other) != 1) {
4459-
ret = TSK_ERR_BAD_PARAM_VALUE;
4460-
goto out;
4462+
ret = TSK_ERR_MULTIPLE_ROOTS;
4463+
goto out;
44614464
}
4462-
44634465
if (self->tree_sequence->num_samples != other->tree_sequence->num_samples) {
4464-
ret = TSK_ERR_BAD_PARAM_VALUE;
4466+
ret = TSK_ERR_SAMPLE_SIZE_MISMATCH;
44654467
goto out;
44664468
}
44674469

44684470
num_samples = (tsk_id_t) self->tree_sequence->num_samples;
4469-
N = (num_samples * (num_samples - 1)) / 2;
4471+
N = (num_samples * (num_samples - 1)) / 2;
44704472
num_nodes_self = self->num_nodes;
44714473
num_nodes_other = other->num_nodes;
44724474
stack = malloc(TSK_MAX(num_nodes_self, num_nodes_other) * sizeof(*stack));
@@ -4478,8 +4480,7 @@ tsk_tree_kc_distance(tsk_tree_t *self, tsk_tree_t *other, double lambda, double
44784480
path_distance[1] = malloc(num_nodes_other * sizeof(path_distance[1]));
44794481
time_distance[0] = malloc(num_nodes_self * sizeof(time_distance[0]));
44804482
time_distance[1] = malloc(num_nodes_other * sizeof(time_distance[1]));
4481-
4482-
if (stack == NULL
4483+
if (stack == NULL
44834484
|| m[0] == NULL
44844485
|| m[1] == NULL
44854486
|| M[0] == NULL
@@ -4489,15 +4490,30 @@ tsk_tree_kc_distance(tsk_tree_t *self, tsk_tree_t *other, double lambda, double
44894490
ret = TSK_ERR_NO_MEMORY;
44904491
goto out;
44914492
}
4492-
4493-
for (i=0; i <= N + num_samples; i++) {
4493+
4494+
for (i = 0; i < num_samples; i++) {
4495+
if (samples[i] != other_samples[i]) {
4496+
ret = TSK_ERR_SAMPLES_NOT_EQUAL;
4497+
goto out;
4498+
}
4499+
u = samples[i];
4500+
if (self->left_child[u] != TSK_NULL || other->left_child[u] != TSK_NULL) {
4501+
/* It's probably possible to support this, but it's too awkward
4502+
* to deal with and seems like a fairly niche requirement. */
4503+
ret = TSK_ERR_INTERNAL_SAMPLES;
4504+
goto out;
4505+
}
4506+
}
4507+
4508+
for (i = 0; i <= N + num_samples; i++) {
44944509
m[0][i] = 1;
44954510
m[1][i] = 1;
44964511
}
44974512

4498-
for (tree_index=0; tree_index != 2; tree_index++) {
4513+
for (tree_index = 0; tree_index < 2; tree_index++) {
44994514
tree = trees[tree_index];
45004515
times = tree->tree_sequence->tables->nodes.time;
4516+
sample_index_map = tree->tree_sequence->sample_index_map;
45014517
stack_top = 0;
45024518
u = tree->left_root;
45034519
root_time = times[u];
@@ -4509,26 +4525,26 @@ tsk_tree_kc_distance(tsk_tree_t *self, tsk_tree_t *other, double lambda, double
45094525
path_depth = stack[stack_top].path_depth;
45104526
time_depth = stack[stack_top].time_depth;
45114527
stack_top--;
4512-
for (v = tree->left_child[u]; v != TSK_NULL;
4513-
v = tree->right_sib[v]) {
4528+
for (v = tree->left_child[u]; v != TSK_NULL; v = tree->right_sib[v]) {
45144529
stack_top++;
45154530
stack[stack_top].node = v;
45164531
stack[stack_top].path_depth = path_depth + 1;
45174532
stack[stack_top].time_depth = times[v];
45184533
}
45194534
path_distance[tree_index][u] = path_depth;
45204535
time_distance[tree_index][u] = root_time - time_depth;
4521-
if (tree->left_child[u] == TSK_NULL) {
4522-
M[tree_index][u + N] = times[tree->parent[u]] - times[u];
4536+
u_index = sample_index_map[u];
4537+
if (u_index != TSK_NULL) {
4538+
M[tree_index][u_index + N] = times[tree->parent[u]] - times[u];
45234539
}
45244540
}
4525-
for (n1 = 0; n1 != num_samples; n1++) {
4526-
for (n2 = n1 + 1; n2 != num_samples; n2++){
4527-
ret = tsk_tree_get_mrca(tree, n1, n2, &mrca);
4541+
for (n1 = 0; n1 < num_samples; n1++) {
4542+
for (n2 = n1 + 1; n2 < num_samples; n2++){
4543+
ret = tsk_tree_get_mrca(tree, samples[n1], samples[n2], &mrca);
45284544
if (ret != 0) {
45294545
goto out;
45304546
}
4531-
pair_index = (n1 * (n1 - 2 * num_samples + 1)) / (-2) + n2 - n1 -1;
4547+
pair_index = n2 - n1 - 1 + (-1 * n1 * (n1 - 2 * num_samples + 1)) / 2;
45324548
assert (m[tree_index][pair_index] == 1);
45334549
m[tree_index][pair_index] = path_distance[tree_index][mrca];
45344550
M[tree_index][pair_index] = time_distance[tree_index][mrca];
@@ -4539,25 +4555,20 @@ tsk_tree_kc_distance(tsk_tree_t *self, tsk_tree_t *other, double lambda, double
45394555
vT1 = 0;
45404556
vT2 = 0;
45414557
distance_sum=0;
4542-
4543-
for (i = 0; i != N + num_samples; i++) {
4558+
for (i = 0; i < N + num_samples; i++) {
45444559
vT1 = (m[0][i] * (1 - lambda)) + (lambda * M[0][i]);
45454560
vT2 = (m[1][i] * (1 - lambda)) + (lambda * M[1][i]);
45464561
distance_sum += (vT1 - vT2) * (vT1 - vT2);
45474562
}
4548-
4563+
45494564
*result = sqrt(distance_sum);
45504565
out:
45514566
tsk_safe_free(stack);
4552-
tsk_safe_free(m[0]);
4553-
tsk_safe_free(m[1]);
4554-
tsk_safe_free(M[0]);
4555-
tsk_safe_free(M[1]);
4556-
tsk_safe_free(path_distance[0]);
4557-
tsk_safe_free(path_distance[1]);
4558-
tsk_safe_free(time_distance[0]);
4559-
tsk_safe_free(time_distance[1]);
4567+
for (i = 0; i < 2; i++) {
4568+
tsk_safe_free(m[i]);
4569+
tsk_safe_free(M[i]);
4570+
tsk_safe_free(path_distance[i]);
4571+
tsk_safe_free(time_distance[i]);
4572+
}
45604573
return ret;
45614574
}
4562-
4563-

0 commit comments

Comments
 (0)