Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 9 additions & 9 deletions crates/cubecl-convolution/src/tests/convolution_test_launcher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -88,15 +88,15 @@ pub fn test_convolution_algorithm<A, Args, P, R>(
}

let elem_size = size_of::<P::EG>();
let lhs_handle = unsafe {
TensorHandleRef::from_raw_parts(&lhs.handle, &lhs.strides, &lhs.shape, elem_size)
};
let rhs_handle = unsafe {
TensorHandleRef::from_raw_parts(&rhs.handle, &rhs.strides, &rhs.shape, elem_size)
};
let out_handle = unsafe {
TensorHandleRef::from_raw_parts(&out.handle, &out.strides, &out.shape, elem_size)
};
let lhs_handle =
TensorHandleRef::<R>::try_from_parts(&lhs.handle, &lhs.strides, &lhs.shape, elem_size)
.expect("valid lhs handle");
let rhs_handle =
TensorHandleRef::<R>::try_from_parts(&rhs.handle, &rhs.strides, &rhs.shape, elem_size)
.expect("valid rhs handle");
let out_handle =
TensorHandleRef::<R>::try_from_parts(&out.handle, &out.strides, &out.shape, elem_size)
.expect("valid out handle");

let lhs_handle = A::into_tensor_handle::<R, P::EG>(&client, &lhs_handle, MatmulIdent::Lhs);
let rhs_handle = A::into_tensor_handle::<R, P::EG>(&client, &rhs_handle, MatmulIdent::Rhs);
Expand Down
169 changes: 167 additions & 2 deletions crates/cubecl-core/src/frontend/container/tensor/launch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,36 @@ use crate::{

use super::Tensor;

/// Errors that can occur when constructing a tensor handle safely.
#[non_exhaustive]
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum TensorHandleError {
/// Rank of shape and strides differ.
RankMismatch {
shape_rank: usize,
stride_rank: usize,
},
/// Element size must be > 0.
ElemSizeZero,
/// A stride is zero for a dimension with extent > 1.
ZeroStride { axis: usize },
}

/// Errors that can occur when converting a handle to a runtime tensor argument.
#[non_exhaustive]
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum TensorArgError {
/// Requested vectorization factor is not supported by the runtime.
UnsupportedVectorization {
requested: u8,
supported: &'static [u8],
},
/// Inner-most dimension is not contiguous (stride != 1) while vectorization > 1.
NonContiguousInner,
/// Inner-most dimension is not divisible by the vectorization factor.
MisalignedVectorization { last_dim: usize, factor: u8 },
}

/// Argument to be used for [tensors](Tensor) passed as arguments to kernels.
#[derive(Debug)]
pub enum TensorArg<'a, R: Runtime> {
Expand Down Expand Up @@ -178,17 +208,50 @@ impl<R: Runtime> ArgSettings<R> for TensorArg<'_, R> {

impl<'a, R: Runtime> TensorHandleRef<'a, R> {
/// Convert the handle into a [tensor argument](TensorArg).
pub fn as_tensor_arg(&'a self, vectorisation: u8) -> TensorArg<'a, R> {
pub fn as_tensor_arg(&'a self, vectorization: u8) -> TensorArg<'a, R> {
// In debug builds, assert that the requested vectorization is supported
// by the runtime. Validation of the chosen factor should normally be
// performed upstream (at selection time) to avoid redundant checks in
// hot paths.
debug_assert!(
R::supported_line_sizes().contains(&vectorization),
"unsupported vectorization {} (supported: {:?})",
vectorization,
R::supported_line_sizes()
);
unsafe {
TensorArg::from_raw_parts_and_size(
self.handle,
self.strides,
self.shape,
vectorisation,
vectorization,
self.elem_size,
)
}
}
/// Convert the handle into a [tensor argument](TensorArg) with basic safety checks
/// for vectorization compatibility.
///
/// Note: This convenience is primarily intended for host wrappers / FFI
/// ingestion paths. In internal code, prefer validating the chosen
/// vectorization factor at selection time and then calling
/// [`as_tensor_arg`], to avoid redundant work in hot paths.
///
/// This does not enforce inner‑most contiguity or alignment requirements as
/// kernels may vectorize along axes other than the innermost.
pub fn try_as_tensor_arg(
&'a self,
vectorization: u8,
) -> Result<TensorArg<'a, R>, TensorArgError> {
if !R::supported_line_sizes().contains(&vectorization) {
return Err(TensorArgError::UnsupportedVectorization {
requested: vectorization,
supported: R::supported_line_sizes(),
});
}
Ok(self.as_tensor_arg(vectorization))
}

/// Create a handle from raw parts.
///
/// # Safety
Expand All @@ -201,6 +264,21 @@ impl<'a, R: Runtime> TensorHandleRef<'a, R> {
shape: &'a [usize],
elem_size: usize,
) -> Self {
// Basic invariants for debug builds only; upstream layers are expected
// to ensure correctness in release builds.
debug_assert_eq!(
shape.len(),
strides.len(),
"rank mismatch (shape={}, strides={})",
shape.len(),
strides.len()
);
debug_assert!(elem_size > 0, "element size must be > 0");
// Note: zero strides are permitted here to support explicit broadcast
// views in advanced/internal paths. The checked constructor
// (`try_from_parts`) rejects them when `d > 1` to provide safety at
// boundaries; callers who intentionally need zero‑stride broadcasting
// can opt into this `unsafe` API.
Self {
handle,
strides,
Expand All @@ -209,4 +287,91 @@ impl<'a, R: Runtime> TensorHandleRef<'a, R> {
runtime: PhantomData,
}
}

/// Safely create a tensor handle from raw parts with basic shape/stride validation.
///
/// Note: This is mainly useful for host / FFI boundaries to surface clear
/// errors early. Internal code should ensure these invariants when
/// constructing handles and may use the `unsafe` constructor directly in
/// performance‑critical paths.
pub fn try_from_parts(
handle: &'a cubecl_runtime::server::Handle,
strides: &'a [usize],
shape: &'a [usize],
elem_size: usize,
) -> Result<Self, TensorHandleError> {
if shape.len() != strides.len() {
return Err(TensorHandleError::RankMismatch {
shape_rank: shape.len(),
stride_rank: strides.len(),
});
}
if elem_size == 0 {
return Err(TensorHandleError::ElemSizeZero);
}
// Disallow zero strides when corresponding dimension extent > 1 (broadcasted dims with extent 1 are allowed).
for (i, (&s, &d)) in strides.iter().zip(shape.iter()).enumerate() {
if s == 0 && d > 1 {
return Err(TensorHandleError::ZeroStride { axis: i });
}
}
Comment on lines +303 to +317
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The same thing here: most of those things are validated in other places, and it's kind of wasteful to do those validations multiple times.

Ok(unsafe { Self::from_raw_parts(handle, strides, shape, elem_size) })
}

/// Safely create a tensor handle from raw parts using the element type for size.
pub fn try_from_typed<E: CubePrimitive>(
handle: &'a cubecl_runtime::server::Handle,
strides: &'a [usize],
shape: &'a [usize],
) -> Result<Self, TensorHandleError> {
let elem_size = E::size().expect("Element should have a size");
Self::try_from_parts(handle, strides, shape, elem_size)
}
}

impl core::fmt::Display for TensorHandleError {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
match self {
TensorHandleError::RankMismatch {
shape_rank,
stride_rank,
} => {
write!(
f,
"rank mismatch (shape={}, strides={})",
shape_rank, stride_rank
)
}
TensorHandleError::ElemSizeZero => write!(f, "element size must be > 0"),
TensorHandleError::ZeroStride { axis } => {
write!(f, "zero stride on axis {} with extent > 1", axis)
}
}
}
}

impl core::fmt::Display for TensorArgError {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
match self {
TensorArgError::UnsupportedVectorization {
requested,
supported,
} => {
write!(
f,
"unsupported vectorization {}, supported: {:?}",
requested, supported
)
}
TensorArgError::NonContiguousInner => write!(
f,
"non-contiguous innermost dimension for vectorized access"
),
TensorArgError::MisalignedVectorization { last_dim, factor } => write!(
f,
"innermost dimension {} not divisible by vectorization {}",
last_dim, factor
),
}
}
}
2 changes: 2 additions & 0 deletions crates/cubecl-core/src/runtime_tests/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ pub mod sequence;
pub mod slice;
pub mod synchronization;
pub mod tensor;
pub mod tensor_handle;
pub mod tensormap;
pub mod to_client;
pub mod topology;
Expand Down Expand Up @@ -138,6 +139,7 @@ macro_rules! testgen_untyped {
cubecl_core::testgen_comparison!();

cubecl_core::testgen_to_client!();
cubecl_core::testgen_tensor_handle!();
};
}

Expand Down
Loading
Loading