Skip to content

Commit 4de9d3c

Browse files
SiegeLordExtensorflower-gardener
authored andcommitted
Inference Gym: Fix ground truth for eight schools.
The Stan model had the incorrect prior compared to the TFP code. The old tests passed presumably because their ESS was so bad that the z-test had no power whatsoever. I boosted the number of leapfrog steps to increase it substantially. PiperOrigin-RevId: 384281260
1 parent 3da47c7 commit 4de9d3c

File tree

4 files changed

+33
-33
lines changed

4 files changed

+33
-33
lines changed

spinoffs/inference_gym/inference_gym/targets/eight_schools_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def testEightSchoolsHMC(self):
4646
model,
4747
num_chains=4,
4848
num_steps=4000,
49-
num_leapfrog_steps=3,
49+
num_leapfrog_steps=10,
5050
step_size=0.4,
5151
)
5252

spinoffs/inference_gym/inference_gym/targets/ground_truth/eight_schools.py

Lines changed: 30 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -27,58 +27,58 @@
2727
import numpy as np
2828

2929
IDENTITY_AVG_EFFECT_MEAN = np.array([
30-
3.2137590385402826,
30+
5.75975286026516,
3131
]).reshape(())
3232

3333
IDENTITY_AVG_EFFECT_MEAN_STANDARD_ERROR = np.array([
34-
0.011278555132469106,
34+
0.01786934334960865,
3535
]).reshape(())
3636

3737
IDENTITY_AVG_EFFECT_STANDARD_DEVIATION = np.array([
38-
3.972567360644285,
38+
5.46540575237222,
3939
]).reshape(())
4040

4141
IDENTITY_LOG_STDDEV_MEAN = np.array([
42-
2.4598093936288152,
42+
2.45326548536025,
4343
]).reshape(())
4444

4545
IDENTITY_LOG_STDDEV_MEAN_STANDARD_ERROR = np.array([
46-
0.0018535105349442114,
46+
0.0019127444310499044,
4747
]).reshape(())
4848

4949
IDENTITY_LOG_STDDEV_STANDARD_DEVIATION = np.array([
50-
0.5110436981514841,
50+
0.5146604914005865,
5151
]).reshape(())
5252

5353
IDENTITY_SCHOOL_EFFECTS_MEAN = np.array([
54-
13.355582058626903,
55-
6.102211570244123,
56-
1.024737484826771,
57-
5.369929612269985,
58-
0.7926315948389041,
59-
2.244649709577478,
60-
11.763530998898284,
61-
6.261180511721017,
54+
14.76492662128668,
55+
7.1559568069514174,
56+
2.5889680578429823,
57+
6.556760136283709,
58+
1.8189620982794488,
59+
3.3973176011560953,
60+
12.791827392809507,
61+
7.94913889128149,
6262
]).reshape((8,))
6363

6464
IDENTITY_SCHOOL_EFFECTS_MEAN_STANDARD_ERROR = np.array([
65-
0.02544285870818445,
66-
0.014957463084240069,
67-
0.021316316118379103,
68-
0.01575696229480018,
69-
0.01373758544135346,
70-
0.01550588664268995,
71-
0.017107240447608994,
72-
0.02209825465307125,
65+
0.024677770813437375,
66+
0.014741663125129497,
67+
0.02289079528770855,
68+
0.015999446293377167,
69+
0.014481515793592113,
70+
0.016516334290635457,
71+
0.01649735808006444,
72+
0.02296023745612875,
7373
]).reshape((8,))
7474

7575
IDENTITY_SCHOOL_EFFECTS_STANDARD_DEVIATION = np.array([
76-
10.807906054798417,
77-
7.766699608891697,
78-
10.244640099796499,
79-
8.219198474822052,
80-
7.303418026739246,
81-
8.258823240336538,
82-
8.19256497826584,
83-
10.673579208829045,
76+
10.796464460046044,
77+
7.813843501513539,
78+
10.473283439128995,
79+
8.314419652710829,
80+
7.45775465347688,
81+
8.454005086842871,
82+
8.16339794409753,
83+
10.913796824041532,
8484
]).reshape((8,))

spinoffs/inference_gym/inference_gym/tools/get_ground_truth.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
venv=$(mktemp -d)
2424
virtualenv -p python3.6 $venv
2525
source $venv/bin/activate
26-
pip install cmdstanpy==0.9 pandas numpy tf-nightly tfp-nightly tfds-nightly
26+
pip install 'cmdstanpy>=0.9.0' pandas numpy tf-nightly tfp-nightly tfds-nightly
2727
install_cmdstan
2828
2929
python -m inference_gym.tools.get_ground_truth \

spinoffs/inference_gym/inference_gym/tools/stan/eight_schools.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def eight_schools():
4949
school_effects <- std_school_effects * exp(log_stddev) + avg_effect;
5050
}
5151
model {
52-
avg_effect ~ normal(0, 5);
52+
avg_effect ~ normal(0, 10);
5353
log_stddev ~ normal(5, 1);
5454
std_school_effects ~ normal(0, 1);
5555
treatment_effects ~ normal(school_effects, treatment_stddevs);

0 commit comments

Comments
 (0)