Skip to content

Commit bdd15ff

Browse files
Latest changes to acessing basis, links, and components.
1 parent 20379ad commit bdd15ff

File tree

3 files changed

+93
-121
lines changed

3 files changed

+93
-121
lines changed

examples/fem/basis.py

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,27 @@ def build_2d_lagrange_vandermonde(n, p, pts, exps):
6565
return np.linalg.solve(V, I)
6666

6767

68+
class BasisCollection:
69+
def __init__(self, basis=[]):
70+
self.basis = basis
71+
72+
def add_declarations(self, comp):
73+
for basis in self.basis:
74+
basis.add_declarations(comp)
75+
76+
def eval(self, comp, pt):
77+
soln = {}
78+
for basis in self.basis:
79+
soln.update(basis.eval(comp, pt))
80+
return soln
81+
82+
def transform(self, detJ, Jinv, orig):
83+
soln = {}
84+
for basis in self.basis:
85+
soln.update(basis.transform(detJ, Jinv, orig))
86+
return soln
87+
88+
6889
class Basis:
6990
def __init__(self, names, nnodes=1, kind="input"):
7091
if isinstance(names, (list, tuple)):
@@ -370,11 +391,3 @@ def get_spaces(self):
370391

371392
def get_names(self, space):
372393
return self.names[space]
373-
374-
375-
class BasisCollection:
376-
def __init__(self):
377-
self.basis = []
378-
379-
def add_basis(self, basis):
380-
self.basis.append(basis)

examples/fem/fem.py

Lines changed: 71 additions & 112 deletions
Original file line numberDiff line numberDiff line change
@@ -51,16 +51,47 @@ def _initialize(self):
5151
return
5252

5353
def add_source(self, model):
54+
# Get the names of things associated with H1
55+
names = self.space.get_names("H1") # u, v
5456

55-
# if "H1" in self.space.get_spaces():
56-
# model.add_component(f"src_{self.kind}", ndof, )
57-
# if "H1" in self.space.get_spaces():
57+
# Create amigo source component with input names and geo names
58+
input_names = []
59+
geo_names = []
60+
data_names = []
61+
if self.kind == "input":
62+
input_names = names
63+
elif self.kind == "data":
64+
data_names = names
65+
elif self.kind == "geo":
66+
geo_names = names
67+
68+
dof_src = DofSource(
69+
input_names=input_names, geo_names=geo_names, data_names=data_names
70+
)
5871

59-
# if "H1" in self.space.get_spaces():
72+
# Add global mesh source component
73+
nnodes = mesh.get_num_nodes()
74+
model.add_component(f"src_{self.kind}", nnodes, dof_src)
75+
76+
def link_dof(self, model, domain, etype, elem_name):
77+
names = self.space.get_names("H1")
78+
conn = self.mesh.get_conn(domain, etype)
79+
for name in names:
80+
model.link(
81+
f"src_{self.kind}.{name}", f"{elem_name}.{name}", src_indices=conn
82+
)
83+
return
6084

61-
pass
85+
def get_basis(self, etype):
86+
basis_list = []
87+
88+
for sp in ["H1"]:
89+
# self._get_basis(etype, sp, names, kind) # names and kind come from space info
90+
names =self.space.get_names(sp)
91+
basis_list.append(self._get_basis(etype, sp, names, self.kind))
92+
return basis.BasisCollection(basis_list)
6293

63-
def get_basis(self, etype, space, names=[], kind="input"):
94+
def _get_basis(self, etype, space, names=[], kind="input"):
6495
if etype == "CPS3":
6596
if space == "H1":
6697
return basis.TriangleLagrangeBasis(1, names, kind=kind)
@@ -90,10 +121,6 @@ def get_quadrature(self, etype):
90121

91122
raise NotImplementedError(f"Quadrature for element {etype} not implemented")
92123

93-
def link_dof(self, model, name, elem_name, conn):
94-
model.link(f"src.{name}", f"{elem_name}.{name}", src_indices=conn)
95-
return
96-
97124

98125
class Mesh:
99126
def __init__(self, filename):
@@ -122,6 +149,9 @@ def get_domains(self):
122149
def get_conn(self, name, etype):
123150
return self.parser.get_conn(name, etype)
124151

