@@ -6,39 +6,60 @@ import numpy as np
6
6
import optype .numpy as onp
7
7
import optype .numpy .compat as npc
8
8
9
- from scipy .linalg import solve , solve_banded , solve_circulant , solve_toeplitz , solve_triangular
9
+ from scipy .linalg import inv , solve , solve_banded , solve_circulant , solve_toeplitz , solve_triangular
10
+
11
+ b1_nd : onp .ArrayND [np .bool_ ]
10
12
11
13
i8_1d : onp .Array1D [np .int8 ]
12
14
i8_2d : onp .Array2D [np .int8 ]
13
15
i8_3d : onp .Array3D [np .int8 ]
16
+ i8_nd : onp .ArrayND [np .int8 ]
17
+
18
+ i32_1d : onp .Array1D [np .int32 ]
19
+ i32_2d : onp .Array2D [np .int32 ]
20
+ i32_3d : onp .Array3D [np .int32 ]
21
+ i32_nd : onp .ArrayND [np .int32 ]
14
22
15
23
f16_1d : onp .Array1D [np .float16 ]
16
24
f16_2d : onp .Array2D [np .float16 ]
17
25
f16_3d : onp .Array3D [np .float16 ]
26
+ f16_nd : onp .ArrayND [np .float16 ]
18
27
19
28
f32_1d : onp .Array1D [np .float32 ]
20
29
f32_2d : onp .Array2D [np .float32 ]
21
30
f32_3d : onp .Array3D [np .float32 ]
31
+ f32_nd : onp .ArrayND [np .float32 ]
22
32
23
33
f64_1d : onp .Array1D [np .float64 ]
24
34
f64_2d : onp .Array2D [np .float64 ]
25
35
f64_3d : onp .Array3D [np .float64 ]
36
+ f64_nd : onp .ArrayND [np .float64 ]
26
37
27
38
f80_1d : onp .Array1D [np .longdouble ]
28
39
f80_2d : onp .Array2D [np .longdouble ]
29
40
f80_3d : onp .Array3D [np .longdouble ]
41
+ f80_nd : onp .ArrayND [np .longdouble ]
30
42
31
43
c64_1d : onp .Array1D [np .complex64 ]
32
44
c64_2d : onp .Array2D [np .complex64 ]
33
45
c64_3d : onp .Array3D [np .complex64 ]
46
+ c64_nd : onp .ArrayND [np .complex64 ]
34
47
35
48
c128_1d : onp .Array1D [np .complex128 ]
36
49
c128_2d : onp .Array2D [np .complex128 ]
37
50
c128_3d : onp .Array3D [np .complex128 ]
51
+ c128_nd : onp .ArrayND [np .complex128 ]
38
52
39
53
c160_1d : onp .Array1D [np .clongdouble ]
40
54
c160_2d : onp .Array2D [np .clongdouble ]
41
55
c160_3d : onp .Array3D [np .clongdouble ]
56
+ c160_nd : onp .ArrayND [np .clongdouble ]
57
+
58
+ py_b_2d : list [list [bool ]]
59
+ py_b_3d : list [list [list [bool ]]]
60
+
61
+ py_i_2d : list [list [int ]]
62
+ py_i_3d : list [list [list [int ]]]
42
63
43
64
py_f_1d : list [float ]
44
65
py_f_2d : list [list [float ]]
@@ -314,4 +335,50 @@ assert_type(solve_circulant(py_c_1d, py_c_3d), onp.ArrayND[np.complex128])
314
335
assert_type (solve_circulant (py_c_2d , py_c_1d ), onp .ArrayND [np .complex128 ])
315
336
316
337
###
317
- # TODO(jorenham): inv, pinv, pinvh, det, lstsq, matrix_balance, matmul_toeplitz
338
+ # inv
339
+
340
+ assert_type (inv (f32_2d ), onp .Array2D [np .float32 ])
341
+ assert_type (inv (f64_2d ), onp .Array2D [np .float64 ])
342
+ assert_type (inv (c64_2d ), onp .Array2D [np .complex64 ])
343
+ assert_type (inv (c128_2d ), onp .Array2D [np .complex128 ])
344
+
345
+ assert_type (inv (py_b_2d ), onp .Array2D [np .float32 ])
346
+ assert_type (inv (py_i_2d ), onp .Array2D [np .float64 ])
347
+ assert_type (inv (py_f_2d ), onp .Array2D [np .float64 ])
348
+ assert_type (inv (py_c_2d ), onp .Array2D [np .complex128 ])
349
+
350
+ assert_type (inv (f32_3d ), onp .Array3D [np .float32 ])
351
+ assert_type (inv (f64_3d ), onp .Array3D [np .float64 ])
352
+ assert_type (inv (c64_3d ), onp .Array3D [np .complex64 ])
353
+ assert_type (inv (c128_3d ), onp .Array3D [np .complex128 ])
354
+
355
+ assert_type (inv (py_b_3d ), onp .ArrayND [np .float32 ])
356
+ assert_type (inv (py_i_3d ), onp .ArrayND [np .float64 ])
357
+ assert_type (inv (py_f_3d ), onp .ArrayND [np .float64 ])
358
+ assert_type (inv (py_c_3d ), onp .ArrayND [np .complex128 ])
359
+
360
+ assert_type (inv (b1_nd ), onp .ArrayND [np .float32 ])
361
+ assert_type (inv (i8_nd ), onp .ArrayND [np .float32 ])
362
+ assert_type (inv (f16_nd ), onp .ArrayND [np .float32 ])
363
+ assert_type (inv (f32_nd ), onp .ArrayND [np .float32 ])
364
+ assert_type (inv (i32_nd ), onp .ArrayND [np .float64 ])
365
+ assert_type (inv (f64_nd ), onp .ArrayND [np .float64 ])
366
+ assert_type (inv (f80_nd ), onp .ArrayND [np .float64 ])
367
+ assert_type (inv (c64_nd ), onp .ArrayND [np .complex64 ])
368
+ assert_type (inv (c128_nd ), onp .ArrayND [np .complex128 ])
369
+ assert_type (inv (c160_nd ), onp .ArrayND [np .complex128 ])
370
+
371
+ ###
372
+ # TODO(jorenham): det
373
+
374
+ ###
375
+ # TODO(jorenham): lstsq
376
+
377
+ ###
378
+ # TODO(jorenham): pinv[h]
379
+
380
+ ###
381
+ # TODO(jorenham): matrix_balance
382
+
383
+ ###
384
+ # TODO(jorenham): matmul_toeplitz
0 commit comments