|
12 | 12 | from numba import boolean, float64, int64 |
13 | 13 | from numba.types import Array, Tuple |
14 | 14 |
|
15 | | -from dcor._utils import _transform_to_1d |
16 | | - |
17 | | -from ._utils import CompileMode, get_namespace |
| 15 | +from ._utils import CompileMode, _transform_to_1d, get_namespace |
18 | 16 |
|
19 | 17 | NumbaVector = Array(dtype=float64, ndim=1, layout="C") |
20 | 18 | NumbaVectorReadOnly = Array(dtype=float64, ndim=1, layout="C", readonly=True) |
| 19 | +NumbaVectorReadOnlyNonContiguous = Array( |
| 20 | + dtype=float64, |
| 21 | + ndim=1, |
| 22 | + layout="A", |
| 23 | + readonly=True, |
| 24 | +) |
21 | 25 | NumbaIntVector = Array(dtype=int64, ndim=1, layout="C") |
22 | 26 | NumbaIntVectorReadOnly = Array(dtype=int64, ndim=1, layout="C", readonly=True) |
23 | 27 | NumbaMatrix = Array(dtype=float64, ndim=2, layout="C") |
@@ -333,9 +337,12 @@ def _get_impl_args( |
333 | 337 | Get the parameters used in the algorithm. |
334 | 338 | """ |
335 | 339 |
|
336 | | - n = x.shape[-1] |
| 340 | + x = np.ascontiguousarray(x) |
| 341 | + y = np.ascontiguousarray(y) |
| 342 | + |
| 343 | + n = x.shape[0] |
337 | 344 | assert n > 3 |
338 | | - assert n == y.shape[-1] |
| 345 | + assert n == y.shape[0] |
339 | 346 | temp = np.arange(n) |
340 | 347 |
|
341 | 348 | argsort_x = np.argsort(x) |
@@ -430,7 +437,7 @@ def _get_impl_args( |
430 | 437 | NumbaVector, |
431 | 438 | NumbaIntVectorReadOnly, |
432 | 439 | NumbaMatrix, |
433 | | - ))(NumbaVectorReadOnly, NumbaVectorReadOnly, boolean), |
| 440 | + ))(NumbaVectorReadOnlyNonContiguous, NumbaVectorReadOnlyNonContiguous, boolean), |
434 | 441 | cache=True, |
435 | 442 | )(_get_impl_args) |
436 | 443 |
|
@@ -507,7 +514,12 @@ def _rowwise_distance_covariance_sqr_avl_generic_internal( |
507 | 514 | res[0] = _distance_covariance_sqr_avl_impl_compiled(*args) |
508 | 515 |
|
509 | 516 | return numba.guvectorize( |
510 | | - [(NumbaVectorReadOnly, NumbaVectorReadOnly, boolean, float64[:])], |
| 517 | + [( |
| 518 | + NumbaVectorReadOnlyNonContiguous, |
| 519 | + NumbaVectorReadOnlyNonContiguous, |
| 520 | + boolean, |
| 521 | + float64[:], |
| 522 | + )], |
511 | 523 | '(n),(n),()->()', |
512 | 524 | nopython=True, |
513 | 525 | cache=True, |
|
0 commit comments