Skip to content

Commit 018341e

Browse files
authored
feat: Add broadcasting support to linear layout (#889)
1 parent fb9a730 commit 018341e

File tree

4 files changed

+103
-15
lines changed

4 files changed

+103
-15
lines changed

crates/cubecl-cpp/src/shared/instruction.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -467,7 +467,7 @@ for ({i_ty} {i} = {start}; {i} {cmp} {end}; {increment}) {{
467467
} => {
468468
let out = out.fmt_left();
469469
match *split_meta {
470-
true => writeln!(f, "{out} = static_info.x[{info_offset}];"),
470+
true => writeln!(f, "{out} = {STATIC_INFO_NAME}.x[{info_offset}];"),
471471
false => writeln!(f, "{out} = {INFO_NAME}[{info_offset}];"),
472472
}
473473
}

crates/cubecl-std/src/fast_math.rs

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,6 @@ pub enum FastDivmod {
1616
multiplier: u32,
1717
shift_right: u32,
1818
},
19-
PowerOfTwo {
20-
shift: u32,
21-
mask: u32,
22-
},
2319
Fallback {
2420
divisor: u32,
2521
},
@@ -36,13 +32,6 @@ impl<R: Runtime> FastDivmodArgs<'_, R> {
3632
pub fn new(client: &ComputeClient<R::Server, R::Channel>, divisor: u32) -> Self {
3733
debug_assert!(divisor != 0);
3834

39-
if divisor.is_power_of_two() {
40-
return FastDivmodArgs::PowerOfTwo {
41-
shift: ScalarArg::new(divisor.trailing_zeros()),
42-
mask: ScalarArg::new(divisor - 1),
43-
};
44-
}
45-
4635
if !u64::supported_uses(client).contains(TypeUsage::Arithmetic) {
4736
return FastDivmodArgs::Fallback {
4837
divisor: ScalarArg::new(divisor),
@@ -73,7 +62,6 @@ impl FastDivmod {
7362
let t = u32::mul_hi(dividend, *multiplier);
7463
(t + dividend) >> shift_right
7564
}
76-
FastDivmod::PowerOfTwo { shift, .. } => dividend >> *shift,
7765
FastDivmod::Fallback { divisor } => dividend / divisor,
7866
}
7967
}
@@ -82,7 +70,6 @@ impl FastDivmod {
8270
let q = self.div(dividend);
8371
match self {
8472
FastDivmod::Fast { divisor, .. } => dividend - q * divisor,
85-
FastDivmod::PowerOfTwo { mask, .. } => dividend & mask,
8673
FastDivmod::Fallback { divisor } => dividend % divisor,
8774
}
8875
}
@@ -92,7 +79,6 @@ impl FastDivmod {
9279
let r = match self {
9380
FastDivmod::Fast { divisor, .. } => dividend - q * divisor,
9481
FastDivmod::Fallback { divisor } => dividend - q * divisor,
95-
FastDivmod::PowerOfTwo { mask, .. } => dividend & *mask,
9682
};
9783

9884
(q, r)

crates/cubecl-std/src/tensor/layout/linear.rs

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ impl LinearLayoutExpand {
4949
}
5050

5151
impl<'a, R: Runtime> LinearLayoutArgs<'a, R> {
52+
/// Construct a linear layout from shapes, strides and line size of the tensor
5253
pub fn from_shape_strides(
5354
client: &ComputeClient<R::Server, R::Channel>,
5455
shape: &[usize],
@@ -69,13 +70,52 @@ impl<'a, R: Runtime> LinearLayoutArgs<'a, R> {
6970
}
7071
}
7172

73+
/// Construct a possibly broadcast linear layout from shapes/strides and a reference shape
74+
pub fn from_shape_strides_with_reference(
75+
client: &ComputeClient<R::Server, R::Channel>,
76+
shape: &[usize],
77+
reference_shape: &[usize],
78+
strides: &[usize],
79+
line_size: &'a u8,
80+
) -> Self {
81+
if shape != reference_shape {
82+
// Broadcast layouts are always treated as permuted
83+
Self::Permuted(PermutedLayoutLaunch::from_shapes_strides_ref(
84+
client,
85+
shape,
86+
reference_shape,
87+
strides,
88+
line_size,
89+
))
90+
} else {
91+
Self::from_shape_strides(client, shape, strides, line_size)
92+
}
93+
}
94+
95+
/// Construct a linear layout from a tensor handle
7296
pub fn from_handle(
7397
client: &ComputeClient<R::Server, R::Channel>,
7498
handle: &TensorHandleRef<'a, R>,
7599
line_size: &'a u8,
76100
) -> Self {
77101
Self::from_shape_strides(client, handle.shape, handle.strides, line_size)
78102
}
103+
104+
/// Construct a possibly broadcast linear layout from a tensor handle and reference handle
105+
pub fn from_handle_with_reference(
106+
client: &ComputeClient<R::Server, R::Channel>,
107+
handle: &TensorHandleRef<'a, R>,
108+
reference: &TensorHandleRef<'a, R>,
109+
line_size: &'a u8,
110+
) -> Self {
111+
Self::from_shape_strides_with_reference(
112+
client,
113+
handle.shape,
114+
reference.shape,
115+
handle.strides,
116+
line_size,
117+
)
118+
}
79119
}
80120

81121
#[cube]
@@ -120,6 +160,21 @@ pub fn linear_view<'a, R: Runtime>(
120160
LinearViewLaunch::new(buffer, layout)
121161
}
122162

163+
/// Create a possibly broadcast linear tensor view from a handle, reference handle and line size
164+
pub fn linear_view_with_reference<'a, R: Runtime>(
165+
client: &ComputeClient<R::Server, R::Channel>,
166+
handle: &'a TensorHandleRef<'a, R>,
167+
reference: &'a TensorHandleRef<'a, R>,
168+
line_size: &'a u8,
169+
) -> LinearViewLaunch<'a, R> {
170+
let len = handle.shape.iter().product::<usize>();
171+
let layout = LinearLayoutArgs::from_handle_with_reference(client, handle, reference, line_size);
172+
let buffer = unsafe {
173+
ArrayArg::from_raw_parts_and_size(handle.handle, len, *line_size, handle.elem_size)
174+
};
175+
LinearViewLaunch::new(buffer, layout)
176+
}
177+
123178
pub fn linear_view_alias<'a, R: Runtime>(
124179
client: &ComputeClient<R::Server, R::Channel>,
125180
handle: &'a TensorHandleRef<'a, R>,

crates/cubecl-std/src/tensor/layout/permuted.rs

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ pub struct PermutedLayout {
2121
}
2222

2323
impl<'a, R: Runtime> PermutedLayoutLaunch<'a, R> {
24+
/// Create a new permuted layout for a possibly broadcast tensor, with a reference shape to be
25+
/// broadcast to.
2426
pub fn from_shape_strides(
2527
client: &ComputeClient<R::Server, R::Channel>,
2628
shape: &[usize],
@@ -45,6 +47,51 @@ impl<'a, R: Runtime> PermutedLayoutLaunch<'a, R> {
4547
Self::new(shape, strides, ScalarArg::new(len as u32), line_size)
4648
}
4749

50+
/// Create a new permuted layout for a possibly broadcast tensor, with a reference shape to be
51+
/// broadcast to.
52+
pub fn from_shapes_strides_ref(
53+
client: &ComputeClient<R::Server, R::Channel>,
54+
shape: &[usize],
55+
reference_shape: &[usize],
56+
strides: &[usize],
57+
line_size: &'a u8,
58+
) -> Self {
59+
debug_assert!(
60+
shape.len() == reference_shape.len(),
61+
"Shape and reference should have the same rank"
62+
);
63+
debug_assert!(
64+
shape
65+
.iter()
66+
.zip(reference_shape)
67+
.all(|(s, r)| s == r || *s == 1),
68+
"Shape should be equal to reference or 1 on each dimension"
69+
);
70+
71+
let strides: Vec<usize> = strides
72+
.iter()
73+
.zip(shape.iter().zip(reference_shape))
74+
.map(|(stride, (s, r))| if *s == *r { *stride } else { 0 })
75+
.collect();
76+
77+
Self::from_shape_strides(client, reference_shape, &strides, line_size)
78+
}
79+
80+
pub fn from_handles_ref(
81+
client: &ComputeClient<R::Server, R::Channel>,
82+
handle: &TensorHandleRef<'_, R>,
83+
reference_handle: &TensorHandleRef<'_, R>,
84+
line_size: &'a u8,
85+
) -> Self {
86+
Self::from_shapes_strides_ref(
87+
client,
88+
handle.shape,
89+
reference_handle.shape,
90+
handle.strides,
91+
line_size,
92+
)
93+
}
94+
4895
pub fn from_handle(
4996
client: &ComputeClient<R::Server, R::Channel>,
5097
handle: &TensorHandleRef<'_, R>,

0 commit comments

Comments
 (0)