Skip to content

Commit 77556c2

Browse files
committed
core: safe TensorHandleRef + try_as_tensor_arg; errors improved; adopt across crates
- TensorHandleRef::{try_from_parts, try_from_typed} - TensorHandleRef::try_as_tensor_arg (validates runtime-supported vectorization only) - Errors: #[non_exhaustive], Display impls; UnsupportedVectorization { requested, supported } - Adopt try_as_tensor_arg in attention/matmul/convolution/reduce/std - Runtime tests for handle validation and unsupported vectorization factors core(tensor): avoid redundant checks in hot paths; use debug_asserts and clarify try_* docs internal: use direct as_tensor_arg in internal launch paths; reserve try_* for FFI/tests
1 parent cf40b4e commit 77556c2

File tree

6 files changed

+349
-40
lines changed

6 files changed

+349
-40
lines changed

crates/cubecl-convolution/src/tests/convolution_test_launcher.rs

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -88,15 +88,12 @@ pub fn test_convolution_algorithm<A, Args, P, R>(
8888
}
8989

9090
let elem_size = size_of::<P::EG>();
91-
let lhs_handle = unsafe {
92-
TensorHandleRef::from_raw_parts(&lhs.handle, &lhs.strides, &lhs.shape, elem_size)
93-
};
94-
let rhs_handle = unsafe {
95-
TensorHandleRef::from_raw_parts(&rhs.handle, &rhs.strides, &rhs.shape, elem_size)
96-
};
97-
let out_handle = unsafe {
98-
TensorHandleRef::from_raw_parts(&out.handle, &out.strides, &out.shape, elem_size)
99-
};
91+
let lhs_handle = TensorHandleRef::<R>::try_from_parts(&lhs.handle, &lhs.strides, &lhs.shape, elem_size)
92+
.expect("valid lhs handle");
93+
let rhs_handle = TensorHandleRef::<R>::try_from_parts(&rhs.handle, &rhs.strides, &rhs.shape, elem_size)
94+
.expect("valid rhs handle");
95+
let out_handle = TensorHandleRef::<R>::try_from_parts(&out.handle, &out.strides, &out.shape, elem_size)
96+
.expect("valid out handle");
10097

10198
let lhs_handle = A::into_tensor_handle::<R, P::EG>(&client, &lhs_handle, MatmulIdent::Lhs);
10299
let rhs_handle = A::into_tensor_handle::<R, P::EG>(&client, &rhs_handle, MatmulIdent::Rhs);

crates/cubecl-core/src/frontend/container/tensor/launch.rs

