Skip to content

Commit f935594

Browse files
BUG: lib: Allow type uint64 for eye() arguments.
Closes numpygh-9982. (Plus a few small PEP 8 fixes.)
1 parent e49478c commit f935594

File tree

2 files changed

+16
-4
lines changed

2 files changed

+16
-4
lines changed

numpy/lib/tests/test_twodim_base.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,12 @@ def test_basic(self):
4444
assert_equal(eye(3) == 1,
4545
eye(3, dtype=bool))
4646

47+
def test_uint64(self):
48+
# Regression test for gh-9982
49+
assert_equal(eye(np.uint64(2), dtype=int), array([[1, 0], [0, 1]]))
50+
assert_equal(eye(np.uint64(2), M=np.uint64(4), k=np.uint64(1)),
51+
array([[0, 1, 0, 0], [0, 0, 1, 0]]))
52+
4753
def test_diag(self):
4854
assert_equal(eye(4, k=1),
4955
array([[0, 1, 0, 0],
@@ -382,7 +388,7 @@ def test_tril_triu_dtype():
382388
assert_equal(np.triu(arr).dtype, arr.dtype)
383389
assert_equal(np.tril(arr).dtype, arr.dtype)
384390

385-
arr = np.zeros((3,3), dtype='f4,f4')
391+
arr = np.zeros((3, 3), dtype='f4,f4')
386392
assert_equal(np.triu(arr).dtype, arr.dtype)
387393
assert_equal(np.tril(arr).dtype, arr.dtype)
388394

numpy/lib/twodim_base.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
33
"""
44
import functools
5+
import operator
56

67
from numpy.core.numeric import (
78
asanyarray, arange, zeros, greater_equal, multiply, ones,
@@ -214,6 +215,11 @@ def eye(N, M=None, k=0, dtype=float, order='C', *, like=None):
214215
m = zeros((N, M), dtype=dtype, order=order)
215216
if k >= M:
216217
return m
218+
# Ensure M and k are integers, so we don't get any surprise casting
219+
# results in the expressions `M-k` and `M+1` used below. This avoids
220+
# a problem with inputs with type (for example) np.uint64.
221+
M = operator.index(M)
222+
k = operator.index(k)
217223
if k >= 0:
218224
i = k
219225
else:
@@ -494,8 +500,8 @@ def triu(m, k=0):
494500
Upper triangle of an array.
495501
496502
Return a copy of an array with the elements below the `k`-th diagonal
497-
zeroed. For arrays with ``ndim`` exceeding 2, `triu` will apply to the final
498-
two axes.
503+
zeroed. For arrays with ``ndim`` exceeding 2, `triu` will apply to the
504+
final two axes.
499505
500506
Please refer to the documentation for `tril` for further details.
501507
@@ -804,7 +810,7 @@ def histogram2d(x, y, bins=10, range=None, normed=None, weights=None,
804810
>>> plt.show()
805811
"""
806812
from numpy import histogramdd
807-
813+
808814
if len(x) != len(y):
809815
raise ValueError('x and y must have the same length.')
810816

0 commit comments

Comments
 (0)