Skip to content

Commit 27f4c2e

Browse files
committed
core: safe TensorHandleRef + try_as_tensor_arg; errors improved; adopt across crates\n\n- TensorHandleRef::{try_from_parts, try_from_typed}\n- TensorHandleRef::try_as_tensor_arg (validates runtime-supported vectorization only)\n- Error enums: #[non_exhaustive], Display impls; UnsupportedVectorization { requested, supported }\n- Update attention/matmul/convolution/reduce/std to use try_as_tensor_arg\n- Runtime tests for handle validation and unsupported factors
1 parent ae51a6f commit 27f4c2e

File tree

13 files changed

+383
-67
lines changed

13 files changed

+383
-67
lines changed

crates/cubecl-attention/src/base.rs

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -114,11 +114,17 @@ pub fn launch_tmp<R: Runtime, AP: AttentionPrecision>(
114114
config.cube_dim(),
115115
cube_count_plan.resolve(),
116116
TensorInputsLaunch::new(
117-
query.as_tensor_arg(line_sizes.query),
118-
key.as_tensor_arg(line_sizes.key),
119-
value.as_tensor_arg(line_sizes.value),
117+
query
118+
.try_as_tensor_arg(line_sizes.query)
119+
.expect("valid vectorisation for query"),
120+
key.try_as_tensor_arg(line_sizes.key)
121+
.expect("valid vectorisation for key"),
122+
value
123+
.try_as_tensor_arg(line_sizes.value)
124+
.expect("valid vectorisation for value"),
120125
),
121-
out.as_tensor_arg(line_sizes.out),
126+
out.try_as_tensor_arg(line_sizes.out)
127+
.expect("valid vectorisation for out"),
122128
cube_count_plan.as_args(),
123129
config,
124130
);

crates/cubecl-attention/src/components/args.rs

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -521,9 +521,14 @@ impl<EG: Numeric> ConcreteInputsFactory for TensorInputs<EG> {
521521
line_sizes: &AttentionLineSizes,
522522
) -> Self::RuntimeArg<'a, R> {
523523
TensorInputsLaunch::new(
524-
query.as_tensor_arg(line_sizes.query),
525-
key.as_tensor_arg(line_sizes.key),
526-
value.as_tensor_arg(line_sizes.value),
524+
query
525+
.try_as_tensor_arg(line_sizes.query)
526+
.expect("valid vectorisation for query"),
527+
key.try_as_tensor_arg(line_sizes.key)
528+
.expect("valid vectorisation for key"),
529+
value
530+
.try_as_tensor_arg(line_sizes.value)
531+
.expect("valid vectorisation for value"),
527532
// mask.as_tensor_arg(line_sizes.value),
528533
)
529534
}
@@ -536,7 +541,8 @@ impl<EG: Numeric> ConcreteOutputFactory for Tensor<Line<EG>> {
536541
_problem: &AttentionProblem,
537542
line_sizes: &AttentionLineSizes,
538543
) -> Self::RuntimeArg<'a, R> {
539-
out.as_tensor_arg(line_sizes.out)
544+
out.try_as_tensor_arg(line_sizes.out)
545+
.expect("valid vectorisation for out")
540546
}
541547
}
542548

crates/cubecl-convolution/src/components/global/args.rs

Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -38,11 +38,23 @@ impl<Lhs: Numeric, Rhs: Numeric, EO: Numeric> ConcreteInputsFactory for TensorIn
3838
line_sizes: &MatmulLineSizes,
3939
) -> Self::RuntimeArg<'a, R> {
4040
TensorInputsLaunch::new(
41-
lhs.data().as_tensor_arg(line_sizes.lhs),
42-
lhs.scale().map(|it| it.as_tensor_arg(1)).into(),
43-
rhs.data().as_tensor_arg(line_sizes.rhs),
44-
rhs.scale().map(|it| it.as_tensor_arg(1)).into(),
45-
bias.map(|it| it.as_tensor_arg(line_sizes.out)).into(),
41+
lhs.data()
42+
.try_as_tensor_arg(line_sizes.lhs)
43+
.expect("valid vec lhs"),
44+
lhs.scale()
45+
.map(|it| it.try_as_tensor_arg(1).expect("vec=1"))
46+
.into(),
47+
rhs.data()
48+
.try_as_tensor_arg(line_sizes.rhs)
49+
.expect("valid vec rhs"),
50+
rhs.scale()
51+
.map(|it| it.try_as_tensor_arg(1).expect("vec=1"))
52+
.into(),
53+
bias.map(|it| {
54+
it.try_as_tensor_arg(line_sizes.out)
55+
.expect("valid vec out")
56+
})
57+
.into(),
4658
)
4759
}
4860
}
@@ -104,7 +116,9 @@ impl<Lhs: Numeric, Rhs: Numeric, EO: Numeric> ConcreteInputsFactory
104116
channels_per_pixel: tile_size_k,
105117
pixels_per_column: stage_m,
106118
},
107-
lhs.data().as_tensor_arg(line_sizes.lhs),
119+
lhs.data()
120+
.try_as_tensor_arg(line_sizes.lhs)
121+
.expect("valid vec lhs"),
108122
lhs_elem,
109123
)
110124
.with_elem_stride(elem_stride)
@@ -114,12 +128,15 @@ impl<Lhs: Numeric, Rhs: Numeric, EO: Numeric> ConcreteInputsFactory
114128
TensorMapFormat::Tiled {
115129
tile_size: stage_size_rhs,
116130
},
117-
rhs.data().as_tensor_arg(1),
131+
rhs.data().try_as_tensor_arg(1).expect("vec=1"),
118132
Rhs::as_type_native_unchecked(),
119133
)
120134
.with_prefetch(prefetch_rhs);
121135

