Skip to content

Commit 58e6a34

Browse files
committed
core(tensor): avoid redundant checks in hot paths; use debug_asserts and clarify try_* docs
1 parent 8e4b170 commit 58e6a34

File tree

1 file changed

+41
-4
lines changed
  • crates/cubecl-core/src/frontend/container/tensor

1 file changed

+41
-4
lines changed

crates/cubecl-core/src/frontend/container/tensor/launch.rs

Lines changed: 41 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,16 @@ impl<R: Runtime> ArgSettings<R> for TensorArg<'_, R> {
206206
impl<'a, R: Runtime> TensorHandleRef<'a, R> {
207207
/// Convert the handle into a [tensor argument](TensorArg).
208208
pub fn as_tensor_arg(&'a self, vectorization: u8) -> TensorArg<'a, R> {
209+
// In debug builds, assert that the requested vectorization is supported
210+
// by the runtime. Validation of the chosen factor should normally be
211+
// performed upstream (at selection time) to avoid redundant checks in
212+
// hot paths.
213+
debug_assert!(
214+
R::supported_line_sizes().contains(&vectorization),
215+
"unsupported vectorization {} (supported: {:?})",
216+
vectorization,
217+
R::supported_line_sizes()
218+
);
209219
unsafe {
210220
TensorArg::from_raw_parts_and_size(
211221
self.handle,
@@ -218,10 +228,14 @@ impl<'a, R: Runtime> TensorHandleRef<'a, R> {
218228
}
219229
/// Convert the handle into a [tensor argument](TensorArg) with basic safety checks
220230
/// for vectorization compatibility.
221-
/// Try to convert the handle into a tensor argument, validating that the
222-
/// requested vectorization factor is supported by the runtime. This does not
223-
/// enforce inner-most contiguity or alignment requirements as kernels may
224-
/// legally vectorize along axes other than the innermost.
231+
///
232+
/// Note: This convenience is primarily intended for host wrappers / FFI
233+
/// ingestion paths. In internal code, prefer validating the chosen
234+
/// vectorization factor at selection time and then calling
235+
/// [`as_tensor_arg`], to avoid redundant work in hot paths.
236+
///
237+
/// This does not enforce inner‑most contiguity or alignment requirements as
238+
/// kernels may vectorize along axes other than the innermost.
225239
pub fn try_as_tensor_arg(
226240
&'a self,
227241
vectorization: u8,
@@ -244,6 +258,24 @@ impl<'a, R: Runtime> TensorHandleRef<'a, R> {
244258
shape: &'a [usize],
245259
elem_size: usize,
246260
) -> Self {
261+
// Basic invariants for debug builds only; upstream layers are expected
262+
// to ensure correctness in release builds.
263+
debug_assert_eq!(
264+
shape.len(),
265+
strides.len(),
266+
"rank mismatch (shape={}, strides={})",
267+
shape.len(),
268+
strides.len()
269+
);
270+
debug_assert!(elem_size > 0, "element size must be > 0");
271+
// Disallow zero strides when corresponding dimension extent > 1
272+
for (i, (&s, &d)) in strides.iter().zip(shape.iter()).enumerate() {
273+
debug_assert!(
274+
!(s == 0 && d > 1),
275+
"zero stride on axis {} with extent > 1",
276+
i
277+
);
278+
}
247279
Self {
248280
handle,
249281
strides,
@@ -254,6 +286,11 @@ impl<'a, R: Runtime> TensorHandleRef<'a, R> {
254286
}
255287

256288
/// Safely create a tensor handle from raw parts with basic shape/stride validation.
289+
///
290+
/// Note: This is mainly useful for host / FFI boundaries to surface clear
291+
/// errors early. Internal code should ensure these invariants when
292+
/// constructing handles and may use the `unsafe` constructor directly in
293+
/// performance‑critical paths.
257294
pub fn try_from_parts(
258295
handle: &'a cubecl_runtime::server::Handle,
259296
strides: &'a [usize],

0 commit comments

Comments
 (0)