Lines changed: 145 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,33 @@ use crate::{
1313

1414
use super::Tensor;
1515

16+
/// Errors that can occur when constructing a tensor handle safely.
17+
#[non_exhaustive]
18+
#[derive(Clone, Debug, PartialEq, Eq)]
19+
pub enum TensorHandleError {
20+
/// Rank of shape and strides differ.
21+
RankMismatch {
22+
shape_rank: usize,
23+
stride_rank: usize,
24+
},
25+
/// Element size must be > 0.
26+
ElemSizeZero,
27+
/// A stride is zero for a dimension with extent > 1.
28+
ZeroStride { axis: usize },
29+
}
30+
31+
/// Errors that can occur when converting a handle to a runtime tensor argument.
32+
#[non_exhaustive]
33+
#[derive(Clone, Debug, PartialEq, Eq)]
34+
pub enum TensorArgError {
35+
/// Requested vectorization factor is not supported by the runtime.
36+
UnsupportedVectorization { requested: u8, supported: &'static [u8] },
37+
/// Inner-most dimension is not contiguous (stride != 1) while vectorization > 1.
38+
NonContiguousInner,
39+
/// Inner-most dimension is not divisible by the vectorization factor.
40+
MisalignedVectorization { last_dim: usize, factor: u8 },
41+
}
42+
1643
/// Argument to be used for [tensors](Tensor) passed as arguments to kernels.
1744
#[derive(Debug)]
1845
pub enum TensorArg<'a, R: Runtime> {
@@ -178,17 +205,47 @@ impl<R: Runtime> ArgSettings<R> for TensorArg<'_, R> {
178205

179206
impl<'a, R: Runtime> TensorHandleRef<'a, R> {
180207
/// Convert the handle into a [tensor argument](TensorArg).
181-
pub fn as_tensor_arg(&'a self, vectorisation: u8) -> TensorArg<'a, R> {
208+
pub fn as_tensor_arg(&'a self, vectorization: u8) -> TensorArg<'a, R> {
209+
// In debug builds, assert that the requested vectorization is supported
210+
// by the runtime. Validation of the chosen factor should normally be
211+
// performed upstream (at selection time) to avoid redundant checks in
212+
// hot paths.
213+
debug_assert!(
214+
R::supported_line_sizes().contains(&vectorization),
215+
"unsupported vectorization {} (supported: {:?})",
216+
vectorization,
217+
R::supported_line_sizes()
218+
);
182219
unsafe {
183220
TensorArg::from_raw_parts_and_size(
184221
self.handle,
185222
self.strides,
186223
self.shape,
187-
vectorisation,
224+
vectorization,
188225
self.elem_size,
189226
)
190227
}
191228
}
229+
/// Convert the handle into a [tensor argument](TensorArg) with basic safety checks
230+
/// for vectorization compatibility.
231+
///
232+
/// Note: This convenience is primarily intended for host wrappers / FFI
233+
/// ingestion paths. In internal code, prefer validating the chosen
234+
/// vectorization factor at selection time and then calling
235+
/// [`as_tensor_arg`], to avoid redundant work in hot paths.
236+
///
237+
/// This does not enforce inner‑most contiguity or alignment requirements as
238+
/// kernels may vectorize along axes other than the innermost.
239+
pub fn try_as_tensor_arg(
240+
&'a self,
241+
vectorization: u8,
242+
) -> Result<TensorArg<'a, R>, TensorArgError> {
243+
if !R::supported_line_sizes().contains(&vectorization) {
244+
return Err(TensorArgError::UnsupportedVectorization { requested: vectorization, supported: R::supported_line_sizes() });
245+
}
246+
Ok(self.as_tensor_arg(vectorization))
247+
}
248+
192249
/// Create a handle from raw parts.
193250
///
194251
/// # Safety
@@ -201,6 +258,24 @@ impl<'a, R: Runtime> TensorHandleRef<'a, R> {
201258
shape: &'a [usize],
202259
elem_size: usize,
203260
) -> Self {
261+
// Basic invariants for debug builds only; upstream layers are expected
262+
// to ensure correctness in release builds.
263+
debug_assert_eq!(
264+
shape.len(),
265+
strides.len(),
266+
"rank mismatch (shape={}, strides={})",
267+
shape.len(),
268+
strides.len()
269+
);
270+
debug_assert!(elem_size > 0, "element size must be > 0");
271+
// Disallow zero strides when corresponding dimension extent > 1
272+
for (i, (&s, &d)) in strides.iter().zip(shape.iter()).enumerate() {
273+
debug_assert!(
274+
!(s == 0 && d > 1),
275+
"zero stride on axis {} with extent > 1",
276+
i
277+
);
278+
}
204279
Self {
205280
handle,
206281
strides,
@@ -209,4 +284,72 @@ impl<'a, R: Runtime> TensorHandleRef<'a, R> {
209284
runtime: PhantomData,
210285
}
211286
}
287+
288+
/// Safely create a tensor handle from raw parts with basic shape/stride validation.
289+
///
290+
/// Note: This is mainly useful for host / FFI boundaries to surface clear
291+
/// errors early. Internal code should ensure these invariants when
292+
/// constructing handles and may use the `unsafe` constructor directly in
293+
/// performance‑critical paths.
294+
pub fn try_from_parts(
295+
handle: &'a cubecl_runtime::server::Handle,
296+
strides: &'a [usize],
297+
shape: &'a [usize],
298+
elem_size: usize,
299+
) -> Result<Self, TensorHandleError> {
300+
if shape.len() != strides.len() {
301+
return Err(TensorHandleError::RankMismatch {
302+
shape_rank: shape.len(),
303+
stride_rank: strides.len(),
304+
});
305+
}
306+
if elem_size == 0 {
307+
return Err(TensorHandleError::ElemSizeZero);
308+
}
309+
// Disallow zero strides when corresponding dimension extent > 1 (broadcasted dims with extent 1 are allowed).
310+
for (i, (&s, &d)) in strides.iter().zip(shape.iter()).enumerate() {
311+
if s == 0 && d > 1 {
312+
return Err(TensorHandleError::ZeroStride { axis: i });
313+
}
314+
}
315+
Ok(unsafe { Self::from_raw_parts(handle, strides, shape, elem_size) })
316+
}
317+
318+
/// Safely create a tensor handle from raw parts using the element type for size.
319+
pub fn try_from_typed<E: CubePrimitive>(
320+
handle: &'a cubecl_runtime::server::Handle,
321+
strides: &'a [usize],
322+
shape: &'a [usize],
323+
) -> Result<Self, TensorHandleError> {
324+
let elem_size = E::size().expect("Element should have a size");
325+
Self::try_from_parts(handle, strides, shape, elem_size)
326+
}
327+
}
328+
329+
impl core::fmt::Display for TensorHandleError {
330+
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
331+
match self {
332+
TensorHandleError::RankMismatch { shape_rank, stride_rank } => {
333+
write!(f, "rank mismatch (shape={}, strides={})", shape_rank, stride_rank)
334+
}
335+
TensorHandleError::ElemSizeZero => write!(f, "element size must be > 0"),
336+
TensorHandleError::ZeroStride { axis } => write!(f, "zero stride on axis {} with extent > 1", axis),
337+
}
338+
}
339+
}
340+
341+
impl core::fmt::Display for TensorArgError {
342+
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
343+
match self {
344+
TensorArgError::UnsupportedVectorization { requested, supported } => {
345+
write!(f, "unsupported vectorization {}, supported: {:?}", requested, supported)
346+
}
347+
TensorArgError::NonContiguousInner => write!(f, "non-contiguous innermost dimension for vectorized access"),
348+
TensorArgError::MisalignedVectorization { last_dim, factor } => write!(
349+
f,
350+
"innermost dimension {} not divisible by vectorization {}",
351+
last_dim, factor
352+
),
353+
}
354+
}
212355
}

crates/cubecl-core/src/runtime_tests/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ pub mod sequence;
2121
pub mod slice;
2222
pub mod synchronization;
2323
pub mod tensor;
24+
pub mod tensor_handle;
2425
pub mod tensormap;
2526
pub mod to_client;
2627
pub mod topology;
@@ -138,6 +139,7 @@ macro_rules! testgen_untyped {
138139
cubecl_core::testgen_comparison!();
139140

140141
cubecl_core::testgen_to_client!();
142+
cubecl_core::testgen_tensor_handle!();
141143
};
142144
}
143145

Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,165 @@
1+
use crate::prelude::*; // brings TensorArgError, TensorHandleError, TensorHandleRef, Runtime, ComputeClient
2+
3+
fn make_client<R: Runtime>() -> ComputeClient<R::Server, R::Channel> {
4+
R::client(&R::Device::default())
5+
}
6+
7+
pub fn test_handle_try_from_typed_ok_and_vec_checked_ok<R: Runtime>() {
8+
let client = make_client::<R>();
9+
let shape = vec![2usize, 8usize];
10+
let strides = compact_strides(&shape);
11+
let bytes = bytemuck::cast_slice::<f32, u8>(&vec![0.0f32; shape.iter().product()]).to_vec();
12+
let handle = client.create(&bytes);
13+
14+
let href = TensorHandleRef::<R>::try_from_typed::<f32>(&handle, &strides, &shape).expect("ok");
15+
16+
// Pick a supported factor that divides last dim (if any), else 1
17+
let mut picked = 1u8;
18+
for f in R::supported_line_sizes() {
19+
let f8 = (*f) as u8;
20+
if f8 > 1 && shape[1] % (*f as usize) == 0 {
21+
picked = f8;
22+
break;
23+
}
24+
}
25+
let _arg = href.try_as_tensor_arg(picked).expect("vec ok");
26+
}
27+
28+
pub fn test_handle_try_from_parts_rank_mismatch<R: Runtime>() {
29+
let client = make_client::<R>();
30+
let shape = vec![2usize, 4usize];
31+
let strides_good = compact_strides(&shape);
32+
let bytes = bytemuck::cast_slice::<f32, u8>(&vec![0.0f32; shape.iter().product()]).to_vec();
33+
let handle = client.create(&bytes);
34+
35+
let err = TensorHandleRef::<R>::try_from_parts(
36+
&handle,
37+
&strides_good[..1],
38+
&shape,
39+
core::mem::size_of::<f32>(),
40+
)
41+
.unwrap_err();
42+
match err {
43+
TensorHandleError::RankMismatch { .. } => {}
44+
_ => panic!("wrong error: {err:?}"),
45+
}
46+
}
47+
48+
pub fn test_handle_try_from_parts_zero_stride<R: Runtime>() {
49+
let client = make_client::<R>();
50+
let shape = vec![2usize, 4usize];
51+
let mut strides = compact_strides(&shape);
52+
strides[0] = 0; // invalid when dim > 1
53+
let bytes = bytemuck::cast_slice::<f32, u8>(&vec![0.0f32; shape.iter().product()]).to_vec();
54+
let handle = client.create(&bytes);
55+
56+
let err = TensorHandleRef::<R>::try_from_parts(
57+
&handle,
58+
&strides,
59+
&shape,
60+
core::mem::size_of::<f32>(),
61+
)
62+
.unwrap_err();
63+
match err {
64+
TensorHandleError::ZeroStride { .. } => {}
65+
_ => panic!("wrong error: {err:?}"),
66+
}
67+
}
68+
69+
pub fn test_vec_checked_unsupported_factor<R: Runtime>() {
70+
let client = make_client::<R>();
71+
let shape = vec![1usize, 8usize];
72+
let strides = compact_strides(&shape);
73+
let bytes = bytemuck::cast_slice::<f32, u8>(&vec![0.0f32; shape.iter().product()]).to_vec();
74+
let handle = client.create(&bytes);
75+
let href = TensorHandleRef::<R>::try_from_typed::<f32>(&handle, &strides, &shape).expect("ok");
76+
77+
// pick factor 7 which is typically unsupported
78+
let err = href.try_as_tensor_arg(7).unwrap_err();
79+
match err {
80+
TensorArgError::UnsupportedVectorization { .. } => {}
81+
_ => panic!("wrong error: {err:?}"),
82+
}
83+
}
84+
85+
pub fn test_vec_checked_noncontiguous_inner_allows_vectorized<R: Runtime>() {
86+
let client = make_client::<R>();
87+
let shape = vec![2usize, 8usize];
88+
let mut strides = compact_strides(&shape);
89+
// Make inner stride non-contiguous (allowed by checked API)
90+
strides[1] = 2;
91+
let bytes = bytemuck::cast_slice::<f32, u8>(&vec![0.0f32; shape.iter().product()]).to_vec();
92+
let handle = client.create(&bytes);
93+
94+
let href = TensorHandleRef::<R>::try_from_parts(
95+
&handle,
96+
&strides,
97+
&shape,
98+
core::mem::size_of::<f32>(),
99+
)
100+
.expect("ok");
101+
102+
// Choose a supported factor > 1 if available
103+
let mut picked = None;
104+
for f in R::supported_line_sizes() {
105+
if *f > 1 {
106+
picked = Some(*f as u8);
107+
break;
108+
}
109+
}
110+
if let Some(factor) = picked {
111+
let _ = href
112+
.try_as_tensor_arg(factor)
113+
.expect("non-contiguous inner allowed");
114+
}
115+
}
116+
117+
118+
// Misalignment (last dim not divisible by factor) is permitted; tail handling is kernel-specific.
119+
// We do not error on that case in the checked API.
120+
121+
#[macro_export]
122+
macro_rules! testgen_tensor_handle {
123+
() => {
124+
use super::*;
125+
126+
#[test]
127+
fn test_tensor_handle_try_from_typed_ok_and_vec_checked_ok() {
128+
cubecl_core::runtime_tests::tensor_handle::test_handle_try_from_typed_ok_and_vec_checked_ok::<TestRuntime>();
129+
}
130+
131+
#[test]
132+
fn test_tensor_handle_try_from_parts_rank_mismatch() {
133+
cubecl_core::runtime_tests::tensor_handle::test_handle_try_from_parts_rank_mismatch::<TestRuntime>();
134+
}
135+
136+
#[test]
137+
fn test_tensor_handle_try_from_parts_zero_stride() {
138+
cubecl_core::runtime_tests::tensor_handle::test_handle_try_from_parts_zero_stride::<TestRuntime>();
139+
}
140+
141+
#[test]
142+
fn test_vec_checked_unsupported_factor() {
143+
cubecl_core::runtime_tests::tensor_handle::test_vec_checked_unsupported_factor::<TestRuntime>();
144+
}
145+
146+
#[test]
147+
fn test_vec_checked_noncontiguous_inner_allows_vectorized() {
148+
cubecl_core::runtime_tests::tensor_handle::test_vec_checked_noncontiguous_inner_allows_vectorized::<TestRuntime>();
149+
}
150+
151+
};
152+
}
153+
154+
fn compact_strides(shape: &[usize]) -> Vec<usize> {
155+
let rank = shape.len();
156+
if rank == 0 {
157+
return vec![];
158+
}
159+
let mut strides = vec![0; rank];
160+
strides[rank - 1] = 1;
161+
for i in (0..rank - 1).rev() {
162+
strides[i] = strides[i + 1] * shape[i + 1];
163+
}
164+
strides
165+
}

0 commit comments

Comments
 (0)