1
1
from typing import Any , assert_type
2
2
3
3
import numpy as np
4
- import numpy . typing as npt
4
+ import optype . numpy as onp
5
5
import optype .numpy .compat as npc
6
6
7
7
from scipy .special import logsumexp , softmax
@@ -28,23 +28,23 @@ assert_type(logsumexp(f16_1d), np.float16)
28
28
assert_type (logsumexp (c64_0d ), np .complex64 )
29
29
assert_type (logsumexp (c64_1d ), np .complex64 )
30
30
31
- assert_type (logsumexp (py_f_0d , keepdims = True ), npt . NDArray [np .float64 ])
32
- assert_type (logsumexp (py_f_1d , keepdims = True ), npt . NDArray [np .float64 ])
33
- assert_type (logsumexp (py_c_0d , keepdims = True ), npt . NDArray [np .complex128 ])
34
- assert_type (logsumexp (py_c_1d , keepdims = True ), npt . NDArray [np .complex128 ])
35
- assert_type (logsumexp (f16_0d , keepdims = True ), npt . NDArray [np .float16 ])
36
- assert_type (logsumexp (f16_1d , keepdims = True ), npt . NDArray [np .float16 ])
37
- assert_type (logsumexp (c64_0d , keepdims = True ), npt . NDArray [np .complex64 ])
38
- assert_type (logsumexp (c64_1d , keepdims = True ), npt . NDArray [np .complex64 ])
31
+ assert_type (logsumexp (py_f_0d , keepdims = True ), onp . ArrayND [np .float64 ])
32
+ assert_type (logsumexp (py_f_1d , keepdims = True ), onp . ArrayND [np .float64 ])
33
+ assert_type (logsumexp (py_c_0d , keepdims = True ), onp . ArrayND [np .complex128 ])
34
+ assert_type (logsumexp (py_c_1d , keepdims = True ), onp . ArrayND [np .complex128 ])
35
+ assert_type (logsumexp (f16_0d , keepdims = True ), onp . ArrayND [np .float16 ])
36
+ assert_type (logsumexp (f16_1d , keepdims = True ), onp . ArrayND [np .float16 ])
37
+ assert_type (logsumexp (c64_0d , keepdims = True ), onp . ArrayND [np .complex64 ])
38
+ assert_type (logsumexp (c64_1d , keepdims = True ), onp . ArrayND [np .complex64 ])
39
39
40
- assert_type (logsumexp (py_f_0d , axis = 0 ), npt . NDArray [np .float64 ] | Any )
41
- assert_type (logsumexp (py_f_1d , axis = 0 ), npt . NDArray [np .float64 ] | Any )
42
- assert_type (logsumexp (py_c_0d , axis = 0 ), npt . NDArray [np .complex128 ] | Any )
43
- assert_type (logsumexp (py_c_1d , axis = 0 ), npt . NDArray [np .complex128 ] | Any )
44
- assert_type (logsumexp (f16_0d , axis = 0 ), npt . NDArray [np .float16 ] | Any )
45
- assert_type (logsumexp (f16_1d , axis = 0 ), npt . NDArray [np .float16 ] | Any )
46
- assert_type (logsumexp (c64_0d , axis = 0 ), npt . NDArray [np .complex64 ] | Any )
47
- assert_type (logsumexp (c64_1d , axis = 0 ), npt . NDArray [np .complex64 ] | Any )
40
+ assert_type (logsumexp (py_f_0d , axis = 0 ), onp . ArrayND [np .float64 ] | Any )
41
+ assert_type (logsumexp (py_f_1d , axis = 0 ), onp . ArrayND [np .float64 ] | Any )
42
+ assert_type (logsumexp (py_c_0d , axis = 0 ), onp . ArrayND [np .complex128 ] | Any )
43
+ assert_type (logsumexp (py_c_1d , axis = 0 ), onp . ArrayND [np .complex128 ] | Any )
44
+ assert_type (logsumexp (f16_0d , axis = 0 ), onp . ArrayND [np .float16 ] | Any )
45
+ assert_type (logsumexp (f16_1d , axis = 0 ), onp . ArrayND [np .float16 ] | Any )
46
+ assert_type (logsumexp (c64_0d , axis = 0 ), onp . ArrayND [np .complex64 ] | Any )
47
+ assert_type (logsumexp (c64_1d , axis = 0 ), onp . ArrayND [np .complex64 ] | Any )
48
48
49
49
assert_type (logsumexp (py_f_0d , return_sign = True ), tuple [np .float64 , np .float64 ])
50
50
assert_type (logsumexp (py_f_1d , return_sign = True ), tuple [np .float64 , np .float64 ])
@@ -55,23 +55,23 @@ assert_type(logsumexp(f16_1d, return_sign=True), tuple[np.float16, np.float16])
55
55
assert_type (logsumexp (c64_0d , return_sign = True ), tuple [npc .floating , np .complex64 ])
56
56
assert_type (logsumexp (c64_1d , return_sign = True ), tuple [npc .floating , np .complex64 ])
57
57
58
- assert_type (logsumexp (py_f_0d , keepdims = True , return_sign = True ), tuple [npt . NDArray [np .float64 ], npt . NDArray [np .float64 ]])
59
- assert_type (logsumexp (py_f_1d , keepdims = True , return_sign = True ), tuple [npt . NDArray [np .float64 ], npt . NDArray [np .float64 ]])
60
- assert_type (logsumexp (py_c_0d , keepdims = True , return_sign = True ), tuple [npt . NDArray [np .float64 ], npt . NDArray [np .complex128 ]])
61
- assert_type (logsumexp (py_c_1d , keepdims = True , return_sign = True ), tuple [npt . NDArray [np .float64 ], npt . NDArray [np .complex128 ]])
62
- assert_type (logsumexp (f16_0d , keepdims = True , return_sign = True ), tuple [npt . NDArray [np .float16 ], npt . NDArray [np .float16 ]])
63
- assert_type (logsumexp (f16_1d , keepdims = True , return_sign = True ), tuple [npt . NDArray [np .float16 ], npt . NDArray [np .float16 ]])
64
- assert_type (logsumexp (c64_0d , keepdims = True , return_sign = True ), tuple [npt . NDArray [npc .floating ], npt . NDArray [np .complex64 ]])
65
- assert_type (logsumexp (c64_1d , keepdims = True , return_sign = True ), tuple [npt . NDArray [npc .floating ], npt . NDArray [np .complex64 ]])
58
+ assert_type (logsumexp (py_f_0d , keepdims = True , return_sign = True ), tuple [onp . ArrayND [np .float64 ], onp . ArrayND [np .float64 ]])
59
+ assert_type (logsumexp (py_f_1d , keepdims = True , return_sign = True ), tuple [onp . ArrayND [np .float64 ], onp . ArrayND [np .float64 ]])
60
+ assert_type (logsumexp (py_c_0d , keepdims = True , return_sign = True ), tuple [onp . ArrayND [np .float64 ], onp . ArrayND [np .complex128 ]])
61
+ assert_type (logsumexp (py_c_1d , keepdims = True , return_sign = True ), tuple [onp . ArrayND [np .float64 ], onp . ArrayND [np .complex128 ]])
62
+ assert_type (logsumexp (f16_0d , keepdims = True , return_sign = True ), tuple [onp . ArrayND [np .float16 ], onp . ArrayND [np .float16 ]])
63
+ assert_type (logsumexp (f16_1d , keepdims = True , return_sign = True ), tuple [onp . ArrayND [np .float16 ], onp . ArrayND [np .float16 ]])
64
+ assert_type (logsumexp (c64_0d , keepdims = True , return_sign = True ), tuple [onp . ArrayND [npc .floating ], onp . ArrayND [np .complex64 ]])
65
+ assert_type (logsumexp (c64_1d , keepdims = True , return_sign = True ), tuple [onp . ArrayND [npc .floating ], onp . ArrayND [np .complex64 ]])
66
66
67
- assert_type (logsumexp (py_f_0d , axis = 0 , return_sign = True ), tuple [npt . NDArray [np .float64 ] | Any , npt . NDArray [np .float64 ] | Any ])
68
- assert_type (logsumexp (py_f_1d , axis = 0 , return_sign = True ), tuple [npt . NDArray [np .float64 ] | Any , npt . NDArray [np .float64 ] | Any ])
69
- assert_type (logsumexp (py_c_0d , axis = 0 , return_sign = True ), tuple [npt . NDArray [np .float64 ] | Any , npt . NDArray [np .complex128 ] | Any ])
70
- assert_type (logsumexp (py_c_1d , axis = 0 , return_sign = True ), tuple [npt . NDArray [np .float64 ] | Any , npt . NDArray [np .complex128 ] | Any ])
71
- assert_type (logsumexp (f16_0d , axis = 0 , return_sign = True ), tuple [npt . NDArray [np .float16 ] | Any , npt . NDArray [np .float16 ] | Any ])
72
- assert_type (logsumexp (f16_1d , axis = 0 , return_sign = True ), tuple [npt . NDArray [np .float16 ] | Any , npt . NDArray [np .float16 ] | Any ])
73
- assert_type (logsumexp (c64_0d , axis = 0 , return_sign = True ), tuple [npt . NDArray [npc .floating ] | Any , npt . NDArray [np .complex64 ] | Any ])
74
- assert_type (logsumexp (c64_1d , axis = 0 , return_sign = True ), tuple [npt . NDArray [npc .floating ] | Any , npt . NDArray [np .complex64 ] | Any ])
67
+ assert_type (logsumexp (py_f_0d , axis = 0 , return_sign = True ), tuple [onp . ArrayND [np .float64 ] | Any , onp . ArrayND [np .float64 ] | Any ])
68
+ assert_type (logsumexp (py_f_1d , axis = 0 , return_sign = True ), tuple [onp . ArrayND [np .float64 ] | Any , onp . ArrayND [np .float64 ] | Any ])
69
+ assert_type (logsumexp (py_c_0d , axis = 0 , return_sign = True ), tuple [onp . ArrayND [np .float64 ] | Any , onp . ArrayND [np .complex128 ] | Any ])
70
+ assert_type (logsumexp (py_c_1d , axis = 0 , return_sign = True ), tuple [onp . ArrayND [np .float64 ] | Any , onp . ArrayND [np .complex128 ] | Any ])
71
+ assert_type (logsumexp (f16_0d , axis = 0 , return_sign = True ), tuple [onp . ArrayND [np .float16 ] | Any , onp . ArrayND [np .float16 ] | Any ])
72
+ assert_type (logsumexp (f16_1d , axis = 0 , return_sign = True ), tuple [onp . ArrayND [np .float16 ] | Any , onp . ArrayND [np .float16 ] | Any ])
73
+ assert_type (logsumexp (c64_0d , axis = 0 , return_sign = True ), tuple [onp . ArrayND [npc .floating ] | Any , onp . ArrayND [np .complex64 ] | Any ])
74
+ assert_type (logsumexp (c64_1d , axis = 0 , return_sign = True ), tuple [onp . ArrayND [npc .floating ] | Any , onp . ArrayND [np .complex64 ] | Any ])
75
75
76
76
###
77
77
# softmax (equiv log_softmax)
@@ -81,7 +81,7 @@ assert_type(softmax(py_c_0d), np.complex128)
81
81
assert_type (softmax (f16_0d ), np .float16 )
82
82
assert_type (softmax (c64_0d ), np .complex64 )
83
83
84
- assert_type (softmax (py_f_1d ), npt . NDArray [np .float64 ])
85
- assert_type (softmax (py_c_1d ), npt . NDArray [np .complex128 ])
84
+ assert_type (softmax (py_f_1d ), onp . ArrayND [np .float64 ])
85
+ assert_type (softmax (py_c_1d ), onp . ArrayND [np .complex128 ])
86
86
assert_type (softmax (f16_1d ), np .ndarray [tuple [int ], np .dtype [np .float16 ]])
87
87
assert_type (softmax (c64_1d ), np .ndarray [tuple [int ], np .dtype [np .complex64 ]])
0 commit comments