Skip to content

Commit 7dcb116

Browse files
Fix/autotuner (#1062)
* Improve error messages when no valid kernel is found * Fix autotuner * Clippy * Cleanup
1 parent 15e85a9 commit 7dcb116

File tree

6 files changed

+197
-86
lines changed

6 files changed

+197
-86
lines changed

crates/cubecl-common/src/format.rs

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
use alloc::format;
2+
use alloc::string::String;
3+
4+
/// Format strings for use in identifiers and types.
5+
pub fn format_str(string: &str, markers: &[(char, char)], include_space: bool) -> String {
6+
let mut result = String::new();
7+
let mut depth = 0;
8+
let indentation = 4;
9+
10+
let mut prev = ' ';
11+
let mut in_string = false;
12+
13+
for c in string.chars() {
14+
if c == ' ' {
15+
if in_string {
16+
result.push(c);
17+
}
18+
19+
continue;
20+
}
21+
if c == '"' {
22+
in_string = !in_string;
23+
}
24+
25+
let mut found_marker = false;
26+
27+
for (start, end) in markers {
28+
let (start, end) = (*start, *end);
29+
30+
if c == start {
31+
depth += 1;
32+
if prev != ' ' && include_space {
33+
result.push(' ');
34+
}
35+
result.push(start);
36+
result.push('\n');
37+
result.push_str(&" ".repeat(indentation * depth));
38+
found_marker = true;
39+
} else if c == end {
40+
depth -= 1;
41+
if prev != start {
42+
if prev == ' ' {
43+
result.pop();
44+
}
45+
result.push_str(",\n");
46+
result.push_str(&" ".repeat(indentation * depth));
47+
result.push(end);
48+
} else {
49+
for _ in 0..(&" ".repeat(indentation * depth).len()) + 1 + indentation {
50+
result.pop();
51+
}
52+
result.push(end);
53+
}
54+
found_marker = true;
55+
}
56+
}
57+
58+
if found_marker {
59+
prev = c;
60+
continue;
61+
}
62+
63+
if c == ',' && depth > 0 {
64+
if prev == ' ' {
65+
result.pop();
66+
}
67+
68+
result.push_str(",\n");
69+
result.push_str(&" ".repeat(indentation * depth));
70+
continue;
71+
}
72+
73+
if c == ':' && include_space {
74+
result.push(c);
75+
result.push(' ');
76+
prev = ' ';
77+
} else {
78+
result.push(c);
79+
prev = c;
80+
}
81+
}
82+
83+
result
84+
}
85+
86+
/// Format a debug type.
87+
pub fn format_debug<F: core::fmt::Debug>(string: &F) -> String {
88+
let string = format!("{string:?}");
89+
format_str(&string, &[('(', ')'), ('[', ']'), ('{', '}')], true)
90+
}
91+
92+
#[cfg(test)]
93+
mod tests {
94+
use hashbrown::HashMap;
95+
96+
use super::*;
97+
98+
#[derive(Debug)]
99+
#[allow(unused)]
100+
struct Test {
101+
map: HashMap<String, u32>,
102+
}
103+
104+
#[test]
105+
fn test_format_debug() {
106+
let test = Test {
107+
map: HashMap::from_iter([("Hey with space".to_string(), 8)].into_iter()),
108+
};
109+
110+
let formatted = format_debug(&test);
111+
let expected = r#"Test {
112+
map: {
113+
"Hey with space": 8,
114+
},
115+
}"#;
116+
117+
assert_eq!(expected, formatted);
118+
}
119+
}

crates/cubecl-common/src/lib.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,9 @@ pub mod future;
5151
/// Quantization primitives required outside of `cubecl-quant`
5252
pub mod quant;
5353

54+
/// Format utilities.
55+
pub mod format;
56+
5457
/// Various utilities to create ID's.
5558
extern crate alloc;
5659

crates/cubecl-core/src/compute/kernel.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,11 @@ use std::{
55
};
66

77
use crate::{Compiler, KernelOptions};
8-
use cubecl_common::{CubeDim, ExecutionMode};
8+
use cubecl_common::{CubeDim, ExecutionMode, format::format_str};
99
use cubecl_ir::{Id, Scope, StorageType, Type};
1010
use cubecl_runtime::{
1111
config::{GlobalConfig, compilation::CompilationLogLevel},
12-
id::{KernelId, format_str},
12+
id::KernelId,
1313
kernel::KernelMetadata,
1414
};
1515
use serde::{Deserialize, Serialize};

crates/cubecl-runtime/src/id.rs

Lines changed: 1 addition & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
use alloc::format;
22
use alloc::string::String;
3-
use alloc::string::ToString;
43
use alloc::sync::Arc;
54
use core::{
65
any::{Any, TypeId},
76
fmt::Display,
87
hash::{BuildHasher, Hash, Hasher},
98
};
109
use cubecl_common::ExecutionMode;
10+
use cubecl_common::format::format_str;
1111

