Skip to content

Commit 0b1abc4

Browse files
authored
Unwrap unrolled loop variables (#955)
1 parent 0303f39 commit 0b1abc4

File tree

34 files changed

+212
-216
lines changed

34 files changed

+212
-216
lines changed

crates/cubecl-attention/src/components/global/dummy/read.rs

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -135,9 +135,11 @@ impl<AP: AttentionPrecision, G: GlobalAttentionConfig> DummyKeyReader<AP, G> {
135135
let index_load = row_load_in_tile * tile_cols_load + col_load;
136136
let index_store = col_load * tile_rows_load + row_load_in_tile;
137137

138-
slice[index_store + store_offset] = Line::cast_from(
139-
view.read_checked(((tile_row_load, tile_col_load), index_load)),
140-
);
138+
slice[index_store + store_offset] =
139+
Line::cast_from(view.read_checked((
140+
(tile_row_load, tile_col_load).runtime(),
141+
index_load,
142+
)));
141143
}
142144
}
143145
}
@@ -210,7 +212,7 @@ impl<AP: AttentionPrecision, G: GlobalAttentionConfig> DummyValueReader<AP, G> {
210212
let index = row_in_tile * tile_cols + col;
211213

212214
slice[index + offset] = Line::cast_from(
213-
view.read_checked(((tile_row, tile_col), index)),
215+
view.read_checked(((tile_row, tile_col).runtime(), index)),
214216
);
215217
}
216218
}

crates/cubecl-attention/src/components/stage/dummy/attention.rs

Lines changed: 7 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -56,50 +56,39 @@ impl<
5656

5757
let p = config.tiling_scheme().partition_size;
5858

59-
let mut kv = comptime![0u32];
60-
6159
let mut max_placeholder = TA::init_max_placeholder(config.num_rows_per_unit());
6260
let mut sum_placeholder = TA::init_sum_placeholder(config.num_rows_per_unit());
6361

