@@ -6,19 +6,16 @@ use std::time::Duration;
66
77use cudarc:: driver:: sys:: CUevent_flags :: CU_EVENT_DEFAULT ;
88use cudarc:: driver:: {
9- CudaContext , CudaFunction , CudaStream , CudaViewMut , DeviceRepr , LaunchConfig , PushKernelArg ,
9+ CudaContext , CudaFunction , CudaStream , DeviceRepr , LaunchConfig , PushKernelArg ,
1010} ;
1111use cudarc:: nvrtc:: Ptx ;
12- use vortex_array:: Canonical ;
1312use vortex_array:: arrays:: PrimitiveArray ;
14- use vortex_array:: validity:: Validity ;
15- use vortex_buffer:: BufferMut ;
1613use vortex_dtype:: { NativePType , PType , match_each_native_ptype} ;
1714use vortex_error:: { VortexExpect , VortexResult , vortex_err} ;
1815use vortex_fastlanes:: { BitPackedVTable , FoRArray } ;
1916
20- use crate :: bit_unpack;
2117use crate :: task:: GPUTask ;
18+ use crate :: { ErasedCudaSlice , GpuArray , GpuPrimitiveArray , bit_unpack} ;
2219
2320struct ForTask < P > {
2421 stream : Arc < CudaStream > ,
@@ -71,43 +68,23 @@ fn cuda_for_kernel(ptype: PType, ctx: &Arc<CudaContext>) -> VortexResult<CudaFun
7168
7269impl < P : NativePType + DeviceRepr > GPUTask for ForTask < P > {
7370 fn launch_task ( & mut self ) -> VortexResult < ( ) > {
74- let len = self . len ( ) ;
7571 self . bp_task . launch_task ( ) ?;
7672 let mut launch = self . stream . launch_builder ( & self . func ) ;
77- let mut view = unsafe {
78- self . bp_task
79- . output ( )
80- . transmute_mut :: < P > ( len)
81- . vortex_expect ( "" )
82- } ;
73+ let mut view = self . bp_task . output ( ) . as_slice :: < P > ( ) ;
8374 launch. arg ( & mut view) ;
8475 launch. arg ( & self . reference ) ;
8576 unsafe { launch. launch ( self . launch_config ) }
8677 . map_err ( |e| vortex_err ! ( "Failed to launch: {e}" ) )
8778 . map ( |_| ( ) )
8879 }
8980
90- fn export_result ( & mut self ) -> VortexResult < Canonical > {
91- let len = self . len ( ) ;
92- let mut buffer = BufferMut :: < P > :: with_capacity ( len) ;
93-
94- unsafe { buffer. set_len ( len) }
95- self . stream
96- . memcpy_dtoh (
97- & unsafe { self . bp_task . output ( ) . transmute :: < P > ( len) . vortex_expect ( "" ) } ,
98- & mut buffer,
99- )
100- . map_err ( |e| vortex_err ! ( "Failed to copy to device: {e}" ) ) ?;
101- self . stream
102- . synchronize ( )
103- . map_err ( |e| vortex_err ! ( "Failed to synchronize: {e}" ) ) ?;
104- Ok ( Canonical :: Primitive ( PrimitiveArray :: new (
105- buffer,
106- Validity :: NonNullable ,
107- ) ) )
81+ fn export_result ( & mut self ) -> VortexResult < GpuArray > {
82+ Ok ( GpuArray :: Primitive ( GpuPrimitiveArray {
83+ values : self . bp_task . output ( ) ,
84+ } ) )
10885 }
10986
110- fn output ( & mut self ) -> CudaViewMut < ' _ , u8 > {
87+ fn output ( & mut self ) -> ErasedCudaSlice {
11188 self . bp_task . output ( )
11289 }
11390
0 commit comments