@@ -151,25 +151,29 @@ async def test_codecs_use_of_gpu_prototype() -> None:
151151@gpu_test  
152152@pytest .mark .asyncio  
153153async  def  test_sharding_use_of_gpu_prototype () ->  None :
154-     expect  =  cp .zeros ((10 , 10 ), dtype = "uint16" , order = "F" )
155-     a  =  await  zarr .api .asynchronous .create_array (
156-         StorePath (MemoryStore ()) /  "test_codecs_use_of_gpu_prototype" ,
157-         shape = expect .shape ,
158-         chunks = (5 , 5 ),
159-         shards = (10 , 10 ),
160-         dtype = expect .dtype ,
161-         fill_value = 0 ,
162-     )
163-     expect [:] =  cp .arange (100 ).reshape (10 , 10 )
164- 
165-     await  a .setitem (
166-         selection = (slice (0 , 10 ), slice (0 , 10 )),
167-         value = expect [:],
168-         prototype = gpu .buffer_prototype ,
169-     )
170-     got  =  await  a .getitem (selection = (slice (0 , 10 ), slice (0 , 10 )), prototype = gpu .buffer_prototype )
171-     assert  isinstance (got , cp .ndarray )
172-     assert  cp .array_equal (expect , got )
154+     with  zarr .config .enable_gpu ():
155+         expect  =  cp .zeros ((10 , 10 ), dtype = "uint16" , order = "F" )
156+ 
157+         a  =  await  zarr .api .asynchronous .create_array (
158+             StorePath (MemoryStore ()) /  "test_codecs_use_of_gpu_prototype" ,
159+             shape = expect .shape ,
160+             chunks = (5 , 5 ),
161+             shards = (10 , 10 ),
162+             dtype = expect .dtype ,
163+             fill_value = 0 ,
164+         )
165+         expect [:] =  cp .arange (100 ).reshape (10 , 10 )
166+ 
167+         await  a .setitem (
168+             selection = (slice (0 , 10 ), slice (0 , 10 )),
169+             value = expect [:],
170+             prototype = gpu .buffer_prototype ,
171+         )
172+         got  =  await  a .getitem (
173+             selection = (slice (0 , 10 ), slice (0 , 10 )), prototype = gpu .buffer_prototype 
174+         )
175+         assert  isinstance (got , cp .ndarray )
176+         assert  cp .array_equal (expect , got )
173177
174178
175179def  test_numpy_buffer_prototype () ->  None :
0 commit comments