@@ -89,7 +89,7 @@ def test_bernoulli_good(stanfile: str):
8989 assert bern_fit .draws ().shape == (100 , 2 , len (BERNOULLI_COLS ))
9090 assert bern_fit .metric_type == 'diag_e'
9191 assert bern_fit .step_size .shape == (2 ,)
92- assert bern_fit .metric .shape == (2 , 1 )
92+ assert bern_fit .inv_metric .shape == (2 , 1 )
9393
9494 assert bern_fit .draws (concat_chains = True ).shape == (
9595 200 ,
@@ -125,7 +125,7 @@ def test_bernoulli_good(stanfile: str):
125125 assert bern_sample .shape == (100 , 2 , len (BERNOULLI_COLS ))
126126 assert bern_fit .metric_type == 'dense_e'
127127 assert bern_fit .step_size .shape == (2 ,)
128- assert bern_fit .metric .shape == (2 , 1 , 1 )
128+ assert bern_fit .inv_metric .shape == (2 , 1 , 1 )
129129
130130 bern_fit = bern_model .sample (
131131 data = jdata ,
@@ -186,9 +186,7 @@ def test_bernoulli_good(stanfile: str):
186186
187187
188188@pytest .mark .parametrize ("stanfile" , ["bernoulli.stan" ])
189- def test_bernoulli_unit_e (
190- stanfile : str , caplog : pytest .LogCaptureFixture
191- ) -> None :
189+ def test_bernoulli_unit_e (stanfile : str ) -> None :
192190 stan = os .path .join (DATAFILES_PATH , stanfile )
193191 bern_model = CmdStanModel (stan_file = stan )
194192
@@ -204,19 +202,9 @@ def test_bernoulli_unit_e(
204202 show_progress = False ,
205203 )
206204 assert bern_fit .metric_type == 'unit_e'
207- assert bern_fit .metric is None
205+ assert bern_fit .inv_metric is None
208206 assert bern_fit .step_size .shape == (2 ,)
209- with caplog .at_level (logging .INFO ):
210- logging .getLogger ()
211- assert bern_fit .metric is None
212- check_present (
213- caplog ,
214- (
215- 'cmdstanpy' ,
216- 'INFO' ,
217- 'Unit diagnonal metric, inverse mass matrix size unknown.' ,
218- ),
219- )
207+
220208 assert bern_fit .draws ().shape == (100 , 2 , len (BERNOULLI_COLS ))
221209
222210
@@ -535,7 +523,7 @@ def test_fixed_param_good() -> None:
535523 )
536524 assert datagen_fit .runset ._args .method == Method .SAMPLE
537525 assert datagen_fit .metric_type is None
538- assert datagen_fit .metric is None
526+ assert datagen_fit .inv_metric is None
539527 assert datagen_fit .step_size is None
540528 assert datagen_fit .divergences is None
541529 assert datagen_fit .max_treedepths is None
@@ -638,7 +626,7 @@ def test_fixed_param_good() -> None:
638626 assert datagen_fit .column_names == tuple (column_names )
639627 assert datagen_fit .num_draws_sampling == 100
640628 assert datagen_fit .draws ().shape == (100 , 1 , len (column_names ))
641- assert datagen_fit .metric is None
629+ assert datagen_fit .inv_metric is None
642630 assert datagen_fit .metric_type is None
643631 assert datagen_fit .step_size is None
644632
@@ -860,7 +848,7 @@ def test_validate_big_run() -> None:
860848 assert fit .column_names == tuple (column_names )
861849 assert fit .metric_type == 'diag_e'
862850 assert fit .step_size .shape == (2 ,)
863- assert fit .metric .shape == (2 , 2095 )
851+ assert fit .inv_metric .shape == (2 , 2095 )
864852 assert fit .draws ().shape == (1000 , 2 , 2102 )
865853 assert fit .draws_pd (vars = ['phi' ]).shape == (2000 , 2095 )
866854 with raises_nested (ValueError , r'Unknown variable: gamma' ):
@@ -2136,8 +2124,8 @@ def test_sample_dense_mass_matrix():
21362124 linear_model = CmdStanModel (stan_file = stan )
21372125
21382126 fit = linear_model .sample (data = jdata , metric = "dense_e" , chains = 2 )
2139- assert fit .metric is not None
2140- assert fit .metric .shape == (2 , 3 , 3 )
2127+ assert fit .inv_metric is not None
2128+ assert fit .inv_metric .shape == (2 , 3 , 3 )
21412129
21422130
21432131def test_no_output_draws ():
0 commit comments