@@ -105,7 +105,9 @@ def validate(self, chains: Optional[int]) -> None:
105105 * length of per-chain lists equals specified # of chains
106106 """
107107 if not isinstance (chains , (int , np .integer )) or chains < 1 :
108- raise ValueError ('Sampler expects number of chains to be greater than 0.' )
108+ raise ValueError (
109+ 'Sampler expects number of chains to be greater than 0.'
110+ )
109111 if not (
110112 self .adapt_delta is None
111113 and self .adapt_init_phase is None
@@ -117,13 +119,17 @@ def validate(self, chains: Optional[int]) -> None:
117119 if self .adapt_delta is not None :
118120 msg = '{}, adapt_delta: {}' .format (msg , self .adapt_delta )
119121 if self .adapt_init_phase is not None :
120- msg = '{}, adapt_init_phase: {}' .format (msg , self .adapt_init_phase )
122+ msg = '{}, adapt_init_phase: {}' .format (
123+ msg , self .adapt_init_phase
124+ )
121125 if self .adapt_metric_window is not None :
122126 msg = '{}, adapt_metric_window: {}' .format (
123127 msg , self .adapt_metric_window
124128 )
125129 if self .adapt_step_size is not None :
126- msg = '{}, adapt_step_size: {}' .format (msg , self .adapt_step_size )
130+ msg = '{}, adapt_step_size: {}' .format (
131+ msg , self .adapt_step_size
132+ )
127133 raise ValueError (msg )
128134
129135 if self .iter_warmup is not None :
@@ -151,7 +157,9 @@ def validate(self, chains: Optional[int]) -> None:
151157 positive_int (self .max_treedepth , 'max_treedepth' )
152158
153159 if self .step_size is not None :
154- if isinstance (self .step_size , (float , int , np .integer , np .floating )):
160+ if isinstance (
161+ self .step_size , (float , int , np .integer , np .floating )
162+ ):
155163 if self .step_size <= 0 :
156164 raise ValueError (
157165 'Argument "step_size" must be > 0, found {}.' .format (
@@ -189,7 +197,9 @@ def validate(self, chains: Optional[int]) -> None:
189197 self .metric_file = self .metric
190198 elif isinstance (self .metric , dict ):
191199 if 'inv_metric' not in self .metric :
192- raise ValueError ('Entry "inv_metric" not found in metric dict.' )
200+ raise ValueError (
201+ 'Entry "inv_metric" not found in metric dict.'
202+ )
193203 dims = list (np .asarray (self .metric ['inv_metric' ]).shape )
194204 if len (dims ) == 1 :
195205 self .metric_type = 'diag_e'
@@ -218,14 +228,20 @@ def validate(self, chains: Optional[int]) -> None:
218228 'for chain {}.' .format (i + 1 )
219229 )
220230 if i == 0 :
221- dims = list (np .asarray (metric_dict ['inv_metric' ]).shape )
231+ dims = list (
232+ np .asarray (metric_dict ['inv_metric' ]).shape
233+ )
222234 else :
223- dims2 = list (np .asarray (metric_dict ['inv_metric' ]).shape )
235+ dims2 = list (
236+ np .asarray (metric_dict ['inv_metric' ]).shape
237+ )
224238 if dims != dims2 :
225239 raise ValueError (
226240 'Found inconsistent "inv_metric" entry '
227241 'for chain {}: entry has dims '
228- '{}, expected {}.' .format (i + 1 , dims , dims2 )
242+ '{}, expected {}.' .format (
243+ i + 1 , dims , dims2
244+ )
229245 )
230246 dict_file = create_named_text_file (
231247 dir = _TMPDIR , prefix = "metric" , suffix = ".json"
@@ -249,13 +265,15 @@ def validate(self, chains: Optional[int]) -> None:
249265 dims2 = read_metric (metric )
250266 if len (dims ) != len (dims2 ):
251267 raise ValueError (
252- 'Metrics files {}, {}, inconsistent metrics' .format (
268+ 'Metrics files {}, {},'
269+ ' inconsistent metrics' .format (
253270 self .metric [0 ], metric
254271 )
255272 )
256273 if dims != dims2 :
257274 raise ValueError (
258- 'Metrics files {}, {}, inconsistent metrics' .format (
275+ 'Metrics files {}, {},'
276+ ' inconsistent metrics' .format (
259277 self .metric [0 ], metric
260278 )
261279 )
@@ -268,7 +286,9 @@ def validate(self, chains: Optional[int]) -> None:
268286 else :
269287 raise ValueError (
270288 'Argument "metric" must be a list of pathnames or '
271- 'Python dicts, found list of {}.' .format (type (self .metric [0 ]))
289+ 'Python dicts, found list of {}.' .format (
290+ type (self .metric [0 ])
291+ )
272292 )
273293 else :
274294 raise ValueError (
@@ -281,9 +301,8 @@ def validate(self, chains: Optional[int]) -> None:
281301 if self .adapt_delta is not None :
282302 if not 0 < self .adapt_delta < 1 :
283303 raise ValueError (
284- 'Argument "adapt_delta" must be between 0 and 1, found {}' .format (
285- self .adapt_delta
286- )
304+ 'Argument "adapt_delta" must be between 0 and 1,'
305+ ' found {}' .format (self .adapt_delta )
287306 )
288307 if self .adapt_init_phase is not None :
289308 if self .adapt_init_phase < 0 or not isinstance (
@@ -437,7 +456,9 @@ def validate(self, _chains: Optional[int] = None) -> None:
437456 )
438457 if self .algorithm .lower () != 'lbfgs' :
439458 if self .history_size is not None :
440- raise ValueError ('history_size requires that algorithm be set to lbfgs' )
459+ raise ValueError (
460+ 'history_size requires that algorithm be set to lbfgs'
461+ )
441462
442463 positive_float (self .init_alpha , 'init_alpha' )
443464 positive_int (self .iter , 'iter' )
@@ -620,7 +641,9 @@ def validate(
620641 """
621642 for csv in self .sample_csv_files :
622643 if not os .path .exists (csv ):
623- raise ValueError ('Invalid path for sample csv file: {}' .format (csv ))
644+ raise ValueError (
645+ 'Invalid path for sample csv file: {}' .format (csv )
646+ )
624647
625648 def compose (self , idx : int , cmd : list [str ]) -> list [str ]:
626649 """
@@ -667,7 +690,10 @@ def validate(
667690 """
668691 Check arguments correctness and consistency.
669692 """
670- if self .algorithm is not None and self .algorithm not in self .VARIATIONAL_ALGOS :
693+ if (
694+ self .algorithm is not None
695+ and self .algorithm not in self .VARIATIONAL_ALGOS
696+ ):
671697 raise ValueError (
672698 'Please specify variational algorithms as one of [{}]' .format (
673699 ', ' .join (self .VARIATIONAL_ALGOS )
@@ -794,16 +820,19 @@ def validate(self) -> None:
794820 if chain_id < 1 :
795821 raise ValueError ('invalid chain_id {}' .format (chain_id ))
796822 if self .output_dir is not None :
797- self .output_dir = os .path .realpath (os .path .expanduser (self .output_dir ))
823+ self .output_dir = os .path .realpath (
824+ os .path .expanduser (self .output_dir )
825+ )
798826 if not os .path .exists (self .output_dir ):
799827 try :
800828 os .makedirs (self .output_dir )
801- get_logger ().info ('created output directory: %s' , self .output_dir )
829+ get_logger ().info (
830+ 'created output directory: %s' , self .output_dir
831+ )
802832 except (RuntimeError , PermissionError ) as exc :
803833 raise ValueError (
804- 'Invalid path for output files, no such dir: {}.' .format (
805- self .output_dir
806- )
834+ 'Invalid path for output files, '
835+ 'no such dir: {}.' .format (self .output_dir )
807836 ) from exc
808837 if not os .path .isdir (self .output_dir ):
809838 raise ValueError (
@@ -818,12 +847,14 @@ def validate(self) -> None:
818847 os .remove (testpath ) # cleanup
819848 except Exception as exc :
820849 raise ValueError (
821- 'Invalid path for output files, cannot write to dir: {}.' .format (
822- self .output_dir
823- )
850+ 'Invalid path for output files,'
851+ ' cannot write to dir: {}.' .format (self .output_dir )
824852 ) from exc
825853 if self .refresh is not None :
826- if not isinstance (self .refresh , (int , np .integer )) or self .refresh < 1 :
854+ if (
855+ not isinstance (self .refresh , (int , np .integer ))
856+ or self .refresh < 1
857+ ):
827858 raise ValueError (
828859 'Argument "refresh" must be a positive integer value, '
829860 'found {}.' .format (self .refresh )
@@ -895,7 +926,9 @@ def validate(self) -> None:
895926 if isinstance (self .inits , (float , int , np .floating , np .integer )):
896927 if self .inits < 0 :
897928 raise ValueError (
898- 'Argument "inits" must be > 0, found {}' .format (self .inits )
929+ 'Argument "inits" must be > 0, found {}' .format (
930+ self .inits
931+ )
899932 )
900933 elif isinstance (self .inits , str ):
901934 if not (
0 commit comments