@@ -13,6 +13,33 @@ use crate::{
1313
1414use super :: Tensor ;
1515
16+ /// Errors that can occur when constructing a tensor handle safely.
17+ #[ non_exhaustive]
18+ #[ derive( Clone , Debug , PartialEq , Eq ) ]
19+ pub enum TensorHandleError {
20+ /// Rank of shape and strides differ.
21+ RankMismatch {
22+ shape_rank : usize ,
23+ stride_rank : usize ,
24+ } ,
25+ /// Element size must be > 0.
26+ ElemSizeZero ,
27+ /// A stride is zero for a dimension with extent > 1.
28+ ZeroStride { axis : usize } ,
29+ }
30+
31+ /// Errors that can occur when converting a handle to a runtime tensor argument.
32+ #[ non_exhaustive]
33+ #[ derive( Clone , Debug , PartialEq , Eq ) ]
34+ pub enum TensorArgError {
35+ /// Requested vectorization factor is not supported by the runtime.
36+ UnsupportedVectorization { requested : u8 , supported : & ' static [ u8 ] } ,
37+ /// Inner-most dimension is not contiguous (stride != 1) while vectorization > 1.
38+ NonContiguousInner ,
39+ /// Inner-most dimension is not divisible by the vectorization factor.
40+ MisalignedVectorization { last_dim : usize , factor : u8 } ,
41+ }
42+
1643/// Argument to be used for [tensors](Tensor) passed as arguments to kernels.
1744#[ derive( Debug ) ]
1845pub enum TensorArg < ' a , R : Runtime > {
@@ -178,17 +205,47 @@ impl<R: Runtime> ArgSettings<R> for TensorArg<'_, R> {
178205
179206impl < ' a , R : Runtime > TensorHandleRef < ' a , R > {
180207 /// Convert the handle into a [tensor argument](TensorArg).
181- pub fn as_tensor_arg ( & ' a self , vectorisation : u8 ) -> TensorArg < ' a , R > {
208+ pub fn as_tensor_arg ( & ' a self , vectorization : u8 ) -> TensorArg < ' a , R > {
209+ // In debug builds, assert that the requested vectorization is supported
210+ // by the runtime. Validation of the chosen factor should normally be
211+ // performed upstream (at selection time) to avoid redundant checks in
212+ // hot paths.
213+ debug_assert ! (
214+ R :: supported_line_sizes( ) . contains( & vectorization) ,
215+ "unsupported vectorization {} (supported: {:?})" ,
216+ vectorization,
217+ R :: supported_line_sizes( )
218+ ) ;
182219 unsafe {
183220 TensorArg :: from_raw_parts_and_size (
184221 self . handle ,
185222 self . strides ,
186223 self . shape ,
187- vectorisation ,
224+ vectorization ,
188225 self . elem_size ,
189226 )
190227 }
191228 }
229+ /// Convert the handle into a [tensor argument](TensorArg) with basic safety checks
230+ /// for vectorization compatibility.
231+ ///
232+ /// Note: This convenience is primarily intended for host wrappers / FFI
233+ /// ingestion paths. In internal code, prefer validating the chosen
234+ /// vectorization factor at selection time and then calling
235+ /// [`as_tensor_arg`], to avoid redundant work in hot paths.
236+ ///
237+ /// This does not enforce inner‑most contiguity or alignment requirements as
238+ /// kernels may vectorize along axes other than the innermost.
239+ pub fn try_as_tensor_arg (
240+ & ' a self ,
241+ vectorization : u8 ,
242+ ) -> Result < TensorArg < ' a , R > , TensorArgError > {
243+ if !R :: supported_line_sizes ( ) . contains ( & vectorization) {
244+ return Err ( TensorArgError :: UnsupportedVectorization { requested : vectorization, supported : R :: supported_line_sizes ( ) } ) ;
245+ }
246+ Ok ( self . as_tensor_arg ( vectorization) )
247+ }
248+
192249 /// Create a handle from raw parts.
193250 ///
194251 /// # Safety
@@ -201,6 +258,24 @@ impl<'a, R: Runtime> TensorHandleRef<'a, R> {
201258 shape : & ' a [ usize ] ,
202259 elem_size : usize ,
203260 ) -> Self {
261+ // Basic invariants for debug builds only; upstream layers are expected
262+ // to ensure correctness in release builds.
263+ debug_assert_eq ! (
264+ shape. len( ) ,
265+ strides. len( ) ,
266+ "rank mismatch (shape={}, strides={})" ,
267+ shape. len( ) ,
268+ strides. len( )
269+ ) ;
270+ debug_assert ! ( elem_size > 0 , "element size must be > 0" ) ;
271+ // Disallow zero strides when corresponding dimension extent > 1
272+ for ( i, ( & s, & d) ) in strides. iter ( ) . zip ( shape. iter ( ) ) . enumerate ( ) {
273+ debug_assert ! (
274+ !( s == 0 && d > 1 ) ,
275+ "zero stride on axis {} with extent > 1" ,
276+ i
277+ ) ;
278+ }
204279 Self {
205280 handle,
206281 strides,
@@ -209,4 +284,72 @@ impl<'a, R: Runtime> TensorHandleRef<'a, R> {
209284 runtime : PhantomData ,
210285 }
211286 }
287+
288+ /// Safely create a tensor handle from raw parts with basic shape/stride validation.
289+ ///
290+ /// Note: This is mainly useful for host / FFI boundaries to surface clear
291+ /// errors early. Internal code should ensure these invariants when
292+ /// constructing handles and may use the `unsafe` constructor directly in
293+ /// performance‑critical paths.
294+ pub fn try_from_parts (
295+ handle : & ' a cubecl_runtime:: server:: Handle ,
296+ strides : & ' a [ usize ] ,
297+ shape : & ' a [ usize ] ,
298+ elem_size : usize ,
299+ ) -> Result < Self , TensorHandleError > {
300+ if shape. len ( ) != strides. len ( ) {
301+ return Err ( TensorHandleError :: RankMismatch {
302+ shape_rank : shape. len ( ) ,
303+ stride_rank : strides. len ( ) ,
304+ } ) ;
305+ }
306+ if elem_size == 0 {
307+ return Err ( TensorHandleError :: ElemSizeZero ) ;
308+ }
309+ // Disallow zero strides when corresponding dimension extent > 1 (broadcasted dims with extent 1 are allowed).
310+ for ( i, ( & s, & d) ) in strides. iter ( ) . zip ( shape. iter ( ) ) . enumerate ( ) {
311+ if s == 0 && d > 1 {
312+ return Err ( TensorHandleError :: ZeroStride { axis : i } ) ;
313+ }
314+ }
315+ Ok ( unsafe { Self :: from_raw_parts ( handle, strides, shape, elem_size) } )
316+ }
317+
318+ /// Safely create a tensor handle from raw parts using the element type for size.
319+ pub fn try_from_typed < E : CubePrimitive > (
320+ handle : & ' a cubecl_runtime:: server:: Handle ,
321+ strides : & ' a [ usize ] ,
322+ shape : & ' a [ usize ] ,
323+ ) -> Result < Self , TensorHandleError > {
324+ let elem_size = E :: size ( ) . expect ( "Element should have a size" ) ;
325+ Self :: try_from_parts ( handle, strides, shape, elem_size)
326+ }
327+ }
328+
329+ impl core:: fmt:: Display for TensorHandleError {
330+ fn fmt ( & self , f : & mut core:: fmt:: Formatter < ' _ > ) -> core:: fmt:: Result {
331+ match self {
332+ TensorHandleError :: RankMismatch { shape_rank, stride_rank } => {
333+ write ! ( f, "rank mismatch (shape={}, strides={})" , shape_rank, stride_rank)
334+ }
335+ TensorHandleError :: ElemSizeZero => write ! ( f, "element size must be > 0" ) ,
336+ TensorHandleError :: ZeroStride { axis } => write ! ( f, "zero stride on axis {} with extent > 1" , axis) ,
337+ }
338+ }
339+ }
340+
341+ impl core:: fmt:: Display for TensorArgError {
342+ fn fmt ( & self , f : & mut core:: fmt:: Formatter < ' _ > ) -> core:: fmt:: Result {
343+ match self {
344+ TensorArgError :: UnsupportedVectorization { requested, supported } => {
345+ write ! ( f, "unsupported vectorization {}, supported: {:?}" , requested, supported)
346+ }
347+ TensorArgError :: NonContiguousInner => write ! ( f, "non-contiguous innermost dimension for vectorized access" ) ,
348+ TensorArgError :: MisalignedVectorization { last_dim, factor } => write ! (
349+ f,
350+ "innermost dimension {} not divisible by vectorization {}" ,
351+ last_dim, factor
352+ ) ,
353+ }
354+ }
212355}
0 commit comments