File tree Expand file tree Collapse file tree 4 files changed +33
-33
lines changed
spinoffs/inference_gym/inference_gym Expand file tree Collapse file tree 4 files changed +33
-33
lines changed Original file line number Diff line number Diff line change @@ -46,7 +46,7 @@ def testEightSchoolsHMC(self):
46
46
model ,
47
47
num_chains = 4 ,
48
48
num_steps = 4000 ,
49
- num_leapfrog_steps = 3 ,
49
+ num_leapfrog_steps = 10 ,
50
50
step_size = 0.4 ,
51
51
)
52
52
Original file line number Diff line number Diff line change 27
27
import numpy as np
28
28
29
29
IDENTITY_AVG_EFFECT_MEAN = np .array ([
30
- 3.2137590385402826 ,
30
+ 5.75975286026516 ,
31
31
]).reshape (())
32
32
33
33
IDENTITY_AVG_EFFECT_MEAN_STANDARD_ERROR = np .array ([
34
- 0.011278555132469106 ,
34
+ 0.01786934334960865 ,
35
35
]).reshape (())
36
36
37
37
IDENTITY_AVG_EFFECT_STANDARD_DEVIATION = np .array ([
38
- 3.972567360644285 ,
38
+ 5.46540575237222 ,
39
39
]).reshape (())
40
40
41
41
IDENTITY_LOG_STDDEV_MEAN = np .array ([
42
- 2.4598093936288152 ,
42
+ 2.45326548536025 ,
43
43
]).reshape (())
44
44
45
45
IDENTITY_LOG_STDDEV_MEAN_STANDARD_ERROR = np .array ([
46
- 0.0018535105349442114 ,
46
+ 0.0019127444310499044 ,
47
47
]).reshape (())
48
48
49
49
IDENTITY_LOG_STDDEV_STANDARD_DEVIATION = np .array ([
50
- 0.5110436981514841 ,
50
+ 0.5146604914005865 ,
51
51
]).reshape (())
52
52
53
53
IDENTITY_SCHOOL_EFFECTS_MEAN = np .array ([
54
- 13.355582058626903 ,
55
- 6.102211570244123 ,
56
- 1.024737484826771 ,
57
- 5.369929612269985 ,
58
- 0.7926315948389041 ,
59
- 2.244649709577478 ,
60
- 11.763530998898284 ,
61
- 6.261180511721017 ,
54
+ 14.76492662128668 ,
55
+ 7.1559568069514174 ,
56
+ 2.5889680578429823 ,
57
+ 6.556760136283709 ,
58
+ 1.8189620982794488 ,
59
+ 3.3973176011560953 ,
60
+ 12.791827392809507 ,
61
+ 7.94913889128149 ,
62
62
]).reshape ((8 ,))
63
63
64
64
IDENTITY_SCHOOL_EFFECTS_MEAN_STANDARD_ERROR = np .array ([
65
- 0.02544285870818445 ,
66
- 0.014957463084240069 ,
67
- 0.021316316118379103 ,
68
- 0.01575696229480018 ,
69
- 0.01373758544135346 ,
70
- 0.01550588664268995 ,
71
- 0.017107240447608994 ,
72
- 0.02209825465307125 ,
65
+ 0.024677770813437375 ,
66
+ 0.014741663125129497 ,
67
+ 0.02289079528770855 ,
68
+ 0.015999446293377167 ,
69
+ 0.014481515793592113 ,
70
+ 0.016516334290635457 ,
71
+ 0.01649735808006444 ,
72
+ 0.02296023745612875 ,
73
73
]).reshape ((8 ,))
74
74
75
75
IDENTITY_SCHOOL_EFFECTS_STANDARD_DEVIATION = np .array ([
76
- 10.807906054798417 ,
77
- 7.766699608891697 ,
78
- 10.244640099796499 ,
79
- 8.219198474822052 ,
80
- 7.303418026739246 ,
81
- 8.258823240336538 ,
82
- 8.19256497826584 ,
83
- 10.673579208829045 ,
76
+ 10.796464460046044 ,
77
+ 7.813843501513539 ,
78
+ 10.473283439128995 ,
79
+ 8.314419652710829 ,
80
+ 7.45775465347688 ,
81
+ 8.454005086842871 ,
82
+ 8.16339794409753 ,
83
+ 10.913796824041532 ,
84
84
]).reshape ((8 ,))
Original file line number Diff line number Diff line change 23
23
venv=$(mktemp -d)
24
24
virtualenv -p python3.6 $venv
25
25
source $venv/bin/activate
26
- pip install cmdstanpy== 0.9 pandas numpy tf-nightly tfp-nightly tfds-nightly
26
+ pip install ' cmdstanpy>= 0.9.0' pandas numpy tf-nightly tfp-nightly tfds-nightly
27
27
install_cmdstan
28
28
29
29
python -m inference_gym.tools.get_ground_truth \
Original file line number Diff line number Diff line change @@ -49,7 +49,7 @@ def eight_schools():
49
49
school_effects <- std_school_effects * exp(log_stddev) + avg_effect;
50
50
}
51
51
model {
52
- avg_effect ~ normal(0, 5 );
52
+ avg_effect ~ normal(0, 10 );
53
53
log_stddev ~ normal(5, 1);
54
54
std_school_effects ~ normal(0, 1);
55
55
treatment_effects ~ normal(school_effects, treatment_stddevs);
You can’t perform that action at this time.
0 commit comments