-
Notifications
You must be signed in to change notification settings - Fork 31
Open
Description
Believe I found an error with storing outliers for 1d data in the lorenzo predictor.
There is no out of bounds check for storing the outliers, so when looking at outliers there are sometimes extra unexplained outliers with indices outside the range of the data. This can result in occasional annoying cuda error that can be difficult to debug. I think this issue also exists with the 3d1l predictor, but have not tested this or other predictors.
This is the fix I used for the 1d kernel:
$ git diff psz/src/kernel/detail/lrz_c.cuhip.inl
diff --git a/psz/src/kernel/detail/lrz_c.cuhip.inl b/psz/src/kernel/detail/lrz_c.cuhip.inl
index c71ab87..785fd0e 100644
--- a/psz/src/kernel/detail/lrz_c.cuhip.inl
+++ b/psz/src/kernel/detail/lrz_c.cuhip.inl
@@ -115,9 +115,12 @@ __global__ void KERNEL_CUHIP_c_lorenzo_1d1l(
}
if (not quantizable) {
- auto cur_idx = atomicAdd(out_cn, 1);
- out_cidx[cur_idx] = id_base + threadIdx.x * Seq + ix;
- out_cval[cur_idx] = candidate;
+ auto global_idx = id_base + threadIdx.x * Seq + ix;
+ if (global_idx < data_len3.x) {
+ auto cur_idx = atomicAdd(out_cn, 1);
+ out_cidx[cur_idx] = id_base + threadIdx.x * Seq + ix;
+ out_cval[cur_idx] = candidate;
+ }
}
}
__syncthreads();
Metadata
Metadata
Assignees
Labels
No labels