33import numpy as np
44
55import pystan
6- from pystan ._compat import PY2
76
87
98class TestUserInits (unittest .TestCase ):
109
11- model_code = """
12- data {
13- real x;
14- }
15- parameters {
16- real mu;
17- }
18- model {
19- x ~ normal(mu,1);
20- }
21- """
22-
23- data = dict (x = 2 )
10+ @classmethod
11+ def setUpClass (cls ):
12+ model_code = """
13+ data {
14+ real x;
15+ }
16+ parameters {
17+ real mu;
18+ }
19+ model {
20+ x ~ normal(mu,1);
21+ }
22+ """
23+ cls .data = dict (x = 2 )
24+ cls .model = pystan .StanModel (model_code = model_code )
2425
2526 def test_user_init (self ):
26- model_code = self .model_code
27- data = self .data
28- fit1 = pystan .stan (model_code = model_code , iter = 10 , chains = 1 , seed = 2 ,
29- data = data , init = [dict (mu = 4 )], warmup = 0 )
27+ model , data = self .model , self .data
28+ fit1 = model .sampling (iter = 10 , chains = 1 , seed = 2 , data = data , init = [dict (mu = 4 )], warmup = 0 )
3029 self .assertEqual (fit1 .get_inits ()[0 ]['mu' ], 4 )
31- fit2 = pystan .stan (model_code = model_code , iter = 10 , chains = 1 , seed = 2 ,
32- data = data , init = [dict (mu = 400 )], warmup = 0 )
30+ fit2 = model .sampling (iter = 10 , chains = 1 , seed = 2 , data = data , init = [dict (mu = 400 )], warmup = 0 )
3331 self .assertEqual (fit2 .get_inits ()[0 ]['mu' ], 400 )
3432 self .assertFalse (all (fit1 .extract ()['mu' ] == fit2 .extract ()['mu' ]))
3533
3634 def test_user_initfun (self ):
37- model_code = self .model_code
38- data = self .data
35+ model , data = self .model , self .data
3936
4037 def make_inits (chain_id ):
4138 return dict (mu = chain_id )
4239
43- fit1 = pystan .stan (model_code = model_code , iter = 10 , chains = 4 , seed = 2 ,
44- data = data , init = make_inits , warmup = 0 )
40+ fit1 = model .sampling (iter = 10 , chains = 4 , seed = 2 , data = data , init = make_inits , warmup = 0 )
4541 for i , inits in enumerate (fit1 .get_inits ()):
4642 self .assertEqual (inits ['mu' ], i )
4743
4844 def test_user_initfun_chainid (self ):
49- model_code = self .model_code
50- data = self .data
45+ model , data = self .model , self .data
5146
5247 def make_inits (chain_id ):
5348 return dict (mu = chain_id )
5449
5550 chain_id = [9 , 10 , 11 , 12 ]
56- fit1 = pystan . stan ( model_code = model_code , iter = 10 , chains = 4 , seed = 2 ,
57- data = data , init = make_inits , warmup = 0 , chain_id = chain_id )
51+ fit1 = model . sampling ( iter = 10 , chains = 4 , seed = 2 , data = data ,
52+ init = make_inits , warmup = 0 , chain_id = chain_id )
5853 for i , inits in zip (chain_id , fit1 .get_inits ()):
5954 self .assertEqual (inits ['mu' ], i )
6055
@@ -79,44 +74,44 @@ def test_user_init_unspecified(self):
7974
8075class TestUserInitsMatrix (unittest .TestCase ):
8176
82- model_code = """
83- data {
84- int<lower=2> K;
85- int<lower=1> D;
86- }
87- parameters {
88- matrix[K,D] beta;
89- }
90- model {
91- for (k in 1:K)
92- for (d in 1:D)
93- beta[k,d] ~ normal(if_else(d==2,100, 0),1);
94- }"""
95- model = pystan .StanModel (model_code = model_code )
96- data = dict (K = 3 , D = 4 )
77+ @classmethod
78+ def setUpClass (cls ):
79+ model_code = """
80+ data {
81+ int<lower=2> K;
82+ int<lower=1> D;
83+ }
84+ parameters {
85+ matrix[K,D] beta;
86+ }
87+ model {
88+ for (k in 1:K)
89+ for (d in 1:D)
90+ beta[k,d] ~ normal(if_else(d==2,100, 0),1);
91+ }
92+ """
93+ cls .model = pystan .StanModel (model_code = model_code )
94+ cls .data = dict (K = 3 , D = 4 )
9795
9896 def test_user_init (self ):
99- model_code = self .model_code
100- data = self .data
97+ model , data = self .model , self .data
10198 beta = np .ones ((data ['K' ], data ['D' ]))
102- fit1 = pystan . stan ( model_code = model_code , iter = 10 , chains = 1 , seed = 2 ,
103- data = data , init = [dict (beta = beta )], warmup = 0 )
99+ fit1 = model . sampling ( iter = 10 , chains = 1 , seed = 2 ,
100+ data = data , init = [dict (beta = beta )], warmup = 0 )
104101 np .testing .assert_equal (fit1 .get_inits ()[0 ]['beta' ], beta )
105102 beta = 5 * np .ones ((data ['K' ], data ['D' ]))
106- fit2 = pystan . stan ( model_code = model_code , iter = 10 , chains = 1 , seed = 2 ,
107- data = data , init = [dict (beta = beta )], warmup = 0 )
103+ fit2 = model . sampling ( iter = 10 , chains = 1 , seed = 2 ,
104+ data = data , init = [dict (beta = beta )], warmup = 0 )
108105 np .testing .assert_equal (fit2 .get_inits ()[0 ]['beta' ], beta )
109106
110107 def test_user_initfun (self ):
111- model_code = self .model_code
112- data = self .data
108+ model , data = self .model , self .data
113109
114110 beta = np .ones ((data ['K' ], data ['D' ]))
115111
116112 def make_inits (chain_id ):
117113 return dict (beta = beta * chain_id )
118114
119- fit1 = pystan .stan (model_code = model_code , iter = 10 , chains = 4 , seed = 2 ,
120- data = data , init = make_inits , warmup = 0 )
115+ fit1 = model .sampling (iter = 10 , chains = 4 , seed = 2 , data = data , init = make_inits , warmup = 0 )
121116 for i , inits in enumerate (fit1 .get_inits ()):
122117 np .testing .assert_equal (beta * i , inits ['beta' ])
0 commit comments