Skip to content

Commit 9b08383

Browse files
authored
enable unit attention (#1079)
1 parent a15c1d2 commit 9b08383

File tree

4 files changed

+108
-35
lines changed

4 files changed

+108
-35
lines changed

crates/cubecl-attention/src/base.rs

Lines changed: 6 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,8 @@ use cubecl_std::tensor::TensorHandle;
44

55
use crate::{
66
components::{
7-
AttentionElems, AttentionIdent, AttentionPartitionSize, AttentionProblem,
8-
AttentionSelection, AttentionSetupError, AttentionStageSize, AttentionTileSize,
9-
AttentionTilingScheme, AvailableLineSizes,
7+
AttentionElems, AttentionIdent, AttentionProblem, AttentionSetupError, AvailableLineSizes,
108
args::{TensorArgs, TensorInputsLaunch},
11-
batch::HypercubeSelection,
129
},
1310
kernels::{Algorithm, blackbox_accelerated::BlackboxAcceleratedAlgorithm, unit::UnitAlgorithm},
1411
};
@@ -120,48 +117,22 @@ pub fn launch_attention<R: Runtime, A: Algorithm>(
120117
causal: false,
121118
};
122119

123-
let tile_size = AttentionTileSize {
124-
seq_q: 8,
125-
head_dim: 8,
126-
seq_kv: 8,
127-
val_dim: 8,
128-
};
129-
130-
assert!(problem.head_dim as u32 % tile_size.head_dim == 0);
131-
let partition_head_dim = problem.head_dim as u32 / tile_size.head_dim;
132-
let partition_val_dim = partition_head_dim;
133-
134-
let selection = AttentionSelection {
135-
hypercube_selection: HypercubeSelection {},
136-
tiling_scheme: AttentionTilingScheme {
137-
tile_size,
138-
partition_size: AttentionPartitionSize {
139-
seq_q: 1,
140-
head_dim: partition_head_dim,
141-
seq_kv: 1,
142-
val_dim: partition_val_dim,
143-
},
144-
stage_size: AttentionStageSize { seq_q: 1 },
145-
},
146-
plane_dim: 32,
147-
reuse_key_value: false,
148-
two_rows_in_array_tile: false,
149-
};
150-
151-
let config = BlackboxAcceleratedAlgorithm::setup(
120+
let selection = A::selection(
152121
client,
153122
&problem,
154-
&selection,
123+
client.properties().hardware.plane_size_max,
155124
&line_sizes,
156125
attention_elems,
157126
)?;
158127

128+
let config = A::setup(client, &problem, &selection, &line_sizes, attention_elems)?;
129+
159130
let cube_count_plan = config
160131
.hypercube_config()
161132
.cube_count_plan(&problem, &selection);
162133

163134
let result = unsafe {
164-
<BlackboxAcceleratedAlgorithm as Algorithm>::BatchAttention::launch_unchecked::<TensorArgs, R>(
135+
<A as Algorithm>::BatchAttention::launch_unchecked::<TensorArgs, R>(
165136
client,
166137
config.cube_dim(),
167138
cube_count_plan.resolve(),

crates/cubecl-attention/src/kernels/algorithm.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,4 +27,12 @@ pub trait Algorithm {
2727
) -> Result<<Self::BatchAttention as BatchAttentionFamily>::Config, AttentionSetupError> {
2828
Self::BatchAttention::setup(client, problem, selection, line_sizes, dtypes)
2929
}
30+
31+
fn selection<R: Runtime>(
32+
client: &ComputeClient<R>,
33+
problem: &AttentionProblem,
34+
plane_dim: u32,
35+
line_sizes: &AttentionLineSizes,
36+
dtypes: &AttentionElems,
37+
) -> Result<AttentionSelection, AttentionSetupError>;
3038
}

crates/cubecl-attention/src/kernels/blackbox_accelerated.rs

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,15 @@
1+
use cubecl_core::client::ComputeClient;
12
use cubecl_matmul::components::{global::PartitionedStageFamily, stage::StridedStageFamily};
23

4+
use crate::components::batch::HypercubeSelection;
35
use crate::components::stage::plane::PlanePartitionStageAttentionFamily;
46
use crate::components::tile::TileAttentionFamily;
57
use crate::components::tile::accelerated::BlackboxAcceleratedTileAttention;
8+
use crate::components::{
9+
AttentionElems, AttentionLineSizes, AttentionPartitionSize, AttentionProblem,
10+
AttentionSelection, AttentionSetupError, AttentionStageSize, AttentionTileSize,
11+
AttentionTilingScheme,
12+
};
613
use crate::{
714
components::{
815
AvailableLineSizes, batch::simple::SimpleBatchAttentionFamily,
@@ -27,4 +34,48 @@ impl Algorithm for BlackboxAcceleratedAlgorithm {
2734
fn filter_line_sizes(available_line_sizes: AvailableLineSizes) -> AvailableLineSizes {
2835
Self::TileAttention::filter_line_sizes(available_line_sizes)
2936
}
37+
38+
fn selection<R: cubecl_core::Runtime>(
39+
_client: &ComputeClient<R>,
40+
problem: &AttentionProblem,
41+
plane_dim: u32,
42+
_line_sizes: &AttentionLineSizes,
43+
_dtypes: &AttentionElems,
44+
) -> Result<AttentionSelection, AttentionSetupError> {
45+
#[cfg(target_os = "macos")]
46+
let tile_size = AttentionTileSize {
47+
seq_q: 8,
48+
head_dim: 8,
49+
seq_kv: 8,
50+
val_dim: 8,
51+
};
52+
#[cfg(not(target_os = "macos"))]
53+
let tile_size = AttentionTileSize {
54+
seq_q: 16,
55+
head_dim: 16,
56+
seq_kv: 16,
57+
val_dim: 16,
58+
};
59+
60+
assert!(problem.head_dim as u32 % tile_size.head_dim == 0);
61+
let partition_head_dim = problem.head_dim as u32 / tile_size.head_dim;
62+
let partition_val_dim = partition_head_dim;
63+
64+
Ok(AttentionSelection {
65+
hypercube_selection: HypercubeSelection {},
66+
tiling_scheme: AttentionTilingScheme {
67+
tile_size,
68+
partition_size: AttentionPartitionSize {
69+
seq_q: 1,
70+
head_dim: partition_head_dim,
71+
seq_kv: 1,
72+
val_dim: partition_val_dim,
73+
},
74+
stage_size: AttentionStageSize { seq_q: 1 },
75+
},
76+
plane_dim,
77+
reuse_key_value: false,
78+
two_rows_in_array_tile: false,
79+
})
80+
}
3081
}

crates/cubecl-attention/src/kernels/unit.rs

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,14 @@
1+
use cubecl_core::client::ComputeClient;
12
use cubecl_matmul::components::{global::PartitionedStageFamily, stage::StridedStageFamily};
23

4+
use crate::components::batch::HypercubeSelection;
35
use crate::components::stage::unit::UnitPartitionStageAttentionFamily;
46
use crate::components::tile::unit_register::UnitRegisterTileAttention;
7+
use crate::components::{
8+
AttentionElems, AttentionLineSizes, AttentionPartitionSize, AttentionProblem,
9+
AttentionSelection, AttentionSetupError, AttentionStageSize, AttentionTileSize,
10+
AttentionTilingScheme,
11+
};
512
use crate::{
613
components::{
714
batch::simple::SimpleBatchAttentionFamily, global::simple::SimpleGlobalAttentionFamily,
@@ -21,4 +28,40 @@ impl Algorithm for UnitAlgorithm {
2128
>;
2229
type GlobalAttention = SimpleGlobalAttentionFamily<Self::StageAttention>;
2330
type BatchAttention = SimpleBatchAttentionFamily<Self::GlobalAttention>;
31+
32+
fn selection<R: cubecl_core::Runtime>(
33+
_client: &ComputeClient<R>,
34+
problem: &AttentionProblem,
35+
plane_dim: u32,
36+
_line_sizes: &AttentionLineSizes,
37+
_dtypes: &AttentionElems,
38+
) -> Result<AttentionSelection, AttentionSetupError> {
39+
let tile_size = AttentionTileSize {
40+
seq_q: 4,
41+
head_dim: 4,
42+
seq_kv: 4,
43+
val_dim: 4,
44+
};
45+
46+
assert!(problem.head_dim as u32 % tile_size.head_dim == 0);
47+
let partition_head_dim = problem.head_dim as u32 / tile_size.head_dim;
48+
let partition_val_dim = partition_head_dim;
49+
50+
Ok(AttentionSelection {
51+
hypercube_selection: HypercubeSelection {},
52+
tiling_scheme: AttentionTilingScheme {
53+
tile_size,
54+
partition_size: AttentionPartitionSize {
55+
seq_q: 1,
56+
head_dim: partition_head_dim,
57+
seq_kv: 1,
58+
val_dim: partition_val_dim,
59+
},
60+
stage_size: AttentionStageSize { seq_q: plane_dim },
61+
},
62+
plane_dim,
63+
reuse_key_value: false,
64+
two_rows_in_array_tile: false,
65+
})
66+
}
2467
}

0 commit comments

Comments
 (0)