Skip to content

Commit 9348a3e

Browse files
authored
feat: Split traits (#868)
1 parent ae51a6f commit 9348a3e

File tree

51 files changed

+973
-1076
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

51 files changed

+973
-1076
lines changed

crates/cubecl-attention/src/components/args.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,10 @@ impl<EI: Numeric, EO: Numeric, MA: AttentionArgs> VirtualTensorOperationsExpand<
243243
) -> ExpandElementTyped<TensorMap<EO>> {
244244
unimplemented!("TensorOutputExpand can't be turned into a tensor map");
245245
}
246+
}
246247

248+
impl<EI: Numeric, EO: Numeric, MA: AttentionArgs> Lined for TensorOutput<EI, EO, MA> {}
249+
impl<EI: Numeric, EO: Numeric, MA: AttentionArgs> LinedExpand for TensorOutputExpand<EI, EO, MA> {
247250
fn line_size(&self) -> u32 {
248251
let mut scope = Scope::root(false);
249252
TensorOutputExpand::__expand_line_size_method(self.clone(), &mut scope)
@@ -312,7 +315,10 @@ impl<EI: Numeric, EO: Numeric, MA: AttentionArgs> VirtualTensorOperationsExpand<
312315
) -> ExpandElementTyped<TensorMap<EI>> {
313316
TensorInputExpand::__expand_as_tensor_map_method(self.clone(), scope)
314317
}
318+
}
315319

320+
impl<EI: Numeric, EO: Numeric, MA: AttentionArgs> Lined for TensorInput<EI, EO, MA> {}
321+
impl<EI: Numeric, EO: Numeric, MA: AttentionArgs> LinedExpand for TensorInputExpand<EI, EO, MA> {
316322
fn line_size(&self) -> u32 {
317323
let mut scope = Scope::root(false);
318324
TensorInputExpand::__expand_line_size_method(self.clone(), &mut scope)

crates/cubecl-attention/src/components/global/dummy/attention.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ impl<
113113
comment!("Global: Init Query Loader");
114114
let layout =
115115
SimpleGlobalLayout::new(&query, config.global_memory_config(FlashIdent::Query));
116-
DummyQueryLoader::<AP, Self::Config>::new(q_offset, query.view(layout.virt()), config)
116+
DummyQueryLoader::<AP, Self::Config>::new(q_offset, query.view(layout), config)
117117
}
118118

119119
fn init_key_loader(
@@ -122,7 +122,7 @@ impl<
122122
) -> Self::KeyLoader {
123123
comment!("Global: Init Key Loader");
124124
let layout = SimpleGlobalLayout::new(&key, config.global_memory_config(FlashIdent::Key));
125-
DummyKeyLoader::new(key.view(layout.virt()), config)
125+
DummyKeyLoader::new(key.view(layout), config)
126126
}
127127

128128
fn init_value_loader(
@@ -132,7 +132,7 @@ impl<
132132
comment!("Global: Init Value Loader");
133133
let layout =
134134
SimpleGlobalLayout::new(&value, config.global_memory_config(FlashIdent::Value));
135-
DummyValueLoader::new(value.view(layout.virt()), config)
135+
DummyValueLoader::new(value.view(layout), config)
136136
}
137137

138138
fn init_writer(
@@ -142,6 +142,6 @@ impl<
142142
) -> Self::Writer {
143143
comment!("Global: Init Writer");
144144
let layout = SimpleGlobalLayout::new(&out, config.global_memory_config(FlashIdent::Out));
145-
SA::init_writer(q_offset, out.view_mut(layout.virt()))
145+
SA::init_writer(q_offset, out.view_mut(layout))
146146
}
147147
}

crates/cubecl-convolution/src/components/global/layout/base.rs

Lines changed: 0 additions & 48 deletions
This file was deleted.

crates/cubecl-convolution/src/components/global/layout/im2col.rs

Lines changed: 22 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,14 @@ use cubecl_matmul::components::{
66
};
77
use cubecl_std::{
88
FastDivmod,
9-
tensor::layout::{Coords3d, Layout, VirtualLayoutOperations, VirtualLayoutOperationsExpand},
9+
tensor::layout::{Coords3d, Layout, LayoutExpand},
1010
};
1111

