@@ -34,9 +34,7 @@ def spherical_pw(N, k, r, setup):
3434 n = np .arange (N + 1 )
3535
3636 bn = weights (N , kr , setup )
37- for i , x in enumerate (kr ):
38- bn [i , :] = bn [i , :] * 4 * np .pi * (1j )** n
39- return bn
37+ return 4 * np .pi * (1j )** n * bn
4038
4139
4240def spherical_ps (N , k , r , rs , setup ):
@@ -72,11 +70,14 @@ def spherical_ps(N, k, r, rs, setup):
7270 n = np .arange (N + 1 )
7371
7472 bn = weights (N , k * r , setup )
73+ if len (k ) == 1 :
74+ bn = bn [np .newaxis , :]
75+
7576 for i , x in enumerate (krs ):
7677 hn = special .spherical_jn (n , x ) - 1j * special .spherical_yn (n , x )
7778 bn [i , :] = bn [i , :] * 4 * np .pi * (- 1j ) * hn * k [i ]
7879
79- return bn
80+ return np . squeeze ( bn )
8081
8182
8283def weights (N , kr , setup ):
@@ -106,6 +107,7 @@ def weights(N, kr, setup):
106107 Radial weights for all orders up to N and the given wavenumbers.
107108
108109 """
110+ kr = util .asarray_1d (kr )
109111 n = np .arange (N + 1 )
110112 bns = np .zeros ((len (kr ), N + 1 ), dtype = complex )
111113 for i , x in enumerate (kr ):
@@ -259,9 +261,7 @@ def circular_pw(N, k, r, setup):
259261 n = np .arange (N + 1 )
260262
261263 bn = circ_radial_weights (N , kr , setup )
262- for i , x in enumerate (kr ):
263- bn [i , :] = bn [i , :] * (1j )** (n )
264- return bn
264+ return (1j )** (n ) * bn
265265
266266
267267def circular_ls (N , k , r , rs , setup ):
@@ -297,10 +297,12 @@ def circular_ls(N, k, r, rs, setup):
297297 n = np .arange (N + 1 )
298298
299299 bn = circ_radial_weights (N , k * r , setup )
300+ if len (k ) == 1 :
301+ bn = bn [np .newaxis , :]
300302 for i , x in enumerate (krs ):
301303 Hn = special .hankel2 (n , x )
302304 bn [i , :] = bn [i , :] * - 1j / 4 * Hn
303- return bn
305+ return np . squeeze ( bn )
304306
305307
306308def circ_radial_weights (N , kr , setup ):
@@ -329,6 +331,7 @@ def circ_radial_weights(N, kr, setup):
329331 Radial weights for all orders up to N and the given wavenumbers.
330332
331333 """
334+ kr = util .asarray_1d (kr )
332335 n = np .arange (N + 1 )
333336 Bns = np .zeros ((len (kr ), N + 1 ), dtype = complex )
334337 for i , x in enumerate (kr ):
0 commit comments