Skip to content

Commit f407c26

Browse files
authored
refactor: Allow accessing the underlying buffer for TensorMap (#940)
* Make `as_tensor_map` optional * Revert debug setting * Add 1D TMA load * Remove leftover print * Add 1D TMA store
1 parent 1dd4861 commit f407c26

File tree

28 files changed

+249
-184
lines changed

28 files changed

+249
-184
lines changed

crates/cubecl-attention/src/components/args.rs

Lines changed: 21 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
use cubecl::prelude::*;
22
use cubecl_core::{self as cubecl};
3-
use cubecl_std::tensor::r#virtual::{VirtualTensorOperations, VirtualTensorOperationsExpand};
3+
use cubecl_std::{
4+
CubeOption, CubeOptionExpand,
5+
tensor::r#virtual::{VirtualTensorOperations, VirtualTensorOperationsExpand},
6+
};
47

58
use crate::components::{
69
line_size::AttentionLineSizes, problem::AttentionProblem, selection::AttentionSelection,
@@ -88,17 +91,17 @@ pub trait AttentionArgs: Send + Sync + 'static + Clone {
8891
/// Reinterpret query as tensor map
8992
fn as_tensor_map_query<Q: Float, K: Float, V: Float, O: Float>(
9093
state: &Self::State<Q, K, V, O>,
91-
) -> TensorMap<Q>;
94+
) -> CubeOption<TensorMap<Q>>;
9295

9396
/// Reinterpret key as tensor map
9497
fn as_tensor_map_key<Q: Float, K: Float, V: Float, O: Float>(
9598
state: &Self::State<Q, K, V, O>,
96-
) -> TensorMap<K>;
99+
) -> CubeOption<TensorMap<K>>;
97100

98101
/// Reinterpret value as tensor map
99102
fn as_tensor_map_value<Q: Float, K: Float, V: Float, O: Float>(
100103
state: &Self::State<Q, K, V, O>,
101-
) -> TensorMap<V>;
104+
) -> CubeOption<TensorMap<V>>;
102105

103106
/// Write the line to the output at the given coordinate using the state.
104107
fn write_out<Q: Float, K: Float, V: Float, O: Float>(
@@ -290,11 +293,8 @@ impl<Q: Float, K: Float, V: Float, O: Float, MA: AttentionArgs> VirtualTensorOpe
290293
TensorOutputExpand::__expand_buffer_len_method(self.clone(), scope)
291294
}
292295

293-
fn __expand_as_tensor_map_method(
294-
&self,
295-
_scope: &mut Scope,
296-
) -> ExpandElementTyped<TensorMap<O>> {
297-
unimplemented!("TensorOutputExpand can't be turned into a tensor map");
296+
fn __expand_as_tensor_map_method(&self, scope: &mut Scope) -> CubeOptionExpand<TensorMap<O>> {
297+
CubeOption::__expand_new_None(scope)
298298
}
299299
}
300300

@@ -367,7 +367,7 @@ impl<Q: Float, K: Float, V: Float, O: Float, MA: AttentionArgs> VirtualTensorOpe
367367
TensorQueryExpand::__expand_buffer_len_method(self.clone(), scope)
368368
}
369369

370-
fn __expand_as_tensor_map_method(&self, scope: &mut Scope) -> ExpandElementTyped<TensorMap<Q>> {
370+
fn __expand_as_tensor_map_method(&self, scope: &mut Scope) -> CubeOptionExpand<TensorMap<Q>> {
371371
TensorQueryExpand::__expand_as_tensor_map_method(self.clone(), scope)
372372
}
373373
}
@@ -441,7 +441,7 @@ impl<Q: Float, K: Float, V: Float, O: Float, MA: AttentionArgs> VirtualTensorOpe
441441
TensorKeyExpand::__expand_buffer_len_method(self.clone(), scope)
442442
}
443443

444-
fn __expand_as_tensor_map_method(&self, scope: &mut Scope) -> ExpandElementTyped<TensorMap<K>> {
444+
fn __expand_as_tensor_map_method(&self, scope: &mut Scope) -> CubeOptionExpand<TensorMap<K>> {
445445
TensorKeyExpand::__expand_as_tensor_map_method(self.clone(), scope)
446446
}
447447
}
@@ -515,7 +515,7 @@ impl<Q: Float, K: Float, V: Float, O: Float, MA: AttentionArgs> VirtualTensorOpe
515515
TensorValueExpand::__expand_buffer_len_method(self.clone(), scope)
516516
}
517517

