Skip to content

Commit 3d07421

Browse files
SiegeLordExtensorflower-gardener
authored andcommitted
Inference Gym: Update to cmdstanpy 0.9
PiperOrigin-RevId: 379575291
1 parent 7d28f9b commit 3d07421

12 files changed

+24
-24
lines changed

spinoffs/inference_gym/inference_gym/tools/get_ground_truth.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
venv=$(mktemp -d)
2424
virtualenv -p python3.6 $venv
2525
source $venv/bin/activate
26-
pip install cmdstanpy==0.8 pandas numpy tf-nightly tfp-nightly tfds-nightly
26+
pip install cmdstanpy==0.9 pandas numpy tf-nightly tfp-nightly tfds-nightly
2727
install_cmdstan
2828
2929
python -m inference_gym.tools.get_ground_truth \
@@ -91,7 +91,7 @@ def main(argv):
9191
stan_model = getattr(targets, FLAGS.target)()
9292

9393
with stan_model.sample_fn(
94-
sampling_iters=FLAGS.stan_samples,
94+
iter_sampling=FLAGS.stan_samples,
9595
chains=FLAGS.stan_chains,
9696
show_progress=True) as mcmc_output:
9797
summary = mcmc_output.summary()
@@ -114,7 +114,7 @@ def main(argv):
114114
# very slow and wastes memory. Consider reading the CSV files ourselves.
115115

116116
# sample shape is [num_samples, num_chains, num_columns]
117-
chain = mcmc_output.sample[:, chain_id, :]
117+
chain = mcmc_output.draws()[:, chain_id, :]
118118
dataframe = pd.DataFrame(chain, columns=mcmc_output.column_names)
119119

120120
transformed_samples = fn(dataframe)

spinoffs/inference_gym/inference_gym/tools/stan/brownian_motion.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def brownian_motion(observed_locs, innovation_noise_scale,
7676

7777
def _ext_identity(samples):
7878
"""Extracts the values of all latent variables."""
79-
locs = util.get_columns(samples, r'^loc\.\d+$')
79+
locs = util.get_columns(samples, r'^loc\[\d+\]$')
8080
return locs
8181

8282
extract_fns = {'identity': _ext_identity}
@@ -137,7 +137,7 @@ def _ext_identity(samples):
137137
samples, r'^innovation_noise_scale$')[:, 0],
138138
'observation_noise_scale': util.get_columns(
139139
samples, r'^observation_noise_scale$')[:, 0],
140-
'locs': util.get_columns(samples, r'^loc\.\d+$')}
140+
'locs': util.get_columns(samples, r'^loc\[\d+\]$')}
141141

142142
extract_fns = {'identity': _ext_identity}
143143

spinoffs/inference_gym/inference_gym/tools/stan/eight_schools.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def _ext_identity(samples):
7373
res['log_stddev'] = util.get_columns(
7474
samples, r'^log_stddev$')[:, 0]
7575
res['school_effects'] = util.get_columns(
76-
samples, r'^school_effects\.\d+$')
76+
samples, r'^school_effects\[\d+\]$')
7777
return res
7878

7979
extract_fns = {'identity': _ext_identity}

spinoffs/inference_gym/inference_gym/tools/stan/item_response_theory.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -150,19 +150,19 @@ def _ext_identity(samples):
150150
)[:, 0]
151151
res['student_ability'] = util.get_columns(
152152
samples,
153-
r'^student_ability\.\d+$',
153+
r'^student_ability\[\d+\]$',
154154
)
155155
res['question_difficulty'] = util.get_columns(
156156
samples,
157-
r'^question_difficulty\.\d+$',
157+
r'^question_difficulty\[\d+\]$',
158158
)
159159
return res
160160

161161
def _ext_test_nll(samples):
162162
return util.get_columns(samples, r'^test_nll$')[:, 0]
163163

164164
def _ext_per_example_test_nll(samples):
165-
return util.get_columns(samples, r'^per_example_test_nll\.\d+$')
165+
return util.get_columns(samples, r'^per_example_test_nll\[\d+\]$')
166166

167167
extract_fns = {'identity': _ext_identity}
168168
if have_test:

spinoffs/inference_gym/inference_gym/tools/stan/log_gaussian_cox_process.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ def _ext_identity(samples):
113113
)[:, 0]
114114
res['log_intensity'] = util.get_columns(
115115
samples,
116-
r'^log_intensity\.\d+$',
116+
r'^log_intensity\[\d+\]$',
117117
)
118118
return res
119119

