@@ -148,6 +148,34 @@ async def test_codecs_use_of_gpu_prototype() -> None:
148
148
assert cp .array_equal (expect , got )
149
149
150
150
151
+ @gpu_test
152
+ @pytest .mark .asyncio
153
+ async def test_sharding_use_of_gpu_prototype () -> None :
154
+ with zarr .config .enable_gpu ():
155
+ expect = cp .zeros ((10 , 10 ), dtype = "uint16" , order = "F" )
156
+
157
+ a = await zarr .api .asynchronous .create_array (
158
+ StorePath (MemoryStore ()) / "test_codecs_use_of_gpu_prototype" ,
159
+ shape = expect .shape ,
160
+ chunks = (5 , 5 ),
161
+ shards = (10 , 10 ),
162
+ dtype = expect .dtype ,
163
+ fill_value = 0 ,
164
+ )
165
+ expect [:] = cp .arange (100 ).reshape (10 , 10 )
166
+
167
+ await a .setitem (
168
+ selection = (slice (0 , 10 ), slice (0 , 10 )),
169
+ value = expect [:],
170
+ prototype = gpu .buffer_prototype ,
171
+ )
172
+ got = await a .getitem (
173
+ selection = (slice (0 , 10 ), slice (0 , 10 )), prototype = gpu .buffer_prototype
174
+ )
175
+ assert isinstance (got , cp .ndarray )
176
+ assert cp .array_equal (expect , got )
177
+
178
+
151
179
def test_numpy_buffer_prototype () -> None :
152
180
buffer = cpu .buffer_prototype .buffer .create_zero_length ()
153
181
ndbuffer = cpu .buffer_prototype .nd_buffer .create (shape = (1 , 2 ), dtype = np .dtype ("int64" ))
@@ -157,6 +185,16 @@ def test_numpy_buffer_prototype() -> None:
157
185
ndbuffer .as_scalar ()
158
186
159
187
188
+ @gpu_test
189
+ def test_gpu_buffer_prototype () -> None :
190
+ buffer = gpu .buffer_prototype .buffer .create_zero_length ()
191
+ ndbuffer = gpu .buffer_prototype .nd_buffer .create (shape = (1 , 2 ), dtype = cp .dtype ("int64" ))
192
+ assert isinstance (buffer .as_array_like (), cp .ndarray )
193
+ assert isinstance (ndbuffer .as_ndarray_like (), cp .ndarray )
194
+ with pytest .raises (ValueError , match = "Buffer does not contain a single scalar value" ):
195
+ ndbuffer .as_scalar ()
196
+
197
+
160
198
# TODO: the same test for other buffer classes
161
199
def test_cpu_buffer_as_scalar () -> None :
162
200
buf = cpu .buffer_prototype .nd_buffer .create (shape = (), dtype = "int64" )
0 commit comments