66from jax import Array
77from jax .experimental .sparse import BCOO
88
9+ from jaxfun .galerkin import JAXFunction
910from jaxfun .typing import TrialSpaceType
1011from jaxfun .utils .common import lambdify , matmat , tosparse
1112
2829)
2930from .orthogonal import OrthogonalSpace
3031from .tensorproductspace import (
32+ BlockTPMatrix ,
3133 DirectSumTPS ,
3234 TensorMatrix ,
3335 TensorProductSpace ,
@@ -160,6 +162,7 @@ def inner(
160162
161163 if isinstance (z , tuple ): # multivar
162164 mats .append ((z , global_indices ))
165+ sc = 1
163166 continue
164167
165168 if z .size == 0 :
@@ -202,7 +205,6 @@ def inner(
202205 scales .append (
203206 evaluate_jaxfunction_expr_quad (a0 ["jaxfunction" ], N = num_quad_points )
204207 )
205-
206208 Am = assemble_multivar (mats_ , scales , test_space )
207209 if has_bcs :
208210 sign = 1 if all_linear else - 1
@@ -224,7 +226,6 @@ def inner(
224226 bresults .append (vectorize_bresult (res , test_space , gi [0 ][0 ]))
225227
226228 if "bilinear" in coeffs :
227- assert coeffs ["bilinear" ] == 1
228229 assert isinstance (trial_space , TensorProductSpace )
229230 aresults .append (
230231 TensorMatrix (
@@ -312,6 +313,7 @@ def inner(
312313 sc = sc * (- 1 )
313314
314315 bs = []
316+
315317 for key , bi in b0 .items ():
316318 if key in ("coeff" , "multivar" , "jaxfunction" ):
317319 continue
@@ -346,16 +348,17 @@ def inner(
346348 if isinstance (bs [0 ], tuple ):
347349 assert isinstance (num_quad_points , tuple )
348350 # multivar or JAXFunction
351+ uj = jnp .array (1.0 )
349352 if "multivar" in b0 :
350353 s = test_space .system .base_scalars ()
351- uj = lambdify (s , b0 ["multivar" ], modules = "jax" )(
354+ uj * = lambdify (s , b0 ["multivar" ], modules = "jax" )(
352355 * test_space .mesh (N = num_quad_points )
353356 )
354- elif "jaxfunction" in b0 :
355- uj = evaluate_jaxfunction_expr_quad (
357+ if "jaxfunction" in b0 :
358+ uj * = evaluate_jaxfunction_expr_quad (
356359 b0 ["jaxfunction" ], N = num_quad_points
357360 )
358- else :
361+ if "jaxfunction" not in b0 and "multivar" not in b0 :
359362 raise ValueError ("Expected multivar or jaxfunction key in b0" )
360363 res = bs [0 ][0 ].T @ uj @ bs [1 ][0 ]
361364 bresults .append (vectorize_bresult (res , test_space , global_index ))
@@ -683,7 +686,7 @@ def assemble_multivar(
683686 test_space: Tensor product space (for mesh / variable order).
684687
685688 Returns:
686- Dense matrix of shape (i*j , k* l) assembled from factors.
689+ Dense matrix of shape (i, k, j, l) assembled from factors.
687690 """
688691 P0 , P1 = mats [0 ]
689692 P2 , P3 = mats [1 ]
@@ -730,18 +733,44 @@ def project(ue: sp.Expr, V: TrialSpaceType) -> Array:
730733 Returns:
731734 Coefficient array shaped to V.num_dofs.
732735 """
736+ from scipy import sparse as scipy_sparse
737+
738+ from jaxfun .operators import Dot
739+
733740 if V .dims == 1 :
734741 assert isinstance (V , OrthogonalSpace | Composite | DirectSum )
735742 return project1D (ue , V )
736743
737- if V . is_orthogonal :
744+ if len ( get_jaxfunctions ( ue )) == 0 :
738745 assert not isinstance (V , OrthogonalSpace | Composite | DirectSum )
739- uj = lambdify (V .system .base_scalars (), ue , modules = "jax" )(* V .mesh ())
740- uj = jnp .broadcast_to (uj , V .num_dofs )
746+ if V .rank == 0 :
747+ uj = lambdify (V .system .base_scalars (), ue , modules = "jax" )(* V .mesh ())
748+ uj = jnp .broadcast_to (uj , V .num_quad_points )
749+ elif V .rank == 1 :
750+ assert isinstance (V , VectorTensorProductSpace )
751+ s = V .system .base_scalars ()
752+ bv = V .system .base_vectors ()
753+ uj = (lambdify (s , Dot (ue , n ).doit ())(* V .mesh ()) for n in bv )
754+ uj = jnp .stack (
755+ [jnp .broadcast_to (ui , V .tensorspaces [0 ].num_quad_points ) for ui in uj ],
756+ axis = 0 ,
757+ )
741758 return V .forward (uj )
742759
743760 u = TrialFunction (V )
744761 v = TestFunction (V )
745- M , b = inner (v * (u - ue ))
746- uh = jnp .linalg .solve (M [0 ].mat , b .flatten ()).reshape (V .num_dofs )
762+ if V .rank == 0 :
763+ M , b = inner (v * (u - ue ))
764+ uh = jnp .linalg .solve (M [0 ].mat , b .flatten ()).reshape (V .num_dofs )
765+
766+ elif V .rank == 1 :
767+ assert isinstance (ue , sp .Mul | sp .Add | JAXFunction ), (
768+ "Projection requires unevaluated expressions"
769+ ) # noqa: E501
770+ assert isinstance (V , VectorTensorProductSpace )
771+ M , b = inner (Dot (v , (u - ue )))
772+ A = BlockTPMatrix (M , V , V )
773+ C = A .block_array ()
774+ uh = jnp .array (scipy_sparse .linalg .spsolve (C , b .ravel ()).reshape (b .shape ))
775+
747776 return uh
0 commit comments