152+
def get_num_elements(self, name, etype):
153+
return self.parser.get_conn(name, etype).shape[0]
154+
125155
def plot(self, u, ax=None, nlevels=30, cmap="coolwarm", title=None):
126156
min_level = np.min(u)
127157
max_level = np.max(u)
@@ -215,56 +245,34 @@ def __init__(self, mesh, soln_space, weakform, data_space=[], geo_space=[], ndim
215245
self.weakform = weakform
216246

217247
# Initialize Dof's
218-
self.soln_dof = DegreesOfFreedom(mesh, "H1", "soln")
219-
self.geo_dof = DegreesOfFreedom(mesh, "H1", "geo")
220-
self.data_dof = DegreesOfFreedom(mesh, "H1", "data")
248+
# Take in the soln space -> removes "H1" input
249+
self.soln_dof = DegreesOfFreedom(mesh, self.soln_space, "soln")
250+
self.geo_dof = DegreesOfFreedom(mesh, self.geo_space, "geo")
251+
self.data_dof = DegreesOfFreedom(mesh, self.data_space, "data")
221252

222253
return
223254

224255
def create_model(self, module_name: str):
225256
"""Create and link the Amigo model"""
226257
model = am.Model(module_name)
227258

228-
# Get the names of things associated with H1
229-
input_names = self.soln_space.get_names("H1") # u, v
230-
data_names = self.data_space.get_names("H1") # x, y
231-
geo_names = self.geo_space.get_names("H1") # x, y
232-
233-
# Create amigo source component with input names and geo names
234-
self.dof_src = DofSource(input_names=input_names, geo_names=geo_names)
235-
236-
# Add global mesh source component
237-
nnodes = mesh.get_num_nodes()
238-
model.add_component("src", nnodes, self.dof_src)
259+
self.soln_dof.add_source(model)
260+
self.data_dof.add_source(model)
261+
self.geo_dof.add_source(model)
239262

240263
# Build the elements for all domains
241264
domains = self.mesh.get_domains()
242265
for domain in domains:
243266
for etype in domains[domain]:
267+
# Each element type has a dictionary of solution basis's
244268
soln_basis = {}
269+
245270
# Build a finite-element for each weak form
246271
elem_name = f"Element{etype}_{domain}"
247272

248-
# soln_basis_u = self.soln_dof.get_basis(
249-
# etype, "H1", names=input_names, kind="input"
250-
# )
251-
252-
# soln_basis_v = self.soln_dof.get_basis(
253-
# etype, "H1", names=input_names, kind="input"
254-
# )
255-
256-
# Each input gets a basis function assigned to it
257-
for name in input_names:
258-
soln_basis[f"{name}"] = self.soln_dof.get_basis(
259-
etype, "H1", names=input_names, kind="input"
260-
)
261-
262-
data_basis = self.data_dof.get_basis(
263-
etype, "H1", names=data_names, kind="data"
264-
)
265-
geo_basis = self.geo_dof.get_basis(
266-
etype, "H1", names=geo_names, kind="data"
267-
)
273+
soln_basis = self.soln_dof.get_basis(etype)
274+
data_basis = self.data_dof.get_basis(etype)
275+
geo_basis = self.geo_dof.get_basis(etype)
268276

269277
# Create the quadrature instance
270278
quadrature = self.soln_dof.get_quadrature(etype)
@@ -277,26 +285,16 @@ def create_model(self, module_name: str):
277285
geo_basis,
278286
quadrature,
279287
self.weakform,
280-
etype,
281-
input_names,
282-
data_names,
283-
geo_names,
284288
)
285289

286-
# Get the connectivity
287-
# Needs to pull out any type of connectivity for the basis
288-
conn = self.mesh.get_conn(domain, etype)
289-
290290
# Add the element/component
291-
nelems = conn.shape[0]
291+
nelems = self.mesh.get_num_elements(domain, etype)
292292
model.add_component(elem_name, nelems, elem)
293293

294294
# Link all the element dof to the component
295-
for name in input_names:
296-
self.soln_dof.link_dof(model, name, elem_name, conn)
297-
298-
for name in geo_names:
299-
self.geo_dof.link_dof(model, name, elem_name, conn)
295+
self.soln_dof.link_dof(model, domain, etype, elem_name)
296+
self.data_dof.link_dof(model, domain, etype, elem_name)
297+
self.geo_dof.link_dof(model, domain, etype, elem_name)
300298

301299
return model
302300

@@ -310,38 +308,26 @@ def __init__(
310308
geo_basis,
311309
quadrature,
312310
weakform,
313-
etype,
314-
input_names=[],
315-
data_names=[],
316-
geo_names=[],
317311
):
318312
super().__init__(name=name)
319313

314+
# NOTE: soln_basis is a dict of objectes for each input
320315
self.soln_basis = soln_basis
316+
321317
self.data_basis = data_basis
322318
self.geo_basis = geo_basis
323319
self.quadrature = quadrature
324320
self.weakform = weakform
325-
self.input_names = input_names
326-
327-
# The x/y coordinates
328-
if etype == "CPS3":
329-
shape = (3,)
330-
331-
elif etype == "CPS4":
332-
shape = (4,)
333321

