11"""
22CmdStan arguments
33"""
4+
45import os
56from enum import Enum , auto
67from time import time
7- from typing import Any , Dict , List , Mapping , Optional , Union
8+ from typing import Any , Mapping , Optional , Union
89
910import numpy as np
1011from numpy .random import default_rng
@@ -65,9 +66,9 @@ def __init__(
6566 thin : Optional [int ] = None ,
6667 max_treedepth : Optional [int ] = None ,
6768 metric : Union [
68- str , Dict [str , Any ], List [str ], List [ Dict [str , Any ]], None
69+ str , dict [str , Any ], list [str ], list [ dict [str , Any ]], None
6970 ] = None ,
70- step_size : Union [float , List [float ], None ] = None ,
71+ step_size : Union [float , list [float ], None ] = None ,
7172 adapt_engaged : bool = True ,
7273 adapt_delta : Optional [float ] = None ,
7374 adapt_init_phase : Optional [int ] = None ,
@@ -84,7 +85,7 @@ def __init__(
8485 self .max_treedepth = max_treedepth
8586 self .metric = metric
8687 self .metric_type : Optional [str ] = None
87- self .metric_file : Union [str , List [str ], None ] = None
88+ self .metric_file : Union [str , list [str ], None ] = None
8889 self .step_size = step_size
8990 self .adapt_engaged = adapt_engaged
9091 self .adapt_delta = adapt_delta
@@ -161,8 +162,9 @@ def validate(self, chains: Optional[int]) -> None:
161162 ):
162163 if self .step_size <= 0 :
163164 raise ValueError (
164- 'Argument "step_size" must be > 0, '
165- 'found {}.' .format (self .step_size )
165+ 'Argument "step_size" must be > 0, found {}.' .format (
166+ self .step_size
167+ )
166168 )
167169 else :
168170 if len (self .step_size ) != chains :
@@ -217,9 +219,9 @@ def validate(self, chains: Optional[int]) -> None:
217219 )
218220 )
219221 if all (isinstance (elem , dict ) for elem in self .metric ):
220- metric_files : List [str ] = []
222+ metric_files : list [str ] = []
221223 for i , metric in enumerate (self .metric ):
222- metric_dict : Dict [str , Any ] = metric # type: ignore
224+ metric_dict : dict [str , Any ] = metric # type: ignore
223225 if 'inv_metric' not in metric_dict :
224226 raise ValueError (
225227 'Entry "inv_metric" not found in metric dict '
@@ -343,7 +345,7 @@ def validate(self, chains: Optional[int]) -> None:
343345 'When fixed_param=True, cannot specify adaptation parameters.'
344346 )
345347
346- def compose (self , idx : int , cmd : List [str ]) -> List [str ]:
348+ def compose (self , idx : int , cmd : list [str ]) -> list [str ]:
347349 """
348350 Compose CmdStan command for method-specific non-default arguments.
349351 """
@@ -467,7 +469,7 @@ def validate(self, _chains: Optional[int] = None) -> None:
467469 positive_float (self .tol_param , 'tol_param' )
468470 positive_int (self .history_size , 'history_size' )
469471
470- def compose (self , _idx : int , cmd : List [str ]) -> List [str ]:
472+ def compose (self , _idx : int , cmd : list [str ]) -> list [str ]:
471473 """compose command string for CmdStan for non-default arg values."""
472474 cmd .append ('method=optimize' )
473475 if self .algorithm :
@@ -511,7 +513,7 @@ def validate(self, _chains: Optional[int] = None) -> None:
511513 raise ValueError (f'Invalid path for mode file: { self .mode } ' )
512514 positive_int (self .draws , 'draws' )
513515
514- def compose (self , _idx : int , cmd : List [str ]) -> List [str ]:
516+ def compose (self , _idx : int , cmd : list [str ]) -> list [str ]:
515517 """compose command string for CmdStan for non-default arg values."""
516518 cmd .append ('method=laplace' )
517519 cmd .append (f'mode={ self .mode } ' )
@@ -579,7 +581,7 @@ def validate(self, _chains: Optional[int] = None) -> None:
579581 positive_int (self .num_draws , 'num_draws' )
580582 positive_int (self .num_elbo_draws , 'num_elbo_draws' )
581583
582- def compose (self , _idx : int , cmd : List [str ]) -> List [str ]:
584+ def compose (self , _idx : int , cmd : list [str ]) -> list [str ]:
583585 """compose command string for CmdStan for non-default arg values."""
584586 cmd .append ('method=pathfinder' )
585587
@@ -624,12 +626,13 @@ def compose(self, _idx: int, cmd: List[str]) -> List[str]:
624626class GenerateQuantitiesArgs :
625627 """Arguments needed for generate_quantities method."""
626628
627- def __init__ (self , csv_files : List [str ]) -> None :
629+ def __init__ (self , csv_files : list [str ]) -> None :
628630 """Initialize object."""
629631 self .sample_csv_files = csv_files
630632
631633 def validate (
632- self , chains : Optional [int ] = None # pylint: disable=unused-argument
634+ self ,
635+ chains : Optional [int ] = None , # pylint: disable=unused-argument
633636 ) -> None :
634637 """
635638 Check arguments correctness and consistency.
@@ -642,7 +645,7 @@ def validate(
642645 'Invalid path for sample csv file: {}' .format (csv )
643646 )
644647
645- def compose (self , idx : int , cmd : List [str ]) -> List [str ]:
648+ def compose (self , idx : int , cmd : list [str ]) -> list [str ]:
646649 """
647650 Compose CmdStan command for method-specific non-default arguments.
648651 """
@@ -681,7 +684,8 @@ def __init__(
681684 self .output_samples = output_samples
682685
683686 def validate (
684- self , chains : Optional [int ] = None # pylint: disable=unused-argument
687+ self ,
688+ chains : Optional [int ] = None , # pylint: disable=unused-argument
685689 ) -> None :
686690 """
687691 Check arguments correctness and consistency.
@@ -705,7 +709,7 @@ def validate(
705709 positive_int (self .output_samples , 'output_samples' )
706710
707711 # pylint: disable=unused-argument
708- def compose (self , idx : int , cmd : List [str ]) -> List [str ]:
712+ def compose (self , idx : int , cmd : list [str ]) -> list [str ]:
709713 """
710714 Compose CmdStan command for method-specific non-default arguments.
711715 """
@@ -747,7 +751,7 @@ def __init__(
747751 self ,
748752 model_name : str ,
749753 model_exe : OptionalPath ,
750- chain_ids : Optional [List [int ]],
754+ chain_ids : Optional [list [int ]],
751755 method_args : Union [
752756 SamplerArgs ,
753757 OptimizeArgs ,
@@ -757,8 +761,8 @@ def __init__(
757761 PathfinderArgs ,
758762 ],
759763 data : Union [Mapping [str , Any ], str , None ] = None ,
760- seed : Union [int , List [int ], None ] = None ,
761- inits : Union [int , float , str , List [str ], None ] = None ,
764+ seed : Union [int , list [int ], None ] = None ,
765+ inits : Union [int , float , str , list [str ], None ] = None ,
762766 output_dir : OptionalPath = None ,
763767 sig_figs : Optional [int ] = None ,
764768 save_latent_dynamics : bool = False ,
@@ -959,11 +963,11 @@ def compose_command(
959963 * ,
960964 diagnostic_file : Optional [str ] = None ,
961965 profile_file : Optional [str ] = None ,
962- ) -> List [str ]:
966+ ) -> list [str ]:
963967 """
964968 Compose CmdStan command for non-default arguments.
965969 """
966- cmd : List [str ] = []
970+ cmd : list [str ] = []
967971 if idx is not None and self .chain_ids is not None :
968972 if idx < 0 or idx > len (self .chain_ids ) - 1 :
969973 raise ValueError (
0 commit comments