518-
fn __expand_as_tensor_map_method(&self, scope: &mut Scope) -> ExpandElementTyped<TensorMap<V>> {
518+
fn __expand_as_tensor_map_method(&self, scope: &mut Scope) -> CubeOptionExpand<TensorMap<V>> {
519519
TensorValueExpand::__expand_as_tensor_map_method(self.clone(), scope)
520520
}
521521
}
@@ -606,7 +606,7 @@ impl<Q: Float, K: Float, V: Float, O: Float, MA: AttentionArgs> TensorQuery<Q, K
606606
}
607607

608608
/// Get the buffer length of the tensor.
609-
pub fn as_tensor_map(&self) -> TensorMap<Q> {
609+
pub fn as_tensor_map(&self) -> CubeOption<TensorMap<Q>> {
610610
unsafe { MA::as_tensor_map_query(&(*self.state)) }
611611
}
612612

@@ -660,7 +660,7 @@ impl<Q: Float, K: Float, V: Float, O: Float, MA: AttentionArgs> TensorKey<Q, K,
660660
}
661661

662662
/// Get the buffer length of the tensor.
663-
pub fn as_tensor_map(&self) -> TensorMap<K> {
663+
pub fn as_tensor_map(&self) -> CubeOption<TensorMap<K>> {
664664
unsafe { MA::as_tensor_map_key(&(*self.state)) }
665665
}
666666

@@ -714,7 +714,7 @@ impl<Q: Float, K: Float, V: Float, O: Float, MA: AttentionArgs> TensorValue<Q, K
714714
}
715715