334-
# Data
335-
for name in geo_names:
336-
self.add_data(name, shape=shape)
337-
338-
# Inputs
339-
for name in self.input_names:
340-
self.add_input(name, shape=shape)
322+
# From BasisCollection
323+
self.soln_basis.add_declarations(self)
324+
self.geo_basis.add_declarations(self)
325+
self.data_basis.add_declarations(self)
341326

342327
# Set the arguments to the compute function for each quadrature point
343328
self.set_args(self.quadrature.get_args())
344329

330+
# Add the objective to minimize
345331
self.add_objective("obj")
346332

347333
return
@@ -350,41 +336,16 @@ def compute(self, **args):
350336

351337
quad_weight, quad_point = self.quadrature.get_point(**args)
352338

353-
# # Evaluate the solution fields/data fields (u)
354-
# soln_xi_u = self.soln_basis["u"].eval(self, quad_point)
355-
# data_xi_u = self.data_basis.eval(self, quad_point)
356-
# geo_u = self.geo_basis.eval(self, quad_point)
357-
358-
# # Perform the mapping from computational to physical coordinates (u)
359-
# detJ_u, Jinv_u = self.geo_basis.compute_transform(geo_u)
360-
# soln_phys_u = self.soln_basis["u"].transform(detJ_u, Jinv_u, soln_xi_u)
361-
# data_phys_u = self.data_basis.transform(detJ_u, Jinv_u, data_xi_u)
362-
363-
# # Evaluate the solution fields/data fields (v)
364-
# soln_xi_v = self.soln_basis["v"].eval(self, quad_point)
365-
# data_xi_v = self.data_basis.eval(self, quad_point)
366-
# geo_v = self.geo_basis.eval(self, quad_point)
367-
368-
# # Perform the mapping from computational to physical coordinates (u)
369-
# detJ_v, Jinv_v = self.geo_basis.compute_transform(geo_v)
370-
# soln_phys_v = self.soln_basis["v"].transform(detJ_v, Jinv_v, soln_xi_v)
371-
# data_phys_v = self.data_basis.transform(detJ_v, Jinv_v, data_xi_v)
372-
373-
# Evaluate the solution fields/data fields (u, v)
374-
soln_phys = {}
375-
339+
# Evaluate the solution fields/data fields (u)
340+
soln_xi = self.soln_basis.eval(self, quad_point)
376341
data_xi = self.data_basis.eval(self, quad_point)
377342
geo = self.geo_basis.eval(self, quad_point)
343+
344+
# Perform the mapping from computational to physical coordinates (u)
378345
detJ, Jinv = self.geo_basis.compute_transform(geo)
346+
soln_phys = self.soln_basis.transform(detJ, Jinv, soln_xi)
379347
data_phys = self.data_basis.transform(detJ, Jinv, data_xi)
380348

381-
for name in self.input_names:
382-
# Eval basis at quad point
383-
soln_xi = self.soln_basis[name].eval(self, quad_point)
384-
385-
# Perform the mapping from computational to physical coordinates (u)
386-
soln_phys[name] = self.soln_basis[name].transform(detJ, Jinv, soln_xi)
387-
388349
# Add the contributions directly to the Lagrangian
389350
self.objective["obj"] = quad_weight * detJ * self.weakform(soln_phys, geo=geo)
390351
return
@@ -411,9 +372,6 @@ def weakform(soln, data=None, geo=None):
411372
alpha = 1.0
412373
pi = 3.14159265358979
413374

414-
# Manufactured RHS derived from exact solution
415-
# u_exact = sin(πx)sin(πy) → -Δu_exact = 2π²·sin(πx)sin(πy)
416-
# v_exact = cos(πx)cos(πy) → -Δv_exact = 2π²·cos(πx)cos(πy)
417375
f1 = 2 * pi**2 * am.sin(pi * x) * am.sin(pi * y) + alpha * am.cos(
418376
pi * x
419377
) * am.cos(pi * y)
@@ -432,12 +390,13 @@ def weakform(soln, data=None, geo=None):
432390
return comp1
433391

434392

393+
# Initialize the spaces
435394
soln_space = basis.SolutionSpace({"u": "H1", "v": "H1"})
436395
data_space = basis.SolutionSpace({"x": "H1", "y": "H1"})
437396
geo_space = basis.SolutionSpace({"x": "H1", "y": "H1"})
438397

439-
mesh = Mesh("magnet_order_1.inp")
440-
# mesh = Mesh("plate.inp")
398+
# mesh = Mesh("magnet_order_1.inp")
399+
mesh = Mesh("plate.inp")
441400
problem = Problem(
442401
mesh,
443402
soln_space,

examples/fem/mixed_mesh.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def check_areas(X, conn):
3434
gmsh.model.add("magnet")
3535

3636
# Mesh refinement at nodes
37-
lc = 3e-1
37+
lc = 9e-2
3838
lc1 = 9e-2
3939

4040
# Geometry dimentions

0 commit comments

Comments
 (0)