Skip to content

Commit a26fce4

Browse files
Finish some todos (#951)
1 parent f451b1c commit a26fce4

File tree

4 files changed

+172
-53
lines changed

4 files changed

+172
-53
lines changed

crates/cubecl-hip/src/compute/command.rs

Lines changed: 62 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,15 @@ use cubecl_core::{
77
};
88
use cubecl_hip_sys::{
99
HIP_SUCCESS, hipMemcpyKind_hipMemcpyDeviceToHost, hipMemcpyKind_hipMemcpyHostToDevice,
10+
ihipStream_t,
1011
};
1112
use cubecl_runtime::{
1213
id::KernelId,
1314
logging::ServerLogger,
1415
memory_management::{MemoryAllocationMode, MemoryHandle},
1516
stream::ResolvedStreams,
1617
};
17-
use std::sync::Arc;
18+
use std::{ffi::c_void, sync::Arc};
1819

1920
use crate::{
2021
compute::{
@@ -31,7 +32,7 @@ use crate::{
3132
/// registration, and task execution.
3233
pub struct Command<'a> {
3334
ctx: &'a mut HipContext,
34-
streams: ResolvedStreams<'a, HipStreamBackend>,
35+
pub(crate) streams: ResolvedStreams<'a, HipStreamBackend>,
3536
}
3637

3738
impl<'a> Command<'a> {
@@ -274,58 +275,13 @@ impl<'a> Command<'a> {
274275
return Err(IoError::UnsupportedStrides);
275276
}
276277

277-
let rank = shape.len();
278278
let resource = self.resource(binding)?;
279279
let stream = match stream_id {
280280
Some(id) => self.streams.get(&id),
281281
None => self.streams.current(),
282282
};
283283

284-
if rank <= 1 {
285-
unsafe {
286-
let status = cubecl_hip_sys::hipMemcpyDtoHAsync(
287-
bytes.as_mut_ptr() as *mut _,
288-
resource.ptr,
289-
bytes.len(),
290-
stream.sys,
291-
);
292-
293-
if status != HIP_SUCCESS {
294-
return Err(IoError::Unknown(format!("HIP memcpy failed: {}", status)));
295-
}
296-
}
297-
return Ok(());
298-
}
299-
300-
let dim_x = shape[rank - 1];
301-
let width_bytes = dim_x * elem_size;
302-
let dim_y: usize = shape.iter().rev().skip(1).product();
303-
let pitch = strides[rank - 2] * elem_size;
304-
305-
unsafe {
306-
let status = cubecl_hip_sys::hipMemcpy2DAsync(
307-
bytes.as_mut_ptr() as *mut _,
308-
width_bytes,
309-
resource.ptr,
310-
pitch,
311-
width_bytes,
312-
dim_y,
313-
hipMemcpyKind_hipMemcpyDeviceToHost,
314-
stream.sys,
315-
);
316-
317-
// Fallback, sometimes the copy doesn't work.
318-
if status != HIP_SUCCESS {
319-
let status = cubecl_hip_sys::hipMemcpyDtoHAsync(
320-
bytes.as_mut_ptr() as *mut _,
321-
resource.ptr,
322-
bytes.len(),
323-
stream.sys,
324-
);
325-
assert_eq!(status, HIP_SUCCESS, "Should send data to device");
326-
}
327-
}
328-
Ok(())
284+
unsafe { write_to_cpu(shape, strides, elem_size, bytes, resource.ptr, stream.sys) }
329285
}
330286

331287
/// Writes data from the host to GPU memory as specified by the copy descriptor.
@@ -469,3 +425,61 @@ impl<'a> Command<'a> {
469425
};
470426
}
471427
}
428+
429+
pub(crate) unsafe fn write_to_cpu(
430+
shape: &[usize],
431+
strides: &[usize],
432+
elem_size: usize,
433+
bytes: &mut Bytes,
434+
resource_ptr: *mut c_void,
435+
stream: *mut ihipStream_t,
436+
) -> Result<(), IoError> {
437+
let rank = shape.len();
438+
439+
if rank <= 1 {
440+
let status = unsafe {
441+
cubecl_hip_sys::hipMemcpyDtoHAsync(
442+
bytes.as_mut_ptr() as *mut _,
443+
resource_ptr,
444+
bytes.len(),
445+
stream,
446+
)
447+
};
448+
449+
if status != HIP_SUCCESS {
450+
return Err(IoError::Unknown(format!("HIP memcpy failed: {}", status)));
451+
}
452+
return Ok(());
453+
}
454+
455+
let dim_x = shape[rank - 1];
456+
let width_bytes = dim_x * elem_size;
457+
let dim_y: usize = shape.iter().rev().skip(1).product();
458+
let pitch = strides[rank - 2] * elem_size;
459+
460+
unsafe {
461+
let status = cubecl_hip_sys::hipMemcpy2DAsync(
462+
bytes.as_mut_ptr() as *mut _,
463+
width_bytes,
464+
resource_ptr,
465+
pitch,
466+
width_bytes,
467+
dim_y,
468+
hipMemcpyKind_hipMemcpyDeviceToHost,
469+
stream,
470+
);
471+
472+
// Fallback, sometimes the copy doesn't work.
473+
if status != HIP_SUCCESS {
474+
let status = cubecl_hip_sys::hipMemcpyDtoHAsync(
475+
bytes.as_mut_ptr() as *mut _,
476+
resource_ptr,
477+
bytes.len(),
478+
stream,
479+
);
480+
assert_eq!(status, HIP_SUCCESS, "Should send data to device");
481+
}
482+
}
483+
484+
Ok(())
485+
}

crates/cubecl-hip/src/compute/server.rs

Lines changed: 84 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
use super::storage::gpu::GpuResource;
22
use super::storage::gpu::GpuStorage;
33
use crate::compute::command::Command;
4+
use crate::compute::command::write_to_cpu;
45
use crate::compute::context::HipContext;
6+
use crate::compute::fence::Fence;
57
use crate::compute::stream::HipStreamBackend;
68
use crate::runtime::HipCompiler;
79
use cubecl_common::bytes::Bytes;
@@ -238,7 +240,17 @@ impl ComputeServer for HipServer {
238240
}
239241

240242
impl ServerCommunication for HipServer {
241-
const SERVER_COMM_ENABLED: bool = false;
243+
const SERVER_COMM_ENABLED: bool = true;
244+
245+
fn copy(
246+
server_src: &mut Self,
247+
server_dst: &mut Self,
248+
src: CopyDescriptor<'_>,
249+
stream_id_src: StreamId,
250+
stream_id_dst: StreamId,
251+
) -> Result<Allocation, IoError> {
252+
Self::change_server_serialized(server_src, server_dst, src, stream_id_src, stream_id_dst)
253+
}
242254
}
243255

244256
impl HipServer {
@@ -277,6 +289,77 @@ impl HipServer {
277289

278290
Command::new(&mut self.ctx, streams)
279291
}
292+
293+
fn change_server_serialized(
294+
server_src: &mut Self,
295+
server_dst: &mut Self,
296+
src: CopyDescriptor<'_>,
297+
stream_id_src: StreamId,
298+
stream_id_dst: StreamId,
299+
) -> Result<Allocation, IoError> {
300+
let shape = src.shape.to_vec();
301+
let strides = src.strides.to_vec();
302+
let elem_size = src.elem_size;
303+
let binding = src.binding.clone();
304+
let num_bytes = shape.iter().product::<usize>() * elem_size;
305+
306+
// We start by creating a command on the destination server.
307+
//
308+
// Here we allocate the necessary bytes using pinned memory managed by the destination
309+
// server along a new GPU handle. This way, the bytes could be reused later by that server,
310+
// and the lifetime of that handle is aligned with the execution order of the destination server,
311+
// removing the need to keep the bytes handle alive using synchronization, which would be the
312+
// case if we allocated the bytes using the source server.
313+
let mut command_dst = server_dst.command_no_inputs(stream_id_dst);
314+
let handle = command_dst.reserve(binding.size())?;
315+
let mut bytes = command_dst.reserve_cpu(num_bytes, true, None);
316+
let copy_desc = handle.copy_descriptor(&shape, &strides, elem_size);
317+
318+
// We need to free the command before creating another one.
319+
core::mem::drop(command_dst);
320+
321+
// We create a command on the source server to retrieve the correct resource from the
322+
// source memory pools. We also make sure the current stream is aligned with the stream of
323+
// the binding, where the data was first allocated.
324+
//
325+
// We use the source stream to copy the data from the source server into the allocated
326+
// bytes. This ensures that the source binding follows the correct execution order, meaning
327+
// that we don't have to keep the source handle alive using synchronization, which would be
328+
// the case if we performed the copy on the destination server.
329+
let mut command_src = server_src.command(stream_id_src, [&src.binding].into_iter());
330+
let resource_src = command_src.resource(binding.clone())?;
331+
let stream_src = command_src.streams.current().sys;
332+
333+
unsafe {
334+
write_to_cpu(
335+
&shape,
336+
&strides,
337+
elem_size,
338+
&mut bytes,
339+
resource_src.ptr,
340+
stream_src,
341+
)?;
342+
}
343+
let fence_src = Fence::new(stream_src);
344+
345+
// We need to free the command before creating another one.
346+
core::mem::drop(command_src);
347+
348+
// Finally, we recreate a new command on the destination server to write the data stored in
349+
// pinned memory into the destination server. Here we need to wait for the initial copy
350+
// made by the source server using an event. The synchronization is done lazily on the
351+
// destination stream, which is very efficient.
352+
let mut command_dst = server_dst.command_no_inputs(stream_id_dst);
353+
let stream_dst = command_dst.streams.current().sys;
354+
355+
fence_src.wait_async(stream_dst);
356+
command_dst.write_to_gpu(copy_desc, &bytes)?;
357+
358+
// We drop the last command.
359+
core::mem::drop(command_dst);
360+
361+
Ok(Allocation { handle, strides })
362+
}
280363
}
281364

282365
pub(crate) fn contiguous_strides(shape: &[usize]) -> Vec<usize> {

crates/cubecl-runtime/src/memory_management/memory_pool/exclusive_pool.rs

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
use crate::{
2-
memory_management::MemoryUsage,
2+
memory_management::{BytesFormat, MemoryUsage},
33
server::IoError,
44
storage::{ComputeStorage, StorageHandle, StorageUtilization},
55
};
@@ -24,7 +24,23 @@ pub struct ExclusiveMemoryPool {
2424

2525
impl core::fmt::Display for ExclusiveMemoryPool {
2626
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
27-
f.write_str("ExclusiveMemoryPool")
27+
f.write_fmt(format_args!(
28+
" - Exclusive Pool max_alloc_size={}\n",
29+
BytesFormat::new(self.max_alloc_size)
30+
))?;
31+
32+
for page in self.pages.iter() {
33+
let is_free = page.slice.is_free();
34+
let size = BytesFormat::new(page.slice.effective_size());
35+
36+
f.write_fmt(format_args!(" - Page {size} is_free={is_free}\n"))?;
37+
}
38+
39+
if !self.pages.is_empty() {
40+
f.write_fmt(format_args!("\n{}\n", self.get_memory_usage()))?;
41+
}
42+
43+
Ok(())
2844
}
2945
}
3046

crates/cubecl-runtime/src/memory_management/memory_pool/persistent_pool.rs

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -136,16 +136,22 @@ impl MemoryPool for PersistentPool {
136136
explicit: bool,
137137
) {
138138
if explicit {
139-
self.slices.retain(|_, slice| {
139+
let mut removed = Vec::new();
140+
self.slices.retain(|id, slice| {
140141
if slice.is_free() {
141142
storage.dealloc(slice.storage.id);
143+
removed.push((*id, slice.effective_size()));
142144
false
143145
} else {
144146
true
145147
}
146-
// TODO: Remove the slice id from the sizes map.
147148
});
148149

150+
for (id, size) in removed {
151+
let ids = self.sizes.get_mut(&size).expect("The size should match");
152+
ids.retain(|id_| *id_ != id);
153+
}
154+
149155
storage.flush();
150156
}
151157
}

0 commit comments

Comments
 (0)