@@ -55,7 +55,7 @@ def __await__(self) -> Generator[Any, None, None]:
5555 async def _wait (self ) -> None :
5656 """Polls the CUDA event asynchronously with exponential backoff until it completes."""
5757 delay = self .initial_delay
58- while not self .event .query () : # `query() ` returns True if the event is complete
58+ while not self .event .done : # `done ` returns True if the event is complete
5959 await asyncio .sleep (delay ) # Yield control to other async tasks
6060 delay = min (delay * 2 , self .max_delay ) # Exponential backoff
6161
@@ -127,9 +127,13 @@ async def _convert_from_nvcomp_arrays(
127127 self ,
128128 arrays : Iterable [nvcomp .Array ],
129129 chunks_and_specs : Iterable [tuple [Buffer | None , ArraySpec ]],
130+ awaitable : AsyncCUDAEvent ,
130131 ) -> Iterable [Buffer | None ]:
132+ await awaitable # Wait for array computation to complete before accessing
131133 return [
132- spec .prototype .buffer .from_array_like (cp .asarray (a , dtype = np .dtype ("b" ))) if a else None
134+ spec .prototype .buffer .from_array_like (cp .array (a , dtype = np .dtype ("b" ), copy = False ))
135+ if a
136+ else None
133137 for a , (_ , spec ) in zip (arrays , chunks_and_specs , strict = True )
134138 ]
135139
@@ -155,10 +159,15 @@ async def decode(
155159 filtered_inputs , none_indices = await self ._convert_to_nvcomp_arrays (chunks_and_specs )
156160
157161 outputs = self ._zstd_codec .decode (filtered_inputs ) if len (filtered_inputs ) > 0 else []
162+
163+ # Record event for synchronization
164+ event = cp .cuda .Event ()
165+ awaitable = AsyncCUDAEvent (event ) # Convert CUDA event to awaitable object
166+
158167 for index in none_indices :
159168 outputs .insert (index , None )
160169
161- return await self ._convert_from_nvcomp_arrays (outputs , chunks_and_specs )
170+ return await self ._convert_from_nvcomp_arrays (outputs , chunks_and_specs , awaitable )
162171
163172 async def encode (
164173 self ,
@@ -183,10 +192,15 @@ async def encode(
183192 filtered_inputs , none_indices = await self ._convert_to_nvcomp_arrays (chunks_and_specs )
184193
185194 outputs = self ._zstd_codec .encode (filtered_inputs ) if len (filtered_inputs ) > 0 else []
195+
196+ # Record event for synchronization
197+ event = cp .cuda .Event ()
198+ awaitable = AsyncCUDAEvent (event ) # Convert CUDA event to awaitable object
199+
186200 for index in none_indices :
187201 outputs .insert (index , None )
188202
189- return await self ._convert_from_nvcomp_arrays (outputs , chunks_and_specs )
203+ return await self ._convert_from_nvcomp_arrays (outputs , chunks_and_specs , awaitable )
190204
191205 def compute_encoded_size (self , _input_byte_length : int , _chunk_spec : ArraySpec ) -> int :
192206 raise NotImplementedError
0 commit comments