6462
#[unroll]
6563
#[allow(clippy::explicit_counter_loop)]
66-
for _ in 0..p.seq_kv {
67-
let mut hd = comptime![0u32];
68-
64+
for kv in 0..p.seq_kv {
6965
#[unroll]
7066
#[allow(clippy::explicit_counter_loop)]
71-
for _ in 0..p.head_dim {
67+
for hd in 0..p.head_dim {
7268
let key_tile = SK::tile(key_reader, (hd, kv).runtime());
7369

7470
TA::fill_key(
7571
&key_tile,
7672
key_value_partition.get_key_at_mut(hd, kv, config),
7773
config.tile_config(),
7874
);
79-
80-
comptime![hd += 1];
8175
}
8276

83-
let mut q = comptime![0u32];
8477
let mut scales = Sequence::<RowWise<SM<AP>>>::new();
8578

8679
#[unroll]
8780
#[allow(clippy::explicit_counter_loop)]
88-
for _ in 0..p.seq_q {
81+
for q in 0..p.seq_q {
8982
let softmax_tile = softmax_partition.get_at_mut(q, kv, config);
9083
TA::zero_softmax(softmax_tile, config.tile_config());
9184

92-
let mut hd = comptime![0u32];
93-
9485
#[unroll]
9586
#[allow(clippy::explicit_counter_loop)]
96-
for _ in 0..p.head_dim {
87+
for hd in 0..p.head_dim {
9788
let query_tile = query_partition.get_at(q, hd, config);
9889
let key_tile = key_value_partition.get_key_at(hd, kv, config);
9990

10091
TA::accumulate_score(query_tile, key_tile, softmax_tile, config.tile_config());
101-
102-
comptime![hd += 1];
10392
}
10493

10594
let state_q = state.index_mut(q);
@@ -113,52 +102,37 @@ impl<
113102
config.tiling_scheme().elements_in_partition_head_dim(),
114103
config.tile_config(),
115104
));
116-
117-
comptime![q += 1];
118105
}
119106

120-
let mut vd = comptime![0u32];
121-
122107
#[unroll]
123108
#[allow(clippy::explicit_counter_loop)]
124-
for _ in 0..p.val_dim {
109+
for vd in 0..p.val_dim {
125110
let value_tile = SV::tile(value_reader, (kv, vd).runtime());
126111

127112
TA::fill_value(
128113
&value_tile,
129114
key_value_partition.get_value_at_mut(kv, vd, config),
130115
config.tile_config(),
131116
);
132-
133-
comptime![vd += 1];
134117
}
135118

136-
let mut q = comptime![0u32];
137-
138119
#[unroll]
139120
#[allow(clippy::explicit_counter_loop)]
140-
for _ in 0..p.seq_q {
141-
let mut vd = comptime![0u32];
121+
for q in 0..p.seq_q {
142122
let softmax_tile = softmax_partition.get_at(q, kv, config);
143123

144124
#[unroll]
145125
#[allow(clippy::explicit_counter_loop)]
146-
for _ in 0..p.val_dim {
126+
for vd in 0..p.val_dim {
147127
TA::accumulate_value(
148128
softmax_tile,
149129
key_value_partition.get_value_at(kv, vd, config),
150130
accumulator_partition.get_at_mut(q, vd, config),
151131
scales.index(q),
152132
config.tile_config(),
153133
);
154-
155-
comptime![vd += 1];
156134
}
157-
158-
comptime![q += 1];
159135
}
160-
161-
comptime![kv += 1];
162136
}
163137
}
164138

crates/cubecl-attention/src/components/tile/dummy/attention_matmul/dummy_register/matmul.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ impl<E: Float> ArrayTile<E> {
106106
self.unit_size.1 * (UNIT_POS_X % self.num_units_per_row) + c
107107
}
108108

109-
fn abs_pos(&self, local_pos: Coords2d) -> Coords2d {
109+
fn abs_pos(&self, #[comptime] local_pos: Coords2d) -> Coords2d {
110110
(
111111
self.abs_row_index(local_pos.0),
112112
self.abs_col_index(local_pos.1),

crates/cubecl-attention/src/components/tile/row/reduce/dummy_reducer.rs

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -25,24 +25,18 @@ impl Reducer for DummyReducer {
2525
let plane_offset = UNIT_POS_Y * num_vals_in_plane;
2626
let unit_offset = UNIT_POS_X;
2727

28-
let mut r = comptime![0u32];
29-
3028
#[unroll]
31-
for _ in 0..config.num_rows_per_unit() {
29+
for r in 0..config.num_rows_per_unit() {
3230
let row_offset = r * config.plane_dim();
3331
let offset = plane_offset + row_offset + unit_offset;
3432

3533
smem[offset] = local_vals.index(r);
36-
37-
comptime![r += 1];
3834
}
3935

4036
sync_cube();
4137

42-
let mut r = comptime![0u32];
43-
4438
#[unroll]
45-
for _ in 0..config.num_rows_per_unit() {
39+
for r in 0..config.num_rows_per_unit() {
4640
let mut val = vals.index(r);
4741

4842
let row_offset = r * config.plane_dim();
@@ -56,8 +50,6 @@ impl Reducer for DummyReducer {
5650
}
5751

5852
vals.replace_at(r, val);
59-
60-
comptime![r += 1];
6153
}
6254

6355
sync_cube();

crates/cubecl-attention/src/components/tile/row/rowwise.rs

Lines changed: 10 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -33,13 +33,10 @@ impl<E: Float> RowWise<E> {
3333
}
3434

3535
pub fn copy_from(&mut self, other: &RowWise<E>) {
36-
let mut i = comptime![0u32];
3736
#[unroll]
38-
for _ in 0..self.num_rows {
37+
for i in 0..self.num_rows {
3938
let row_val = self.vals.index_mut(i);
4039
row_val.val = other.index(i);
41-
42-
comptime![i += 1];
4340
}
4441
}
4542

@@ -48,57 +45,42 @@ impl<E: Float> RowWise<E> {
4845
}
4946

5047
pub fn fill(&mut self, val: E) {
51-
let mut i = comptime![0u32];
5248
#[unroll]
53-
for _ in 0..self.num_rows {
49+
for i in 0..self.num_rows {
5450
let row_val = self.vals.index_mut(i);
5551
row_val.val = val;
56-
57-
comptime![i += 1];
5852
}
5953
}
6054

6155
pub fn add_inplace(&mut self, other: &RowWise<E>) {
62-
let mut i = comptime![0u32];
6356
#[unroll]
64-
for _ in 0..self.num_rows {
57+
for i in 0..self.num_rows {
6558
let row_val = self.vals.index_mut(i);
6659
row_val.val += other.index(i);
67-
68-
comptime![i += 1];
6960
}
7061
}
7162

7263
pub fn mul_inplace(&mut self, other: &RowWise<E>) {
73-
let mut i = comptime![0u32];
7464
#[unroll]
75-
for _ in 0..self.num_rows {
65+
for i in 0..self.num_rows {
7666
let row_val = self.vals.index_mut(i);
7767
row_val.val *= other.index(i);
78-
79-
comptime![i += 1];
8068
}
8169
}
8270

8371
pub fn recip_inplace(&mut self) {
84-
let mut i = comptime![0u32];
8572
#[unroll]
86-
for _ in 0..self.num_rows {
73+
for i in 0..self.num_rows {
8774
let row_val = self.vals.index_mut(i);
8875
row_val.val = Recip::recip(row_val.val);
89-
90-
comptime![i += 1];
9176
}
9277
}
9378

9479
pub fn max_inplace(&mut self, other: &RowWise<E>) {
95-
let mut i = comptime![0u32];
9680
#[unroll]
97-
for _ in 0..self.num_rows {
81+
for i in 0..self.num_rows {
9882
let row_val = self.vals.index_mut(i);
9983
row_val.val = Max::max(row_val.val, other.index(i));
100-
101-
comptime![i += 1];
10284
}
10385
}
10486

@@ -109,14 +91,11 @@ impl<E: Float> RowWise<E> {
10991

11092
pub fn cast_from<E2: Float>(&self) -> RowWise<E2> {
11193
let mut vals = Sequence::new();
112-
let mut i = comptime![0u32];
11394

11495
#[unroll]
115-
for _ in 0..self.num_rows {
96+
for i in 0..self.num_rows {
11697
let val = E2::cast_from(self.index(i));
11798
vals.push(RowVal::<E2> { val });
118-
119-
comptime![i += 1];
12099
}
121100

122101
RowWise::<E2> {
@@ -127,14 +106,11 @@ impl<E: Float> RowWise<E> {
127106

128107
pub fn exp_m_diff(&self, other: &RowWise<E>) -> RowWise<E> {
129108
let mut vals = Sequence::new();
130-
let mut i = comptime![0u32];
131109

132110
#[unroll]
133-
for _ in 0..self.num_rows {
111+
for i in 0..self.num_rows {
134112
let val = Exp::exp(self.index(i) - other.index(i));
135113
vals.push(RowVal::<E> { val });
136-
137-
comptime![i += 1];
138114
}
139115

140116
RowWise::<E> {
@@ -145,14 +121,11 @@ impl<E: Float> RowWise<E> {
145121

146122
pub fn mul(&self, other: &RowWise<E>) -> RowWise<E> {
147123
let mut vals = Sequence::new();
148-
let mut i = comptime![0u32];
149124

150125
#[unroll]
151-
for _ in 0..self.num_rows {
126+
for i in 0..self.num_rows {
152127
let val = self.index(i) * other.index(i);
153128
vals.push(RowVal::<E> { val });
154-
155-
comptime![i += 1];
156129
}
157130

158131
RowWise::<E> {
@@ -163,14 +136,11 @@ impl<E: Float> RowWise<E> {
163136

164137
pub fn add(&self, other: &RowWise<E>) -> RowWise<E> {
165138
let mut vals = Sequence::new();
166-
let mut i = comptime![0u32];
167139

168140
#[unroll]
169-
for _ in 0..self.num_rows {
141+
for i in 0..self.num_rows {
170142
let val = self.index(i) + other.index(i);
171143
vals.push(RowVal::<E> { val });
172-
173-
comptime![i += 1];
174144
}
175145

176146
RowWise::<E> {

crates/cubecl-convolution/src/components/global/layout/im2col.rs

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,7 @@ use cubecl_std::{
1212
use crate::{
1313
components::{
1414
ConvGemmConfig, ConvolutionConfig, ConvolutionParams, ConvolutionProblem,
15-
global::{
16-
layout::{NhwcCoords, unwrap},
17-
read::im2col_tma::div_mod_seq,
18-
},
15+
global::{layout::NhwcCoords, read::im2col_tma::div_mod_seq},
1916
},
2017
kernels::layered::selector::RuntimeArgs,
2118
};
@@ -80,7 +77,6 @@ impl Layout for Im2colLayout {
8077

8178
#[unroll]
8279
for i in 0..spatial_dims {
83-
let i = unwrap(i);
8480
let dim = comptime![spatial_dims - i - 1];
8581
let ksize = comptime![params.kernel_size[dim as usize]];
8682
let k_pos = rem % ksize;

crates/cubecl-convolution/src/components/global/layout/spatial.rs

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
use cubecl::prelude::*;
2-
use cubecl_core::{self as cubecl, intrinsic};
2+
use cubecl_core::{self as cubecl};
33
use cubecl_std::tensor::{
44
layout::{Coordinates, Coords1d, Layout, LayoutExpand},
55
r#virtual::VirtualTensor,
@@ -154,7 +154,6 @@ impl Layout for NhwcLayout {
154154

155155
#[unroll]
156156
for i in 0..spatial_dims {
157-
let i = unwrap(i);
158157
read_pos += *spatial.index(i) as u32 * *self.strides_spatial.index(i);
159158
}
160159

@@ -172,7 +171,6 @@ impl Layout for NhwcLayout {
172171

173172
#[unroll]
174173
for i in 0..spatial_dims {
175-
let i = unwrap(i);
176174
let pos = *pos.spatial.index(i);
177175
spatial_in_bounds &= pos >= 0 && (pos as u32) < *self.shapes_spatial.index(i);
178176
}
@@ -192,12 +190,6 @@ impl Layout for NhwcLayout {
192190
}
193191
}
194192

195-
#[allow(unused_variables)]
196-
#[cube]
197-
pub(crate) fn unwrap(v: u32) -> comptime_type!(u32) {
198-
intrinsic!(|_| v.constant().expect("Must be constant").as_u32())
199-
}
200-
201193
#[cube]
202194
pub(crate) fn cast_seq<From: CubePrimitive, To: CubePrimitive>(
203195
seq: Sequence<From>,
@@ -206,7 +198,6 @@ pub(crate) fn cast_seq<From: CubePrimitive, To: CubePrimitive>(
206198
let mut out_seq = Sequence::new();
207199
#[unroll]
208200
for i in 0..num_elems {
209-
let i = unwrap(i);
210201
let elem = To::cast_from(*seq.index(i));
211202
out_seq.push(elem);
212203
}

0 commit comments

Comments
 (0)