Skip to content

Commit 1fadc92

Browse files
Merge pull request #332 from jeromekelleher/prevent-state-updates
general_stat funcs get copies of underlying data.
2 parents e36db7a + 3e76bc4 commit 1fadc92

File tree

2 files changed

+21
-1
lines changed

2 files changed

+21
-1
lines changed

python/_tskitmodule.c

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6503,10 +6503,11 @@ general_stat_func(size_t K, double *X, size_t M, double *Y, void *params)
65036503
npy_intp X_dims = (npy_intp) K;
65046504
npy_intp *Y_dims;
65056505

6506-
X_array = (PyArrayObject *) PyArray_SimpleNewFromData(1, &X_dims, NPY_FLOAT64, X);
6506+
X_array = (PyArrayObject *) PyArray_SimpleNew(1, &X_dims, NPY_FLOAT64);
65076507
if (X_array == NULL) {
65086508
goto out;
65096509
}
6510+
memcpy(PyArray_DATA(X_array), X, K * sizeof(*X));
65106511
arglist = Py_BuildValue("(O)", X_array);
65116512
if (arglist == NULL) {
65126513
goto out;

python/tests/test_tree_stats.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3315,6 +3315,25 @@ def get_tree_sequence(self):
33153315
mutation_rate=2, random_seed=1)
33163316
return ts
33173317

3318+
def test_function_cannot_update_state(self):
3319+
ts = self.get_tree_sequence()
3320+
3321+
def f(x):
3322+
out = x.copy()
3323+
x[:] = 0.0
3324+
return out
3325+
3326+
def g(x):
3327+
return x
3328+
3329+
x = ts.sample_count_stat(
3330+
[ts.samples()], f, output_dim=1, strict=False, mode="node",
3331+
span_normalise=False)
3332+
y = ts.sample_count_stat(
3333+
[ts.samples()], g, output_dim=1, strict=False, mode="node",
3334+
span_normalise=False)
3335+
self.assertArrayEqual(x, y)
3336+
33183337
def test_default_mode(self):
33193338
ts = msprime.simulate(10, recombination_rate=1, random_seed=2)
33203339
W = np.ones((ts.num_samples, 2))

0 commit comments

Comments
 (0)