1212
use crate::{
1313
components::{
1414
ConvolutionConfig,
1515
global::{
16-
layout::{NhwcCoords, unwrap, virtual_layout},
16+
layout::{NhwcCoords, unwrap},
1717
load::im2col_tma::div_mod_seq,
1818
},
1919
},
@@ -24,7 +24,7 @@ use crate::{
2424
/// It first decomposes the `(m, k)` matrix into `((n, out_h, out_w), (k_h, k_w, c))`, then applies
2525
/// the convolution parameters to calculate the position in the input tensor for that kernel element.
2626
#[derive(CubeType, Clone)]
27-
pub struct Im2colGlobalLayout {
27+
pub struct Im2colLayout {
2828
/// Shape of output DHW
2929
pub shape_out: Sequence<FastDivmod>,
3030
/// Shape of channel, for decomposing k
@@ -54,14 +54,14 @@ pub struct Im2colGlobalLayout {
5454
}
5555

5656
#[cube]
57-
impl Im2colGlobalLayout {
57+
impl Im2colLayout {
5858
pub fn new<G: GlobalConfig>(
5959
args: &RuntimeArgs,
6060
#[comptime] config: ConvolutionConfig<G>,
61-
) -> Im2colGlobalLayout {
61+
) -> Im2colLayout {
6262
let shape_out = args.shape_out.clone();
6363

64-
Im2colGlobalLayout {
64+
Im2colLayout {
6565
shape_out,
6666
shape_channel: args.shape_channel,
6767
shape_m: args.shape_m,
@@ -76,32 +76,32 @@ impl Im2colGlobalLayout {
7676
}
7777

7878
#[cube]
79-
impl Layout for Im2colGlobalLayout {
79+
impl Layout for Im2colLayout {
8080
type Coordinates = Coords3d;
8181
type SourceCoordinates = NhwcCoords;
8282

83-
fn to_source_pos(this: &Self, pos: Self::Coordinates) -> NhwcCoords {
83+
fn to_source_pos(&self, pos: Self::Coordinates) -> NhwcCoords {
8484
let (_, view_m, view_k) = pos;
8585

86-
let (batch, out_offs) = div_mod_seq(view_m, &this.shape_out);
86+
let (batch, out_offs) = div_mod_seq(view_m, &self.shape_out);
8787

88-
let (mut rem, channel) = this.shape_channel.div_mod(view_k);
88+
let (mut rem, channel) = self.shape_channel.div_mod(view_k);
8989

90-
let spatial_dims = comptime![this.shape_out.len()];
90+
let spatial_dims = comptime![self.shape_out.len()];
9191
let mut in_pos = Sequence::<i32>::new();
9292

9393
#[unroll]
9494
for i in 0..spatial_dims {
9595
let i = unwrap(i);
9696
let dim = comptime![spatial_dims - i - 1];
97-
let ksize = comptime![this.kernel_size[dim as usize]];
97+
let ksize = comptime![self.kernel_size[dim as usize]];
9898
let k_pos = rem % ksize;
9999
rem /= ksize;
100100

101101
let out_pos = *out_offs.index(dim);
102-
let stride = comptime![this.stride[dim as usize]];
103-
let dilate = comptime![this.dilation[dim as usize]];
104-
let pad = comptime![this.padding[dim as usize]];
102+
let stride = comptime![self.stride[dim as usize]];
103+
let dilate = comptime![self.dilation[dim as usize]];
104+
let pad = comptime![self.padding[dim as usize]];
105105

106106
let pos = (out_pos * stride + k_pos * dilate) as i32 - pad;
107107
in_pos.push(pos);
@@ -116,21 +116,19 @@ impl Layout for Im2colGlobalLayout {
116116
}
117117
}
118118

119-
fn shape(this: &Self) -> Self::Coordinates {
120-
(1, this.shape_m, this.shape_k)
119+
fn shape(&self) -> Self::Coordinates {
120+
(1, self.shape_m, self.shape_k)
121121
}
122122

123-
fn to_source_pos_checked(this: &Self, pos: Self::Coordinates) -> (NhwcCoords, bool) {
124-
(this.to_source_pos(pos), this.is_in_bounds(pos))
123+
fn to_source_pos_checked(&self, pos: Self::Coordinates) -> (NhwcCoords, bool) {
124+
(self.to_source_pos(pos), self.is_in_bounds(pos))
125125
}
126126

127-
fn is_in_bounds(this: &Self, pos: Self::Coordinates) -> bool {
127+
fn is_in_bounds(&self, pos: Self::Coordinates) -> bool {
128128
let (_, view_m, view_k) = pos;
129129
// Shouldn't be relied on because it doesn't check spatial
130-
let m_in_bounds = comptime!(!this.config.check_row_bounds) || view_m < this.shape_m;
131-
let k_in_bounds = comptime!(!this.config.check_col_bounds) || view_k < this.shape_k;
130+
let m_in_bounds = comptime!(!self.config.check_row_bounds) || view_m < self.shape_m;
131+
let k_in_bounds = comptime!(!self.config.check_col_bounds) || view_k < self.shape_k;
132132
m_in_bounds && k_in_bounds
133133
}
134134
}
135-
136-
virtual_layout!(Im2colGlobalLayout, Im2colGlobalLayoutExpand);

crates/cubecl-convolution/src/components/global/layout/mod.rs

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,8 @@
1-
mod base;
21
mod im2col;
32
mod spatial;
43
mod weight;
54
mod write;
65

7-
pub(crate) use base::*;
86
pub use im2col::*;
97
pub use spatial::*;
108
pub use weight::*;

crates/cubecl-convolution/src/components/global/layout/spatial.rs

Lines changed: 17 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,11 @@
11
use cubecl::prelude::*;
22
use cubecl_core::{self as cubecl, intrinsic};
33
use cubecl_std::tensor::{
4-
layout::{
5-
Coordinates, Coords1d, Layout, VirtualLayoutOperations, VirtualLayoutOperationsExpand,
6-
},
4+
layout::{Coordinates, Coords1d, Layout, LayoutExpand},
75
r#virtual::VirtualTensor,
86
};
97

10-
use crate::components::{Dimensionality, global::layout::virtual_layout};
8+
use crate::components::Dimensionality;
119

1210
#[derive(CubeType, Clone)]
1311
pub struct NhwcCoords {
@@ -93,42 +91,39 @@ impl Layout for NhwcLayout {
9391
type Coordinates = NhwcCoords;
9492
type SourceCoordinates = Coords1d;
9593

96-
fn to_source_pos(this: &Self, pos: Self::Coordinates) -> Self::SourceCoordinates {
94+
fn to_source_pos(&self, pos: Self::Coordinates) -> Self::SourceCoordinates {
9795
let NhwcCoords {
9896
batch,
9997
spatial,
10098
channel,
10199
} = pos;
102100

103-
let spatial_dims = this.shapes_spatial.len();
104-
let mut read_pos = batch * this.stride_batch + channel * this.stride_channel;
101+
let spatial_dims = self.shapes_spatial.len();
102+
let mut read_pos = batch * self.stride_batch + channel * self.stride_channel;
105103

106104
#[unroll]
107105
for i in 0..spatial_dims {
108106
let i = unwrap(i);
109-
read_pos += *spatial.index(i) as u32 * *this.strides_spatial.index(i);
107+
read_pos += *spatial.index(i) as u32 * *self.strides_spatial.index(i);
110108
}
111109

112-
read_pos / this.line_size
110+
read_pos / self.line_size
113111
}
114112

115-
fn to_source_pos_checked(
116-
this: &Self,
117-
pos: Self::Coordinates,
118-
) -> (Self::SourceCoordinates, bool) {
119-
(this.to_source_pos(pos.clone()), this.is_in_bounds(pos))
113+
fn to_source_pos_checked(&self, pos: Self::Coordinates) -> (Self::SourceCoordinates, bool) {
114+
(self.to_source_pos(pos.clone()), self.is_in_bounds(pos))
120115
}
121116

122-
fn is_in_bounds(this: &Self, pos: Self::Coordinates) -> bool {
123-
if comptime![this.check_spatial] {
124-
let spatial_dims = this.shapes_spatial.len();
117+
fn is_in_bounds(&self, pos: Self::Coordinates) -> bool {
118+
if comptime![self.check_spatial] {
119+
let spatial_dims = self.shapes_spatial.len();
125120
let mut spatial_in_bounds = true;
126121

127122
#[unroll]
128123
for i in 0..spatial_dims {
129124
let i = unwrap(i);
130125
let pos = *pos.spatial.index(i);
131-
spatial_in_bounds &= pos >= 0 && (pos as u32) < *this.shapes_spatial.index(i);
126+
spatial_in_bounds &= pos >= 0 && (pos as u32) < *self.shapes_spatial.index(i);
132127
}
133128

134129
spatial_in_bounds
@@ -137,17 +132,15 @@ impl Layout for NhwcLayout {
137132
}
138133
}
139134

140-
fn shape(this: &Self) -> Self::Coordinates {
135+
fn shape(&self) -> Self::Coordinates {
141136
NhwcCoords {
142-
batch: this.shape_batch,
143-
spatial: cast_seq(this.shapes_spatial.clone()),
144-
channel: this.shape_channel,
137+
batch: self.shape_batch,
138+
spatial: cast_seq(self.shapes_spatial.clone()),
139+
channel: self.shape_channel,
145140
}
146141
}
147142
}
148143

149-
virtual_layout!(NhwcLayout, NhwcLayoutExpand);
150-
151144
#[allow(unused_variables)]
152145
#[cube]
153146
pub(crate) fn unwrap(v: u32) -> comptime_type!(u32) {

crates/cubecl-convolution/src/components/global/layout/weight.rs

Lines changed: 14 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,15 @@ use cubecl_matmul::components::{
77
use cubecl_std::{
88
FastDivmod,
99
tensor::{
10-
layout::{Coords3d, Layout, VirtualLayoutOperations, VirtualLayoutOperationsExpand},
10+
layout::{Coords3d, Layout, LayoutExpand},
1111
r#virtual::VirtualTensor,
1212
},
1313
};
1414

1515
use crate::{
1616
components::{
1717
ConvGemmConfig, ConvolutionConfig,
18-
global::layout::{NhwcCoords, unwrap, virtual_layout},
18+
global::layout::{NhwcCoords, unwrap},
1919
},
2020
kernels::layered::selector::RuntimeArgs,
2121
};
@@ -83,19 +83,19 @@ impl Layout for WeightLayout {
8383
type Coordinates = Coords3d;
8484
type SourceCoordinates = NhwcCoords;
8585

86-
fn to_source_pos(this: &Self, coords: Self::Coordinates) -> NhwcCoords {
86+
fn to_source_pos(&self, coords: Self::Coordinates) -> NhwcCoords {
8787
let (_, k, n) = coords;
8888

89-
let (mut rem, in_c) = this.channels.div_mod(k);
89+
let (mut rem, in_c) = self.channels.div_mod(k);
9090

91-
let spatial_dims = comptime![this.strides_spatial.len()];
91+
let spatial_dims = comptime![self.strides_spatial.len()];
9292
let mut kernel_pos = Sequence::<i32>::new();
9393

9494
#[unroll]
9595
for i in 0..spatial_dims {
9696
let i = unwrap(i);
9797
let dim = comptime![spatial_dims - i - 1];
98-
let ksize = comptime![this.kernel_size[dim as usize]];
98+
let ksize = comptime![self.kernel_size[dim as usize]];
9999
let k_pos = rem % ksize;
100100
rem /= ksize;
101101

@@ -111,20 +111,18 @@ impl Layout for WeightLayout {
111111
}
112112
}
113113

114-
fn to_source_pos_checked(this: &Self, coords: Self::Coordinates) -> (NhwcCoords, bool) {
115-
(this.to_source_pos(coords), this.is_in_bounds(coords))
114+
fn to_source_pos_checked(&self, coords: Self::Coordinates) -> (NhwcCoords, bool) {
115+
(self.to_source_pos(coords), self.is_in_bounds(coords))
116116
}
117117

118-
fn shape(this: &Self) -> Self::Coordinates {
119-
(1, this.shape_k, this.shape_n)
118+
fn shape(&self) -> Self::Coordinates {
119+
(1, self.shape_k, self.shape_n)
120120
}
121121

122-
fn is_in_bounds(this: &Self, pos: Self::Coordinates) -> bool {
122+
fn is_in_bounds(&self, pos: Self::Coordinates) -> bool {
123123
let (_, k, n) = pos;
124-
let check_k = comptime![this.config.check_row_bounds];
125-
let check_n = comptime![this.config.check_col_bounds];
126-
(!check_k || k < this.shape_k) && (!check_n || n < this.shape_n)
124+
let check_k = comptime![self.config.check_row_bounds];
125+
let check_n = comptime![self.config.check_col_bounds];
126+
(!check_k || k < self.shape_k) && (!check_n || n < self.shape_n)
127127
}
128128
}
129-
130-
virtual_layout!(WeightLayout, WeightLayoutExpand);

0 commit comments

Comments
 (0)