@@ -13,6 +13,36 @@ use crate::{
1313
1414use super :: Tensor ;
1515
16+ /// Errors that can occur when constructing a tensor handle safely.
17+ #[ non_exhaustive]
18+ #[ derive( Clone , Debug , PartialEq , Eq ) ]
19+ pub enum TensorHandleError {
20+ /// Rank of shape and strides differ.
21+ RankMismatch {
22+ shape_rank : usize ,
23+ stride_rank : usize ,
24+ } ,
25+ /// Element size must be > 0.
26+ ElemSizeZero ,
27+ /// A stride is zero for a dimension with extent > 1.
28+ ZeroStride { axis : usize } ,
29+ }
30+
31+ /// Errors that can occur when converting a handle to a runtime tensor argument.
32+ #[ non_exhaustive]
33+ #[ derive( Clone , Debug , PartialEq , Eq ) ]
34+ pub enum TensorArgError {
35+ /// Requested vectorization factor is not supported by the runtime.
36+ UnsupportedVectorization {
37+ requested : u8 ,
38+ supported : & ' static [ u8 ] ,
39+ } ,
40+ /// Inner-most dimension is not contiguous (stride != 1) while vectorization > 1.
41+ NonContiguousInner ,
42+ /// Inner-most dimension is not divisible by the vectorization factor.
43+ MisalignedVectorization { last_dim : usize , factor : u8 } ,
44+ }
45+
1646/// Argument to be used for [tensors](Tensor) passed as arguments to kernels.
1747#[ derive( Debug ) ]
1848pub enum TensorArg < ' a , R : Runtime > {
@@ -178,17 +208,50 @@ impl<R: Runtime> ArgSettings<R> for TensorArg<'_, R> {
178208
179209impl < ' a , R : Runtime > TensorHandleRef < ' a , R > {
180210 /// Convert the handle into a [tensor argument](TensorArg).
181- pub fn as_tensor_arg ( & ' a self , vectorisation : u8 ) -> TensorArg < ' a , R > {
211+ pub fn as_tensor_arg ( & ' a self , vectorization : u8 ) -> TensorArg < ' a , R > {
212+ // In debug builds, assert that the requested vectorization is supported
213+ // by the runtime. Validation of the chosen factor should normally be
214+ // performed upstream (at selection time) to avoid redundant checks in
215+ // hot paths.
216+ debug_assert ! (
217+ R :: supported_line_sizes( ) . contains( & vectorization) ,
218+ "unsupported vectorization {} (supported: {:?})" ,
219+ vectorization,
220+ R :: supported_line_sizes( )
221+ ) ;
182222 unsafe {
183223 TensorArg :: from_raw_parts_and_size (
184224 self . handle ,
185225 self . strides ,
186226 self . shape ,
187- vectorisation ,
227+ vectorization ,
188228 self . elem_size ,
189229 )
190230 }
191231 }
232+ /// Convert the handle into a [tensor argument](TensorArg) with basic safety checks
233+ /// for vectorization compatibility.
234+ ///
235+ /// Note: This convenience is primarily intended for host wrappers / FFI
236+ /// ingestion paths. In internal code, prefer validating the chosen
237+ /// vectorization factor at selection time and then calling
238+ /// [`as_tensor_arg`], to avoid redundant work in hot paths.
239+ ///
240+ /// This does not enforce inner‑most contiguity or alignment requirements as
241+ /// kernels may vectorize along axes other than the innermost.
242+ pub fn try_as_tensor_arg (
243+ & ' a self ,
244+ vectorization : u8 ,
245+ ) -> Result < TensorArg < ' a , R > , TensorArgError > {
246+ if !R :: supported_line_sizes ( ) . contains ( & vectorization) {
247+ return Err ( TensorArgError :: UnsupportedVectorization {
248+ requested : vectorization,
249+ supported : R :: supported_line_sizes ( ) ,
250+ } ) ;
251+ }
252+ Ok ( self . as_tensor_arg ( vectorization) )
253+ }
254+
192255 /// Create a handle from raw parts.
193256 ///
194257 /// # Safety
@@ -201,6 +264,21 @@ impl<'a, R: Runtime> TensorHandleRef<'a, R> {
201264 shape : & ' a [ usize ] ,
202265 elem_size : usize ,
203266 ) -> Self {
267+ // Basic invariants for debug builds only; upstream layers are expected
268+ // to ensure correctness in release builds.
269+ debug_assert_eq ! (
270+ shape. len( ) ,
271+ strides. len( ) ,
272+ "rank mismatch (shape={}, strides={})" ,
273+ shape. len( ) ,
274+ strides. len( )
275+ ) ;
276+ debug_assert ! ( elem_size > 0 , "element size must be > 0" ) ;
277+ // Note: zero strides are permitted here to support explicit broadcast
278+ // views in advanced/internal paths. The checked constructor
279+ // (`try_from_parts`) rejects them when `d > 1` to provide safety at
280+ // boundaries; callers who intentionally need zero‑stride broadcasting
281+ // can opt into this `unsafe` API.
204282 Self {
205283 handle,
206284 strides,
@@ -209,4 +287,91 @@ impl<'a, R: Runtime> TensorHandleRef<'a, R> {
209287 runtime : PhantomData ,
210288 }
211289 }
290+
291+ /// Safely create a tensor handle from raw parts with basic shape/stride validation.
292+ ///
293+ /// Note: This is mainly useful for host / FFI boundaries to surface clear
294+ /// errors early. Internal code should ensure these invariants when
295+ /// constructing handles and may use the `unsafe` constructor directly in
296+ /// performance‑critical paths.
297+ pub fn try_from_parts (
298+ handle : & ' a cubecl_runtime:: server:: Handle ,
299+ strides : & ' a [ usize ] ,
300+ shape : & ' a [ usize ] ,
301+ elem_size : usize ,
302+ ) -> Result < Self , TensorHandleError > {
303+ if shape. len ( ) != strides. len ( ) {
304+ return Err ( TensorHandleError :: RankMismatch {
305+ shape_rank : shape. len ( ) ,
306+ stride_rank : strides. len ( ) ,
307+ } ) ;
308+ }
309+ if elem_size == 0 {
310+ return Err ( TensorHandleError :: ElemSizeZero ) ;
311+ }
312+ // Disallow zero strides when corresponding dimension extent > 1 (broadcasted dims with extent 1 are allowed).
313+ for ( i, ( & s, & d) ) in strides. iter ( ) . zip ( shape. iter ( ) ) . enumerate ( ) {
314+ if s == 0 && d > 1 {
315+ return Err ( TensorHandleError :: ZeroStride { axis : i } ) ;
316+ }
317+ }
318+ Ok ( unsafe { Self :: from_raw_parts ( handle, strides, shape, elem_size) } )
319+ }
320+
321+ /// Safely create a tensor handle from raw parts using the element type for size.
322+ pub fn try_from_typed < E : CubePrimitive > (
323+ handle : & ' a cubecl_runtime:: server:: Handle ,
324+ strides : & ' a [ usize ] ,
325+ shape : & ' a [ usize ] ,
326+ ) -> Result < Self , TensorHandleError > {
327+ let elem_size = E :: size ( ) . expect ( "Element should have a size" ) ;
328+ Self :: try_from_parts ( handle, strides, shape, elem_size)
329+ }
330+ }
331+
332+ impl core:: fmt:: Display for TensorHandleError {
333+ fn fmt ( & self , f : & mut core:: fmt:: Formatter < ' _ > ) -> core:: fmt:: Result {
334+ match self {
335+ TensorHandleError :: RankMismatch {
336+ shape_rank,
337+ stride_rank,
338+ } => {
339+ write ! (
340+ f,
341+ "rank mismatch (shape={}, strides={})" ,
342+ shape_rank, stride_rank
343+ )
344+ }
345+ TensorHandleError :: ElemSizeZero => write ! ( f, "element size must be > 0" ) ,
346+ TensorHandleError :: ZeroStride { axis } => {
347+ write ! ( f, "zero stride on axis {} with extent > 1" , axis)
348+ }
349+ }
350+ }
351+ }
352+
353+ impl core:: fmt:: Display for TensorArgError {
354+ fn fmt ( & self , f : & mut core:: fmt:: Formatter < ' _ > ) -> core:: fmt:: Result {
355+ match self {
356+ TensorArgError :: UnsupportedVectorization {
357+ requested,
358+ supported,
359+ } => {
360+ write ! (
361+ f,
362+ "unsupported vectorization {}, supported: {:?}" ,
363+ requested, supported
364+ )
365+ }
366+ TensorArgError :: NonContiguousInner => write ! (
367+ f,
368+ "non-contiguous innermost dimension for vectorized access"
369+ ) ,
370+ TensorArgError :: MisalignedVectorization { last_dim, factor } => write ! (
371+ f,
372+ "innermost dimension {} not divisible by vectorization {}" ,
373+ last_dim, factor
374+ ) ,
375+ }
376+ }
212377}
0 commit comments