Skip to content

Commit 432a11d

Browse files
committed
feat: add more kernels with cumulative sum
1 parent 0299bab commit 432a11d

9 files changed

+163
-14
lines changed

dev/generate-kernel-signatures.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,9 +50,12 @@
5050
"awkward_missing_repeat",
5151
"awkward_RegularArray_getitem_jagged_expand",
5252
"awkward_ListArray_getitem_jagged_expand",
53+
"awkward_ListArray_getitem_jagged_carrylen",
5354
"awkward_ListArray_getitem_next_array_advanced",
5455
"awkward_ListArray_getitem_next_array",
5556
"awkward_ListArray_getitem_next_at",
57+
"awkward_ListArray_getitem_next_range_counts",
58+
"awkward_ListArray_rpad_and_clip_length_axis1",
5659
"awkward_NumpyArray_reduce_adjust_starts_64",
5760
"awkward_NumpyArray_reduce_adjust_starts_shifts_64",
5861
"awkward_RegularArray_getitem_next_at",

dev/generate-tests.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -690,9 +690,12 @@ def gencpuunittests(specdict):
690690
"awkward_missing_repeat",
691691
"awkward_RegularArray_getitem_jagged_expand",
692692
"awkward_ListArray_getitem_jagged_expand",
693+
"awkward_ListArray_getitem_jagged_carrylen",
693694
"awkward_ListArray_getitem_next_array_advanced",
694695
"awkward_ListArray_getitem_next_array",
695696
"awkward_ListArray_getitem_next_at",
697+
"awkward_ListArray_getitem_next_range_counts",
698+
"awkward_ListArray_rpad_and_clip_length_axis1",
696699
"awkward_NumpyArray_reduce_adjust_starts_64",
697700
"awkward_NumpyArray_reduce_adjust_starts_shifts_64",
698701
"awkward_RegularArray_getitem_next_at",

src/awkward/_connect/cuda/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,12 +84,15 @@ def fetch_template_specializations(kernel_dict):
8484
"awkward_IndexedArray_flatten_nextcarry",
8585
"awkward_IndexedArray_getitem_nextcarry",
8686
"awkward_IndexedArray_getitem_nextcarry_outindex",
87+
"awkward_ListArray_getitem_next_range_counts",
8788
"awkward_IndexedArray_index_of_nulls",
8889
"awkward_IndexedArray_reduce_next_64",
8990
"awkward_IndexedArray_reduce_next_nonlocal_nextshifts_64",
9091
"awkward_IndexedArray_reduce_next_nonlocal_nextshifts_fromshifts_64",
9192
"awkward_IndexedOptionArray_rpad_and_clip_mask_axis1",
9293
"awkward_ListArray_compact_offsets",
94+
"awkward_ListArray_getitem_jagged_carrylen",
95+
"awkward_ListArray_rpad_and_clip_length_axis1",
9396
"awkward_MaskedArray_getitem_next_jagged_project",
9497
"awkward_UnionArray_project",
9598
"awkward_reduce_count_64",

