@@ -13,6 +13,33 @@ use crate::{
1313
1414use super :: Tensor ;
1515
16+ /// Errors that can occur when constructing a tensor handle safely.
17+ #[ non_exhaustive]
18+ #[ derive( Clone , Debug , PartialEq , Eq ) ]
19+ pub enum TensorHandleError {
20+ /// Rank of shape and strides differ.
21+ RankMismatch {
22+ shape_rank : usize ,
23+ stride_rank : usize ,
24+ } ,
25+ /// Element size must be > 0.
26+ ElemSizeZero ,
27+ /// A stride is zero for a dimension with extent > 1.
28+ ZeroStride { axis : usize } ,
29+ }
30+
31+ /// Errors that can occur when converting a handle to a runtime tensor argument.
32+ #[ non_exhaustive]
33+ #[ derive( Clone , Debug , PartialEq , Eq ) ]
34+ pub enum TensorArgError {
35+ /// Requested vectorization factor is not supported by the runtime.
36+ UnsupportedVectorization { requested : u8 , supported : & ' static [ u8 ] } ,
37+ /// Inner-most dimension is not contiguous (stride != 1) while vectorization > 1.
38+ NonContiguousInner ,
39+ /// Inner-most dimension is not divisible by the vectorization factor.
40+ MisalignedVectorization { last_dim : usize , factor : u8 } ,
41+ }
42+
1643/// Argument to be used for [tensors](Tensor) passed as arguments to kernels.
1744#[ derive( Debug ) ]
1845pub enum TensorArg < ' a , R : Runtime > {
@@ -178,17 +205,33 @@ impl<R: Runtime> ArgSettings<R> for TensorArg<'_, R> {
178205
179206impl < ' a , R : Runtime > TensorHandleRef < ' a , R > {
180207 /// Convert the handle into a [tensor argument](TensorArg).
181- pub fn as_tensor_arg ( & ' a self , vectorisation : u8 ) -> TensorArg < ' a , R > {
208+ pub fn as_tensor_arg ( & ' a self , vectorization : u8 ) -> TensorArg < ' a , R > {
182209 unsafe {
183210 TensorArg :: from_raw_parts_and_size (
184211 self . handle ,
185212 self . strides ,
186213 self . shape ,
187- vectorisation ,
214+ vectorization ,
188215 self . elem_size ,
189216 )
190217 }
191218 }
219+ /// Convert the handle into a [tensor argument](TensorArg) with basic safety checks
220+ /// for vectorization compatibility.
221+ /// Try to convert the handle into a tensor argument, validating that the
222+ /// requested vectorization factor is supported by the runtime. This does not
223+ /// enforce inner-most contiguity or alignment requirements as kernels may
224+ /// legally vectorize along axes other than the innermost.
225+ pub fn try_as_tensor_arg (
226+ & ' a self ,
227+ vectorization : u8 ,
228+ ) -> Result < TensorArg < ' a , R > , TensorArgError > {
229+ if !R :: supported_line_sizes ( ) . contains ( & vectorization) {
230+ return Err ( TensorArgError :: UnsupportedVectorization { requested : vectorization, supported : R :: supported_line_sizes ( ) } ) ;
231+ }
232+ Ok ( self . as_tensor_arg ( vectorization) )
233+ }
234+
192235 /// Create a handle from raw parts.
193236 ///
194237 /// # Safety
@@ -209,4 +252,67 @@ impl<'a, R: Runtime> TensorHandleRef<'a, R> {
209252 runtime : PhantomData ,
210253 }
211254 }
255+
256+ /// Safely create a tensor handle from raw parts with basic shape/stride validation.
257+ pub fn try_from_parts (
258+ handle : & ' a cubecl_runtime:: server:: Handle ,
259+ strides : & ' a [ usize ] ,
260+ shape : & ' a [ usize ] ,
261+ elem_size : usize ,
262+ ) -> Result < Self , TensorHandleError > {
263+ if shape. len ( ) != strides. len ( ) {
264+ return Err ( TensorHandleError :: RankMismatch {
265+ shape_rank : shape. len ( ) ,
266+ stride_rank : strides. len ( ) ,
267+ } ) ;
268+ }
269+ if elem_size == 0 {
270+ return Err ( TensorHandleError :: ElemSizeZero ) ;
271+ }
272+ // Disallow zero strides when corresponding dimension extent > 1 (broadcasted dims with extent 1 are allowed).
273+ for ( i, ( & s, & d) ) in strides. iter ( ) . zip ( shape. iter ( ) ) . enumerate ( ) {
274+ if s == 0 && d > 1 {
275+ return Err ( TensorHandleError :: ZeroStride { axis : i } ) ;
276+ }
277+ }
278+ Ok ( unsafe { Self :: from_raw_parts ( handle, strides, shape, elem_size) } )
279+ }
280+
281+ /// Safely create a tensor handle from raw parts using the element type for size.
282+ pub fn try_from_typed < E : CubePrimitive > (
283+ handle : & ' a cubecl_runtime:: server:: Handle ,
284+ strides : & ' a [ usize ] ,
285+ shape : & ' a [ usize ] ,
286+ ) -> Result < Self , TensorHandleError > {
287+ let elem_size = E :: size ( ) . expect ( "Element should have a size" ) ;
288+ Self :: try_from_parts ( handle, strides, shape, elem_size)
289+ }
290+ }
291+
292+ impl core:: fmt:: Display for TensorHandleError {
293+ fn fmt ( & self , f : & mut core:: fmt:: Formatter < ' _ > ) -> core:: fmt:: Result {
294+ match self {
295+ TensorHandleError :: RankMismatch { shape_rank, stride_rank } => {
296+ write ! ( f, "rank mismatch (shape={}, strides={})" , shape_rank, stride_rank)
297+ }
298+ TensorHandleError :: ElemSizeZero => write ! ( f, "element size must be > 0" ) ,
299+ TensorHandleError :: ZeroStride { axis } => write ! ( f, "zero stride on axis {} with extent > 1" , axis) ,
300+ }
301+ }
302+ }
303+
304+ impl core:: fmt:: Display for TensorArgError {
305+ fn fmt ( & self , f : & mut core:: fmt:: Formatter < ' _ > ) -> core:: fmt:: Result {
306+ match self {
307+ TensorArgError :: UnsupportedVectorization { requested, supported } => {
308+ write ! ( f, "unsupported vectorization {}, supported: {:?}" , requested, supported)
309+ }
310+ TensorArgError :: NonContiguousInner => write ! ( f, "non-contiguous innermost dimension for vectorized access" ) ,
311+ TensorArgError :: MisalignedVectorization { last_dim, factor } => write ! (
312+ f,
313+ "innermost dimension {} not divisible by vectorization {}" ,
314+ last_dim, factor
315+ ) ,
316+ }
317+ }
212318}
0 commit comments