Skip to content

Commit 3a98f13

Browse files
committed
core: safe TensorHandleRef + try_as_tensor_arg; errors improved; adopt across crates
- TensorHandleRef::{try_from_parts, try_from_typed} - TensorHandleRef::try_as_tensor_arg (validates runtime-supported vectorization only) - Errors: #[non_exhaustive], Display impls; UnsupportedVectorization { requested, supported } - Adopt try_as_tensor_arg in attention/matmul/convolution/reduce/std - Runtime tests for handle validation and unsupported vectorization factors core(tensor): avoid redundant checks in hot paths; use debug_asserts and clarify try_* docs internal: use direct as_tensor_arg in internal launch paths; reserve try_* for FFI/tests
1 parent cf40b4e commit 3a98f13

File tree

5 files changed

+377
-37
lines changed

5 files changed

+377
-37
lines changed

crates/cubecl-convolution/src/tests/convolution_test_launcher.rs

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -88,15 +88,15 @@ pub fn test_convolution_algorithm<A, Args, P, R>(
8888
}
8989

9090
let elem_size = size_of::<P::EG>();
91-
let lhs_handle = unsafe {
92-
TensorHandleRef::from_raw_parts(&lhs.handle, &lhs.strides, &lhs.shape, elem_size)
93-
};
94-
let rhs_handle = unsafe {
95-
TensorHandleRef::from_raw_parts(&rhs.handle, &rhs.strides, &rhs.shape, elem_size)
96-
};
97-
let out_handle = unsafe {
98-
TensorHandleRef::from_raw_parts(&out.handle, &out.strides, &out.shape, elem_size)
99-
};
91+
let lhs_handle =
92+
TensorHandleRef::<R>::try_from_parts(&lhs.handle, &lhs.strides, &lhs.shape, elem_size)
93+
.expect("valid lhs handle");
94+
let rhs_handle =
95+
TensorHandleRef::<R>::try_from_parts(&rhs.handle, &rhs.strides, &rhs.shape, elem_size)
96+
.expect("valid rhs handle");
97+
let out_handle =
98+
TensorHandleRef::<R>::try_from_parts(&out.handle, &out.strides, &out.shape, elem_size)
99+
.expect("valid out handle");
100100

101101
let lhs_handle = A::into_tensor_handle::<R, P::EG>(&client, &lhs_handle, MatmulIdent::Lhs);
102102
let rhs_handle = A::into_tensor_handle::<R, P::EG>(&client, &rhs_handle, MatmulIdent::Rhs);

crates/cubecl-core/src/frontend/container/tensor/launch.rs

