Skip to content

Commit 2e6d385

Browse files
committed
[FIX] intp -> unitp for cupy
This will need to handle -ve fill value for count
1 parent 6a1a4c7 commit 2e6d385

File tree

1 file changed

+21
-19
lines changed

1 file changed

+21
-19
lines changed

flox/aggregations.py

Lines changed: 21 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -228,8 +228,8 @@ def __repr__(self):
228228
combine="sum",
229229
fill_value=0,
230230
final_fill_value=0,
231-
dtypes=np.intp,
232-
final_dtype=np.intp,
231+
dtypes=np.uintp,
232+
final_dtype=np.uintp,
233233
)
234234

235235
# note that the fill values are the result of np.func([np.nan, np.nan])
@@ -250,7 +250,7 @@ def __repr__(self):
250250
combine=("sum", "sum"),
251251
finalize=lambda sum_, count: sum_ / count,
252252
fill_value=(0, 0),
253-
dtypes=(None, np.intp),
253+
dtypes=(None, np.uintp),
254254
final_dtype=np.floating,
255255
)
256256
nanmean = Aggregation(
@@ -259,7 +259,7 @@ def __repr__(self):
259259
combine=("sum", "sum"),
260260
finalize=lambda sum_, count: sum_ / count,
261261
fill_value=(0, 0),
262-
dtypes=(None, np.intp),
262+
dtypes=(None, np.uintp),
263263
final_dtype=np.floating,
264264
)
265265

@@ -283,7 +283,7 @@ def _std_finalize(sumsq, sum_, count, ddof=0):
283283
finalize=_var_finalize,
284284
fill_value=0,
285285
final_fill_value=np.nan,
286-
dtypes=(None, None, np.intp),
286+
dtypes=(None, None, np.uintp),
287287
final_dtype=np.floating,
288288
)
289289
nanvar = Aggregation(
@@ -293,7 +293,7 @@ def _std_finalize(sumsq, sum_, count, ddof=0):
293293
finalize=_var_finalize,
294294
fill_value=0,
295295
final_fill_value=np.nan,
296-
dtypes=(None, None, np.intp),
296+
dtypes=(None, None, np.uintp),
297297
final_dtype=np.floating,
298298
)
299299
std = Aggregation(
@@ -303,7 +303,7 @@ def _std_finalize(sumsq, sum_, count, ddof=0):
303303
finalize=_std_finalize,
304304
fill_value=0,
305305
final_fill_value=np.nan,
306-
dtypes=(None, None, np.intp),
306+
dtypes=(None, None, np.uintp),
307307
final_dtype=np.floating,
308308
)
309309
nanstd = Aggregation(
@@ -313,7 +313,7 @@ def _std_finalize(sumsq, sum_, count, ddof=0):
313313
finalize=_std_finalize,
314314
fill_value=0,
315315
final_fill_value=np.nan,
316-
dtypes=(None, None, np.intp),
316+
dtypes=(None, None, np.uintp),
317317
final_dtype=np.floating,
318318
)
319319

@@ -336,7 +336,7 @@ def argreduce_preprocess(array, axis):
336336
assert len(axis) == 1
337337
axis = axis[0]
338338

339-
idx = dask.array.arange(array.shape[axis], chunks=array.chunks[axis], dtype=np.intp)
339+
idx = dask.array.arange(array.shape[axis], chunks=array.chunks[axis], dtype=np.uintp)
340340
# broadcast (TODO: is this needed?)
341341
idx = idx[tuple(slice(None) if i == axis else np.newaxis for i in range(array.ndim))]
342342

@@ -362,8 +362,8 @@ def _zip_index(array_, idx_):
362362
fill_value=(dtypes.NINF, 0),
363363
final_fill_value=-1,
364364
finalize=lambda *x: x[1],
365-
dtypes=(None, np.intp),
366-
final_dtype=np.intp,
365+
dtypes=(None, np.uintp),
366+
final_dtype=np.uintp,
367367
)
368368

369369
argmin = Aggregation(
@@ -375,8 +375,8 @@ def _zip_index(array_, idx_):
375375
fill_value=(dtypes.INF, 0),
376376
final_fill_value=-1,
377377
finalize=lambda *x: x[1],
378-
dtypes=(None, np.intp),
379-
final_dtype=np.intp,
378+
dtypes=(None, np.uintp),
379+
final_dtype=np.uintp,
380380
)
381381

382382
nanargmax = Aggregation(
@@ -388,8 +388,8 @@ def _zip_index(array_, idx_):
388388
fill_value=(dtypes.NINF, -1),
389389
final_fill_value=-1,
390390
finalize=lambda *x: x[1],
391-
dtypes=(None, np.intp),
392-
final_dtype=np.intp,
391+
dtypes=(None, np.uintp),
392+
final_dtype=np.uintp,
393393
)
394394

395395
nanargmin = Aggregation(
@@ -401,8 +401,8 @@ def _zip_index(array_, idx_):
401401
fill_value=(dtypes.INF, -1),
402402
final_fill_value=-1,
403403
finalize=lambda *x: x[1],
404-
dtypes=(None, np.intp),
405-
final_dtype=np.intp,
404+
dtypes=(None, np.uintp),
405+
final_dtype=np.uintp,
406406
)
407407

408408
first = Aggregation("first", chunk=None, combine=None, fill_value=0)
@@ -520,7 +520,9 @@ def _initialize_aggregation(
520520
agg.combine += ("sum",)
521521
agg.fill_value["intermediate"] += (0,)
522522
agg.fill_value["numpy"] += (0,)
523-
agg.dtype["intermediate"] += (np.intp,)
524-
agg.dtype["numpy"] += (np.intp,)
523+
# uintp is supported by cupy, intp is not
524+
# Also count is >=0, so uint should be fine.
525+
agg.dtype["intermediate"] += (np.uintp,)
526+
agg.dtype["numpy"] += (np.uintp,)
525527

526528
return agg

0 commit comments

Comments
 (0)