Skip to content

Commit 35cb298

Browse files
authored
Adds zeros_like to the Array API (#28)
* fix asarray * zeros_like * code coverage
1 parent 2fc79f6 commit 35cb298

File tree

9 files changed

+53
-8
lines changed

9 files changed

+53
-8
lines changed

_doc/index.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ well as to execute it.
3939

4040
Sources available on
4141
`github/onnx-array-api <https://github.com/sdpython/onnx-array-api>`_,
42-
see also `code coverage <cov/index.html>`_.
42+
see also `code coverage <_static/cov_html/index.html>`_.
4343

4444
.. runpython::
4545
:showcode:

_doc/run_coverage.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
python3 -m pytest --cov --cov-report html:_doc/_static/cov_html _unittests

_unittests/onnx-numpy-skips.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,3 @@ array_api_tests/test_creation_functions.py::test_empty_like
88
array_api_tests/test_creation_functions.py::test_eye
99
array_api_tests/test_creation_functions.py::test_linspace
1010
array_api_tests/test_creation_functions.py::test_meshgrid
11-
array_api_tests/test_creation_functions.py::test_zeros_like

_unittests/test_array_api.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
export ARRAY_API_TESTS_MODULE=onnx_array_api.array_api.onnx_numpy
2-
pytest -v -rxXfE ../array-api-tests/array_api_tests/test_creation_functions.py::test_ones_like || exit 1
2+
pytest -v -rxXfE ../array-api-tests/array_api_tests/test_creation_functions.py::test_zeros_like || exit 1
33
# pytest ../array-api-tests/array_api_tests/test_creation_functions.py --help
44
pytest -v -rxXfE ../array-api-tests/array_api_tests/test_creation_functions.py --hypothesis-explain --skips-file=_unittests/onnx-numpy-skips.txt || exit 1

_unittests/ut_array_api/test_onnx_numpy.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,10 +127,25 @@ def test_full_like_mx(self):
127127
matnp = mat.numpy()
128128
self.assertEqualArray(expected, matnp)
129129

130+
def test_ones_like_mx(self):
131+
c = EagerTensor(np.array([], dtype=np.uint8))
132+
expected = np.ones_like(c.numpy())
133+
mat = xp.ones_like(c)
134+
matnp = mat.numpy()
135+
self.assertEqualArray(expected, matnp)
136+
137+
def test_as_array(self):
138+
r = xp.asarray(9223372036854775809)
139+
self.assertEqual(r.dtype, DType(TensorProto.UINT64))
140+
self.assertEqual(r.numpy(), 9223372036854775809)
141+
r = EagerTensor(np.array(9223372036854775809, dtype=np.uint64))
142+
self.assertEqual(r.dtype, DType(TensorProto.UINT64))
143+
self.assertEqual(r.numpy(), 9223372036854775809)
144+
130145

131146
if __name__ == "__main__":
132147
# import logging
133148

134149
# logging.basicConfig(level=logging.DEBUG)
135-
# TestOnnxNumpy().test_full_like_mx()
150+
# TestOnnxNumpy().test_as_array()
136151
unittest.main(verbosity=2)

azure-pipelines.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,7 @@ jobs:
184184
black --diff .
185185
displayName: 'Black'
186186
- script: |
187-
python -m pytest
187+
python -m pytest --cov
188188
displayName: 'Runs Unit Tests'
189189
- script: |
190190
python -u setup.py bdist_wheel

onnx_array_api/array_api/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
"sum",
3030
"take",
3131
"zeros",
32+
"zeros_like",
3233
]
3334

3435

onnx_array_api/array_api/_onnx_common.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,10 +71,17 @@ def asarray(
7171
elif a is True:
7272
v = TEagerTensor(np.array(True, dtype=np.bool_))
7373
else:
74+
va = np.asarray(a)
75+
v = None
7476
try:
75-
v = TEagerTensor(np.asarray(a, dtype=np.int64))
77+
vai = np.asarray(a, dtype=np.int64)
7678
except OverflowError:
77-
v = TEagerTensor(np.asarray(a, dtype=np.uint64))
79+
v = TEagerTensor(va)
80+
if v is None:
81+
if int(va) == int(vai):
82+
v = TEagerTensor(vai)
83+
else:
84+
v = TEagerTensor(va)
7885
elif isinstance(a, float):
7986
v = TEagerTensor(np.array(a, dtype=np.float64))
8087
elif isinstance(a, bool):

onnx_array_api/npx/npx_functions.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -681,7 +681,7 @@ def ones_like(
681681
dtype: OptParType[DType] = None,
682682
) -> TensorType[ElemType.numerics, "T"]:
683683
"""
684-
Implements :func:`numpy.zeros`.
684+
Implements :func:`numpy.ones_like`.
685685
"""
686686
o = make_tensor(
687687
name="one",
@@ -955,3 +955,25 @@ def zeros(
955955
value=make_tensor(name="zero", data_type=dtype.code, dims=[1], vals=[0]),
956956
op="ConstantOfShape",
957957
)
958+
959+
960+
@npxapi_inline
961+
def zeros_like(
962+
x: TensorType[ElemType.allowed, "T"],
963+
/,
964+
*,
965+
dtype: OptParType[DType] = None,
966+
) -> TensorType[ElemType.numerics, "T"]:
967+
"""
968+
Implements :func:`numpy.zeros_like`.
969+
"""
970+
o = make_tensor(
971+
name="zero",
972+
data_type=TensorProto.INT64 if dtype is None else dtype.code,
973+
dims=[1],
974+
vals=[0],
975+
)
976+
v = var(x.shape, value=o, op="ConstantOfShape")
977+
if dtype is None:
978+
return var(v, x, op="CastLike")
979+
return v

0 commit comments

Comments
 (0)