spinoffs/inference_gym/inference_gym/tools/stan/logistic_regression.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -112,13 +112,13 @@ def logistic_regression(
112112
model = util.cached_stan_model(code)
113113

114114
def _ext_identity(samples):
115-
return util.get_columns(samples, r'^weights\.\d+$')
115+
return util.get_columns(samples, r'^weights\[\d+\]$')
116116

117117
def _ext_test_nll(samples):
118118
return util.get_columns(samples, r'^test_nll$')[:, 0]
119119

120120
def _ext_per_example_test_nll(samples):
121-
return util.get_columns(samples, r'^per_example_test_nll\.\d+$')
121+
return util.get_columns(samples, r'^per_example_test_nll\[\d+\]$')
122122

123123
extract_fns = {'identity': _ext_identity}
124124
if have_test:

spinoffs/inference_gym/inference_gym/tools/stan/lorenz_system.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ def partially_observed_lorenz_system(observed_values, innovation_scale,
107107

108108
def _ext_identity(samples):
109109
"""Extracts the values of all latent variables."""
110-
latents = util.get_columns(samples, r'^latents\.\d+\.\d+$')
110+
latents = util.get_columns(samples, r'^latents\[\d+,\d+\]$')
111111
# Last two dimensions are swapped in Stan output.
112112
return latents.reshape((-1, 3, 30)).swapaxes(1, 2)
113113

@@ -173,7 +173,7 @@ def partially_observed_lorenz_system_unknown_scales(
173173

174174
def _ext_identity(samples):
175175
"""Extracts the values of all latent variables."""
176-
latents = util.get_columns(samples, r'^latents\.\d+\.\d+$')
176+
latents = util.get_columns(samples, r'^latents\[\d+,\d+\]$')
177177
return {
178178
'innovation_scale': util.get_columns(samples,
179179
r'^innovation_scale$')[:, 0],

spinoffs/inference_gym/inference_gym/tools/stan/probit_regression.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -113,13 +113,13 @@ def probit_regression(
113113
model = util.cached_stan_model(code)
114114

115115
def _ext_identity(samples):
116-
return util.get_columns(samples, r'^weights\.\d+$')
116+
return util.get_columns(samples, r'^weights\[\d+\]$')
117117

118118
def _ext_test_nll(samples):
119119
return util.get_columns(samples, r'^test_nll$')[:, 0]
120120

121121
def _ext_per_example_test_nll(samples):
122-
return util.get_columns(samples, r'^per_example_test_nll\.\d+$')
122+
return util.get_columns(samples, r'^per_example_test_nll\[\d+\]$')
123123

124124
extract_fns = {'identity': _ext_identity}
125125
if have_test:

spinoffs/inference_gym/inference_gym/tools/stan/radon_contextual_effects.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -109,8 +109,8 @@ def _ext_identity(samples):
109109
samples, r'^county_effect_mean$')[:, 0]
110110
res['county_effect_scale'] = util.get_columns(
111111
samples, r'^county_effect_scale$')[:, 0]
112-
res['county_effect'] = util.get_columns(samples, r'^county_effect\.\d+$')
113-
res['weight'] = util.get_columns(samples, r'^weight\.\d+$')
112+
res['county_effect'] = util.get_columns(samples, r'^county_effect\[\d+\]$')
113+
res['weight'] = util.get_columns(samples, r'^weight\[\d+\]$')
114114
res['log_radon_scale'] = (
115115
util.get_columns(samples, r'^log_radon_scale$')[:, 0])
116116
return res

spinoffs/inference_gym/inference_gym/tools/stan/radon_contextual_effects_halfnormal.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -113,8 +113,8 @@ def _ext_identity(samples):
113113
samples, r'^county_effect_mean$')[:, 0]
114114
res['county_effect_scale'] = util.get_columns(
115115
samples, r'^county_effect_scale$')[:, 0]
116-
res['county_effect'] = util.get_columns(samples, r'^county_effect\.\d+$')
117-
res['weight'] = util.get_columns(samples, r'^weight\.\d+$')
116+
res['county_effect'] = util.get_columns(samples, r'^county_effect\[\d+\]$')
117+
res['weight'] = util.get_columns(samples, r'^weight\[\d+\]$')
118118
res['log_radon_scale'] = (
119119
util.get_columns(samples, r'^log_radon_scale$')[:, 0])
120120
return res

0 commit comments

Comments
 (0)