Lines changed: 167 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,36 @@ use crate::{
1313

1414
use super::Tensor;
1515

16+
/// Errors that can occur when constructing a tensor handle safely.
17+
#[non_exhaustive]
18+
#[derive(Clone, Debug, PartialEq, Eq)]
19+
pub enum TensorHandleError {
20+
/// Rank of shape and strides differ.
21+
RankMismatch {
22+
shape_rank: usize,
23+
stride_rank: usize,
24+
},
25+
/// Element size must be > 0.
26+
ElemSizeZero,
27+
/// A stride is zero for a dimension with extent > 1.
28+
ZeroStride { axis: usize },
29+
}
30+
31+
/// Errors that can occur when converting a handle to a runtime tensor argument.
32+
#[non_exhaustive]
33+
#[derive(Clone, Debug, PartialEq, Eq)]
34+
pub enum TensorArgError {
35+
/// Requested vectorization factor is not supported by the runtime.
36+
UnsupportedVectorization {
37+
requested: u8,
38+
supported: &'static [u8],
39+
},
40+
/// Inner-most dimension is not contiguous (stride != 1) while vectorization > 1.
41+
NonContiguousInner,
42+
/// Inner-most dimension is not divisible by the vectorization factor.
43+
MisalignedVectorization { last_dim: usize, factor: u8 },
44+
}
45+
1646
/// Argument to be used for [tensors](Tensor) passed as arguments to kernels.
1747
#[derive(Debug)]
1848
pub enum TensorArg<'a, R: Runtime> {
@@ -178,17 +208,50 @@ impl<R: Runtime> ArgSettings<R> for TensorArg<'_, R> {
178208

179209
impl<'a, R: Runtime> TensorHandleRef<'a, R> {
180210
/// Convert the handle into a [tensor argument](TensorArg).
181-
pub fn as_tensor_arg(&'a self, vectorisation: u8) -> TensorArg<'a, R> {
211+
pub fn as_tensor_arg(&'a self, vectorization: u8) -> TensorArg<'a, R> {
212+
// In debug builds, assert that the requested vectorization is supported
213+
// by the runtime. Validation of the chosen factor should normally be
214+
// performed upstream (at selection time) to avoid redundant checks in
215+
// hot paths.
216+
debug_assert!(
217+
R::supported_line_sizes().contains(&vectorization),
218+
"unsupported vectorization {} (supported: {:?})",
219+
vectorization,
220+
R::supported_line_sizes()
221+
);
182222
unsafe {
183223
TensorArg::from_raw_parts_and_size(
184224
self.handle,
185225
self.strides,
186226
self.shape,
187-
vectorisation,
227+
vectorization,
188228
self.elem_size,
189229
)
190230
}
191231
}
232+
/// Convert the handle into a [tensor argument](TensorArg) with basic safety checks
233+
/// for vectorization compatibility.
234+
///
235+
/// Note: This convenience is primarily intended for host wrappers / FFI
236+
/// ingestion paths. In internal code, prefer validating the chosen
237+
/// vectorization factor at selection time and then calling
238+
/// [`as_tensor_arg`], to avoid redundant work in hot paths.
239+
///
240+
/// This does not enforce inner‑most contiguity or alignment requirements as
241+
/// kernels may vectorize along axes other than the innermost.
242+
pub fn try_as_tensor_arg(
243+
&'a self,
244+
vectorization: u8,
245+
) -> Result<TensorArg<'a, R>, TensorArgError> {
246+
if !R::supported_line_sizes().contains(&vectorization) {
247+
return Err(TensorArgError::UnsupportedVectorization {
248+
requested: vectorization,
249+
supported: R::supported_line_sizes(),
250+
});
251+
}
252+
Ok(self.as_tensor_arg(vectorization))
253+
}
254+
192255
/// Create a handle from raw parts.
193256
///
194257
/// # Safety
@@ -201,6 +264,21 @@ impl<'a, R: Runtime> TensorHandleRef<'a, R> {
201264
shape: &'a [usize],
202265
elem_size: usize,
203266
) -> Self {
267+
// Basic invariants for debug builds only; upstream layers are expected
268+
// to ensure correctness in release builds.
269+
debug_assert_eq!(
270+
shape.len(),
271+
strides.len(),
272+
"rank mismatch (shape={}, strides={})",
273+
shape.len(),
274+
strides.len()
275+
);
276+
debug_assert!(elem_size > 0, "element size must be > 0");
277+
// Note: zero strides are permitted here to support explicit broadcast
278+
// views in advanced/internal paths. The checked constructor
279+
// (`try_from_parts`) rejects them when `d > 1` to provide safety at
280+
// boundaries; callers who intentionally need zero‑stride broadcasting
281+
// can opt into this `unsafe` API.
204282
Self {
205283
handle,
206284
strides,
@@ -209,4 +287,91 @@ impl<'a, R: Runtime> TensorHandleRef<'a, R> {
209287
runtime: PhantomData,
210288
}
211289
}
290+
291+
/// Safely create a tensor handle from raw parts with basic shape/stride validation.
292+
///
293+
/// Note: This is mainly useful for host / FFI boundaries to surface clear
294+
/// errors early. Internal code should ensure these invariants when
295+
/// constructing handles and may use the `unsafe` constructor directly in
296+
/// performance‑critical paths.
297+
pub fn try_from_parts(
298+
handle: &'a cubecl_runtime::server::Handle,
299+
strides: &'a [usize],
300+
shape: &'a [usize],
301+
elem_size: usize,
302+
) -> Result<Self, TensorHandleError> {
303+
if shape.len() != strides.len() {
304+
return Err(TensorHandleError::RankMismatch {
305+
shape_rank: shape.len(),
306+
stride_rank: strides.len(),
307+
});
308+
}
309+
if elem_size == 0 {
310+
return Err(TensorHandleError::ElemSizeZero);
311+
}
312+
// Disallow zero strides when corresponding dimension extent > 1 (broadcasted dims with extent 1 are allowed).
313+
for (i, (&s, &d)) in strides.iter().zip(shape.iter()).enumerate() {
314+
if s == 0 && d > 1 {
315+
return Err(TensorHandleError::ZeroStride { axis: i });
316+
}
317+
}
318+
Ok(unsafe { Self::from_raw_parts(handle, strides, shape, elem_size) })
319+
}
320+
321+
/// Safely create a tensor handle from raw parts using the element type for size.
322+
pub fn try_from_typed<E: CubePrimitive>(
323+
handle: &'a cubecl_runtime::server::Handle,
324+
strides: &'a [usize],
325+
shape: &'a [usize],
326+
) -> Result<Self, TensorHandleError> {
327+
let elem_size = E::size().expect("Element should have a size");
328+
Self::try_from_parts(handle, strides, shape, elem_size)
329+
}
330+
}
331+
332+
impl core::fmt::Display for TensorHandleError {
333+
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
334+
match self {
335+
TensorHandleError::RankMismatch {
336+
shape_rank,
337+
stride_rank,
338+
} => {
339+
write!(
340+
f,
341+
"rank mismatch (shape={}, strides={})",
342+
shape_rank, stride_rank
343+
)
344+
}
345+
TensorHandleError::ElemSizeZero => write!(f, "element size must be > 0"),
346+
TensorHandleError::ZeroStride { axis } => {
347+
write!(f, "zero stride on axis {} with extent > 1", axis)
348+
}
349+
}
350+
}
351+
}
352+
353+
impl core::fmt::Display for TensorArgError {
354+
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
355+
match self {
356+
TensorArgError::UnsupportedVectorization {
357+
requested,
358+
supported,
359+
} => {
360+
write!(
361+
f,
362+
"unsupported vectorization {}, supported: {:?}",
363+
requested, supported
364+
)
365+
}
366+
TensorArgError::NonContiguousInner => write!(
367+
f,
368+
"non-contiguous innermost dimension for vectorized access"
369+
),
370+
TensorArgError::MisalignedVectorization { last_dim, factor } => write!(
371+
f,
372+
"innermost dimension {} not divisible by vectorization {}",
373+
last_dim, factor
374+
),
375+
}
376+
}
212377
}

crates/cubecl-core/src/runtime_tests/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ pub mod sequence;
2121
pub mod slice;
2222
pub mod synchronization;
2323
pub mod tensor;
24+
pub mod tensor_handle;
2425
pub mod tensormap;
2526
pub mod to_client;
2627
pub mod topology;
@@ -138,6 +139,7 @@ macro_rules! testgen_untyped {
138139
cubecl_core::testgen_comparison!();
139140

140141
cubecl_core::testgen_to_client!();
142+
cubecl_core::testgen_tensor_handle!();
141143
};
142144
}
143145

0 commit comments

Comments
 (0)