Skip to content

Commit 8c89fef

Browse files
committed
ENH: Add special-casing for complexfloating so that it can take 2 parameters
1 parent aecdb9f commit 8c89fef

File tree

2 files changed

+18
-4
lines changed

2 files changed

+18
-4
lines changed

numpy/core/src/multiarray/scalartypes.c.src

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1812,12 +1812,22 @@ numbertype_class_getitem_abc(PyObject *cls, PyObject *args)
18121812

18131813
#ifdef Py_GENERICALIASOBJECT_H
18141814
Py_ssize_t args_len;
1815+
int args_len_expected;
1816+
1817+
/* complexfloating should take 2 parameters, all others take 1 */
1818+
if (PyType_IsSubtype((PyTypeObject *)cls,
1819+
&PyComplexFloatingArrType_Type)) {
1820+
args_len_expected = 2;
1821+
}
1822+
else {
1823+
args_len_expected = 1;
1824+
}
18151825

18161826
args_len = PyTuple_Check(args) ? PyTuple_Size(args) : 1;
1817-
if (args_len != 1) {
1827+
if (args_len != args_len_expected) {
18181828
return PyErr_Format(PyExc_TypeError,
18191829
"Too %s arguments for %s",
1820-
args_len > 1 ? "many" : "few",
1830+
args_len > args_len_expected ? "many" : "few",
18211831
((PyTypeObject *)cls)->tp_name);
18221832
}
18231833
generic_alias = Py_GenericAlias(cls, args);

numpy/core/tests/test_scalar_methods.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -142,13 +142,17 @@ class TestClassGetItem:
142142
np.unsignedinteger,
143143
np.signedinteger,
144144
np.floating,
145-
np.complexfloating,
146145
])
147146
def test_abc(self, cls: Type[np.number]) -> None:
148147
alias = cls[Any]
149148
assert isinstance(alias, types.GenericAlias)
150149
assert alias.__origin__ is cls
151150

151+
def test_abc_complexfloating(self) -> None:
152+
alias = np.complexfloating[Any, Any]
153+
assert isinstance(alias, types.GenericAlias)
154+
assert alias.__origin__ is np.complexfloating
155+
152156
@pytest.mark.parametrize("cls", [np.generic, np.flexible, np.character])
153157
def test_abc_non_numeric(self, cls: Type[np.generic]) -> None:
154158
with pytest.raises(TypeError):
@@ -174,7 +178,7 @@ def test_subscript_scalar(self) -> None:
174178

175179

176180
@pytest.mark.skipif(sys.version_info >= (3, 9), reason="Requires python 3.8")
177-
@pytest.mark.parametrize("cls", [np.number, np.int64])
181+
@pytest.mark.parametrize("cls", [np.number, np.complexfloating, np.int64])
178182
def test_class_getitem_38(cls: Type[np.number]) -> None:
179183
match = "Type subscription requires python >= 3.9"
180184
with pytest.raises(TypeError, match=match):

0 commit comments

Comments
 (0)