1515from cmdstanpy .cmdstan_args import CmdStanArgs , VariationalArgs
1616from cmdstanpy .model import CmdStanModel
1717from cmdstanpy .stanfit import CmdStanVB , RunSet , from_csv
18+ from cmdstanpy .utils .cmdstan import cmdstan_version_before
1819
1920HERE = os .path .dirname (os .path .abspath (__file__ ))
2021DATAFILES_PATH = os .path .join (HERE , 'data' )
@@ -150,12 +151,12 @@ def test_variational_good() -> None:
150151 'mu[2]' ,
151152 )
152153 # fixed seed, id=1 by default will give known output values
153- assert variational .eta == 100
154+ assert variational .eta == 10
154155 np .testing .assert_almost_equal (
155- variational .variational_params_dict ['mu[1]' ], 311.545 , decimal = 2
156+ variational .variational_params_dict ['mu[1]' ], 302.142 , decimal = 2
156157 )
157158 np .testing .assert_almost_equal (
158- variational .variational_params_dict ['mu[2]' ], 532.801 , decimal = 2
159+ variational .variational_params_dict ['mu[2]' ], 361.005 , decimal = 2
159160 )
160161 np .testing .assert_almost_equal (
161162 variational .variational_params_np [0 ],
@@ -177,7 +178,11 @@ def test_variational_eta_small() -> None:
177178 DATAFILES_PATH , 'variational' , 'eta_should_be_small.stan'
178179 )
179180 model = CmdStanModel (stan_file = stan )
180- variational = model .variational (algorithm = 'meanfield' , seed = 12345 )
181+ if cmdstan_version_before (2 , 35 ):
182+ seed = 12345
183+ else :
184+ seed = 1234
185+ variational = model .variational (algorithm = 'meanfield' , seed = seed )
181186 assert variational .column_names == (
182187 'lp__' ,
183188 'log_p__' ,
0 commit comments