122-
let bias = bias.map(|it| it.as_tensor_arg(line_sizes.out));
136+
let bias = bias.map(|it| {
137+
it.try_as_tensor_arg(line_sizes.out)
138+
.expect("valid vec out")
139+
});
123140

124141
// TODO: Think about how to handle scales with TMA
125142
TensorMapInputsLaunch::new(lhs, rhs, bias.into())

crates/cubecl-convolution/src/tests/convolution_test_launcher.rs

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -88,15 +88,12 @@ pub fn test_convolution_algorithm<A, Args, P, R>(
8888
}
8989

9090
let elem_size = size_of::<P::EG>();
91-
let lhs_handle = unsafe {
92-
TensorHandleRef::from_raw_parts(&lhs.handle, &lhs.strides, &lhs.shape, elem_size)
93-
};
94-
let rhs_handle = unsafe {
95-
TensorHandleRef::from_raw_parts(&rhs.handle, &rhs.strides, &rhs.shape, elem_size)
96-
};
97-
let out_handle = unsafe {
98-
TensorHandleRef::from_raw_parts(&out.handle, &out.strides, &out.shape, elem_size)
99-
};
91+
let lhs_handle = TensorHandleRef::<R>::try_from_parts(&lhs.handle, &lhs.strides, &lhs.shape, elem_size)
92+
.expect("valid lhs handle");
93+
let rhs_handle = TensorHandleRef::<R>::try_from_parts(&rhs.handle, &rhs.strides, &rhs.shape, elem_size)
94+
.expect("valid rhs handle");
95+
let out_handle = TensorHandleRef::<R>::try_from_parts(&out.handle, &out.strides, &out.shape, elem_size)
96+
.expect("valid out handle");
10097

10198
let lhs_handle = A::into_tensor_handle::<R, P::EG>(&client, &lhs_handle, MatmulIdent::Lhs);
10299
let rhs_handle = A::into_tensor_handle::<R, P::EG>(&client, &rhs_handle, MatmulIdent::Rhs);

crates/cubecl-core/src/frontend/container/tensor/launch.rs

