11use crate :: {
22 CudaCompiler ,
33 compute:: {
4- DataTransferItem , DataTransferRuntime , MB , context:: CudaContext ,
5- io:: controller:: PinnedMemoryManagedAllocController , storage:: gpu:: GpuResource ,
6- stream:: CudaStreamBackend , sync:: Fence , valid_strides,
4+ MB , context:: CudaContext , io:: controller:: PinnedMemoryManagedAllocController ,
5+ storage:: gpu:: GpuResource , stream:: CudaStreamBackend , sync:: Fence , valid_strides,
76 } ,
87} ;
98use cubecl_common:: { bytes:: Bytes , stream_id:: StreamId } ;
@@ -14,7 +13,6 @@ use cubecl_core::{
1413 server:: { Binding , CopyDescriptor , Handle , IoError , ProfileError } ,
1514} ;
1615use cubecl_runtime:: {
17- data_service:: DataTransferId ,
1816 id:: KernelId ,
1917 logging:: ServerLogger ,
2018 memory_management:: { MemoryAllocationMode , MemoryHandle } ,
@@ -25,21 +23,13 @@ use cudarc::driver::sys::{
2523} ;
2624use std:: { ffi:: c_void, ops:: DerefMut , sync:: Arc } ;
2725
28- const DEVICE_TO_DEVICE : DeviceTransferStrategy = DeviceTransferStrategy :: Serialized ;
29-
30- #[ allow( unused) ]
31- enum DeviceTransferStrategy {
32- Peer ,
33- Serialized ,
34- }
35-
3626#[ derive( new) ]
3727/// The `Command` struct encapsulates a CUDA context and a set of resolved CUDA streams, providing an
3828/// interface for executing GPU-related operations such as memory allocation, data transfers, kernel
3929/// registration, and task execution.
4030pub struct Command < ' a > {
4131 ctx : & ' a mut CudaContext ,
42- streams : ResolvedStreams < ' a , CudaStreamBackend > ,
32+ pub ( crate ) streams : ResolvedStreams < ' a , CudaStreamBackend > ,
4333}
4434
4535impl < ' a > Command < ' a > {
@@ -241,7 +231,7 @@ impl<'a> Command<'a> {
241231 Ok ( ( data, fences) )
242232 }
243233
244- fn copy_to_bytes (
234+ pub fn copy_to_bytes (
245235 & mut self ,
246236 descriptor : CopyDescriptor < ' _ > ,
247237 pinned : bool ,
@@ -344,72 +334,6 @@ impl<'a> Command<'a> {
344334 Ok ( handle)
345335 }
346336
347- /// Registers a source for an asynchronous data transfer operation.
348- ///
349- /// # Parameters
350- ///
351- /// * `id` - The unique identifier for the data transfer.
352- /// * `src` - The descriptor for the source GPU memory.
353- pub fn data_transfer_src ( & mut self , id : DataTransferId , src : CopyDescriptor < ' _ > ) {
354- let src_resource = self . resource ( src. binding . clone ( ) ) . unwrap ( ) ;
355- let client = DataTransferRuntime :: client ( ) ;
356- let current = self . streams . current ( ) ;
357-
358- let handle = DataTransferItem {
359- stream : current. sys ,
360- context : self . ctx . context ,
361- resource : src_resource,
362- } ;
363-
364- match DEVICE_TO_DEVICE {
365- DeviceTransferStrategy :: Peer => {
366- let fence = Fence :: new ( current. sys ) ;
367-
368- client. register_src_peer ( id, handle, fence) ;
369- }
370- DeviceTransferStrategy :: Serialized => {
371- client. register_src_serialized ( id, handle, src. binding ) ;
372- }
373- }
374- }
375-
376- /// Registers a destination for an asynchronous data transfer operation.
377- ///
378- /// # Parameters
379- ///
380- /// * `id` - The unique identifier for the data transfer.
381- /// * `dest` - The descriptor for the destination GPU memory.
382- pub fn data_transfer_dest ( & mut self , id : DataTransferId , dest : CopyDescriptor < ' _ > ) {
383- let dst_resource = self . resource ( dest. binding ) . unwrap ( ) ;
384- let current = self . streams . current ( ) ;
385- let client = DataTransferRuntime :: client ( ) ;
386-
387- let item = DataTransferItem {
388- context : self . ctx . context ,
389- stream : current. sys ,
390- resource : dst_resource,
391- } ;
392-
393- match DEVICE_TO_DEVICE {
394- DeviceTransferStrategy :: Peer => {
395- client. register_dest_peer ( id, item) ;
396- }
397- DeviceTransferStrategy :: Serialized => {
398- let num_bytes = dest. shape . iter ( ) . product :: < usize > ( ) * dest. elem_size ;
399- let bytes = self . reserve_cpu ( num_bytes, true , None ) ;
400-
401- client. register_dest_serialized (
402- id,
403- item,
404- bytes,
405- dest. shape . to_vec ( ) ,
406- dest. strides . to_vec ( ) ,
407- dest. elem_size ,
408- ) ;
409- }
410- }
411- }
412-
413337 /// Synchronizes the current stream, ensuring all pending operations are complete.
414338 ///
415339 /// # Returns
0 commit comments