src/awkward/_connect/cuda/cuda_kernels/awkward_ByteMaskedArray_numnull.cu

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@ awkward_ByteMaskedArray_numnull_a(T* numnull,
2424
int64_t thread_id = blockIdx.x * blockDim.x + threadIdx.x;
2525

2626
if (thread_id < length) {
27-
*numnull = 0;
2827
if ((mask[thread_id] != 0) != validwhen) {
2928
scan_in_array[thread_id] = 1;
3029
}
@@ -45,6 +44,6 @@ awkward_ByteMaskedArray_numnull_b(T* numnull,
4544
uint64_t invocation_index,
4645
uint64_t* err_code) {
4746
if (err_code[0] == NO_ERROR) {
48-
*numnull = scan_in_array[length - 1];
47+
*numnull = (T)scan_in_array[length - 1];
4948
}
5049
}

src/awkward/_connect/cuda/cuda_kernels/awkward_IndexedArray_numnull.cu

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,11 @@
1414
template <typename T, typename C>
1515
__global__ void
1616
awkward_IndexedArray_numnull_a(T* numnull,
17-
const C* fromindex,
18-
int64_t lenindex,
19-
int64_t* scan_in_array,
20-
uint64_t invocation_index,
21-
uint64_t* err_code) {
17+
const C* fromindex,
18+
int64_t lenindex,
19+
int64_t* scan_in_array,
20+
uint64_t invocation_index,
21+
uint64_t* err_code) {
2222
if (err_code[0] == NO_ERROR) {
2323
int64_t thread_id = blockIdx.x * blockDim.x + threadIdx.x;
2424

@@ -36,12 +36,12 @@ awkward_IndexedArray_numnull_a(T* numnull,
3636
template <typename T, typename C>
3737
__global__ void
3838
awkward_IndexedArray_numnull_b(T* numnull,
39-
const C* fromindex,
40-
int64_t lenindex,
41-
int64_t* scan_in_array,
42-
uint64_t invocation_index,
43-
uint64_t* err_code) {
39+
const C* fromindex,
40+
int64_t lenindex,
41+
int64_t* scan_in_array,
42+
uint64_t invocation_index,
43+
uint64_t* err_code) {
4444
if (err_code[0] == NO_ERROR) {
45-
*numnull = scan_in_array[lenindex - 1];
45+
*numnull = (T)scan_in_array[lenindex - 1];
4646
}
4747
}

src/awkward/_connect/cuda/cuda_kernels/awkward_IndexedArray_numnull_parents.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ awkward_IndexedArray_numnull_parents_b(T* numnull,
5454
numnull[thread_id] = 0;
5555
}
5656
}
57-
*tolength = scan_in_array[lenindex - 1];
57+
*tolength = (T)scan_in_array[lenindex - 1];
5858
}
5959
}
6060

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
// BSD 3-Clause License; see https://github.com/scikit-hep/awkward-1.0/blob/main/LICENSE
2+
3+
// BEGIN PYTHON
4+
// def f(grid, block, args):
5+
// (carrylen, slicestarts, slicestops, sliceouterlen, invocation_index, err_code) = args
6+
// scan_in_array = cupy.empty(sliceouterlen, dtype=cupy.int64)
7+
// cuda_kernel_templates.get_function(fetch_specialization(["awkward_ListArray_getitem_jagged_carrylen_a", carrylen.dtype, slicestarts.dtype, slicestops.dtype]))(grid, block, (carrylen, slicestarts, slicestops, sliceouterlen, scan_in_array, invocation_index, err_code))
8+
// scan_in_array = inclusive_scan(grid, block, (scan_in_array, invocation_index, err_code))
9+
// cuda_kernel_templates.get_function(fetch_specialization(["awkward_ListArray_getitem_jagged_carrylen_b", carrylen.dtype, slicestarts.dtype, slicestops.dtype]))(grid, block, (carrylen, slicestarts, slicestops, sliceouterlen, scan_in_array, invocation_index, err_code))
10+
// out["awkward_ListArray_getitem_jagged_carrylen_a", {dtype_specializations}] = None
11+
// out["awkward_ListArray_getitem_jagged_carrylen_b", {dtype_specializations}] = None
12+
// END PYTHON
13+
14+
template <typename T, typename C, typename U>
15+
__global__ void
16+
awkward_ListArray_getitem_jagged_carrylen_a(T* carrylen,
17+
const C* slicestarts,
18+
const U* slicestops,
19+
int64_t sliceouterlen,
20+
int64_t* scan_in_array,
21+
uint64_t invocation_index,
22+
uint64_t* err_code) {
23+
if (err_code[0] == NO_ERROR) {
24+
int64_t thread_id = blockIdx.x * blockDim.x + threadIdx.x;
25+
26+
if (thread_id < sliceouterlen) {
27+
scan_in_array[thread_id] = (T)(slicestops[thread_id] - slicestarts[thread_id]);
28+
}
29+
}
30+
}
31+
32+
template <typename T, typename C, typename U>
33+
__global__ void
34+
awkward_ListArray_getitem_jagged_carrylen_b(T* carrylen,
35+
const C* slicestarts,
36+
const U* slicestops,
37+
int64_t sliceouterlen,
38+
int64_t* scan_in_array,
39+
uint64_t invocation_index,
40+
uint64_t* err_code) {
41+
if (err_code[0] == NO_ERROR) {
42+
*carrylen = scan_in_array[sliceouterlen - 1];
43+
}
44+
}
45+
46+
// fails for sliceouterlen = 1
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
// BSD 3-Clause License; see https://github.com/scikit-hep/awkward-1.0/blob/main/LICENSE
2+
3+
// BEGIN PYTHON
4+
// def f(grid, block, args):
5+
// (total, fromoffsets, lenstarts, invocation_total, err_code) = args
6+
// scan_in_array = cupy.empty(lenstarts, dtype=cupy.int64)
7+
// cuda_kernel_templates.get_function(fetch_specialization(["awkward_ListArray_getitem_next_range_counts_a", total.dtype, fromoffsets.dtype]))(grid, block, (total, fromoffsets, lenstarts, scan_in_array, invocation_total, err_code))
8+
// scan_in_array = inclusive_scan(grid, block, (scan_in_array, invocation_total, err_code))
9+
// cuda_kernel_templates.get_function(fetch_specialization(["awkward_ListArray_getitem_next_range_counts_b", total.dtype, fromoffsets.dtype]))(grid, block, (total, fromoffsets, lenstarts, scan_in_array, invocation_total, err_code))
10+
// out["awkward_ListArray_getitem_next_range_counts_a", {dtype_specializations}] = None
11+
// out["awkward_ListArray_getitem_next_range_counts_b", {dtype_specializations}] = None
12+
// END PYTHON
13+
14+
template <typename T, typename C>
15+
__global__ void
16+
awkward_ListArray_getitem_next_range_counts_a(T* total,
17+
const C* fromoffsets,
18+
int64_t lenstarts,
19+
int64_t* scan_in_array,
20+
uint64_t invocation_total,
21+
uint64_t* err_code) {
22+
if (err_code[0] == NO_ERROR) {
23+
int64_t thread_id = blockIdx.x * blockDim.x + threadIdx.x;
24+
25+
if (thread_id < lenstarts) {
26+
scan_in_array[thread_id] = (T)(fromoffsets[thread_id + 1] - fromoffsets[thread_id]);
27+
}
28+
}
29+
}
30+
31+
template <typename T, typename C>
32+
__global__ void
33+
awkward_ListArray_getitem_next_range_counts_b(T* total,
34+
const C* fromoffsets,
35+
int64_t lenstarts,
36+
int64_t* scan_in_array,
37+
uint64_t invocation_total,
38+
uint64_t* err_code) {
39+
if (err_code[0] == NO_ERROR) {
40+
int64_t thread_id = blockIdx.x * blockDim.x + threadIdx.x;
41+
42+
if (thread_id < lenstarts) {
43+
*total = scan_in_array[lenstarts - 1];
44+
}
45+
}
46+
}
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
// BSD 3-Clause License; see https://github.com/scikit-hep/awkward-1.0/blob/main/LICENSE
2+
3+
// BEGIN PYTHON
4+
// def f(grid, block, args):
5+
// (tomin, fromstarts, fromstops, target, lenstarts, invocation_index, err_code) = args
6+
// scan_in_array = cupy.empty(lenstarts, dtype=cupy.int64)
7+
// cuda_kernel_templates.get_function(fetch_specialization(["awkward_ListArray_rpad_and_clip_length_axis1_a", tomin.dtype, fromstarts.dtype, fromstops.dtype]))(grid, block, (tomin, fromstarts, fromstops, target, lenstarts, scan_in_array, invocation_index, err_code))
8+
// scan_in_array = inclusive_scan(grid, block, (scan_in_array, invocation_index, err_code))
9+
// cuda_kernel_templates.get_function(fetch_specialization(["awkward_ListArray_rpad_and_clip_length_axis1_b", tomin.dtype, fromstarts.dtype, fromstops.dtype]))(grid, block, (tomin, fromstarts, fromstops, target, lenstarts, scan_in_array, invocation_index, err_code))
10+
// out["awkward_ListArray_rpad_and_clip_length_axis1_a", {dtype_specializations}] = None
11+
// out["awkward_ListArray_rpad_and_clip_length_axis1_b", {dtype_specializations}] = None
12+
// END PYTHON
13+
14+
template <typename T, typename C, typename U>
15+
__global__ void
16+
awkward_ListArray_rpad_and_clip_length_axis1_a(T* tomin,
17+
const C* fromstarts,
18+
const U* fromstops,
19+
int64_t target,
20+
int64_t lenstarts,
21+
int64_t* scan_in_array,
22+
uint64_t invocation_index,
23+
uint64_t* err_code) {
24+
if (err_code[0] == NO_ERROR) {
25+
int64_t thread_id = blockIdx.x * blockDim.x + threadIdx.x;
26+
27+
if (thread_id < lenstarts) {
28+
int64_t rangeval = fromstops[thread_id] - fromstarts[thread_id];
29+
scan_in_array[thread_id] = (target > rangeval) ? target : rangeval;
30+
}
31+
}
32+
}
33+
34+
template <typename T, typename C, typename U>
35+
__global__ void
36+
awkward_ListArray_rpad_and_clip_length_axis1_b(T* tomin,
37+
const C* fromstarts,
38+
const U* fromstops,
39+
int64_t target,
40+
int64_t lenstarts,
41+
int64_t* scan_in_array,
42+
uint64_t invocation_index,
43+
uint64_t* err_code) {
44+
if (err_code[0] == NO_ERROR) {
45+
*tomin = scan_in_array[lenstarts - 1];
46+
}
47+
}
48+
49+
// fails for lenstarts = 1

0 commit comments

Comments
 (0)