Lines changed: 108 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,33 @@ use crate::{
1313

1414
use super::Tensor;
1515

16+
/// Errors that can occur when constructing a tensor handle safely.
17+
#[non_exhaustive]
18+
#[derive(Clone, Debug, PartialEq, Eq)]
19+
pub enum TensorHandleError {
20+
/// Rank of shape and strides differ.
21+
RankMismatch {
22+
shape_rank: usize,
23+
stride_rank: usize,
24+
},
25+
/// Element size must be > 0.
26+
ElemSizeZero,
27+
/// A stride is zero for a dimension with extent > 1.
28+
ZeroStride { axis: usize },
29+
}
30+
31+
/// Errors that can occur when converting a handle to a runtime tensor argument.
32+
#[non_exhaustive]
33+
#[derive(Clone, Debug, PartialEq, Eq)]
34+
pub enum TensorArgError {
35+
/// Requested vectorization factor is not supported by the runtime.
36+
UnsupportedVectorization { requested: u8, supported: &'static [u8] },
37+
/// Inner-most dimension is not contiguous (stride != 1) while vectorization > 1.
38+
NonContiguousInner,
39+
/// Inner-most dimension is not divisible by the vectorization factor.
40+
MisalignedVectorization { last_dim: usize, factor: u8 },
41+
}
42+
1643
/// Argument to be used for [tensors](Tensor) passed as arguments to kernels.
1744
#[derive(Debug)]
1845
pub enum TensorArg<'a, R: Runtime> {
@@ -178,17 +205,33 @@ impl<R: Runtime> ArgSettings<R> for TensorArg<'_, R> {
178205

179206
impl<'a, R: Runtime> TensorHandleRef<'a, R> {
180207
/// Convert the handle into a [tensor argument](TensorArg).
181-
pub fn as_tensor_arg(&'a self, vectorisation: u8) -> TensorArg<'a, R> {
208+
pub fn as_tensor_arg(&'a self, vectorization: u8) -> TensorArg<'a, R> {
182209
unsafe {
183210
TensorArg::from_raw_parts_and_size(
184211
self.handle,
185212
self.strides,
186213
self.shape,
187-
vectorisation,
214+
vectorization,
188215
self.elem_size,
189216
)
190217
}
191218
}
219+
/// Convert the handle into a [tensor argument](TensorArg) with basic safety checks
220+
/// for vectorization compatibility.
221+
/// Try to convert the handle into a tensor argument, validating that the
222+
/// requested vectorization factor is supported by the runtime. This does not
223+
/// enforce inner-most contiguity or alignment requirements as kernels may
224+
/// legally vectorize along axes other than the innermost.
225+
pub fn try_as_tensor_arg(
226+
&'a self,
227+
vectorization: u8,
228+
) -> Result<TensorArg<'a, R>, TensorArgError> {
229+
if !R::supported_line_sizes().contains(&vectorization) {
230+
return Err(TensorArgError::UnsupportedVectorization { requested: vectorization, supported: R::supported_line_sizes() });
231+
}
232+
Ok(self.as_tensor_arg(vectorization))
233+
}
234+
192235
/// Create a handle from raw parts.
193236
///
194237
/// # Safety
@@ -209,4 +252,67 @@ impl<'a, R: Runtime> TensorHandleRef<'a, R> {
209252
runtime: PhantomData,
210253
}
211254
}
255+
256+
/// Safely create a tensor handle from raw parts with basic shape/stride validation.
257+
pub fn try_from_parts(
258+
handle: &'a cubecl_runtime::server::Handle,
259+
strides: &'a [usize],
260+
shape: &'a [usize],
261+
elem_size: usize,
262+
) -> Result<Self, TensorHandleError> {
263+
if shape.len() != strides.len() {
264+
return Err(TensorHandleError::RankMismatch {
265+
shape_rank: shape.len(),
266+
stride_rank: strides.len(),
267+
});
268+
}
269+
if elem_size == 0 {
270+
return Err(TensorHandleError::ElemSizeZero);
271+
}
272+
// Disallow zero strides when corresponding dimension extent > 1 (broadcasted dims with extent 1 are allowed).
273+
for (i, (&s, &d)) in strides.iter().zip(shape.iter()).enumerate() {
274+
if s == 0 && d > 1 {
275+
return Err(TensorHandleError::ZeroStride { axis: i });
276+
}
277+
}
278+
Ok(unsafe { Self::from_raw_parts(handle, strides, shape, elem_size) })
279+
}
280+
281+
/// Safely create a tensor handle from raw parts using the element type for size.
282+
pub fn try_from_typed<E: CubePrimitive>(
283+
handle: &'a cubecl_runtime::server::Handle,
284+
strides: &'a [usize],
285+
shape: &'a [usize],
286+
) -> Result<Self, TensorHandleError> {
287+
let elem_size = E::size().expect("Element should have a size");
288+
Self::try_from_parts(handle, strides, shape, elem_size)
289+
}
290+
}
291+
292+
impl core::fmt::Display for TensorHandleError {
293+
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
294+
match self {
295+
TensorHandleError::RankMismatch { shape_rank, stride_rank } => {
296+
write!(f, "rank mismatch (shape={}, strides={})", shape_rank, stride_rank)
297+
}
298+
TensorHandleError::ElemSizeZero => write!(f, "element size must be > 0"),
299+
TensorHandleError::ZeroStride { axis } => write!(f, "zero stride on axis {} with extent > 1", axis),
300+
}
301+
}
302+
}
303+
304+
impl core::fmt::Display for TensorArgError {
305+
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
306+
match self {
307+
TensorArgError::UnsupportedVectorization { requested, supported } => {
308+
write!(f, "unsupported vectorization {}, supported: {:?}", requested, supported)
309+
}
310+
TensorArgError::NonContiguousInner => write!(f, "non-contiguous innermost dimension for vectorized access"),
311+
TensorArgError::MisalignedVectorization { last_dim, factor } => write!(
312+
f,
313+
"innermost dimension {} not divisible by vectorization {}",
314+
last_dim, factor
315+
),
316+
}
317+
}
212318
}

crates/cubecl-core/src/runtime_tests/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ pub mod sequence;
2121
pub mod slice;
2222
pub mod synchronization;
2323
pub mod tensor;
24+
pub mod tensor_handle;
2425
pub mod tensormap;
2526
pub mod to_client;
2627
pub mod topology;
@@ -138,6 +139,7 @@ macro_rules! testgen_untyped {
138139
cubecl_core::testgen_comparison!();
139140

140141
cubecl_core::testgen_to_client!();
142+
cubecl_core::testgen_tensor_handle!();
141143
};
142144
}
143145

0 commit comments

Comments
 (0)