diff --git a/crates/cubecl-convolution/src/tests/convolution_test_launcher.rs b/crates/cubecl-convolution/src/tests/convolution_test_launcher.rs index 1aa568835..f5845dadb 100644 --- a/crates/cubecl-convolution/src/tests/convolution_test_launcher.rs +++ b/crates/cubecl-convolution/src/tests/convolution_test_launcher.rs @@ -88,15 +88,15 @@ pub fn test_convolution_algorithm( } let elem_size = size_of::(); - let lhs_handle = unsafe { - TensorHandleRef::from_raw_parts(&lhs.handle, &lhs.strides, &lhs.shape, elem_size) - }; - let rhs_handle = unsafe { - TensorHandleRef::from_raw_parts(&rhs.handle, &rhs.strides, &rhs.shape, elem_size) - }; - let out_handle = unsafe { - TensorHandleRef::from_raw_parts(&out.handle, &out.strides, &out.shape, elem_size) - }; + let lhs_handle = + TensorHandleRef::::try_from_parts(&lhs.handle, &lhs.strides, &lhs.shape, elem_size) + .expect("valid lhs handle"); + let rhs_handle = + TensorHandleRef::::try_from_parts(&rhs.handle, &rhs.strides, &rhs.shape, elem_size) + .expect("valid rhs handle"); + let out_handle = + TensorHandleRef::::try_from_parts(&out.handle, &out.strides, &out.shape, elem_size) + .expect("valid out handle"); let lhs_handle = A::into_tensor_handle::(&client, &lhs_handle, MatmulIdent::Lhs); let rhs_handle = A::into_tensor_handle::(&client, &rhs_handle, MatmulIdent::Rhs); diff --git a/crates/cubecl-core/src/frontend/container/tensor/launch.rs b/crates/cubecl-core/src/frontend/container/tensor/launch.rs index b442e5e44..34f8f1ec5 100644 --- a/crates/cubecl-core/src/frontend/container/tensor/launch.rs +++ b/crates/cubecl-core/src/frontend/container/tensor/launch.rs @@ -13,6 +13,36 @@ use crate::{ use super::Tensor; +/// Errors that can occur when constructing a tensor handle safely. +#[non_exhaustive] +#[derive(Clone, Debug, PartialEq, Eq)] +pub enum TensorHandleError { + /// Rank of shape and strides differ. + RankMismatch { + shape_rank: usize, + stride_rank: usize, + }, + /// Element size must be > 0. + ElemSizeZero, + /// A stride is zero for a dimension with extent > 1. + ZeroStride { axis: usize }, +} + +/// Errors that can occur when converting a handle to a runtime tensor argument. +#[non_exhaustive] +#[derive(Clone, Debug, PartialEq, Eq)] +pub enum TensorArgError { + /// Requested vectorization factor is not supported by the runtime. + UnsupportedVectorization { + requested: u8, + supported: &'static [u8], + }, + /// Inner-most dimension is not contiguous (stride != 1) while vectorization > 1. + NonContiguousInner, + /// Inner-most dimension is not divisible by the vectorization factor. + MisalignedVectorization { last_dim: usize, factor: u8 }, +} + /// Argument to be used for [tensors](Tensor) passed as arguments to kernels. #[derive(Debug)] pub enum TensorArg<'a, R: Runtime> { @@ -178,17 +208,50 @@ impl ArgSettings for TensorArg<'_, R> { impl<'a, R: Runtime> TensorHandleRef<'a, R> { /// Convert the handle into a [tensor argument](TensorArg). - pub fn as_tensor_arg(&'a self, vectorisation: u8) -> TensorArg<'a, R> { + pub fn as_tensor_arg(&'a self, vectorization: u8) -> TensorArg<'a, R> { + // In debug builds, assert that the requested vectorization is supported + // by the runtime. Validation of the chosen factor should normally be + // performed upstream (at selection time) to avoid redundant checks in + // hot paths. + debug_assert!( + R::supported_line_sizes().contains(&vectorization), + "unsupported vectorization {} (supported: {:?})", + vectorization, + R::supported_line_sizes() + ); unsafe { TensorArg::from_raw_parts_and_size( self.handle, self.strides, self.shape, - vectorisation, + vectorization, self.elem_size, ) } } + /// Convert the handle into a [tensor argument](TensorArg) with basic safety checks + /// for vectorization compatibility. + /// + /// Note: This convenience is primarily intended for host wrappers / FFI + /// ingestion paths. In internal code, prefer validating the chosen + /// vectorization factor at selection time and then calling + /// [`as_tensor_arg`], to avoid redundant work in hot paths. + /// + /// This does not enforce inner‑most contiguity or alignment requirements as + /// kernels may vectorize along axes other than the innermost. + pub fn try_as_tensor_arg( + &'a self, + vectorization: u8, + ) -> Result, TensorArgError> { + if !R::supported_line_sizes().contains(&vectorization) { + return Err(TensorArgError::UnsupportedVectorization { + requested: vectorization, + supported: R::supported_line_sizes(), + }); + } + Ok(self.as_tensor_arg(vectorization)) + } + /// Create a handle from raw parts. /// /// # Safety @@ -201,6 +264,21 @@ impl<'a, R: Runtime> TensorHandleRef<'a, R> { shape: &'a [usize], elem_size: usize, ) -> Self { + // Basic invariants for debug builds only; upstream layers are expected + // to ensure correctness in release builds. + debug_assert_eq!( + shape.len(), + strides.len(), + "rank mismatch (shape={}, strides={})", + shape.len(), + strides.len() + ); + debug_assert!(elem_size > 0, "element size must be > 0"); + // Note: zero strides are permitted here to support explicit broadcast + // views in advanced/internal paths. The checked constructor + // (`try_from_parts`) rejects them when `d > 1` to provide safety at + // boundaries; callers who intentionally need zero‑stride broadcasting + // can opt into this `unsafe` API. Self { handle, strides, @@ -209,4 +287,91 @@ impl<'a, R: Runtime> TensorHandleRef<'a, R> { runtime: PhantomData, } } + + /// Safely create a tensor handle from raw parts with basic shape/stride validation. + /// + /// Note: This is mainly useful for host / FFI boundaries to surface clear + /// errors early. Internal code should ensure these invariants when + /// constructing handles and may use the `unsafe` constructor directly in + /// performance‑critical paths. + pub fn try_from_parts( + handle: &'a cubecl_runtime::server::Handle, + strides: &'a [usize], + shape: &'a [usize], + elem_size: usize, + ) -> Result { + if shape.len() != strides.len() { + return Err(TensorHandleError::RankMismatch { + shape_rank: shape.len(), + stride_rank: strides.len(), + }); + } + if elem_size == 0 { + return Err(TensorHandleError::ElemSizeZero); + } + // Disallow zero strides when corresponding dimension extent > 1 (broadcasted dims with extent 1 are allowed). + for (i, (&s, &d)) in strides.iter().zip(shape.iter()).enumerate() { + if s == 0 && d > 1 { + return Err(TensorHandleError::ZeroStride { axis: i }); + } + } + Ok(unsafe { Self::from_raw_parts(handle, strides, shape, elem_size) }) + } + + /// Safely create a tensor handle from raw parts using the element type for size. + pub fn try_from_typed( + handle: &'a cubecl_runtime::server::Handle, + strides: &'a [usize], + shape: &'a [usize], + ) -> Result { + let elem_size = E::size().expect("Element should have a size"); + Self::try_from_parts(handle, strides, shape, elem_size) + } +} + +impl core::fmt::Display for TensorHandleError { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + match self { + TensorHandleError::RankMismatch { + shape_rank, + stride_rank, + } => { + write!( + f, + "rank mismatch (shape={}, strides={})", + shape_rank, stride_rank + ) + } + TensorHandleError::ElemSizeZero => write!(f, "element size must be > 0"), + TensorHandleError::ZeroStride { axis } => { + write!(f, "zero stride on axis {} with extent > 1", axis) + } + } + } +} + +impl core::fmt::Display for TensorArgError { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + match self { + TensorArgError::UnsupportedVectorization { + requested, + supported, + } => { + write!( + f, + "unsupported vectorization {}, supported: {:?}", + requested, supported + ) + } + TensorArgError::NonContiguousInner => write!( + f, + "non-contiguous innermost dimension for vectorized access" + ), + TensorArgError::MisalignedVectorization { last_dim, factor } => write!( + f, + "innermost dimension {} not divisible by vectorization {}", + last_dim, factor + ), + } + } } diff --git a/crates/cubecl-core/src/runtime_tests/mod.rs b/crates/cubecl-core/src/runtime_tests/mod.rs index 16487b54c..e0a53a2aa 100644 --- a/crates/cubecl-core/src/runtime_tests/mod.rs +++ b/crates/cubecl-core/src/runtime_tests/mod.rs @@ -21,6 +21,7 @@ pub mod sequence; pub mod slice; pub mod synchronization; pub mod tensor; +pub mod tensor_handle; pub mod tensormap; pub mod to_client; pub mod topology; @@ -138,6 +139,7 @@ macro_rules! testgen_untyped { cubecl_core::testgen_comparison!(); cubecl_core::testgen_to_client!(); + cubecl_core::testgen_tensor_handle!(); }; } diff --git a/crates/cubecl-core/src/runtime_tests/tensor_handle.rs b/crates/cubecl-core/src/runtime_tests/tensor_handle.rs new file mode 100644 index 000000000..77a62204d --- /dev/null +++ b/crates/cubecl-core/src/runtime_tests/tensor_handle.rs @@ -0,0 +1,164 @@ +use crate::prelude::*; // brings TensorArgError, TensorHandleError, TensorHandleRef, Runtime, ComputeClient + +fn make_client() -> ComputeClient { + R::client(&R::Device::default()) +} + +pub fn test_handle_try_from_typed_ok_and_vec_checked_ok() { + let client = make_client::(); + let shape = vec![2usize, 8usize]; + let strides = compact_strides(&shape); + let bytes = bytemuck::cast_slice::(&vec![0.0f32; shape.iter().product()]).to_vec(); + let handle = client.create(&bytes); + + let href = TensorHandleRef::::try_from_typed::(&handle, &strides, &shape).expect("ok"); + + // Pick a supported factor that divides last dim (if any), else 1 + let mut picked = 1u8; + for f in R::supported_line_sizes() { + let f8 = (*f) as u8; + if f8 > 1 && shape[1] % (*f as usize) == 0 { + picked = f8; + break; + } + } + let _arg = href.try_as_tensor_arg(picked).expect("vec ok"); +} + +pub fn test_handle_try_from_parts_rank_mismatch() { + let client = make_client::(); + let shape = vec![2usize, 4usize]; + let strides_good = compact_strides(&shape); + let bytes = bytemuck::cast_slice::(&vec![0.0f32; shape.iter().product()]).to_vec(); + let handle = client.create(&bytes); + + let err = TensorHandleRef::::try_from_parts( + &handle, + &strides_good[..1], + &shape, + core::mem::size_of::(), + ) + .unwrap_err(); + match err { + TensorHandleError::RankMismatch { .. } => {} + _ => panic!("wrong error: {err:?}"), + } +} + +pub fn test_handle_try_from_parts_zero_stride() { + let client = make_client::(); + let shape = vec![2usize, 4usize]; + let mut strides = compact_strides(&shape); + strides[0] = 0; // invalid when dim > 1 + let bytes = bytemuck::cast_slice::(&vec![0.0f32; shape.iter().product()]).to_vec(); + let handle = client.create(&bytes); + + let err = TensorHandleRef::::try_from_parts( + &handle, + &strides, + &shape, + core::mem::size_of::(), + ) + .unwrap_err(); + match err { + TensorHandleError::ZeroStride { .. } => {} + _ => panic!("wrong error: {err:?}"), + } +} + +pub fn test_vec_checked_unsupported_factor() { + let client = make_client::(); + let shape = vec![1usize, 8usize]; + let strides = compact_strides(&shape); + let bytes = bytemuck::cast_slice::(&vec![0.0f32; shape.iter().product()]).to_vec(); + let handle = client.create(&bytes); + let href = TensorHandleRef::::try_from_typed::(&handle, &strides, &shape).expect("ok"); + + // pick factor 7 which is typically unsupported + let err = href.try_as_tensor_arg(7).unwrap_err(); + match err { + TensorArgError::UnsupportedVectorization { .. } => {} + _ => panic!("wrong error: {err:?}"), + } +} + +pub fn test_vec_checked_noncontiguous_inner_allows_vectorized() { + let client = make_client::(); + let shape = vec![2usize, 8usize]; + let mut strides = compact_strides(&shape); + // Make inner stride non-contiguous (allowed by checked API) + strides[1] = 2; + let bytes = bytemuck::cast_slice::(&vec![0.0f32; shape.iter().product()]).to_vec(); + let handle = client.create(&bytes); + + let href = TensorHandleRef::::try_from_parts( + &handle, + &strides, + &shape, + core::mem::size_of::(), + ) + .expect("ok"); + + // Choose a supported factor > 1 if available + let mut picked = None; + for f in R::supported_line_sizes() { + if *f > 1 { + picked = Some(*f as u8); + break; + } + } + if let Some(factor) = picked { + let _ = href + .try_as_tensor_arg(factor) + .expect("non-contiguous inner allowed"); + } +} + +// Misalignment (last dim not divisible by factor) is permitted; tail handling is kernel-specific. +// We do not error on that case in the checked API. + +#[macro_export] +macro_rules! testgen_tensor_handle { + () => { + use super::*; + + #[test] + fn test_tensor_handle_try_from_typed_ok_and_vec_checked_ok() { + cubecl_core::runtime_tests::tensor_handle::test_handle_try_from_typed_ok_and_vec_checked_ok::(); + } + + #[test] + fn test_tensor_handle_try_from_parts_rank_mismatch() { + cubecl_core::runtime_tests::tensor_handle::test_handle_try_from_parts_rank_mismatch::(); + } + + #[test] + fn test_tensor_handle_try_from_parts_zero_stride() { + cubecl_core::runtime_tests::tensor_handle::test_handle_try_from_parts_zero_stride::(); + } + + #[test] + fn test_vec_checked_unsupported_factor() { + cubecl_core::runtime_tests::tensor_handle::test_vec_checked_unsupported_factor::(); + } + + #[test] + fn test_vec_checked_noncontiguous_inner_allows_vectorized() { + cubecl_core::runtime_tests::tensor_handle::test_vec_checked_noncontiguous_inner_allows_vectorized::(); + } + + }; +} + +fn compact_strides(shape: &[usize]) -> Vec { + let rank = shape.len(); + if rank == 0 { + return vec![]; + } + let mut strides = vec![0; rank]; + strides[rank - 1] = 1; + for i in (0..rank - 1).rev() { + strides[i] = strides[i + 1] * shape[i + 1]; + } + strides +} diff --git a/crates/cubecl-reduce/src/test.rs b/crates/cubecl-reduce/src/test.rs index 3149b025a..1411572cb 100644 --- a/crates/cubecl-reduce/src/test.rs +++ b/crates/cubecl-reduce/src/test.rs @@ -507,22 +507,32 @@ impl TestCase { output_shape[self.axis.unwrap()] = 1; let output_stride = self.output_stride(); - let input = unsafe { - TensorHandleRef::::from_raw_parts( - &input_handle, - &self.stride, - &self.shape, - size_of::

(), - ) - }; - let output = unsafe { - TensorHandleRef::::from_raw_parts( - &output_handle, - &output_stride, - &output_shape, - size_of::(), - ) + let input = match TensorHandleRef::::try_from_parts( + &input_handle, + &self.stride, + &self.shape, + size_of::

(), + ) { + Ok(h) => h, + // Broadcasted inputs may legally use zero strides; fall back to the + // unsafe constructor for those specific cases to exercise kernels. + Err(TensorHandleError::ZeroStride { .. }) => unsafe { + TensorHandleRef::::from_raw_parts( + &input_handle, + &self.stride, + &self.shape, + size_of::

(), + ) + }, + Err(e) => panic!("valid input handle: {e}"), }; + let output = TensorHandleRef::::try_from_parts( + &output_handle, + &output_stride, + &output_shape, + size_of::(), + ) + .expect("valid output handle"); let result = reduce::( &client, @@ -553,17 +563,16 @@ impl TestCase { let input_handle = client.create(F::as_bytes(&input_values)); let output_handle = client.create(F::as_bytes(&[F::from_int(0)])); - let input = unsafe { - TensorHandleRef::::from_raw_parts( - &input_handle, - &self.stride, - &self.shape, - size_of::(), - ) - }; - let output = unsafe { - TensorHandleRef::::from_raw_parts(&output_handle, &[1], &[1], size_of::()) - }; + let input = TensorHandleRef::::try_from_parts( + &input_handle, + &self.stride, + &self.shape, + size_of::(), + ) + .expect("valid input handle"); + let output = + TensorHandleRef::::try_from_parts(&output_handle, &[1], &[1], size_of::()) + .expect("valid output handle"); let cube_count = 3; let result = shared_sum::(&client, input, output, cube_count);