Skip to content

Commit de542ad

Browse files
committed
Merge branch 'release/0.5.6'
Fix: allow non-contiguous arrays again.
2 parents e0bc7bf + 6804668 commit de542ad

File tree

3 files changed

+34
-9
lines changed

3 files changed

+34
-9
lines changed

dcor/VERSION

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
0.5.5
1+
0.5.6

dcor/_fast_dcov_avl.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,16 @@
1212
from numba import boolean, float64, int64
1313
from numba.types import Array, Tuple
1414

15-
from dcor._utils import _transform_to_1d
16-
17-
from ._utils import CompileMode, get_namespace
15+
from ._utils import CompileMode, _transform_to_1d, get_namespace
1816

1917
NumbaVector = Array(dtype=float64, ndim=1, layout="C")
2018
NumbaVectorReadOnly = Array(dtype=float64, ndim=1, layout="C", readonly=True)
19+
NumbaVectorReadOnlyNonContiguous = Array(
20+
dtype=float64,
21+
ndim=1,
22+
layout="A",
23+
readonly=True,
24+
)
2125
NumbaIntVector = Array(dtype=int64, ndim=1, layout="C")
2226
NumbaIntVectorReadOnly = Array(dtype=int64, ndim=1, layout="C", readonly=True)
2327
NumbaMatrix = Array(dtype=float64, ndim=2, layout="C")
@@ -333,9 +337,12 @@ def _get_impl_args(
333337
Get the parameters used in the algorithm.
334338
"""
335339

336-
n = x.shape[-1]
340+
x = np.ascontiguousarray(x)
341+
y = np.ascontiguousarray(y)
342+
343+
n = x.shape[0]
337344
assert n > 3
338-
assert n == y.shape[-1]
345+
assert n == y.shape[0]
339346
temp = np.arange(n)
340347

341348
argsort_x = np.argsort(x)
@@ -430,7 +437,7 @@ def _get_impl_args(
430437
NumbaVector,
431438
NumbaIntVectorReadOnly,
432439
NumbaMatrix,
433-
))(NumbaVectorReadOnly, NumbaVectorReadOnly, boolean),
440+
))(NumbaVectorReadOnlyNonContiguous, NumbaVectorReadOnlyNonContiguous, boolean),
434441
cache=True,
435442
)(_get_impl_args)
436443

@@ -507,7 +514,12 @@ def _rowwise_distance_covariance_sqr_avl_generic_internal(
507514
res[0] = _distance_covariance_sqr_avl_impl_compiled(*args)
508515

509516
return numba.guvectorize(
510-
[(NumbaVectorReadOnly, NumbaVectorReadOnly, boolean, float64[:])],
517+
[(
518+
NumbaVectorReadOnlyNonContiguous,
519+
NumbaVectorReadOnlyNonContiguous,
520+
boolean,
521+
float64[:],
522+
)],
511523
'(n),(n),()->()',
512524
nopython=True,
513525
cache=True,

dcor/_fast_dcov_mergesort.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,12 @@
2424

2525
NumbaVector = Array(dtype=float64, ndim=1, layout="C")
2626
NumbaVectorReadOnly = Array(dtype=float64, ndim=1, layout="C", readonly=True)
27+
NumbaVectorReadOnlyNonContiguous = Array(
28+
dtype=float64,
29+
ndim=1,
30+
layout="A",
31+
readonly=True,
32+
)
2733
NumbaMatrix = Array(dtype=float64, ndim=2, layout="C")
2834
NumbaMatrixReadOnly = Array(dtype=float64, ndim=2, layout="C", readonly=True)
2935

@@ -209,6 +215,9 @@ def _distance_covariance_sqr_mergesort_generic_impl(
209215
unbiased: bool,
210216
) -> np.typing.NDArray[np.float64]:
211217

218+
x = np.ascontiguousarray(x)
219+
y = np.ascontiguousarray(y)
220+
212221
compute_aijbij_term = (
213222
_compute_aijbij_term_compiled
214223
if compiled
@@ -265,7 +274,11 @@ def _distance_covariance_sqr_mergesort_generic_impl(
265274
)
266275
)
267276
_distance_covariance_sqr_mergesort_generic_impl_compiled = numba.njit(
268-
float64(NumbaVectorReadOnly, NumbaVectorReadOnly, boolean),
277+
float64(
278+
NumbaVectorReadOnlyNonContiguous,
279+
NumbaVectorReadOnlyNonContiguous,
280+
boolean,
281+
),
269282
cache=True,
270283
)(
271284
_generate_distance_covariance_sqr_mergesort_generic_impl(

0 commit comments

Comments
 (0)