File tree Expand file tree Collapse file tree 6 files changed +71
-0
lines changed
Expand file tree Collapse file tree 6 files changed +71
-0
lines changed Original file line number Diff line number Diff line change 4343 get_logger ,
4444 scan_generated_quantities_csv ,
4545)
46+
4647from .metadata import InferenceMetadata
4748from .runset import RunSet
4849
@@ -122,6 +123,8 @@ def __repr__(self) -> str:
122123
123124 def __getattr__ (self , attr : str ) -> np .ndarray :
124125 """Synonymous with ``fit.stan_variable(attr)"""
126+ if attr .startswith ("_" ):
127+ raise AttributeError (f"Unknown variable name { attr } " )
125128 try :
126129 return self .stan_variable (attr )
127130 except ValueError as e :
@@ -833,6 +836,8 @@ def __repr__(self) -> str:
833836
834837 def __getattr__ (self , attr : str ) -> np .ndarray :
835838 """Synonymous with ``fit.stan_variable(attr)"""
839+ if attr .startswith ("_" ):
840+ raise AttributeError (f"Unknown variable name { attr } " )
836841 try :
837842 return self .stan_variable (attr )
838843 except ValueError as e :
Original file line number Diff line number Diff line change @@ -52,6 +52,8 @@ def __repr__(self) -> str:
5252
5353 def __getattr__ (self , attr : str ) -> Union [np .ndarray , float ]:
5454 """Synonymous with ``fit.stan_variable(attr)"""
55+ if attr .startswith ("_" ):
56+ raise AttributeError (f"Unknown variable name { attr } " )
5557 try :
5658 return self .stan_variable (attr )
5759 except ValueError as e :
Original file line number Diff line number Diff line change @@ -43,6 +43,8 @@ def __repr__(self) -> str:
4343
4444 def __getattr__ (self , attr : str ) -> Union [np .ndarray , float ]:
4545 """Synonymous with ``fit.stan_variable(attr)"""
46+ if attr .startswith ("_" ):
47+ raise AttributeError (f"Unknown variable name { attr } " )
4648 try :
4749 return self .stan_variable (attr )
4850 except ValueError as e :
Original file line number Diff line number Diff line change @@ -627,6 +627,28 @@ def test_attrs(self):
627627 with self .assertRaisesRegex (AttributeError , 'Unknown variable name:' ):
628628 dummy = fit .c
629629
630+ def test_pickle_ability (self ):
631+ """Ensure fit objects are pickle-able and copy-able"""
632+
633+ import pickle
634+
635+ csvfiles_path = os .path .join (
636+ DATAFILES_PATH , 'optimize' , 'rosenbrock_mle.csv'
637+ )
638+ fit = from_csv (path = csvfiles_path )
639+ pickled = pickle .dumps (fit )
640+ unpickled = pickle .loads (pickled )
641+ self .assertSequenceEqual (
642+ fit .stan_variables ().keys (), unpickled .stan_variables ().keys ()
643+ )
644+
645+ import copy
646+
647+ fit2 = copy .deepcopy (fit )
648+ self .assertSequenceEqual (
649+ fit .stan_variables ().keys (), fit2 .stan_variables ().keys ()
650+ )
651+
630652
631653if __name__ == '__main__' :
632654 unittest .main ()
Original file line number Diff line number Diff line change @@ -1881,6 +1881,26 @@ def test_diagnostics(self):
18811881 self .assertEqual (fit .max_treedepths , None )
18821882 self .assertEqual (fit .divergences , None )
18831883
1884+ def test_pickle_ability (self ):
1885+ """Ensure fit objects are pickle-able and copy-able"""
1886+
1887+ import pickle
1888+
1889+ csvfiles_path = os .path .join (DATAFILES_PATH , 'lotka-volterra.csv' )
1890+ fit = from_csv (path = csvfiles_path )
1891+ pickled = pickle .dumps (fit )
1892+ unpickled = pickle .loads (pickled )
1893+ self .assertSequenceEqual (
1894+ fit .stan_variables ().keys (), unpickled .stan_variables ().keys ()
1895+ )
1896+
1897+ import copy
1898+
1899+ fit2 = copy .deepcopy (fit )
1900+ self .assertSequenceEqual (
1901+ fit .stan_variables ().keys (), fit2 .stan_variables ().keys ()
1902+ )
1903+
18841904
18851905if __name__ == '__main__' :
18861906 unittest .main ()
Original file line number Diff line number Diff line change @@ -284,6 +284,26 @@ def test_attrs(self):
284284 with self .assertRaisesRegex (AttributeError , 'Unknown variable name:' ):
285285 dummy = fit .c
286286
287+ def test_pickle_ability (self ):
288+ """Ensure fit objects are pickle-able and copy-able"""
289+
290+ import pickle
291+
292+ csvfiles_path = os .path .join (DATAFILES_PATH , 'variational' )
293+ fit = from_csv (path = csvfiles_path )
294+ pickled = pickle .dumps (fit )
295+ unpickled = pickle .loads (pickled )
296+ self .assertSequenceEqual (
297+ fit .stan_variables ().keys (), unpickled .stan_variables ().keys ()
298+ )
299+
300+ import copy
301+
302+ fit2 = copy .deepcopy (fit )
303+ self .assertSequenceEqual (
304+ fit .stan_variables ().keys (), fit2 .stan_variables ().keys ()
305+ )
306+
287307
288308if __name__ == '__main__' :
289309 unittest .main ()
You can’t perform that action at this time.
0 commit comments