Skip to content

Commit c47fc84

Browse files
Add shared state map (#934)
1 parent c85f511 commit c47fc84

File tree

5 files changed

+144
-29
lines changed

5 files changed

+144
-29
lines changed

crates/cubecl-common/src/bytes/default_controller.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
33
use crate::bytes::{AllocationController, AllocationError};
44
use alloc::alloc::Layout;
5+
use alloc::vec::Vec;
56
use bytemuck::Contiguous;
67
use core::{alloc::LayoutError, marker::PhantomData, mem::MaybeUninit, ptr::NonNull};
78

crates/cubecl-common/src/lib.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@ pub mod rand;
1515
/// Device module.
1616
pub mod device;
1717

18+
/// Map utilities and implementations.
19+
pub mod map;
20+
1821
/// Utilities module to manipulate bytes.
1922
#[cfg(feature = "serde")]
2023
pub mod bytes;

crates/cubecl-common/src/map.rs

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
use crate::stub::{Mutex, RwLock, RwLockReadGuard, RwLockWriteGuard};
2+
use alloc::sync::Arc;
3+
use hashbrown::HashMap;
4+
5+
/// A thread-safe map that allows concurrent access to values using read-write locks.
6+
pub struct SharedStateMap<K, V> {
7+
state: Mutex<Option<HashMap<K, Arc<RwLock<V>>>>>,
8+
}
9+
10+
/// A value in the [SharedStateMap] that provides read and write access.
11+
pub struct SharedState<V> {
12+
val: Arc<RwLock<V>>,
13+
}
14+
15+
impl<V> SharedState<V> {
16+
/// Acquires a read lock on the value, returning a read guard.
17+
pub fn read(&self) -> RwLockReadGuard<'_, V> {
18+
self.val.read().unwrap()
19+
}
20+
21+
/// Acquires a write lock on the value, returning a write guard.
22+
pub fn write(&self) -> RwLockWriteGuard<'_, V> {
23+
self.val.write().unwrap()
24+
}
25+
}
26+
27+
impl<K, V> Default for SharedStateMap<K, V>
28+
where
29+
K: core::hash::Hash + core::cmp::PartialEq + core::cmp::Eq,
30+
{
31+
fn default() -> Self {
32+
Self::new()
33+
}
34+
}
35+
36+
impl<K, V> SharedStateMap<K, V>
37+
where
38+
K: core::hash::Hash + core::cmp::PartialEq + core::cmp::Eq,
39+
{
40+
/// Creates a new, empty `SharedStateMap`.
41+
pub const fn new() -> Self {
42+
Self {
43+
state: Mutex::new(None),
44+
}
45+
}
46+
47+
/// Retrieves a value associated with the given key, if it exists.
48+
pub fn get(&self, k: &K) -> Option<SharedState<V>> {
49+
let mut state = self.state.lock().unwrap();
50+
let map = get_or_init(&mut state);
51+
52+
match map.get(k) {
53+
Some(val) => Some(SharedState { val: val.clone() }),
54+
None => None,
55+
}
56+
}
57+
58+
/// Retrieves a value associated with the given key, or inserts a new value using the provided
59+
/// initializer function if the key does not exist.
60+
pub fn get_or_init<Fn: FnMut(&K) -> V>(&self, k: &K, mut init: Fn) -> SharedState<V>
61+
where
62+
K: Clone,
63+
{
64+
let mut state = self.state.lock().unwrap();
65+
let map = get_or_init(&mut state);
66+
67+
match map.get(k) {
68+
Some(val) => SharedState { val: val.clone() },
69+
None => {
70+
let val = init(k);
71+
let val = Arc::new(RwLock::new(val));
72+
map.insert(k.clone(), val.clone());
73+
SharedState { val: val.clone() }
74+
}
75+
}
76+
}
77+
78+
/// Inserts a key-value pair into the map.
79+
pub fn insert(&self, k: K, v: V) {
80+
let mut state = self.state.lock().unwrap();
81+
let map = get_or_init(&mut state);
82+
83+
map.insert(k, Arc::new(RwLock::new(v)));
84+
}
85+
86+
/// Clears the map, removing all key-value pairs.
87+
pub fn clear(&self) {
88+
let mut state = self.state.lock().unwrap();
89+
let map = get_or_init(&mut state);
90+
map.clear();
91+
}
92+
}
93+
94+
fn get_or_init<T: Default>(state: &mut Option<T>) -> &mut T {
95+
match state {
96+
Some(state) => state,
97+
None => {
98+
*state = Some(T::default());
99+
state.as_mut().unwrap()
100+
}
101+
}
102+
}

crates/cubecl-common/src/stub.rs

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
11
#[cfg(not(feature = "std"))]
2-
use spin::{
3-
Mutex as MutexImported, MutexGuard, Once as OnceImported, RwLock as RwLockImported,
4-
RwLockReadGuard, RwLockWriteGuard,
5-
};
2+
use spin::{Mutex as MutexImported, MutexGuard, Once as OnceImported, RwLock as RwLockImported};
63
#[cfg(feature = "std")]
74
use std::sync::{
85
Mutex as MutexImported, MutexGuard, OnceLock as OnceImported, RwLock as RwLockImported,
9-
RwLockReadGuard, RwLockWriteGuard,
106
};
117

