@@ -34,9 +34,7 @@ def spherical_pw(N, k, r, setup):
3434 n = np .arange (N + 1 )
3535
3636 bn = weights (N , kr , setup )
37- for i , x in enumerate (kr ):
38- bn [i , :] = bn [i , :] * 4 * np .pi * (1j )** n
39- return bn
37+ return 4 * np .pi * (1j )** n * bn
4038
4139
4240def spherical_ps (N , k , r , rs , setup ):
@@ -72,11 +70,14 @@ def spherical_ps(N, k, r, rs, setup):
7270 n = np .arange (N + 1 )
7371
7472 bn = weights (N , k * r , setup )
73+ if len (k ) == 1 :
74+ bn = bn [np .newaxis , :]
75+
7576 for i , x in enumerate (krs ):
7677 hn = special .spherical_jn (n , x ) - 1j * special .spherical_yn (n , x )
7778 bn [i , :] = bn [i , :] * 4 * np .pi * (- 1j ) * hn * k [i ]
7879
79- return bn
80+ return np . squeeze ( bn )
8081
8182
8283def weights (N , kr , setup ):
@@ -106,6 +107,7 @@ def weights(N, kr, setup):
106107 Radial weights for all orders up to N and the given wavenumbers.
107108
108109 """
110+ kr = util .asarray_1d (kr )
109111 n = np .arange (N + 1 )
110112 bns = np .zeros ((len (kr ), N + 1 ), dtype = complex )
111113 for i , x in enumerate (kr ):
0 commit comments