Skip to content

Commit a15c1d2

Browse files
Feat/runtime error (#1078)
1 parent 3443131 commit a15c1d2

File tree

15 files changed

+175
-76
lines changed

15 files changed

+175
-76
lines changed

crates/cubecl-cpu/src/compute/server.rs

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,9 @@ use cubecl_core::{
55
CubeCount, ExecutionMode, MemoryUsage,
66
future::DynFut,
77
server::{
8-
Allocation, AllocationDescriptor, Binding, Bindings, ComputeServer, CopyDescriptor, Handle,
9-
IoError, LaunchError, ProfileError, ProfilingToken, ServerCommunication, ServerUtilities,
8+
Allocation, AllocationDescriptor, Binding, Bindings, ComputeServer, CopyDescriptor,
9+
ExecutionError, Handle, IoError, LaunchError, ProfileError, ProfilingToken,
10+
ServerCommunication, ServerUtilities,
1011
},
1112
};
1213
use cubecl_runtime::{
@@ -202,13 +203,15 @@ impl ComputeServer for CpuServer {
202203

203204
fn flush(&mut self, _stream_id: StreamId) {}
204205

205-
fn sync(&mut self, _stream_id: StreamId) -> DynFut<()> {
206+
fn sync(&mut self, _stream_id: StreamId) -> DynFut<Result<(), ExecutionError>> {
206207
self.utilities.logger.profile_summary();
207-
Box::pin(async move {})
208+
Box::pin(async move { Ok(()) })
208209
}
209210

210211
fn start_profile(&mut self, stream_id: StreamId) -> ProfilingToken {
211-
cubecl_common::future::block_on(self.sync(stream_id));
212+
if let Err(err) = cubecl_common::future::block_on(self.sync(stream_id)) {
213+
self.ctx.timestamps.error(err.into());
214+
};
212215
self.ctx.timestamps.start()
213216
}
214217

@@ -218,7 +221,11 @@ impl ComputeServer for CpuServer {
218221
token: ProfilingToken,
219222
) -> Result<ProfileDuration, ProfileError> {
220223
self.utilities.logger.profile_summary();
221-
cubecl_common::future::block_on(self.sync(stream_id));
224+
225+
if let Err(err) = cubecl_common::future::block_on(self.sync(stream_id)) {
226+
self.ctx.timestamps.error(err.into());
227+
}
228+
222229
self.ctx.timestamps.stop(token)
223230
}
224231

crates/cubecl-cuda/src/compute/command.rs

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ use cubecl_common::{
1212
use cubecl_core::{
1313
ExecutionMode, MemoryUsage,
1414
future::DynFut,
15-
server::{Binding, CopyDescriptor, Handle, IoError, ProfileError},
15+
server::{Binding, CopyDescriptor, ExecutionError, Handle, IoError, ProfileError},
1616
};
1717
use cubecl_runtime::{
1818
compiler::{CompilationError, CubeTask},
@@ -170,9 +170,12 @@ impl<'a> Command<'a> {
170170
let fence = Fence::new(self.streams.current().sys);
171171

172172
async move {
173-
fence.wait_sync();
173+
let sync = fence.wait_sync();
174174
// Release memory handle.
175175
core::mem::drop(descriptors_moved);
176+
177+
sync?;
178+
176179
result
177180
}
178181
}
@@ -373,12 +376,10 @@ impl<'a> Command<'a> {
373376
/// # Returns
374377
///
375378
/// * A `DynFut<()>` future that resolves when the stream is synchronized.
376-
pub fn sync(&mut self) -> DynFut<()> {
379+
pub fn sync(&mut self) -> DynFut<Result<(), ExecutionError>> {
377380
let fence = Fence::new(self.streams.current().sys);
378381

379-
Box::pin(async {
380-
fence.wait_sync();
381-
})
382+
Box::pin(async { fence.wait_sync() })
382383
}
383384

384385
/// Executes a registered CUDA kernel with the specified parameters.

crates/cubecl-cuda/src/compute/server.rs

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ use crate::compute::context::CudaContext;
55
use crate::compute::stream::CudaStreamBackend;
66
use crate::compute::sync::Fence;
77
use cubecl_common::{bytes::Bytes, profile::ProfileDuration, stream_id::StreamId};
8-
use cubecl_core::server::{Binding, ServerCommunication, ServerUtilities};
8+
use cubecl_core::server::{Binding, ExecutionError, ServerCommunication, ServerUtilities};
99
use cubecl_core::server::{IoError, LaunchError};
1010
use cubecl_core::{MemoryConfiguration, prelude::*};
1111
use cubecl_core::{
@@ -398,13 +398,16 @@ impl ComputeServer for CudaServer {
398398

399399
fn flush(&mut self, _stream_id: StreamId) {}
400400

401-
fn sync(&mut self, stream_id: StreamId) -> DynFut<()> {
401+
fn sync(&mut self, stream_id: StreamId) -> DynFut<Result<(), ExecutionError>> {
402402
let mut command = self.command_no_inputs(stream_id);
403403
command.sync()
404404
}
405405

406406
fn start_profile(&mut self, stream_id: StreamId) -> ProfilingToken {
407-
cubecl_common::future::block_on(self.sync(stream_id));
407+
if let Err(err) = cubecl_common::future::block_on(self.sync(stream_id)) {
408+
self.ctx.timestamps.error(err.into());
409+
}
410+
408411
self.ctx.timestamps.start()
409412
}
410413

@@ -413,7 +416,9 @@ impl ComputeServer for CudaServer {
413416
stream_id: StreamId,
414417
token: ProfilingToken,
415418
) -> Result<ProfileDuration, ProfileError> {
416-
cubecl_common::future::block_on(self.sync(stream_id));
419+
if let Err(err) = cubecl_common::future::block_on(self.sync(stream_id)) {
420+
self.ctx.timestamps.error(err.into());
421+
}
417422
self.ctx.timestamps.stop(token)
418423
}
419424

crates/cubecl-cuda/src/compute/stream.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ use crate::compute::{
77
},
88
sync::Fence,
99
};
10-
use cubecl_core::MemoryConfiguration;
10+
use cubecl_core::{MemoryConfiguration, server::ExecutionError};
1111
use cubecl_runtime::{
1212
logging::ServerLogger,
1313
memory_management::{
@@ -77,7 +77,7 @@ impl EventStreamBackend for CudaStreamBackend {
7777
event.wait_async(stream.sys);
7878
}
7979

80-
fn wait_event_sync(event: Self::Event) {
81-
event.wait_sync();
80+
fn wait_event_sync(event: Self::Event) -> Result<(), ExecutionError> {
81+
event.wait_sync()
8282
}
8383
}

crates/cubecl-cuda/src/compute/sync/fence.rs

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
use cubecl_core::server::ExecutionError;
12
use cudarc::driver::sys::{CUevent_flags, CUevent_st, CUevent_wait_flags, CUstream_st};
23

34
/// A fence is simply an [event](CUevent_st) created on a [stream](CUevent_st) that you can wait
@@ -37,11 +38,21 @@ impl Fence {
3738

3839
/// Wait for the [Fence] to be reached, ensuring that all previous tasks enqueued to the
3940
/// [stream](CUstream_st) are completed.
40-
pub fn wait_sync(self) {
41+
pub fn wait_sync(self) -> Result<(), ExecutionError> {
4142
unsafe {
42-
cudarc::driver::result::event::synchronize(self.event).unwrap();
43-
cudarc::driver::result::event::destroy(self.event).unwrap();
43+
cudarc::driver::result::event::synchronize(self.event).map_err(|err| {
44+
ExecutionError::Generic {
45+
context: format!("{err:?}"),
46+
}
47+
})?;
48+
cudarc::driver::result::event::destroy(self.event).map_err(|err| {
49+
ExecutionError::Generic {
50+
context: format!("{err:?}"),
51+
}
52+
})?;
4453
}
54+
55+
Ok(())
4556
}
4657

4758
/// Wait for the [Fence] to be reached, ensuring that all previous tasks enqueued to the

crates/cubecl-hip/src/compute/command.rs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ use cubecl_common::{bytes::Bytes, stream_id::StreamId};
22
use cubecl_core::{
33
ExecutionMode, MemoryUsage,
44
future::DynFut,
5-
server::{Binding, CopyDescriptor, Handle, IoError, ProfileError},
5+
server::{Binding, CopyDescriptor, ExecutionError, Handle, IoError, ProfileError},
66
};
77
use cubecl_hip_sys::{
88
HIP_SUCCESS, hipMemcpyKind_hipMemcpyDeviceToHost, hipMemcpyKind_hipMemcpyHostToDevice,
@@ -175,9 +175,11 @@ impl<'a> Command<'a> {
175175
let fence = Fence::new(self.streams.current().sys);
176176

177177
async move {
178-
fence.wait_sync();
178+
let sync = fence.wait_sync();
179179
// Release memory handle.
180180
core::mem::drop(descriptors_moved);
181+
182+
sync?;
181183
result
182184
}
183185
}
@@ -367,12 +369,10 @@ impl<'a> Command<'a> {
367369
/// # Returns
368370
///
369371
/// * A `DynFut<()>` future that resolves when the stream is synchronized.
370-
pub fn sync(&mut self) -> DynFut<()> {
372+
pub fn sync(&mut self) -> DynFut<Result<(), ExecutionError>> {
371373
let fence = Fence::new(self.streams.current().sys);
372374

373-
Box::pin(async {
374-
fence.wait_sync();
375-
})
375+
Box::pin(async { fence.wait_sync() })
376376
}
377377

378378
/// Executes a registered CUDA kernel with the specified parameters.

crates/cubecl-hip/src/compute/fence.rs

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
use cubecl_core::server::ExecutionError;
12
use cubecl_hip_sys::HIP_SUCCESS;
23

34
/// A fence is simply an [event](hipEvent_t) created on a [stream](hipStream_t) that you can wait
@@ -55,21 +56,30 @@ impl Fence {
5556
"Should successfully wait for stream event"
5657
);
5758
let status = cubecl_hip_sys::hipEventDestroy(self.event);
58-
assert_eq!(status, HIP_SUCCESS, "Should destrdestroy the stream eventt");
59+
assert_eq!(status, HIP_SUCCESS, "Should destroy the stream eventt");
5960
}
6061
}
6162

6263
/// Wait for the [Fence] to be reached, ensuring that all previous tasks enqueued to the
6364
/// [stream](hipStream_t) are completed.
64-
pub fn wait_sync(self) {
65+
pub fn wait_sync(self) -> Result<(), ExecutionError> {
6566
unsafe {
6667
let status = cubecl_hip_sys::hipEventSynchronize(self.event);
67-
assert_eq!(
68-
status, HIP_SUCCESS,
69-
"Should successfully wait for stream event"
70-
);
68+
69+
if status != HIP_SUCCESS {
70+
return Err(ExecutionError::Generic {
71+
context: format!("Should successfully wait for stream event: {status}"),
72+
});
73+
}
7174
let status = cubecl_hip_sys::hipEventDestroy(self.event);
72-
assert_eq!(status, HIP_SUCCESS, "Should destrdestroy the stream eventt");
75+
76+
if status != HIP_SUCCESS {
77+
return Err(ExecutionError::Generic {
78+
context: format!("Should destroy the stream event: {status}"),
79+
});
80+
}
7381
}
82+
83+
Ok(())
7484
}
7585
}

crates/cubecl-hip/src/compute/server.rs

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ use cubecl_common::bytes::Bytes;
1010
use cubecl_common::future::DynFut;
1111
use cubecl_common::profile::ProfileDuration;
1212
use cubecl_common::stream_id::StreamId;
13+
use cubecl_core::server::ExecutionError;
1314
use cubecl_core::server::LaunchError;
1415
use cubecl_core::server::ServerCommunication;
1516
use cubecl_core::server::ServerUtilities;
@@ -223,13 +224,16 @@ impl ComputeServer for HipServer {
223224

224225
fn flush(&mut self, _stream_id: StreamId) {}
225226

226-
fn sync(&mut self, stream_id: StreamId) -> DynFut<()> {
227+
fn sync(&mut self, stream_id: StreamId) -> DynFut<Result<(), ExecutionError>> {
227228
let mut command = self.command_no_inputs(stream_id);
228229
command.sync()
229230
}
230231

231232
fn start_profile(&mut self, stream_id: StreamId) -> ProfilingToken {
232-
cubecl_common::future::block_on(self.sync(stream_id));
233+
if let Err(err) = cubecl_common::future::block_on(self.sync(stream_id)) {
234+
self.ctx.timestamps.error(err.into())
235+
}
236+
233237
self.ctx.timestamps.start()
234238
}
235239

@@ -238,7 +242,9 @@ impl ComputeServer for HipServer {
238242
stream_id: StreamId,
239243
token: ProfilingToken,
240244
) -> Result<ProfileDuration, ProfileError> {
241-
cubecl_common::future::block_on(self.sync(stream_id));
245+
if let Err(err) = cubecl_common::future::block_on(self.sync(stream_id)) {
246+
self.ctx.timestamps.error(err.into())
247+
}
242248
self.ctx.timestamps.stop(token)
243249
}
244250

crates/cubecl-hip/src/compute/stream.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use std::sync::Arc;
22

3-
use cubecl_core::MemoryConfiguration;
3+
use cubecl_core::{MemoryConfiguration, server::ExecutionError};
44
use cubecl_hip_sys::HIP_SUCCESS;
55
use cubecl_runtime::{
66
logging::ServerLogger,
@@ -78,7 +78,7 @@ impl EventStreamBackend for HipStreamBackend {
7878
event.wait_async(stream.sys);
7979
}
8080

81-
fn wait_event_sync(event: Self::Event) {
82-
event.wait_sync();
81+
fn wait_event_sync(event: Self::Event) -> Result<(), ExecutionError> {
82+
event.wait_sync()
8383
}
8484
}

0 commit comments

Comments
 (0)