Skip to content

Commit feab17f

Browse files
Switched to thread local runtime
1 parent 388a9a3 commit feab17f

File tree

12 files changed

+281
-753
lines changed

12 files changed

+281
-753
lines changed

crates/cubecl-runtime/src/channel/base.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,9 @@ use alloc::vec::Vec;
1010

1111
/// The ComputeChannel trait links the ComputeClient to the ComputeServer
1212
/// while ensuring thread-safety
13-
pub trait ComputeChannel<Server: ComputeServer>: Clone + core::fmt::Debug + Send + Sync {
13+
pub trait ComputeChannel<Server: ComputeServer>: Clone + core::fmt::Debug {
1414
/// Given a binding, returns owned resource as bytes
15-
fn read(&self, binding: Binding) -> impl Future<Output = Vec<u8>> + Send;
15+
fn read(&self, binding: Binding) -> impl Future<Output = Vec<u8>>;
1616

1717
/// Given a resource handle, return the storage resource.
1818
fn get_resource(&self, binding: Binding) -> BindingResource<Server>;
@@ -40,12 +40,12 @@ pub trait ComputeChannel<Server: ComputeServer>: Clone + core::fmt::Debug + Send
4040
fn flush(&self);
4141

4242
/// Wait for the completion of every task in the server.
43-
fn sync(&self) -> impl Future<Output = ()> + Send;
43+
fn sync(&self) -> impl Future<Output = ()>;
4444

4545
/// Wait for the completion of every task in the server.
4646
///
4747
/// Returns the (approximate) total amount of GPU work done since the last sync.
48-
fn sync_elapsed(&self) -> impl Future<Output = TimestampsResult> + Send;
48+
fn sync_elapsed(&self) -> impl Future<Output = TimestampsResult>;
4949

5050
/// Get the current memory usage of the server.
5151
fn memory_usage(&self) -> crate::memory_management::MemoryUsage;

crates/cubecl-runtime/src/channel/cell.rs

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ where
4242

4343
impl<Server> ComputeChannel<Server> for RefCellComputeChannel<Server>
4444
where
45-
Server: ComputeServer + Send,
45+
Server: ComputeServer,
4646
{
4747
async fn read(&self, binding: Binding) -> Vec<u8> {
4848
let future = {
@@ -108,8 +108,3 @@ where
108108
self.server.borrow_mut().disable_timestamps();
109109
}
110110
}
111-
112-
/// This is unsafe, since no concurrency is supported by the `RefCell` channel.
113-
/// However using this channel should only be done in single threaded environments such as `no-std`.
114-
unsafe impl<Server: ComputeServer> Send for RefCellComputeChannel<Server> {}
115-
unsafe impl<Server: ComputeServer> Sync for RefCellComputeChannel<Server> {}

crates/cubecl-runtime/src/channel/mpsc.rs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ use super::ComputeChannel;
66
use crate::{
77
memory_management::MemoryUsage,
88
server::{Binding, ComputeServer, CubeCount, Handle},
9-
storage::BindingResource,
9+
storage::{BindingResource, ComputeStorage},
1010
ExecutionMode,
1111
};
1212

@@ -50,7 +50,8 @@ where
5050

5151
impl<Server> MpscComputeChannel<Server>
5252
where
53-
Server: ComputeServer + 'static,
53+
Server: ComputeServer + Send + 'static,
54+
<Server::Storage as ComputeStorage>::Resource: Send,
5455
{
5556
/// Create a new mpsc compute channel.
5657
pub fn new(mut server: Server) -> Self {
@@ -123,6 +124,7 @@ impl<Server: ComputeServer> Clone for MpscComputeChannel<Server> {
123124
impl<Server> ComputeChannel<Server> for MpscComputeChannel<Server>
124125
where
125126
Server: ComputeServer + 'static,
127+
<Server::Storage as ComputeStorage>::Resource: Send,
126128
{
127129
async fn read(&self, binding: Binding) -> Vec<u8> {
128130
let sender = self.state.sender.clone();

crates/cubecl-runtime/src/channel/mutex.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ where
3535

3636
impl<Server> ComputeChannel<Server> for MutexComputeChannel<Server>
3737
where
38-
Server: ComputeServer,
38+
Server: ComputeServer + Send,
3939
{
4040
async fn read(&self, handle: Binding) -> Vec<u8> {
4141
// Nb: The order here is really important - the mutex guard has to be dropped before

crates/cubecl-runtime/src/server.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ use cubecl_common::benchmark::TimestampsResult;
1414
///
1515
/// Everything in the server is mutable, therefore it should be solely accessed through the
1616
/// [compute channel](crate::channel::ComputeChannel) for thread safety.
17-
pub trait ComputeServer: Send + core::fmt::Debug
17+
pub trait ComputeServer: core::fmt::Debug
1818
where
1919
Self: Sized,
2020
{
@@ -26,7 +26,7 @@ where
2626
type Feature: Ord + Copy + Debug + Send + Sync;
2727

2828
/// Given a handle, returns the owned resource as bytes.
29-
fn read(&mut self, binding: Binding) -> impl Future<Output = Vec<u8>> + Send + 'static;
29+
fn read(&mut self, binding: Binding) -> impl Future<Output = Vec<u8>> + 'static;
3030

3131
/// Given a resource handle, returns the storage resource.
3232
fn get_resource(&mut self, binding: Binding) -> BindingResource<Self>;
@@ -57,12 +57,12 @@ where
5757
fn flush(&mut self);
5858

5959
/// Wait for the completion of every task in the server.
60-
fn sync(&mut self) -> impl Future<Output = ()> + Send + 'static;
60+
fn sync(&mut self) -> impl Future<Output = ()> + 'static;
6161

6262
/// Wait for the completion of every task in the server.
6363
///
6464
/// Returns the (approximate) total amount of GPU work done since the last sync.
65-
fn sync_elapsed(&mut self) -> impl Future<Output = TimestampsResult> + Send + 'static;
65+
fn sync_elapsed(&mut self) -> impl Future<Output = TimestampsResult> + 'static;
6666

6767
/// The current memory usage of the server.
6868
fn memory_usage(&self) -> MemoryUsage;

crates/cubecl-runtime/src/storage/base.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,10 +63,10 @@ impl StorageHandle {
6363
}
6464

6565
/// Storage types are responsible for allocating and deallocating memory.
66-
pub trait ComputeStorage: Send {
66+
pub trait ComputeStorage {
6767
/// The resource associated type determines the way data is implemented and how
6868
/// it can be accessed by kernels.
69-
type Resource: Send;
69+
type Resource;
7070

7171
/// The alignment memory is allocated with in this storage.
7272
const ALIGNMENT: u64;

crates/cubecl-runtime/src/tune/tuner.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -224,7 +224,7 @@ impl<K: AutotuneKey> Tuner<K> {
224224
}
225225
}
226226

227-
fn spawn_benchmark_task(future: impl Future<Output = ()> + Send + 'static) {
227+
fn spawn_benchmark_task(future: impl Future<Output = ()> + 'static) {
228228
// On wasm, spawn the tuning as a detached task.
229229
#[cfg(target_family = "wasm")]
230230
wasm_bindgen_futures::spawn_local(future);

crates/cubecl-wgpu/src/compiler/base.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,17 +4,17 @@ use cubecl_core::{
44
use cubecl_runtime::DeviceProperties;
55
use wgpu::{Adapter, ComputePipeline, Device, Queue};
66

7-
use crate::{Pdrc, WgpuServer, WgpuServerInner};
7+
use crate::{Pdrc, WgpuServer};
88

99
pub trait WgpuCompiler: Compiler {
1010
fn compile(
11-
server: &mut WgpuServerInner<Self>,
11+
server: &mut WgpuServer<Self>,
1212
kernel: <WgpuServer<Self> as ComputeServer>::Kernel,
1313
mode: ExecutionMode,
1414
) -> CompiledKernel<Self>;
1515

1616
fn create_pipeline(
17-
server: &mut WgpuServerInner<Self>,
17+
server: &mut WgpuServer<Self>,
1818
kernel: CompiledKernel<Self>,
1919
mode: ExecutionMode,
2020
) -> Pdrc<ComputePipeline>;

crates/cubecl-wgpu/src/compiler/wgsl/compiler.rs

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,8 @@ use super::{shader::ComputeShader, ConstantArray, Item, SharedMemory};
44
use super::{LocalArray, Subgroup};
55
use crate::{
66
compiler::{base::WgpuCompiler, wgsl},
7-
WgpuServer,
7+
Pdrc, WgpuServer,
88
};
9-
use crate::{Pdrc, WgpuServerInner};
109
use cubecl_core::{
1110
ir::{self as cube, HybridAllocator, UIntKind},
1211
prelude::CompiledKernel,
@@ -71,7 +70,7 @@ impl cubecl_core::Compiler for WgslCompiler {
7170

7271
impl WgpuCompiler for WgslCompiler {
7372
fn create_pipeline(
74-
server: &mut WgpuServerInner<Self>,
73+
server: &mut WgpuServer<Self>,
7574
kernel: CompiledKernel<Self>,
7675
mode: ExecutionMode,
7776
) -> Pdrc<ComputePipeline> {
@@ -137,7 +136,7 @@ impl WgpuCompiler for WgslCompiler {
137136
}
138137

139138
fn compile(
140-
_server: &mut WgpuServerInner<Self>,
139+
_server: &mut WgpuServer<Self>,
141140
kernel: <WgpuServer<Self> as ComputeServer>::Kernel,
142141
mode: ExecutionMode,
143142
) -> CompiledKernel<Self> {

0 commit comments

Comments
 (0)