Skip to content

Commit 231b508

Browse files
authored
perf: better integral (#59)
1 parent 68de305 commit 231b508

File tree

2 files changed

+34
-14
lines changed

2 files changed

+34
-14
lines changed

src/ultrasphere/_integral.py

Lines changed: 31 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -150,12 +150,20 @@ def integrate(
150150
151151
Parameters
152152
----------
153-
f : Callable[ [Mapping[TSpherical, Array]],
154-
Mapping[TSpherical, Array] | Array, ]
155-
| Mapping[TSpherical, Array] | Array
156-
The function to integrate or the values of the function.
157-
In case of vectorized function, the function should add extra
158-
axis to the last dimension, not the first dimension.
153+
f : Callable[ [Mapping[TSpherical, Array]], Mapping[TSpherical, Array] | Array, ] | Mapping[TSpherical, Array] | Array # noqa: E501
154+
The function to integrate or the values of the function.
155+
156+
If mapping, the separated parts of the function for each spherical coordinate.
157+
158+
If mapping, the shapes do not need to be broadcastable.
159+
160+
If function, if does_f_support_separation_of_variables is True,
161+
1D array of integration points are passed,
162+
and extra axis should be added to the last dimension.
163+
164+
If function, if does_f_support_separation_of_variables is False,
165+
``c.s_ndim``-D array of integration points are passed,
166+
and extra axis should be added to the last dimension.
159167
does_f_support_separation_of_variables : bool
160168
Whether the function supports separation of variables.
161169
This could significantly reduce the computational cost.
@@ -170,6 +178,7 @@ def integrate(
170178
-------
171179
Array | Mapping[TSpherical, Array]
172180
The integrated value.
181+
Has the same shape as the return values of f or the values of f.
173182
174183
"""
175184
xs, ws = roots(
@@ -199,13 +208,25 @@ def integrate(
199208
# theta(node),u1,...,uM
200209
xpx.broadcast_shapes(value.shape[:1], ws[node].shape)
201210
w = xp.reshape(ws[node], (-1,) + (1,) * (value.ndim - 1))
202-
result[node] = xp.sum(value * w, axis=0)
211+
if value.shape[0] == 1:
212+
result[node] = value[0, ...] * xp.sum(w)
213+
else:
214+
result[node] = xp.vecdot(value, w, axis=0)
203215
# we don't know how to einsum the result
204216
return result
217+
if val.ndim < c.s_ndim:
218+
raise ValueError(
219+
f"The dimension of the return value of f should be at least {c.s_ndim}, got {val.ndim}."
220+
)
221+
xpx.broadcast_shapes(
222+
val.shape[: c.s_ndim],
223+
xpx.broadcast_shapes(*(xs[node].shape for node in c.s_nodes)),
224+
)
205225
# theta1,...,thetaN,u1,...,uM\
206226
for node in c.s_nodes:
207227
w = ws[node]
208-
xpx.broadcast_shapes(val.shape[:1], w.shape)
209-
# val = xp.einsum("i...,i->...", val, w.astype(val.dtype))
210-
val = xp.sum(val * w[(slice(None),) + (None,) * (val.ndim - 1)], axis=0)
228+
if val.shape[0] == 1:
229+
val = val[0, ...] * xp.sum(w)
230+
else:
231+
val = xp.vecdot(val, w[(slice(None),) + (None,) * (val.ndim - 1)], axis=0)
211232
return val

tests/test_integral.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import math
2-
from collections import defaultdict
32
from collections.abc import Callable, Mapping
43
from typing import Any, Literal
54

@@ -21,7 +20,7 @@ def test_sphere_surface_integrate(
2120
xp: ArrayNamespaceFull,
2221
) -> None:
2322
def f2(s):
24-
return xp.asarray(f(s))
23+
return xp.asarray(f(s)) * xp.ones_like(s["theta"])
2524

2625
c = c_spherical()
2726
assert integrate(
@@ -50,9 +49,9 @@ def test_integrate(
5049
# surface integral (area) of the sphere
5150
def f(s: Mapping[TSpherical, Array]) -> Array:
5251
if concat:
53-
return xp.asarray(r**c.s_ndim)
52+
return xp.asarray(r**c.s_ndim) * xp.ones_like(next(iter(s.values())))
5453
else:
55-
return defaultdict(lambda: xp.asarray(r))
54+
return {k: xp.asarray(r) * xp.ones_like(s[k]) for k in c.s_nodes}
5655

5756
actual = integrate(
5857
c,

0 commit comments

Comments
 (0)