716716
/// Get the buffer length of the tensor.
717-
pub fn as_tensor_map(&self) -> TensorMap<V> {
717+
pub fn as_tensor_map(&self) -> CubeOption<TensorMap<V>> {
718718
unsafe { MA::as_tensor_map_value(&(*self.state)) }
719719
}
720720

@@ -885,26 +885,20 @@ impl AttentionArgs for TensorArgs {
885885

886886
fn as_tensor_map_query<Q: Float, K: Float, V: Float, O: Float>(
887887
_state: &Self::State<Q, K, V, O>,
888-
) -> TensorMap<Q> {
889-
comptime!(unimplemented!("Can't use `TensorArgs` as `TensorMap`"));
890-
#[allow(unreachable_code)]
891-
TensorMap::dummy()
888+
) -> CubeOption<TensorMap<Q>> {
889+
CubeOption::new_None()
892890
}
893891

894892
fn as_tensor_map_key<Q: Float, K: Float, V: Float, O: Float>(
895893
_state: &Self::State<Q, K, V, O>,
896-
) -> TensorMap<K> {
897-
comptime!(unimplemented!("Can't use `TensorArgs` as `TensorMap`"));
898-
#[allow(unreachable_code)]
899-
TensorMap::dummy()
894+
) -> CubeOption<TensorMap<K>> {
895+
CubeOption::new_None()
900896
}
901897

902898
fn as_tensor_map_value<Q: Float, K: Float, V: Float, O: Float>(
903899
_state: &Self::State<Q, K, V, O>,
904-
) -> TensorMap<V> {
905-
comptime!(unimplemented!("Can't use `TensorArgs` as `TensorMap`"));
906-
#[allow(unreachable_code)]
907-
TensorMap::dummy()
900+
) -> CubeOption<TensorMap<V>> {
901+
CubeOption::new_None()
908902
}
909903

910904
fn shape_query<Q: Float, K: Float, V: Float, O: Float>(

crates/cubecl-convolution/src/components/global/memory/tma.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ impl<E: Numeric> Im2colTmaReader<E> {
2222
spatial_offsets: Sequence<u32>,
2323
k_offset: u32,
2424
) -> Im2colTmaReader<E> {
25-
let map = tensor.as_tensor_map();
25+
let map = tensor.as_tensor_map().unwrap();
2626

2727
Im2colTmaReader::<E> {
2828
tensor: map,

crates/cubecl-convolution/src/components/global/multi_stage/tma/convolution.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,7 @@ where
222222
) -> Self::RhsGlobalReader {
223223
let (x_offset, y_offset) = offset;
224224
Self::RhsGlobalReader::new(
225-
rhs.as_tensor_map(),
225+
rhs.as_tensor_map().unwrap(),
226226
x_offset,
227227
y_offset,
228228
runtime_args,

crates/cubecl-convolution/src/components/global/single_stage/tma/convolution.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ where
151151
) -> Self::RhsGlobalReader {
152152
let (x_offset, y_offset) = offset;
153153
Self::RhsGlobalReader::new(
154-
rhs.as_tensor_map(),
154+
rhs.as_tensor_map().unwrap(),
155155
x_offset,
156156
y_offset,
157157
runtime_args,

crates/cubecl-core/src/codegen/integrator.rs

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,15 @@ pub struct KernelIntegrator {
1414
expansion: KernelExpansion,
1515
buffer_bindings: Vec<Binding>,
1616
scalar_bindings: Vec<ScalarBinding>,
17-
tensor_maps: Vec<Id>,
17+
tensor_maps: Vec<Binding>,
1818
}
1919

2020
/// The information necessary to compile a [kernel definition](KernelDefinition).
2121
#[derive(Clone)]
2222
pub struct KernelExpansion {
2323
pub buffers: Vec<BufferInfo>,
2424
pub scalars: Vec<ScalarInfo>,
25-
pub tensor_maps: Vec<Id>,
25+
pub tensor_maps: Vec<BufferInfo>,
2626
pub scope: Scope,
2727
}
2828

@@ -143,8 +143,15 @@ impl KernelIntegrator {
143143
}
144144

145145
fn register_tensor_maps(&mut self) {
146-
for id in self.expansion.tensor_maps.drain(..) {
147-
self.tensor_maps.push(id);
146+
for buffer in self.expansion.tensor_maps.drain(..) {
147+
self.tensor_maps.push(Binding {
148+
id: buffer.id,
149+
ty: buffer.item,
150+
visibility: buffer.visibility,
151+
location: Location::Storage,
152+
has_extended_meta: buffer.has_extended_meta,
153+
size: None,
154+
});
148155
}
149156
}
150157
}

crates/cubecl-core/src/compute/builder.rs

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,7 @@ use std::{
55

66
use alloc::collections::BTreeMap;
77

8-
use cubecl_ir::{
9-
ExpandElement, Scope, SemanticType, StorageType, TargetProperties, Variable, VariableKind,
10-
};
8+
use cubecl_ir::{ExpandElement, Scope, StorageType, TargetProperties, Variable, VariableKind};
119
use cubecl_runtime::config::{GlobalConfig, compilation::CompilationLogLevel};
1210

1311
use crate::ir::{Id, Type};
@@ -23,7 +21,7 @@ pub struct KernelBuilder {
2321
pub scope: Scope,
2422
buffers: Vec<BufferInfo>,
2523
scalars: BTreeMap<StorageType, usize>,
26-
tensor_maps: Vec<Id>,
24+
tensor_maps: Vec<BufferInfo>,
2725
}
2826

2927
static DEBUG: AtomicI8 = AtomicI8::new(-1);
@@ -54,13 +52,27 @@ impl KernelBuilder {
5452
}
5553

5654
/// Register a tensor map and return the [element](ExpandElement) to be used for kernel expansion.
57-
pub fn tensor_map(&mut self) -> ExpandElement {
55+
pub fn input_tensor_map(&mut self, item: Type) -> ExpandElement {
5856
let id = self.buffer_id();
59-
self.tensor_maps.push(id);
60-
ExpandElement::Plain(Variable::new(
61-
VariableKind::TensorMap(id),
62-
Type::semantic(SemanticType::TensorMap),
63-
))
57+
self.tensor_maps.push(BufferInfo {
58+
id,
59+
item,
60+
visibility: Visibility::ReadWrite,
61+
has_extended_meta: true,
62+
});
63+
ExpandElement::Plain(Variable::new(VariableKind::TensorMapInput(id), item))
64+
}
65+
66+
/// Register a tensor map and return the [element](ExpandElement) to be used for kernel expansion.
67+
pub fn output_tensor_map(&mut self, item: Type) -> ExpandElement {
68+
let id = self.buffer_id();
69+
self.tensor_maps.push(BufferInfo {
70+
id,
71+
item,
72+
visibility: Visibility::Read,
73+
has_extended_meta: true,
74+
});
75+
ExpandElement::Plain(Variable::new(VariableKind::TensorMapOutput(id), item))
6476
}
6577

6678
/// Register an input array and return the [element](ExpandElement) to be used for kernel expansion.

crates/cubecl-core/src/compute/kernel.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ source:
156156
#[allow(missing_docs)]
157157
pub struct KernelDefinition {
158158
pub buffers: Vec<Binding>,
159-
pub tensor_maps: Vec<Id>,
159+
pub tensor_maps: Vec<Binding>,
160160
pub scalars: Vec<ScalarBinding>,
161161
pub cube_dim: CubeDim,
162162
pub body: Scope,

crates/cubecl-core/src/frontend/barrier.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,7 @@ macro_rules! tensor_map_load_im2col {
240240
};
241241
}
242242

243+
tensor_map_load!(1, x);
243244
tensor_map_load!(2, y, x);
244245
tensor_map_load!(3, z, y, x);
245246
tensor_map_load!(4, w, z, y, x);

0 commit comments

Comments
 (0)