1212
#[macro_export(local_inner_macros)]
1313
/// Create a new storage ID type.
@@ -208,81 +208,6 @@ impl PartialEq for KernelId {
208208

209209
impl Eq for KernelId {}
210210

211-
/// Format strings for use in identifiers and types.
212-
pub fn format_str(kernel_id: &str, markers: &[(char, char)], include_space: bool) -> String {
213-
let kernel_id = kernel_id.to_string();
214-
let mut result = String::new();
215-
let mut depth = 0;
216-
let indentation = 4;
217-
218-
let mut prev = ' ';
219-
220-
for c in kernel_id.chars() {
221-
if c == ' ' {
222-
continue;
223-
}
224-
225-
let mut found_marker = false;
226-
227-
for (start, end) in markers {
228-
let (start, end) = (*start, *end);
229-
230-
if c == start {
231-
depth += 1;
232-
if prev != ' ' && include_space {
233-
result.push(' ');
234-
}
235-
result.push(start);
236-
result.push('\n');
237-
result.push_str(&" ".repeat(indentation * depth));
238-
found_marker = true;
239-
} else if c == end {
240-
depth -= 1;
241-
if prev != start {
242-
if prev == ' ' {
243-
result.pop();
244-
}
245-
result.push_str(",\n");
246-
result.push_str(&" ".repeat(indentation * depth));
247-
result.push(end);
248-
} else {
249-
for _ in 0..(&" ".repeat(indentation * depth).len()) + 1 + indentation {
250-
result.pop();
251-
}
252-
result.push(end);
253-
}
254-
found_marker = true;
255-
}
256-
}
257-
258-
if found_marker {
259-
prev = c;
260-
continue;
261-
}
262-
263-
if c == ',' && depth > 0 {
264-
if prev == ' ' {
265-
result.pop();
266-
}
267-
268-
result.push_str(",\n");
269-
result.push_str(&" ".repeat(indentation * depth));
270-
continue;
271-
}
272-
273-
if c == ':' && include_space {
274-
result.push(c);
275-
result.push(' ');
276-
prev = ' ';
277-
} else {
278-
result.push(c);
279-
prev = c;
280-
}
281-
}
282-
283-
result
284-
}
285-
286211
impl Display for KernelId {
287212
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
288213
match &self.info {

crates/cubecl-runtime/src/tune/base.rs

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,12 @@ pub struct TuneGroup<K> {
4848
pub(crate) priority: PriorityFunc<K>,
4949
}
5050

51+
impl<K> core::fmt::Debug for TuneGroup<K> {
52+
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
53+
f.debug_struct("TuneGroup").field("id", &self.id).finish()
54+
}
55+
}
56+
5157
impl<K> Clone for TuneGroup<K> {
5258
fn clone(&self) -> Self {
5359
Self {
@@ -83,9 +89,12 @@ struct GroupPlan {
8389
indices: HashMap<i8, Vec<usize>>,
8490
}
8591

92+
#[derive(Debug)]
8693
struct Cleanup {
8794
groups: Vec<i8>,
8895
tunables: Vec<(i8, i8)>,
96+
/// Within group priority is too low to even try.
97+
skipped: bool,
8998
}
9099

91100
impl TunePlan {
@@ -154,13 +163,22 @@ impl TunePlan {
154163
};
155164

156165
let (mut group_indices, cleanup) = self.group_plan_next(priority);
166+
// Some entries are skipped for this round of prioritizing.
167+
let skipped = cleanup.skipped || priority < 0;
168+
157169
self.cleanup(cleanup);
158170

159171
if priority >= 0 {
160172
indices.append(&mut group_indices);
161173
}
162174

163-
indices
175+
// The indices list is empty, but it doesn't mean we should stop
176+
// autotuning, since some entries were skipped.
177+
if indices.is_empty() && skipped {
178+
self.next()
179+
} else {
180+
indices
181+
}
164182
}
165183

166184
fn cleanup(&mut self, cleanup: Cleanup) {
@@ -231,6 +249,7 @@ impl TunePlan {
231249
Cleanup {
232250
groups: cleanup_groups,
233251
tunables: cleanup_tunables,
252+
skipped: within_group_prio < 0,
234253
},
235254
)
236255
}
@@ -320,6 +339,24 @@ mod tests {
320339
assert!(plan.next().is_empty());
321340
}
322341

342+
#[test]
343+
fn test_plan_negative_priority() {
344+
let group0 = TuneGroup::<FakeAutotuneKey>::new(|_| 2);
345+
let group1 = TuneGroup::<FakeAutotuneKey>::new(|_| 1);
346+
347+
let tunable0 = Tunable::<FakeAutotuneKey, (), ()>::new(fake_kernel);
348+
let tunable1 = Tunable::<FakeAutotuneKey, (), ()>::new(fake_kernel).group(&group0, |_| -1);
349+
let tunable2 = Tunable::<FakeAutotuneKey, (), ()>::new(fake_kernel).group(&group0, |_| 2);
350+
let tunable3 = Tunable::<FakeAutotuneKey, (), ()>::new(fake_kernel).group(&group1, |_| 2);
351+
352+
let key = FakeAutotuneKey;
353+
let mut plan = TunePlan::new(&key, &[tunable0, tunable1, tunable2, tunable3]);
354+
355+
assert_eq!(plan.next(), vec![0, 2]);
356+
assert_eq!(plan.next(), vec![3]);
357+
assert!(plan.next().is_empty());
358+
}
359+
323360
#[test]
324361
fn test_plan_no_group() {
325362
let tunable0 = Tunable::<FakeAutotuneKey, (), ()>::new(fake_kernel);

0 commit comments

Comments
 (0)