Skip to content

Commit 341ea60

Browse files
committed
CmdStan 2.36 no longer requires fixed_param hacks
1 parent c5bcfb3 commit 341ea60

File tree

3 files changed

+31
-43
lines changed

3 files changed

+31
-43
lines changed

cmdstanpy/model.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -205,9 +205,9 @@ def __init__(
205205
self._compiler_options.add_include_path(path)
206206

207207
# try to detect models w/out parameters, needed for sampler
208-
if not cmdstan_version_before(
209-
2, 27
210-
): # unknown end of version range
208+
if not cmdstan_version_before(2, 27) and cmdstan_version_before(
209+
2, 36
210+
):
211211
try:
212212
model_info = self.src_info()
213213
if 'parameters' in model_info:

cmdstanpy/stanfit/mcmc.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -412,7 +412,7 @@ def _assemble_draws(self) -> None:
412412
self._step_size[chain] = float(step_size.strip())
413413
if self._metadata.cmdstan_config['metric'] != 'unit_e':
414414
line = fd.readline().strip() # metric type
415-
line = fd.readline().lstrip(' #\t')
415+
line = fd.readline().lstrip(' #\t').rstrip()
416416
num_unconstrained_params = len(line.split(','))
417417
if chain == 0: # can't allocate w/o num params
418418
if self.metric_type == 'diag_e':
@@ -429,18 +429,21 @@ def _assemble_draws(self) -> None:
429429
),
430430
dtype=float,
431431
)
432-
if self.metric_type == 'diag_e':
433-
xs = line.split(',')
434-
self._metric[chain, :] = [float(x) for x in xs]
435-
else:
436-
xs = line.split(',')
437-
self._metric[chain, 0, :] = [float(x) for x in xs]
438-
for i in range(1, num_unconstrained_params):
439-
line = fd.readline().lstrip(' #\t').strip()
432+
if line:
433+
if self.metric_type == 'diag_e':
440434
xs = line.split(',')
441-
self._metric[chain, i, :] = [
435+
self._metric[chain, :] = [float(x) for x in xs]
436+
else:
437+
xs = line.strip().split(',')
438+
self._metric[chain, 0, :] = [
442439
float(x) for x in xs
443440
]
441+
for i in range(1, num_unconstrained_params):
442+
line = fd.readline().lstrip(' #\t').rstrip()
443+
xs = line.split(',')
444+
self._metric[chain, i, :] = [
445+
float(x) for x in xs
446+
]
444447
else: # unit_e changed in 2.34 to have an extra line
445448
pos = fd.tell()
446449
line = fd.readline().strip()

test/test_sample.py

Lines changed: 15 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -642,23 +642,23 @@ def test_fixed_param_good() -> None:
642642
assert datagen_fit.step_size is None
643643

644644

645-
def test_fixed_param_unspecified() -> None:
645+
def test_sample_no_params() -> None:
646646
stan = os.path.join(DATAFILES_PATH, 'datagen_poisson_glm.stan')
647647
datagen_model = CmdStanModel(stan_file=stan)
648648
datagen_fit = datagen_model.sample(iter_sampling=100, show_progress=False)
649-
assert datagen_fit.step_size is None
649+
assert np.isnan(datagen_fit.step_size).all()
650650
summary = datagen_fit.summary()
651-
assert 'lp__' not in list(summary.index)
651+
assert 'lp__' in list(summary.index)
652652

653653
exe_only = os.path.join(DATAFILES_PATH, 'exe_only')
654654
shutil.copyfile(datagen_model.exe_file, exe_only)
655655
os.chmod(exe_only, 0o755)
656656
datagen2_model = CmdStanModel(exe_file=exe_only)
657657
datagen2_fit = datagen2_model.sample(iter_sampling=200, show_console=True)
658658
assert datagen2_fit.chains == 4
659-
assert datagen2_fit.step_size is None
659+
assert np.isnan(datagen2_fit.step_size).all()
660660
summary = datagen2_fit.summary()
661-
assert 'lp__' not in list(summary.index)
661+
assert 'lp__' in list(summary.index)
662662

663663

664664
def test_index_bounds_error() -> None:
@@ -823,7 +823,7 @@ def test_validate_good_run() -> None:
823823
assert 'Treedepth satisfactory for all transitions.' in diagnostics
824824
assert 'No divergent transitions found.' in diagnostics
825825
assert 'E-BFMI satisfactory' in diagnostics
826-
assert 'Effective sample size satisfactory.' in diagnostics
826+
assert 'effective sample size satisfactory' in diagnostics.lower()
827827

828828

829829
def test_validate_big_run() -> None:
@@ -1621,33 +1621,18 @@ def test_validate_sample_sig_figs(stanfile='bernoulli.stan'):
16211621

16221622

16231623
def test_validate_summary_sig_figs() -> None:
1624-
# construct CmdStanMCMC from logistic model output, config
1625-
exe = os.path.join(DATAFILES_PATH, 'logistic' + EXTENSION)
1626-
rdata = os.path.join(DATAFILES_PATH, 'logistic.data.R')
1627-
sampler_args = SamplerArgs(iter_sampling=100)
1628-
cmdstan_args = CmdStanArgs(
1629-
model_name='logistic',
1630-
model_exe=exe,
1631-
chain_ids=[1, 2, 3, 4],
1632-
seed=12345,
1633-
data=rdata,
1634-
output_dir=DATAFILES_PATH,
1635-
sig_figs=17,
1636-
method_args=sampler_args,
1624+
# construct CmdStanMCMC from logistic model output
1625+
fit = from_csv(
1626+
[
1627+
os.path.join(DATAFILES_PATH, 'logistic_output_1.csv'),
1628+
os.path.join(DATAFILES_PATH, 'logistic_output_2.csv'),
1629+
os.path.join(DATAFILES_PATH, 'logistic_output_3.csv'),
1630+
os.path.join(DATAFILES_PATH, 'logistic_output_4.csv'),
1631+
]
16371632
)
1638-
runset = RunSet(args=cmdstan_args, chains=4)
1639-
runset._csv_files = [
1640-
os.path.join(DATAFILES_PATH, 'logistic_output_1.csv'),
1641-
os.path.join(DATAFILES_PATH, 'logistic_output_2.csv'),
1642-
os.path.join(DATAFILES_PATH, 'logistic_output_3.csv'),
1643-
os.path.join(DATAFILES_PATH, 'logistic_output_4.csv'),
1644-
]
1645-
retcodes = runset._retcodes
1646-
for i in range(len(retcodes)):
1647-
runset._set_retcode(i, 0)
1648-
fit = CmdStanMCMC(runset)
16491633

16501634
sum_default = fit.summary()
1635+
16511636
beta1_default = format(sum_default.iloc[1, 0], '.18g')
16521637
assert beta1_default.startswith('1.3')
16531638

0 commit comments

Comments
 (0)