@@ -150,12 +150,20 @@ def integrate(
150150
151151 Parameters
152152 ----------
153- f : Callable[ [Mapping[TSpherical, Array]],
154- Mapping[TSpherical, Array] | Array, ]
155- | Mapping[TSpherical, Array] | Array
156- The function to integrate or the values of the function.
157- In case of vectorized function, the function should add extra
158- axis to the last dimension, not the first dimension.
153+ f : Callable[ [Mapping[TSpherical, Array]], Mapping[TSpherical, Array] | Array, ] | Mapping[TSpherical, Array] | Array # noqa: E501
154+ The function to integrate or the values of the function.
155+
156+ If mapping, the separated parts of the function for each spherical coordinate.
157+
158+ If mapping, the shapes do not need to be broadcastable.
159+
160+ If function, if does_f_support_separation_of_variables is True,
161+ 1D array of integration points are passed,
162+ and extra axis should be added to the last dimension.
163+
164+ If function, if does_f_support_separation_of_variables is False,
165+ ``c.s_ndim``-D array of integration points are passed,
166+ and extra axis should be added to the last dimension.
159167 does_f_support_separation_of_variables : bool
160168 Whether the function supports separation of variables.
161169 This could significantly reduce the computational cost.
@@ -170,6 +178,7 @@ def integrate(
170178 -------
171179 Array | Mapping[TSpherical, Array]
172180 The integrated value.
181+ Has the same shape as the return values of f or the values of f.
173182
174183 """
175184 xs , ws = roots (
@@ -199,13 +208,25 @@ def integrate(
199208 # theta(node),u1,...,uM
200209 xpx .broadcast_shapes (value .shape [:1 ], ws [node ].shape )
201210 w = xp .reshape (ws [node ], (- 1 ,) + (1 ,) * (value .ndim - 1 ))
202- result [node ] = xp .sum (value * w , axis = 0 )
211+ if value .shape [0 ] == 1 :
212+ result [node ] = value [0 , ...] * xp .sum (w )
213+ else :
214+ result [node ] = xp .vecdot (value , w , axis = 0 )
203215 # we don't know how to einsum the result
204216 return result
217+ if val .ndim < c .s_ndim :
218+ raise ValueError (
219+ f"The dimension of the return value of f should be at least { c .s_ndim } , got { val .ndim } ."
220+ )
221+ xpx .broadcast_shapes (
222+ val .shape [: c .s_ndim ],
223+ xpx .broadcast_shapes (* (xs [node ].shape for node in c .s_nodes )),
224+ )
205225 # theta1,...,thetaN,u1,...,uM\
206226 for node in c .s_nodes :
207227 w = ws [node ]
208- xpx .broadcast_shapes (val .shape [:1 ], w .shape )
209- # val = xp.einsum("i...,i->...", val, w.astype(val.dtype))
210- val = xp .sum (val * w [(slice (None ),) + (None ,) * (val .ndim - 1 )], axis = 0 )
228+ if val .shape [0 ] == 1 :
229+ val = val [0 , ...] * xp .sum (w )
230+ else :
231+ val = xp .vecdot (val , w [(slice (None ),) + (None ,) * (val .ndim - 1 )], axis = 0 )
211232 return val
0 commit comments