99# This software is distributed under the 3-clause BSD License.
1010# ___________________________________________________________________________
1111
12- import copy
1312import enum
1413from io import StringIO
1514from math import inf
1615
17- from pyomo .common .collections import Bunch
16+ from pyomo .common .collections import Bunch , Sequence , Mapping
1817
1918
2019class ScalarType (str , enum .Enum ):
@@ -35,17 +34,31 @@ def __str__(self):
3534
3635
3736default_print_options = Bunch (schema = False , ignore_time = False )
38-
3937strict = False
4038
4139
4240class UndefinedData (object ):
41+ singleton = {}
42+
43+ def __new__ (cls , name = 'undefined' ):
44+ if name not in UndefinedData .singleton :
45+ UndefinedData .singleton [name ] = super ().__new__ (cls )
46+ UndefinedData .singleton [name ].name = name
47+ return UndefinedData .singleton [name ]
48+
49+ def __deepcopy__ (self , memo ):
50+ # Prevent deepcopy from duplicating this object
51+ return self
52+
53+ def __reduce__ (self ):
54+ return self .__class__ , (self .name ,)
55+
4356 def __str__ (self ):
44- return "<undefined >"
57+ return f"< { self . name } >"
4558
4659
47- undefined = UndefinedData ()
48- ignore = UndefinedData ()
60+ undefined = UndefinedData ('undefined' )
61+ ignore = UndefinedData ('ignore' )
4962
5063
5164class ScalarData (object ):
@@ -64,6 +77,10 @@ def __init__(
6477 self .scalar_description = scalar_description
6578 self .scalar_type = type
6679 self ._required = required
80+ self ._active = False
81+
82+ def __eq__ (self , other ):
83+ return self .__dict__ == other .__dict__
6784
6885 def get_value (self ):
6986 if isinstance (self .value , enum .Enum ):
@@ -109,9 +126,9 @@ def pprint(self, ostream, option, prefix="", repn=None):
109126
110127 value = self .yaml_fix (self .get_value ())
111128
112- if value is inf :
129+ if value == inf :
113130 value = '.inf'
114- elif value is - inf :
131+ elif value == - inf :
115132 value = '-.inf'
116133
117134 if not option .schema and self .description is None and self .units is None :
@@ -149,8 +166,8 @@ def yaml_fix(self, val):
149166
150167 def load (self , repn ):
151168 if type (repn ) is dict :
152- for key in repn :
153- setattr (self , key , repn [ key ] )
169+ for key , val in repn . items () :
170+ setattr (self , key , val )
154171 else :
155172 self .value = repn
156173
@@ -167,12 +184,15 @@ def __init__(self, cls):
167184
168185 def __len__ (self ):
169186 if '_list' in self .__dict__ :
170- return len (self .__dict__ [ ' _list' ] )
187+ return len (self ._list )
171188 return 0
172189
173190 def __getitem__ (self , i ):
174191 return self ._list [i ]
175192
193+ def __eq__ (self , other ):
194+ return self .__dict__ == other .__dict__
195+
176196 def clear (self ):
177197 self ._list = []
178198
@@ -183,21 +203,15 @@ def __call__(self, i=0):
183203 return self ._list [i ]
184204
185205 def __getattr__ (self , name ):
186- try :
187- return self .__dict__ [name ]
188- except :
189- pass
206+ if name [0 ] == "_" :
207+ super ().__getattr__ (name )
190208 if len (self ) == 0 :
191209 self .add ()
192210 return getattr (self ._list [0 ], name )
193211
194212 def __setattr__ (self , name , val ):
195- if name == "__class__" :
196- self .__class__ = val
197- return
198213 if name [0 ] == "_" :
199- self .__dict__ [name ] = val
200- return
214+ return super ().__setattr__ (name , val )
201215 if len (self ) == 0 :
202216 self .add ()
203217 setattr (self ._list [0 ], name , val )
@@ -239,16 +253,10 @@ def load(self, repn):
239253 item = self .add ()
240254 item .load (data )
241255
242- def __getstate__ (self ):
243- return copy .copy (self .__dict__ )
244-
245- def __setstate__ (self , state ):
246- self .__dict__ .update (state )
247-
248256 def __str__ (self ):
249257 ostream = StringIO ()
250258 option = default_print_options
251- self .pprint (ostream , self . _option , repn = self ._repn_ (self . _option ))
259+ self .pprint (ostream , option , repn = self ._repn_ (option ))
252260 return ostream .getvalue ()
253261
254262
@@ -259,41 +267,21 @@ def __str__(self):
259267# first letter is capitalized.
260268#
261269class MapContainer (dict ):
262- def __getnewargs_ex__ (self ):
263- # Pass arguments to __new__ when unpickling
264- return ((0 , 0 ), {})
265-
266- def __getnewargs__ (self ):
267- # Pass arguments to __new__ when unpickling
268- return (0 , 0 )
269-
270- def __new__ (cls , * args , ** kwargs ):
271- #
272- # If the user provides "too many" arguments, then
273- # pre-initialize the '_order' attribute. This pre-initializes
274- # the class during unpickling.
275- #
276- _instance = super (MapContainer , cls ).__new__ (cls , * args , ** kwargs )
277- if len (args ) > 1 :
278- super (MapContainer , _instance ).__setattr__ ('_order' , [])
279- return _instance
280270
281271 def __init__ (self , ordered = False ):
282- dict .__init__ (self )
272+ super () .__init__ ()
283273 self ._active = True
284274 self ._required = False
285- self ._ordered = ordered
286- self ._order = []
287275 self ._option = default_print_options
288276
289- def keys (self ):
290- return self ._order
277+ def __eq__ (self , other ):
278+ # We need to check both our __dict__ (local attributes) and the
279+ # underlying dict data (which doesn't show up in the __dict__).
280+ # So we will use the base __eq__ in addition to checking
281+ # __dict__.
282+ return super ().__eq__ (other ) and self .__dict__ == other .__dict__
291283
292284 def __getattr__ (self , name ):
293- try :
294- return self .__dict__ [name ]
295- except :
296- pass
297285 try :
298286 self ._active = True
299287 return self [self ._convert (name )]
@@ -307,12 +295,8 @@ def __getattr__(self, name):
307295 )
308296
309297 def __setattr__ (self , name , val ):
310- if name == "__class__" :
311- self .__class__ = val
312- return
313298 if name [0 ] == "_" :
314- self .__dict__ [name ] = val
315- return
299+ return super ().__setattr__ (name , val )
316300 self ._active = True
317301 tmp = self ._convert (name )
318302 if tmp not in self :
@@ -341,12 +325,18 @@ def __setitem__(self, name, val):
341325 self ._set_value (tmp , val )
342326
343327 def _set_value (self , name , val ):
344- if isinstance (val , ListContainer ) or isinstance ( val , MapContainer ):
345- dict .__setitem__ (self , name , val )
328+ if isinstance (val , ( ListContainer , MapContainer ) ):
329+ super () .__setitem__ (name , val )
346330 elif isinstance (val , ScalarData ):
347- dict .__getitem__ (self , name ).value = val .value
331+ data = super ().__getitem__ (name )
332+ data .value = val .value
333+ data ._active = val ._active
334+ data ._required = val ._required
335+ data .scalar_type = val .scalar_type
348336 else :
349- dict .__getitem__ (self , name ).value = val
337+ data = super ().__getitem__ (name )
338+ data .value = val
339+ data ._active = True
350340
351341 def __getitem__ (self , name ):
352342 tmp = self ._convert (name )
@@ -357,25 +347,21 @@ def __getitem__(self, name):
357347 + "' for object with type "
358348 + str (type (self ))
359349 )
360- item = dict .__getitem__ (self , tmp )
361- if isinstance (item , ListContainer ) or isinstance ( item , MapContainer ):
350+ item = super () .__getitem__ (tmp )
351+ if isinstance (item , ( ListContainer , MapContainer ) ):
362352 return item
363353 return item .value
364354
365355 def declare (self , name , ** kwds ):
366356 if name in self or type (name ) is int :
367357 return
368- tmp = self ._convert (name )
369- self ._order .append (tmp )
370- if 'value' in kwds and (
371- isinstance (kwds ['value' ], MapContainer )
372- or isinstance (kwds ['value' ], ListContainer )
373- ):
358+ data = kwds .get ('value' , None )
359+ if isinstance (data , (MapContainer , ListContainer )):
374360 if 'active' in kwds :
375- kwds [ 'value' ] ._active = kwds ['active' ]
361+ data ._active = kwds ['active' ]
376362 if 'required' in kwds and kwds ['required' ] is True :
377- kwds [ 'value' ] ._required = True
378- dict .__setitem__ (self , tmp , kwds [ 'value' ] )
363+ data ._required = True
364+ super () .__setitem__ (self . _convert ( name ), data )
379365 else :
380366 data = ScalarData (** kwds )
381367 if 'required' in kwds and kwds ['required' ] is True :
@@ -387,23 +373,16 @@ def declare(self, name, **kwds):
387373 #
388374 # if 'value' in kwds:
389375 # data._default = kwds['value']
390- dict .__setitem__ (self , tmp , data )
376+ super () .__setitem__ (self . _convert ( name ) , data )
391377
392378 def _repn_ (self , option ):
393379 if not option .schema and not self ._active and not self ._required :
394380 return ignore
395- if self ._ordered :
396- tmp = []
397- for key in self ._order :
398- rep = dict .__getitem__ (self , key )._repn_ (option )
399- if not rep == ignore :
400- tmp .append ({key : rep })
401- else :
402- tmp = {}
403- for key in self .keys ():
404- rep = dict .__getitem__ (self , key )._repn_ (option )
405- if not rep == ignore :
406- tmp [key ] = rep
381+ tmp = {}
382+ for key , val in self .items ():
383+ rep = val ._repn_ (option )
384+ if not rep == ignore :
385+ tmp [key ] = rep
407386 return tmp
408387
409388 def _convert (self , name ):
@@ -417,7 +396,6 @@ def __repr__(self):
417396
418397 def __str__ (self ):
419398 ostream = StringIO ()
420- option = default_print_options
421399 self .pprint (ostream , self ._option , repn = self ._repn_ (self ._option ))
422400 return ostream .getvalue ()
423401
@@ -427,10 +405,9 @@ def pprint(self, ostream, option, from_list=False, prefix="", repn=None):
427405 else :
428406 _prefix = prefix
429407 ostream .write ('\n ' )
430- for key in self ._order :
408+ for key , item in self .items () :
431409 if not key in repn :
432410 continue
433- item = dict .__getitem__ (self , key )
434411 ostream .write (_prefix + key + ": " )
435412 _prefix = prefix
436413 if isinstance (item , ListContainer ):
@@ -439,46 +416,16 @@ def pprint(self, ostream, option, from_list=False, prefix="", repn=None):
439416 item .pprint (ostream , option , prefix = _prefix + " " , repn = repn [key ])
440417
441418 def load (self , repn ):
442- for key in repn :
419+ for key , val in repn . items () :
443420 tmp = self ._convert (key )
444421 if tmp not in self :
445422 self .declare (tmp )
446- item = dict .__getitem__ (self , tmp )
423+ item = super () .__getitem__ (tmp )
447424 item ._active = True
448- item .load (repn [key ])
449-
450- def __getnewargs__ (self ):
451- return (False , False )
452-
453- def __getstate__ (self ):
454- return copy .copy (self .__dict__ )
455-
456- def __setstate__ (self , state ):
457- self .__dict__ .update (state )
458-
459-
460- if __name__ == '__main__' :
461- d = MapContainer ()
462- d .declare ('f' )
463- d .declare ('g' )
464- d .declare ('h' )
465- d .declare ('i' , value = ListContainer (UndefinedData ))
466- d .declare ('j' , value = ListContainer (UndefinedData ), active = False )
467- print ("X" )
468- d .f = 1
469- print ("Y" )
470- print (d .f )
471- print (d .keys ())
472- d .g = None
473- print (d .keys ())
474- try :
475- print (d .f , d .g , d .h )
476- except :
477- pass
478- d ['h' ] = None
479- print ("" )
480- print ("FINAL" )
481- print (d .f , d .g , d .h , d .i , d .j )
482- print (d .i ._active , d .j ._active )
483- d .j .add ()
484- print (d .i ._active , d .j ._active )
425+ item .load (val )
426+
427+
428+ # Register these as sequence / mapping types (so things like
429+ # assertStructuredAlmostEqual will process them correctly)
430+ Sequence .register (ListContainer )
431+ Mapping .register (MapContainer )
0 commit comments