55import os
66from enum import Enum , auto
77from time import time
8- from typing import Any , List , Mapping , Optional , Union
8+ from typing import Any , Mapping , Optional , Union
99
1010import numpy as np
1111from numpy .random import default_rng
@@ -58,8 +58,8 @@ def __init__(
5858 thin : Optional [int ] = None ,
5959 max_treedepth : Optional [int ] = None ,
6060 metric_type : Optional [str ] = None ,
61- metric_file : Union [str , List [str ], None ] = None ,
62- step_size : Union [float , List [float ], None ] = None ,
61+ metric_file : Union [str , list [str ], None ] = None ,
62+ step_size : Union [float , list [float ], None ] = None ,
6363 adapt_engaged : bool = True ,
6464 adapt_delta : Optional [float ] = None ,
6565 adapt_init_phase : Optional [int ] = None ,
@@ -75,7 +75,7 @@ def __init__(
7575 self .thin = thin
7676 self .max_treedepth = max_treedepth
7777 self .metric_type : Optional [str ] = metric_type
78- self .metric_file : Union [str , List [str ], None ] = metric_file
78+ self .metric_file : Union [str , list [str ], None ] = metric_file
7979 self .step_size = step_size
8080 self .adapt_engaged = adapt_engaged
8181 self .adapt_delta = adapt_delta
@@ -152,8 +152,9 @@ def validate(self, chains: Optional[int]) -> None:
152152 ):
153153 if self .step_size <= 0 :
154154 raise ValueError (
155- 'Argument "step_size" must be > 0, '
156- 'found {}.' .format (self .step_size )
155+ 'Argument "step_size" must be > 0, found {}.' .format (
156+ self .step_size
157+ )
157158 )
158159 else :
159160 if len (self .step_size ) != chains :
@@ -226,7 +227,7 @@ def validate(self, chains: Optional[int]) -> None:
226227 'When fixed_param=True, cannot specify adaptation parameters.'
227228 )
228229
229- def compose (self , idx : int , cmd : List [str ]) -> List [str ]:
230+ def compose (self , idx : int , cmd : list [str ]) -> list [str ]:
230231 """
231232 Compose CmdStan command for method-specific non-default arguments.
232233 """
@@ -350,7 +351,7 @@ def validate(self, _chains: Optional[int] = None) -> None:
350351 positive_float (self .tol_param , 'tol_param' )
351352 positive_int (self .history_size , 'history_size' )
352353
353- def compose (self , _idx : int , cmd : List [str ]) -> List [str ]:
354+ def compose (self , _idx : int , cmd : list [str ]) -> list [str ]:
354355 """compose command string for CmdStan for non-default arg values."""
355356 cmd .append ('method=optimize' )
356357 if self .algorithm :
@@ -394,7 +395,7 @@ def validate(self, _chains: Optional[int] = None) -> None:
394395 raise ValueError (f'Invalid path for mode file: { self .mode } ' )
395396 positive_int (self .draws , 'draws' )
396397
397- def compose (self , _idx : int , cmd : List [str ]) -> List [str ]:
398+ def compose (self , _idx : int , cmd : list [str ]) -> list [str ]:
398399 """compose command string for CmdStan for non-default arg values."""
399400 cmd .append ('method=laplace' )
400401 cmd .append (f'mode={ self .mode } ' )
@@ -462,7 +463,7 @@ def validate(self, _chains: Optional[int] = None) -> None:
462463 positive_int (self .num_draws , 'num_draws' )
463464 positive_int (self .num_elbo_draws , 'num_elbo_draws' )
464465
465- def compose (self , _idx : int , cmd : List [str ]) -> List [str ]:
466+ def compose (self , _idx : int , cmd : list [str ]) -> list [str ]:
466467 """compose command string for CmdStan for non-default arg values."""
467468 cmd .append ('method=pathfinder' )
468469
@@ -507,12 +508,13 @@ def compose(self, _idx: int, cmd: List[str]) -> List[str]:
507508class GenerateQuantitiesArgs :
508509 """Arguments needed for generate_quantities method."""
509510
510- def __init__ (self , csv_files : List [str ]) -> None :
511+ def __init__ (self , csv_files : list [str ]) -> None :
511512 """Initialize object."""
512513 self .sample_csv_files = csv_files
513514
514515 def validate (
515- self , chains : Optional [int ] = None # pylint: disable=unused-argument
516+ self ,
517+ chains : Optional [int ] = None , # pylint: disable=unused-argument
516518 ) -> None :
517519 """
518520 Check arguments correctness and consistency.
@@ -525,7 +527,7 @@ def validate(
525527 'Invalid path for sample csv file: {}' .format (csv )
526528 )
527529
528- def compose (self , idx : int , cmd : List [str ]) -> List [str ]:
530+ def compose (self , idx : int , cmd : list [str ]) -> list [str ]:
529531 """
530532 Compose CmdStan command for method-specific non-default arguments.
531533 """
@@ -564,7 +566,8 @@ def __init__(
564566 self .output_samples = output_samples
565567
566568 def validate (
567- self , chains : Optional [int ] = None # pylint: disable=unused-argument
569+ self ,
570+ chains : Optional [int ] = None , # pylint: disable=unused-argument
568571 ) -> None :
569572 """
570573 Check arguments correctness and consistency.
@@ -588,7 +591,7 @@ def validate(
588591 positive_int (self .output_samples , 'output_samples' )
589592
590593 # pylint: disable=unused-argument
591- def compose (self , idx : int , cmd : List [str ]) -> List [str ]:
594+ def compose (self , idx : int , cmd : list [str ]) -> list [str ]:
592595 """
593596 Compose CmdStan command for method-specific non-default arguments.
594597 """
@@ -630,7 +633,7 @@ def __init__(
630633 self ,
631634 model_name : str ,
632635 model_exe : OptionalPath ,
633- chain_ids : Optional [List [int ]],
636+ chain_ids : Optional [list [int ]],
634637 method_args : Union [
635638 SamplerArgs ,
636639 OptimizeArgs ,
@@ -640,8 +643,8 @@ def __init__(
640643 PathfinderArgs ,
641644 ],
642645 data : Union [Mapping [str , Any ], str , None ] = None ,
643- seed : Union [int , List [int ], None ] = None ,
644- inits : Union [int , float , str , List [str ], None ] = None ,
646+ seed : Union [int , list [int ], None ] = None ,
647+ inits : Union [int , float , str , list [str ], None ] = None ,
645648 output_dir : OptionalPath = None ,
646649 sig_figs : Optional [int ] = None ,
647650 save_latent_dynamics : bool = False ,
@@ -842,11 +845,11 @@ def compose_command(
842845 * ,
843846 diagnostic_file : Optional [str ] = None ,
844847 profile_file : Optional [str ] = None ,
845- ) -> List [str ]:
848+ ) -> list [str ]:
846849 """
847850 Compose CmdStan command for non-default arguments.
848851 """
849- cmd : List [str ] = []
852+ cmd : list [str ] = []
850853 if idx is not None and self .chain_ids is not None :
851854 if idx < 0 or idx > len (self .chain_ids ) - 1 :
852855 raise ValueError (
0 commit comments