Skip to content

Commit f97cd21

Browse files
committed
Fixing backward Euler - using scalar_product for rhs
1 parent 29f7231 commit f97cd21

File tree

4 files changed

+19
-12
lines changed

4 files changed

+19
-12
lines changed

src/jaxfun/galerkin/ChebyshevU.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ class ChebyshevU(Jacobi):
1717
1818
Implements a Chebyshev basis via the Jacobi formulation with
1919
alpha = beta = 1/2. Provides several evaluation kernels:
20-
* eval_basis_function: Single T_i evaluation (iterative).
20+
* eval_basis_function: Single U_i(x) evaluation (iterative).
2121
* eval_basis_functions: Vectorized generation of all modes < N.
2222
2323
The series expansion (degree N-1):
@@ -226,21 +226,17 @@ def gn(self, n: Symbol | int) -> Expr:
226226
def matrices(
227227
test: tuple[ChebyshevU, int], trial: tuple[ChebyshevU, int]
228228
) -> sparse.BCOO | None:
229-
"""Sparse operator matrices between test/trial Chebyshev modes.
229+
"""Sparse operator matrices between test/trial ChebyshevU modes.
230230
231231
Constructs (possibly rectangular) sparse differentiation / mass-like
232232
matrices for combinations of test index i and trial index j flags:
233233
234234
(i, j):
235235
(0,0): Diagonal mass-matrix.
236-
(0,1): First derivative.
237-
(1,0): Transpose of (0,1).
238-
(0,2): Second derivative.
239-
(2,0): Transpose of (0,2).
240236
241237
Args:
242-
test: Tuple (v, i) with Chebyshev space v and number of derivatives i.
243-
trial: Tuple (u, j) with Chebyshev space u and number of derivatives j.
238+
test: Tuple (v, i) with ChebyshevU space v and number of derivatives i.
239+
trial: Tuple (u, j) with ChebyshevU space u and number of derivatives j.
244240
245241
Returns:
246242
jax.experimental.sparse.BCOO or None if combination unsupported.

src/jaxfun/galerkin/Fourier.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ def scalar_product(self, c: Array) -> Array:
149149
Returns:
150150
Coefficients scaled by 2π / domain_factor.
151151
"""
152-
out = jnp.fft.fft(c, norm="forward") * 2 * jnp.pi / self.domain_factor
152+
out = jnp.fft.fft(c, norm="forward") * 2 * jnp.pi / float(self.domain_factor)
153153
if len(c) > self.N:
154154
return out[self.wavenumbers()]
155155
return out

src/jaxfun/integrators/backward_euler.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@
1111
class BackwardEuler(BaseIntegrator):
1212
"""First-order implicit Euler for linear terms (IMEX for nonlinear terms)."""
1313

14-
_system_diag: Array | None = nnx.data(None)
15-
_system_matrix: Array | None = nnx.data(None)
14+
_system_diag: Array | None = None
15+
_system_matrix: Array | None = None
1616

1717
def setup(self, dt: float) -> None:
1818
"""Precompute the implicit system matrix for the given step size."""
@@ -33,7 +33,7 @@ def step(self, u_hat: Array, dt: float, N: Padding = None) -> Array:
3333
if self.linear_forcing is not None:
3434
rhs = rhs + dt * jnp.asarray(self.linear_forcing)
3535
if self.has_nonlinear:
36-
rhs = rhs + dt * self.apply_mass(self.nonlinear_rhs(u_hat, N))
36+
rhs = rhs + dt * self.nonlinear_rhs_scalar_product(u_hat, N)
3737

3838
if self._system_diag is not None:
3939
return rhs / self._system_diag

src/jaxfun/integrators/base.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,17 @@ def nonlinear_rhs(self, uh: Array, N: Padding = None) -> Array:
230230
assert self._nonlinear_evaluator is not None
231231
return self.functionspace.forward(self._nonlinear_evaluator(uh, N))
232232

233+
def nonlinear_rhs_scalar_product(self, uh: Array, N: Padding = None) -> Array:
234+
"""Return the nonlinear contribution in coefficient space.
235+
236+
Do *not* apply the mass inverse to complete the forward transformation,
237+
because the mass inverse may be required elsewhere.
238+
"""
239+
if not self.has_nonlinear:
240+
return jnp.zeros_like(uh)
241+
assert self._nonlinear_evaluator is not None
242+
return self.functionspace.scalar_product(self._nonlinear_evaluator(uh, N))
243+
233244
def linear_rhs(self, uh: Array) -> Array:
234245
"""Return the linear contribution after applying the inverse mass matrix."""
235246
rhs = jnp.zeros_like(uh)

0 commit comments

Comments
 (0)