@@ -206,6 +206,16 @@ impl<R: Runtime> ArgSettings<R> for TensorArg<'_, R> {
206206impl < ' a , R : Runtime > TensorHandleRef < ' a , R > {
207207 /// Convert the handle into a [tensor argument](TensorArg).
208208 pub fn as_tensor_arg ( & ' a self , vectorization : u8 ) -> TensorArg < ' a , R > {
209+ // In debug builds, assert that the requested vectorization is supported
210+ // by the runtime. Validation of the chosen factor should normally be
211+ // performed upstream (at selection time) to avoid redundant checks in
212+ // hot paths.
213+ debug_assert ! (
214+ R :: supported_line_sizes( ) . contains( & vectorization) ,
215+ "unsupported vectorization {} (supported: {:?})" ,
216+ vectorization,
217+ R :: supported_line_sizes( )
218+ ) ;
209219 unsafe {
210220 TensorArg :: from_raw_parts_and_size (
211221 self . handle ,
@@ -218,10 +228,14 @@ impl<'a, R: Runtime> TensorHandleRef<'a, R> {
218228 }
219229 /// Convert the handle into a [tensor argument](TensorArg) with basic safety checks
220230 /// for vectorization compatibility.
221- /// Try to convert the handle into a tensor argument, validating that the
222- /// requested vectorization factor is supported by the runtime. This does not
223- /// enforce inner-most contiguity or alignment requirements as kernels may
224- /// legally vectorize along axes other than the innermost.
231+ ///
232+ /// Note: This convenience is primarily intended for host wrappers / FFI
233+ /// ingestion paths. In internal code, prefer validating the chosen
234+ /// vectorization factor at selection time and then calling
235+ /// [`as_tensor_arg`], to avoid redundant work in hot paths.
236+ ///
237+ /// This does not enforce inner‑most contiguity or alignment requirements as
238+ /// kernels may vectorize along axes other than the innermost.
225239 pub fn try_as_tensor_arg (
226240 & ' a self ,
227241 vectorization : u8 ,
@@ -244,6 +258,24 @@ impl<'a, R: Runtime> TensorHandleRef<'a, R> {
244258 shape : & ' a [ usize ] ,
245259 elem_size : usize ,
246260 ) -> Self {
261+ // Basic invariants for debug builds only; upstream layers are expected
262+ // to ensure correctness in release builds.
263+ debug_assert_eq ! (
264+ shape. len( ) ,
265+ strides. len( ) ,
266+ "rank mismatch (shape={}, strides={})" ,
267+ shape. len( ) ,
268+ strides. len( )
269+ ) ;
270+ debug_assert ! ( elem_size > 0 , "element size must be > 0" ) ;
271+ // Disallow zero strides when corresponding dimension extent > 1
272+ for ( i, ( & s, & d) ) in strides. iter ( ) . zip ( shape. iter ( ) ) . enumerate ( ) {
273+ debug_assert ! (
274+ !( s == 0 && d > 1 ) ,
275+ "zero stride on axis {} with extent > 1" ,
276+ i
277+ ) ;
278+ }
247279 Self {
248280 handle,
249281 strides,
@@ -254,6 +286,11 @@ impl<'a, R: Runtime> TensorHandleRef<'a, R> {
254286 }
255287
256288 /// Safely create a tensor handle from raw parts with basic shape/stride validation.
289+ ///
290+ /// Note: This is mainly useful for host / FFI boundaries to surface clear
291+ /// errors early. Internal code should ensure these invariants when
292+ /// constructing handles and may use the `unsafe` constructor directly in
293+ /// performance‑critical paths.
257294 pub fn try_from_parts (
258295 handle : & ' a cubecl_runtime:: server:: Handle ,
259296 strides : & ' a [ usize ] ,
0 commit comments