@@ -51,16 +51,47 @@ def _initialize(self):
5151 return
5252
5353 def add_source (self , model ):
54+ # Get the names of things associated with H1
55+ names = self .space .get_names ("H1" ) # u, v
5456
55- # if "H1" in self.space.get_spaces():
56- # model.add_component(f"src_{self.kind}", ndof, )
57- # if "H1" in self.space.get_spaces():
57+ # Create amigo source component with input names and geo names
58+ input_names = []
59+ geo_names = []
60+ data_names = []
61+ if self .kind == "input" :
62+ input_names = names
63+ elif self .kind == "data" :
64+ data_names = names
65+ elif self .kind == "geo" :
66+ geo_names = names
67+
68+ dof_src = DofSource (
69+ input_names = input_names , geo_names = geo_names , data_names = data_names
70+ )
5871
59- # if "H1" in self.space.get_spaces():
72+ # Add global mesh source component
73+ nnodes = mesh .get_num_nodes ()
74+ model .add_component (f"src_{ self .kind } " , nnodes , dof_src )
75+
76+ def link_dof (self , model , domain , etype , elem_name ):
77+ names = self .space .get_names ("H1" )
78+ conn = self .mesh .get_conn (domain , etype )
79+ for name in names :
80+ model .link (
81+ f"src_{ self .kind } .{ name } " , f"{ elem_name } .{ name } " , src_indices = conn
82+ )
83+ return
6084
61- pass
85+ def get_basis (self , etype ):
86+ basis_list = []
87+
88+ for sp in ["H1" ]:
89+ # self._get_basis(etype, sp, names, kind) # names and kind come from space info
90+ names = self .space .get_names (sp )
91+ basis_list .append (self ._get_basis (etype , sp , names , self .kind ))
92+ return basis .BasisCollection (basis_list )
6293
63- def get_basis (self , etype , space , names = [], kind = "input" ):
94+ def _get_basis (self , etype , space , names = [], kind = "input" ):
6495 if etype == "CPS3" :
6596 if space == "H1" :
6697 return basis .TriangleLagrangeBasis (1 , names , kind = kind )
@@ -90,10 +121,6 @@ def get_quadrature(self, etype):
90121
91122 raise NotImplementedError (f"Quadrature for element { etype } not implemented" )
92123
93- def link_dof (self , model , name , elem_name , conn ):
94- model .link (f"src.{ name } " , f"{ elem_name } .{ name } " , src_indices = conn )
95- return
96-
97124
98125class Mesh :
99126 def __init__ (self , filename ):
@@ -122,6 +149,9 @@ def get_domains(self):
122149 def get_conn (self , name , etype ):
123150 return self .parser .get_conn (name , etype )
124151
152+ def get_num_elements (self , name , etype ):
153+ return self .parser .get_conn (name , etype ).shape [0 ]
154+
125155 def plot (self , u , ax = None , nlevels = 30 , cmap = "coolwarm" , title = None ):
126156 min_level = np .min (u )
127157 max_level = np .max (u )
@@ -215,56 +245,34 @@ def __init__(self, mesh, soln_space, weakform, data_space=[], geo_space=[], ndim
215245 self .weakform = weakform
216246
217247 # Initialize Dof's
218- self .soln_dof = DegreesOfFreedom (mesh , "H1" , "soln" )
219- self .geo_dof = DegreesOfFreedom (mesh , "H1" , "geo" )
220- self .data_dof = DegreesOfFreedom (mesh , "H1" , "data" )
248+ # Take in the soln space -> removes "H1" input
249+ self .soln_dof = DegreesOfFreedom (mesh , self .soln_space , "soln" )
250+ self .geo_dof = DegreesOfFreedom (mesh , self .geo_space , "geo" )
251+ self .data_dof = DegreesOfFreedom (mesh , self .data_space , "data" )
221252
222253 return
223254
224255 def create_model (self , module_name : str ):
225256 """Create and link the Amigo model"""
226257 model = am .Model (module_name )
227258
228- # Get the names of things associated with H1
229- input_names = self .soln_space .get_names ("H1" ) # u, v
230- data_names = self .data_space .get_names ("H1" ) # x, y
231- geo_names = self .geo_space .get_names ("H1" ) # x, y
232-
233- # Create amigo source component with input names and geo names
234- self .dof_src = DofSource (input_names = input_names , geo_names = geo_names )
235-
236- # Add global mesh source component
237- nnodes = mesh .get_num_nodes ()
238- model .add_component ("src" , nnodes , self .dof_src )
259+ self .soln_dof .add_source (model )
260+ self .data_dof .add_source (model )
261+ self .geo_dof .add_source (model )
239262
240263 # Build the elements for all domains
241264 domains = self .mesh .get_domains ()
242265 for domain in domains :
243266 for etype in domains [domain ]:
267+ # Each element type has a dictionary of solution basis's
244268 soln_basis = {}
269+
245270 # Build a finite-element for each weak form
246271 elem_name = f"Element{ etype } _{ domain } "
247272
248- # soln_basis_u = self.soln_dof.get_basis(
249- # etype, "H1", names=input_names, kind="input"
250- # )
251-
252- # soln_basis_v = self.soln_dof.get_basis(
253- # etype, "H1", names=input_names, kind="input"
254- # )
255-
256- # Each input gets a basis function assigned to it
257- for name in input_names :
258- soln_basis [f"{ name } " ] = self .soln_dof .get_basis (
259- etype , "H1" , names = input_names , kind = "input"
260- )
261-
262- data_basis = self .data_dof .get_basis (
263- etype , "H1" , names = data_names , kind = "data"
264- )
265- geo_basis = self .geo_dof .get_basis (
266- etype , "H1" , names = geo_names , kind = "data"
267- )
273+ soln_basis = self .soln_dof .get_basis (etype )
274+ data_basis = self .data_dof .get_basis (etype )
275+ geo_basis = self .geo_dof .get_basis (etype )
268276
269277 # Create the quadrature instance
270278 quadrature = self .soln_dof .get_quadrature (etype )
@@ -277,26 +285,16 @@ def create_model(self, module_name: str):
277285 geo_basis ,
278286 quadrature ,
279287 self .weakform ,
280- etype ,
281- input_names ,
282- data_names ,
283- geo_names ,
284288 )
285289
286- # Get the connectivity
287- # Needs to pull out any type of connectivity for the basis
288- conn = self .mesh .get_conn (domain , etype )
289-
290290 # Add the element/component
291- nelems = conn . shape [ 0 ]
291+ nelems = self . mesh . get_num_elements ( domain , etype )
292292 model .add_component (elem_name , nelems , elem )
293293
294294 # Link all the element dof to the component
295- for name in input_names :
296- self .soln_dof .link_dof (model , name , elem_name , conn )
297-
298- for name in geo_names :
299- self .geo_dof .link_dof (model , name , elem_name , conn )
295+ self .soln_dof .link_dof (model , domain , etype , elem_name )
296+ self .data_dof .link_dof (model , domain , etype , elem_name )
297+ self .geo_dof .link_dof (model , domain , etype , elem_name )
300298
301299 return model
302300
@@ -310,38 +308,26 @@ def __init__(
310308 geo_basis ,
311309 quadrature ,
312310 weakform ,
313- etype ,
314- input_names = [],
315- data_names = [],
316- geo_names = [],
317311 ):
318312 super ().__init__ (name = name )
319313
314+ # NOTE: soln_basis is a dict of objectes for each input
320315 self .soln_basis = soln_basis
316+
321317 self .data_basis = data_basis
322318 self .geo_basis = geo_basis
323319 self .quadrature = quadrature
324320 self .weakform = weakform
325- self .input_names = input_names
326-
327- # The x/y coordinates
328- if etype == "CPS3" :
329- shape = (3 ,)
330-
331- elif etype == "CPS4" :
332- shape = (4 ,)
333321
334- # Data
335- for name in geo_names :
336- self .add_data (name , shape = shape )
337-
338- # Inputs
339- for name in self .input_names :
340- self .add_input (name , shape = shape )
322+ # From BasisCollection
323+ self .soln_basis .add_declarations (self )
324+ self .geo_basis .add_declarations (self )
325+ self .data_basis .add_declarations (self )
341326
342327 # Set the arguments to the compute function for each quadrature point
343328 self .set_args (self .quadrature .get_args ())
344329
330+ # Add the objective to minimize
345331 self .add_objective ("obj" )
346332
347333 return
@@ -350,41 +336,16 @@ def compute(self, **args):
350336
351337 quad_weight , quad_point = self .quadrature .get_point (** args )
352338
353- # # Evaluate the solution fields/data fields (u)
354- # soln_xi_u = self.soln_basis["u"].eval(self, quad_point)
355- # data_xi_u = self.data_basis.eval(self, quad_point)
356- # geo_u = self.geo_basis.eval(self, quad_point)
357-
358- # # Perform the mapping from computational to physical coordinates (u)
359- # detJ_u, Jinv_u = self.geo_basis.compute_transform(geo_u)
360- # soln_phys_u = self.soln_basis["u"].transform(detJ_u, Jinv_u, soln_xi_u)
361- # data_phys_u = self.data_basis.transform(detJ_u, Jinv_u, data_xi_u)
362-
363- # # Evaluate the solution fields/data fields (v)
364- # soln_xi_v = self.soln_basis["v"].eval(self, quad_point)
365- # data_xi_v = self.data_basis.eval(self, quad_point)
366- # geo_v = self.geo_basis.eval(self, quad_point)
367-
368- # # Perform the mapping from computational to physical coordinates (u)
369- # detJ_v, Jinv_v = self.geo_basis.compute_transform(geo_v)
370- # soln_phys_v = self.soln_basis["v"].transform(detJ_v, Jinv_v, soln_xi_v)
371- # data_phys_v = self.data_basis.transform(detJ_v, Jinv_v, data_xi_v)
372-
373- # Evaluate the solution fields/data fields (u, v)
374- soln_phys = {}
375-
339+ # Evaluate the solution fields/data fields (u)
340+ soln_xi = self .soln_basis .eval (self , quad_point )
376341 data_xi = self .data_basis .eval (self , quad_point )
377342 geo = self .geo_basis .eval (self , quad_point )
343+
344+ # Perform the mapping from computational to physical coordinates (u)
378345 detJ , Jinv = self .geo_basis .compute_transform (geo )
346+ soln_phys = self .soln_basis .transform (detJ , Jinv , soln_xi )
379347 data_phys = self .data_basis .transform (detJ , Jinv , data_xi )
380348
381- for name in self .input_names :
382- # Eval basis at quad point
383- soln_xi = self .soln_basis [name ].eval (self , quad_point )
384-
385- # Perform the mapping from computational to physical coordinates (u)
386- soln_phys [name ] = self .soln_basis [name ].transform (detJ , Jinv , soln_xi )
387-
388349 # Add the contributions directly to the Lagrangian
389350 self .objective ["obj" ] = quad_weight * detJ * self .weakform (soln_phys , geo = geo )
390351 return
@@ -411,9 +372,6 @@ def weakform(soln, data=None, geo=None):
411372 alpha = 1.0
412373 pi = 3.14159265358979
413374
414- # Manufactured RHS derived from exact solution
415- # u_exact = sin(πx)sin(πy) → -Δu_exact = 2π²·sin(πx)sin(πy)
416- # v_exact = cos(πx)cos(πy) → -Δv_exact = 2π²·cos(πx)cos(πy)
417375 f1 = 2 * pi ** 2 * am .sin (pi * x ) * am .sin (pi * y ) + alpha * am .cos (
418376 pi * x
419377 ) * am .cos (pi * y )
@@ -432,12 +390,13 @@ def weakform(soln, data=None, geo=None):
432390 return comp1
433391
434392
393+ # Initialize the spaces
435394soln_space = basis .SolutionSpace ({"u" : "H1" , "v" : "H1" })
436395data_space = basis .SolutionSpace ({"x" : "H1" , "y" : "H1" })
437396geo_space = basis .SolutionSpace ({"x" : "H1" , "y" : "H1" })
438397
439- mesh = Mesh ("magnet_order_1.inp" )
440- # mesh = Mesh("plate.inp")
398+ # mesh = Mesh("magnet_order_1.inp")
399+ mesh = Mesh ("plate.inp" )
441400problem = Problem (
442401 mesh ,
443402 soln_space ,
0 commit comments