Skip to content

Commit 7cab0b7

Browse files
authored
Expanding on vectors and simplifying/speeding up forward for directsumtps (#34)
* Expanding on vectors and simplifying/speeding up forward for directsumtps * Remove unused fixtures from test_to_from_orthogonal * Increase tolerance for 3D to_from_orthogonal test * Caching inverse of stencil matrix when computed the first time
1 parent 298d11c commit 7cab0b7

File tree

12 files changed

+374
-98
lines changed

12 files changed

+374
-98
lines changed

examples/poisson1D.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,6 @@
1212
from jaxfun.galerkin.Chebyshev import Chebyshev as space
1313
from jaxfun.galerkin.functionspace import FunctionSpace
1414
from jaxfun.galerkin.inner import inner
15-
16-
# from jaxfun.Jacobi import Jacobi as space
1715
from jaxfun.operators import Div, Grad
1816
from jaxfun.utils.common import lambdify, n, ulp
1917

@@ -28,7 +26,8 @@
2826
ue = 1 - x**2 # * sp.exp(sp.cos(2 * sp.pi * x))
2927

3028
# A = inner(v*sp.Derivative(u, x, 2), sparse=True)
31-
# A = inner(-Dot(Grad(v), Grad(u)), sparse=True)
29+
# A = inner(-Dot(sp.sqrt(1-x**2)*Grad(v/sp.sqrt(1-x**2)), Grad(u)), sparse=True) # Cheb
30+
# A = inner(-Dot(Grad(v), Grad(u)), sparse=True) # Legendre
3231
# A = inner(v*Div(Grad(u)), sparse=True)
3332
# b = inner(v*sp.Derivative(ue, x, 2))
3433
A, b = inner(

src/jaxfun/galerkin/Chebyshev.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -270,7 +270,7 @@ def derivative_coeffs(self, c: Array, k: int = 0) -> Array:
270270
N: int = c.shape[0] - 1
271271
x0: Array = jnp.array(0.0, dtype=float)
272272
if N == 0:
273-
return x0
273+
return jnp.array([x0])
274274
x1: Array = c[-1] * N * 2
275275
if N == 1:
276276
return jnp.array([x1, x0])
@@ -282,7 +282,7 @@ def inner_loop(
282282
x2 = 2 * (n + 1) * c[n + 1] + x0
283283
return (x1, x2), x2
284284

285-
_, xs = jax.lax.scan(inner_loop, (x0, x1), jnp.arange(N - 2, -1, -1))
285+
xs = jax.lax.scan(inner_loop, (x0, x1), jnp.arange(N - 2, -1, -1))[1]
286286
return jnp.concatenate(
287287
(jnp.array([xs[-1] / 2]), xs[-2::-1], jnp.array([x1, x0]))
288288
)

src/jaxfun/galerkin/Legendre.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ def derivative_coeffs(self, c: Array, k: int = 0) -> Array:
168168
N: int = c.shape[0] - 1
169169
x0: Array = jnp.array(0.0)
170170
if N == 0:
171-
return x0
171+
return jnp.array([x0])
172172
x1: Array = c[-1] * (2 * N - 1)
173173
if N == 1:
174174
return jnp.array([x1, x0])

src/jaxfun/galerkin/arguments.py

Lines changed: 3 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -786,48 +786,11 @@ def evaluate_mesh(self, x: Array | list[Array]) -> Array:
786786
return self.functionspace.evaluate_mesh(x, self.array, True)
787787

788788

789-
def evaluate_jaxfunction_expr(
790-
a: Basic, xj: Array | tuple[Array, ...], jaxf: JAXFunction | None = None
791-
) -> Array:
792-
"""Evaluate a symbolic JAXFunction expression on a mesh.
793-
794-
Input coordinates ``xj`` are always given in the physical (true) domain.
795-
"""
796-
797-
if jaxf is None:
798-
from .forms import get_jaxfunctions
799-
800-
jaxf: set[JAXFunction] = get_jaxfunctions(a)
801-
assert len(jaxf) == 1, "Single JAXFunction not found in expression."
802-
jaxf: JAXFunction = jaxf.pop()
803-
804-
V = jaxf.functionspace
805-
806-
if isinstance(a, sp.Pow):
807-
wa = a.args[0]
808-
variables = getattr(wa, "variables", ())
809-
var = tuple(int(variables.count(s)) for s in V.system.base_scalars())
810-
var = var[0] if V.dims == 1 else var
811-
h = V.evaluate_derivative(xj, jaxf.array, k=var)
812-
return h ** int(a.exp)
813-
814-
if isinstance(a, sp.Derivative):
815-
variables = getattr(a, "variables", ())
816-
var = tuple(int(variables.count(s)) for s in V.system.base_scalars())
817-
var = var[0] if V.dims == 1 else var
818-
return V.evaluate_derivative(xj, jaxf.array, k=var)
819-
820-
if not isinstance(V, OrthogonalSpace | DirectSum):
821-
return V.evaluate_mesh(xj, jaxf.array, True)
822-
823-
assert isinstance(xj, Array)
824-
return V.evaluate(xj, jaxf.array)
825-
826-
827789
def evaluate_jaxfunction_expr_quad(
828790
a: Basic, jaxf: JAXFunction | None = None, N: int | tuple[int, ...] | None = None
829791
) -> Array:
830792
"""Evaluate a symbolic JAXFunction expression on the quadrature mesh."""
793+
from jaxfun.integrators.nonlinear import compile_nonlinear_evaluator
831794

832795
if jaxf is None:
833796
from .forms import get_jaxfunctions
@@ -837,19 +800,5 @@ def evaluate_jaxfunction_expr_quad(
837800
jaxf: JAXFunction = jaxf.pop()
838801

839802
V = jaxf.functionspace
840-
841-
if isinstance(a, sp.Pow):
842-
wa = a.args[0]
843-
variables = getattr(wa, "variables", ())
844-
var = tuple(int(variables.count(s)) for s in V.system.base_scalars())
845-
var = var[0] if V.dims == 1 else var
846-
h = V.backward_primitive(jaxf.array, k=var, N=N)
847-
return h ** int(a.exp)
848-
849-
if isinstance(a, sp.Derivative):
850-
variables = getattr(a, "variables", ())
851-
var = tuple(int(variables.count(s)) for s in V.system.base_scalars())
852-
var = var[0] if V.dims == 1 else var
853-
return V.backward_primitive(jaxf.array, k=var, N=N)
854-
855-
return V.backward(jaxf.array, N=N)
803+
fun = compile_nonlinear_evaluator(cast(sp.Expr, a), V, cast(AppliedUndef, jaxf))
804+
return fun(jaxf.array, N)

src/jaxfun/galerkin/composite.py

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,11 @@ def to_orthogonal(self, a: Array) -> Array:
252252
"""Map composite coefficients -> underlying orthogonal coefficients."""
253253
return a @ self.S
254254

255+
@jax.jit(static_argnums=0)
256+
def from_orthogonal(self, a: Array) -> Array:
257+
"""Map underlying orthogonal coefficients -> composite coefficients."""
258+
return a @ self.get_inverse_stencil()
259+
255260
@jax.jit(static_argnums=0)
256261
def apply_stencil_galerkin(self, b: Array) -> Array:
257262
"""Apply stencil on both sides (Galerkin mass-like transform)."""
@@ -439,9 +444,11 @@ def __init__(self, a: Composite, b: BCGeneric) -> None:
439444
self.map_reference_domain = a.map_reference_domain
440445
self.map_true_domain = a.map_true_domain
441446
self.quad_points_and_weights = a.quad_points_and_weights
447+
self.num_quad_points = a.num_quad_points
442448
self.get_orthogonal = a.get_orthogonal
443449
self.dims = a.dims
444450
self.rank = a.rank
451+
self.domain = a.domain
445452

446453
@overload
447454
def __getitem__(self, i: Literal[0]) -> Composite: ...
@@ -486,8 +493,16 @@ def to_orthogonal(self, c: Array) -> Array:
486493
"""Map direct-sum coefficients -> underlying orthogonal coefficients."""
487494
c_a = self[0].to_orthogonal(c)
488495
c_b = self[1].to_orthogonal(self[1].bnd_vals())
489-
Nd = c_b.shape[0]
490-
return jnp.concatenate((c_a[:Nd] + c_b, c_a[Nd:]))
496+
c_b = jnp.pad(c_b, (0, c_a.shape[0] - c_b.shape[0]))
497+
return c_a + c_b
498+
499+
@jax.jit(static_argnums=0)
500+
def from_orthogonal(self, a: Array) -> Array:
501+
"""Map underlying orthogonal coefficients -> direct-sum (inhomogeneous) coefficients.""" # noqa: E501
502+
c_b = self[1].to_orthogonal(self[1].bnd_vals())
503+
c_b = jnp.pad(c_b, (0, a.shape[0] - c_b.shape[0]))
504+
c_a = a - c_b
505+
return self[0].from_orthogonal(c_a)
491506

492507
@jax.jit(static_argnums=0)
493508
def evaluate(self, x: Array, c: Array) -> Array:
@@ -512,15 +527,10 @@ def backward_primitive(
512527
"""Return backward transform for k-th derivative."""
513528
return self.orthogonal.backward_primitive(self.to_orthogonal(c), k, kind, N)
514529

530+
@jax.jit(static_argnums=0)
515531
def forward(self, uj: Array) -> Array:
516-
from .arguments import TestFunction, TrialFunction
517-
from .inner import inner
518-
519-
u = TrialFunction(self)
520-
v = TestFunction(self)
521-
M, b = inner(v * u)
522-
b += self[0].scalar_product(uj)
523-
return jnp.linalg.solve(M, b)
532+
a = self[0].orthogonal.forward(uj)
533+
return self.from_orthogonal(a)
524534

525535
@jax.jit(static_argnums=0)
526536
def scalar_product(self, uj: Array) -> Array:

src/jaxfun/galerkin/forms.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,19 @@ def check_if_nonlinear_in_jaxfunction(a: sp.Basic) -> bool:
150150
return False
151151
assert len(jaxfunctions) <= 1, "Multiple JAXFunctions found"
152152
ad = a.doit(linear=True) # assume linear
153-
jf = ad.atoms(Jaxc).pop()
153+
if ad.is_Vector:
154+
for comp in ad.components.values():
155+
jf = comp.atoms(Jaxc)
156+
if len(jf) > 1:
157+
return True
158+
jf = jf.pop()
159+
if sp.diff(comp, jf, 2) != 0:
160+
return True
161+
return False
162+
jf = ad.atoms(Jaxc)
163+
if len(jf) > 1:
164+
return True
165+
jf = jf.pop()
154166
return sp.diff(ad, jf, 2) != 0
155167

156168

@@ -263,7 +275,7 @@ def _split(ms: sp.Expr) -> InnerResultDict:
263275
if arg.atoms(Jaxc):
264276
jaxc.append(arg)
265277
continue
266-
if len(arg.free_symbols) == 1:
278+
if len(arg.free_symbols) <= 1:
267279
rest.append(arg)
268280
else:
269281
multivar.append(arg)
@@ -276,7 +288,7 @@ def _split(ms: sp.Expr) -> InnerResultDict:
276288
if len(multivar) > 0:
277289
d["multivar"] = sp.Mul(*multivar)
278290
if len(jfun) > 0:
279-
d["jaxfunction"] = jfun[0]
291+
d["jaxfunction"] = sp.Mul(*jfun)
280292
if len(jaxc) > 0:
281293
d["coeff"] = jaxc[0]
282294
if d is None:

src/jaxfun/galerkin/inner.py

Lines changed: 41 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from jax import Array
77
from jax.experimental.sparse import BCOO
88

9+
from jaxfun.galerkin import JAXFunction
910
from jaxfun.typing import TrialSpaceType
1011
from jaxfun.utils.common import lambdify, matmat, tosparse
1112

@@ -28,6 +29,7 @@
2829
)
2930
from .orthogonal import OrthogonalSpace
3031
from .tensorproductspace import (
32+
BlockTPMatrix,
3133
DirectSumTPS,
3234
TensorMatrix,
3335
TensorProductSpace,
@@ -160,6 +162,7 @@ def inner(
160162

161163
if isinstance(z, tuple): # multivar
162164
mats.append((z, global_indices))
165+
sc = 1
163166
continue
164167

165168
if z.size == 0:
@@ -202,7 +205,6 @@ def inner(
202205
scales.append(
203206
evaluate_jaxfunction_expr_quad(a0["jaxfunction"], N=num_quad_points)
204207
)
205-
206208
Am = assemble_multivar(mats_, scales, test_space)
207209
if has_bcs:
208210
sign = 1 if all_linear else -1
@@ -224,7 +226,6 @@ def inner(
224226
bresults.append(vectorize_bresult(res, test_space, gi[0][0]))
225227

226228
if "bilinear" in coeffs:
227-
assert coeffs["bilinear"] == 1
228229
assert isinstance(trial_space, TensorProductSpace)
229230
aresults.append(
230231
TensorMatrix(
@@ -312,6 +313,7 @@ def inner(
312313
sc = sc * (-1)
313314

314315
bs = []
316+
315317
for key, bi in b0.items():
316318
if key in ("coeff", "multivar", "jaxfunction"):
317319
continue
@@ -346,16 +348,17 @@ def inner(
346348
if isinstance(bs[0], tuple):
347349
assert isinstance(num_quad_points, tuple)
348350
# multivar or JAXFunction
351+
uj = jnp.array(1.0)
349352
if "multivar" in b0:
350353
s = test_space.system.base_scalars()
351-
uj = lambdify(s, b0["multivar"], modules="jax")(
354+
uj *= lambdify(s, b0["multivar"], modules="jax")(
352355
*test_space.mesh(N=num_quad_points)
353356
)
354-
elif "jaxfunction" in b0:
355-
uj = evaluate_jaxfunction_expr_quad(
357+
if "jaxfunction" in b0:
358+
uj *= evaluate_jaxfunction_expr_quad(
356359
b0["jaxfunction"], N=num_quad_points
357360
)
358-
else:
361+
if "jaxfunction" not in b0 and "multivar" not in b0:
359362
raise ValueError("Expected multivar or jaxfunction key in b0")
360363
res = bs[0][0].T @ uj @ bs[1][0]
361364
bresults.append(vectorize_bresult(res, test_space, global_index))
@@ -683,7 +686,7 @@ def assemble_multivar(
683686
test_space: Tensor product space (for mesh / variable order).
684687
685688
Returns:
686-
Dense matrix of shape (i*j, k*l) assembled from factors.
689+
Dense matrix of shape (i, k, j, l) assembled from factors.
687690
"""
688691
P0, P1 = mats[0]
689692
P2, P3 = mats[1]
@@ -730,18 +733,44 @@ def project(ue: sp.Expr, V: TrialSpaceType) -> Array:
730733
Returns:
731734
Coefficient array shaped to V.num_dofs.
732735
"""
736+
from scipy import sparse as scipy_sparse
737+
738+
from jaxfun.operators import Dot
739+
733740
if V.dims == 1:
734741
assert isinstance(V, OrthogonalSpace | Composite | DirectSum)
735742
return project1D(ue, V)
736743

737-
if V.is_orthogonal:
744+
if len(get_jaxfunctions(ue)) == 0:
738745
assert not isinstance(V, OrthogonalSpace | Composite | DirectSum)
739-
uj = lambdify(V.system.base_scalars(), ue, modules="jax")(*V.mesh())
740-
uj = jnp.broadcast_to(uj, V.num_dofs)
746+
if V.rank == 0:
747+
uj = lambdify(V.system.base_scalars(), ue, modules="jax")(*V.mesh())
748+
uj = jnp.broadcast_to(uj, V.num_quad_points)
749+
elif V.rank == 1:
750+
assert isinstance(V, VectorTensorProductSpace)
751+
s = V.system.base_scalars()
752+
bv = V.system.base_vectors()
753+
uj = (lambdify(s, Dot(ue, n).doit())(*V.mesh()) for n in bv)
754+
uj = jnp.stack(
755+
[jnp.broadcast_to(ui, V.tensorspaces[0].num_quad_points) for ui in uj],
756+
axis=0,
757+
)
741758
return V.forward(uj)
742759

743760
u = TrialFunction(V)
744761
v = TestFunction(V)
745-
M, b = inner(v * (u - ue))
746-
uh = jnp.linalg.solve(M[0].mat, b.flatten()).reshape(V.num_dofs)
762+
if V.rank == 0:
763+
M, b = inner(v * (u - ue))
764+
uh = jnp.linalg.solve(M[0].mat, b.flatten()).reshape(V.num_dofs)
765+
766+
elif V.rank == 1:
767+
assert isinstance(ue, sp.Mul | sp.Add | JAXFunction), (
768+
"Projection requires unevaluated expressions"
769+
) # noqa: E501
770+
assert isinstance(V, VectorTensorProductSpace)
771+
M, b = inner(Dot(v, (u - ue)))
772+
A = BlockTPMatrix(M, V, V)
773+
C = A.block_array()
774+
uh = jnp.array(scipy_sparse.linalg.spsolve(C, b.ravel()).reshape(b.shape))
775+
747776
return uh

src/jaxfun/galerkin/orthogonal.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ def __init__(
7676
self.S = sparse.BCOO(
7777
(jnp.ones(N), jnp.vstack((jnp.arange(N),) * 2).T), shape=(N, N)
7878
)
79+
self.S_inv: Array | None = None
7980
super().__init__(system, name, fun_str)
8081

8182
@abstractmethod
@@ -424,3 +425,10 @@ def get_padded(self, N: int) -> Self:
424425
def get_orthogonal(self) -> Self:
425426
"""Return self (orthogonal space is self; overridden in Composite)."""
426427
return self
428+
429+
@jax.jit(static_argnums=0)
430+
def get_inverse_stencil(self) -> Array:
431+
"""Return inverse of stencil matrix S."""
432+
if self.S_inv is None:
433+
self.S_inv = jnp.linalg.pinv(self.S.todense())
434+
return self.S_inv

0 commit comments

Comments
 (0)