8+
#[cfg(not(feature = "std"))]
9+
pub use spin::{RwLockReadGuard, RwLockWriteGuard};
10+
#[cfg(feature = "std")]
11+
pub use std::sync::{RwLockReadGuard, RwLockWriteGuard};
12+
1213
/// A mutual exclusion primitive useful for protecting shared data
1314
///
1415
/// This mutex will block threads waiting for the lock to become available. The

crates/cubecl-runtime/src/tune/local.rs

Lines changed: 32 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ use core::{
99
fmt::Display,
1010
hash::Hash,
1111
};
12+
use cubecl_common::map::SharedStateMap;
1213
use hashbrown::HashMap;
1314

1415
#[cfg(not(feature = "std"))]
@@ -17,7 +18,7 @@ use alloc::string::ToString;
1718
/// A local tuner allows to create a tuner for a specific key that can be different from the server
1819
/// key.
1920
pub struct LocalTuner<AK: AutotuneKey, ID> {
20-
state: spin::Mutex<Option<HashMap<ID, Tuner<AK>>>>,
21+
state: SharedStateMap<ID, Tuner<AK>>,
2122
name: &'static str,
2223
sets: spin::RwLock<Option<HashMap<TypeId, Arc<dyn Any + Send + Sync>>>>,
2324
}
@@ -45,7 +46,7 @@ where
4546
/// Create a new local tuner.
4647
pub const fn new(name: &'static str) -> Self {
4748
Self {
48-
state: spin::Mutex::new(None),
49+
state: SharedStateMap::new(),
4950
name,
5051
sets: spin::RwLock::new(None),
5152
}
@@ -94,8 +95,7 @@ where
9495

9596
/// Clear the autotune state.
9697
pub fn clear(&self) {
97-
let mut state = self.state.lock();
98-
*state = None;
98+
self.state.clear()
9999
}
100100

101101
#[cfg(feature = "autotune-checks")]
@@ -131,20 +131,16 @@ where
131131

132132
// If this is cached and ready, use the operation.
133133
let autotune_job = {
134-
let mut state = self.state.lock();
135-
let map = state.get_or_insert_with(Default::default);
136-
137-
let tuner = match map.get_mut(id) {
138-
Some(val) => val,
139-
None => map.entry(id.clone()).or_insert_with(move || {
140-
let name = self.name.replace("::", "-");
141-
Tuner::new(&name, &id.to_string())
142-
}),
143-
};
134+
let tuner_state = self.state.get_or_init(id, move |id| {
135+
let name = self.name.replace("::", "-");
136+
Tuner::new(&name, &id.to_string())
137+
});
138+
let tuner = tuner_state.read();
144139

145-
match tuner.fastest(&key) {
140+
let mut tuner = match tuner.fastest(&key) {
146141
TuneCacheResult::Hit { fastest_index } => {
147-
core::mem::drop(state);
142+
core::mem::drop(tuner);
143+
core::mem::drop(tuner_state);
148144

149145
#[cfg(feature = "autotune-checks")]
150146
self.checks(&operations, &inputs);
@@ -156,7 +152,8 @@ where
156152
return result;
157153
}
158154
TuneCacheResult::Pending => {
159-
core::mem::drop(state);
155+
core::mem::drop(tuner);
156+
core::mem::drop(tuner_state);
160157

161158
let op = operations.fastest(0);
162159
let result = op
@@ -166,26 +163,38 @@ where
166163
}
167164
#[cfg(std_io)]
168165
TuneCacheResult::Unchecked => {
166+
core::mem::drop(tuner);
167+
let mut tuner = tuner_state.write();
168+
169169
// If the cache checksum hasn't been checked, do so now, and retry.
170170
let checksum = operations.compute_checksum();
171171
tuner.validate_checksum(&key, &checksum);
172172

173173
// Check if with validation we can use its result
174174
if let TuneCacheResult::Hit { fastest_index } = tuner.fastest(&key) {
175-
core::mem::drop(state);
175+
core::mem::drop(tuner);
176+
core::mem::drop(tuner_state);
176177

177178
let op = operations.fastest(fastest_index);
178179
let result = op
179180
.execute(inputs)
180181
.expect("Should run when selected by autotune.");
181182
return result;
182183
}
184+
185+
tuner
183186
}
184187

185188
#[cfg(not(std_io))]
186-
TuneCacheResult::Unchecked => (),
187-
TuneCacheResult::Miss => (),
188-
}
189+
TuneCacheResult::Unchecked => {
190+
core::mem::drop(tuner);
191+
tuner_state.write()
192+
}
193+
TuneCacheResult::Miss => {
194+
core::mem::drop(tuner);
195+
tuner_state.write()
196+
}
197+
};
189198

190199
if tuner.autotuning.contains(&key) {
191200
Box::new(move || {})
@@ -198,9 +207,8 @@ where
198207
autotune_job();
199208

200209
let index_to_run = {
201-
let mut state = self.state.lock();
202-
let map = state.as_mut().unwrap();
203-
let tuner = map.get_mut(id).unwrap();
210+
let tuner_state = self.state.get(id).unwrap();
211+
let mut tuner = tuner_state.write();
204212

205213
tuner.handle_results();
206214

0 commit comments

Comments
 (0)