55import math
66import os
77from io import StringIO
8- from typing import Any , Hashable , MutableMapping , Optional , Sequence , Union
8+ from typing import Any , Hashable , MutableMapping , Sequence
99
1010import numpy as np
1111import pandas as pd
@@ -97,8 +97,8 @@ def __init__(
9797 self ._check_sampler_diagnostics ()
9898
9999 def create_inits (
100- self , seed : Optional [ int ] = None , chains : int = 4
101- ) -> Union [ list [dict [str , np .ndarray ]], dict [str , np .ndarray ] ]:
100+ self , seed : int | None = None , chains : int = 4
101+ ) -> list [dict [str , np .ndarray ]] | dict [str , np .ndarray ]:
102102 """
103103 Create initial values for the parameters of the model by
104104 randomly selecting draws from the MCMC samples. If the samples
@@ -211,7 +211,7 @@ def column_names(self) -> tuple[str, ...]:
211211 return self ._metadata .column_names
212212
213213 @property
214- def metric_type (self ) -> Optional [ str ] :
214+ def metric_type (self ) -> str | None :
215215 """
216216 Metric type used for adaptation, either 'diag_e' or 'dense_e', according
217217 to CmdStan arg 'metric'.
@@ -225,7 +225,7 @@ def metric_type(self) -> Optional[str]:
225225
226226 # TODO(2.0): remove
227227 @property
228- def metric (self ) -> Optional [ np .ndarray ] :
228+ def metric (self ) -> np .ndarray | None :
229229 """Deprecated. Use ``.inv_metric`` instead."""
230230 get_logger ().warning (
231231 'The "metric" property is deprecated, use "inv_metric" instead. '
@@ -234,7 +234,7 @@ def metric(self) -> Optional[np.ndarray]:
234234 return self .inv_metric
235235
236236 @property
237- def inv_metric (self ) -> Optional [ np .ndarray ] :
237+ def inv_metric (self ) -> np .ndarray | None :
238238 """
239239 Inverse mass matrix used by sampler for each chain.
240240 Returns a ``nchains x nparams`` array when metric_type is 'diag_e',
@@ -248,7 +248,7 @@ def inv_metric(self) -> Optional[np.ndarray]:
248248 return self ._metric
249249
250250 @property
251- def step_size (self ) -> Optional [ np .ndarray ] :
251+ def step_size (self ) -> np .ndarray | None :
252252 """
253253 Step size used by sampler for each chain.
254254 When sampler algorithm 'fixed_param' is specified, step size is None.
@@ -264,15 +264,15 @@ def thin(self) -> int:
264264 return self ._thin
265265
266266 @property
267- def divergences (self ) -> Optional [ np .ndarray ] :
267+ def divergences (self ) -> np .ndarray | None :
268268 """
269269 Per-chain total number of post-warmup divergent iterations.
270270 When sampler algorithm 'fixed_param' is specified, returns None.
271271 """
272272 return self ._divergences if not self ._is_fixed_param else None
273273
274274 @property
275- def max_treedepths (self ) -> Optional [ np .ndarray ] :
275+ def max_treedepths (self ) -> np .ndarray | None :
276276 """
277277 Per-chain total number of post-warmup iterations where the NUTS sampler
278278 reached the maximum allowed treedepth.
@@ -564,7 +564,7 @@ def summary(
564564 summary_data .index .name = None
565565 return summary_data [mask ]
566566
567- def diagnose (self ) -> Optional [ str ] :
567+ def diagnose (self ) -> str | None :
568568 """
569569 Run cmdstan/bin/diagnose over all output CSV files,
570570 return console output.
@@ -586,7 +586,7 @@ def diagnose(self) -> Optional[str]:
586586
587587 def draws_pd (
588588 self ,
589- vars : Union [ list [str ], str , None ] = None ,
589+ vars : list [str ] | str | None = None ,
590590 inc_warmup : bool = False ,
591591 ) -> pd .DataFrame :
592592 """
@@ -664,7 +664,7 @@ def draws_pd(
664664 )[cols ]
665665
666666 def draws_xr (
667- self , vars : Union [ str , list [str ], None ] = None , inc_warmup : bool = False
667+ self , vars : str | list [str ] | None = None , inc_warmup : bool = False
668668 ) -> "xr.Dataset" :
669669 """
670670 Returns the sampler draws as a xarray Dataset.
@@ -822,7 +822,7 @@ def method_variables(self) -> dict[str, np.ndarray]:
822822 for name , var in self ._metadata .method_vars .items ()
823823 }
824824
825- def save_csvfiles (self , dir : Optional [ str ] = None ) -> None :
825+ def save_csvfiles (self , dir : str | None = None ) -> None :
826826 """
827827 Move output CSV files to specified directory. If files were
828828 written to the temporary session directory, clean filename.
0 commit comments