Skip to content
This repository was archived by the owner on Mar 19, 2021. It is now read-only.

Commit b383786

Browse files
authored
Merge pull request #326 from ariddell/feature/speedup-tests
TST: streamline some tests to avoid CI time limit
2 parents 4f93670 + 24fb279 commit b383786

File tree

2 files changed

+56
-58
lines changed

2 files changed

+56
-58
lines changed

pystan/tests/test_misc_args.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,16 +7,19 @@
77

88

99
class TestArgs(unittest.TestCase):
10+
@classmethod
11+
def setUpClass(cls):
12+
model_code = 'parameters {real y;} model {y ~ normal(0,1);}'
13+
cls.model = pystan.StanModel(model_code=model_code)
1014

1115
def test_control(self):
16+
model = self.model
1217
assertRaisesRegex = self.assertRaisesRegexp if PY2 else self.assertRaisesRegex
13-
model_code = 'parameters {real y;} model {y ~ normal(0,1);}'
14-
1518
with assertRaisesRegex(ValueError, '`control` must be a dictionary'):
1619
control_invalid = 3
17-
pystan.stan(model_code=model_code, control=control_invalid)
20+
model.sampling(control=control_invalid)
1821
with assertRaisesRegex(ValueError, '`control` contains unknown'):
1922
control_invalid = dict(foo=3)
20-
pystan.stan(model_code=model_code, control=control_invalid)
23+
model.sampling(control=control_invalid)
2124
with assertRaisesRegex(ValueError, '`metric` must be one of'):
22-
pystan.stan(model_code=model_code, control={'metric': 'lorem-ipsum'})
25+
model.sampling(control={'metric': 'lorem-ipsum'})

pystan/tests/test_user_inits.py

Lines changed: 48 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -3,58 +3,53 @@
33
import numpy as np
44

55
import pystan
6-
from pystan._compat import PY2
76

87

98
class TestUserInits(unittest.TestCase):
109

11-
model_code = """
12-
data {
13-
real x;
14-
}
15-
parameters {
16-
real mu;
17-
}
18-
model {
19-
x ~ normal(mu,1);
20-
}
21-
"""
22-
23-
data = dict(x=2)
10+
@classmethod
11+
def setUpClass(cls):
12+
model_code = """
13+
data {
14+
real x;
15+
}
16+
parameters {
17+
real mu;
18+
}
19+
model {
20+
x ~ normal(mu,1);
21+
}
22+
"""
23+
cls.data = dict(x=2)
24+
cls.model = pystan.StanModel(model_code=model_code)
2425

2526
def test_user_init(self):
26-
model_code = self.model_code
27-
data = self.data
28-
fit1 = pystan.stan(model_code=model_code, iter=10, chains=1, seed=2,
29-
data=data, init=[dict(mu=4)], warmup=0)
27+
model, data = self.model, self.data
28+
fit1 = model.sampling(iter=10, chains=1, seed=2, data=data, init=[dict(mu=4)], warmup=0)
3029
self.assertEqual(fit1.get_inits()[0]['mu'], 4)
31-
fit2 = pystan.stan(model_code=model_code, iter=10, chains=1, seed=2,
32-
data=data, init=[dict(mu=400)], warmup=0)
30+
fit2 = model.sampling(iter=10, chains=1, seed=2, data=data, init=[dict(mu=400)], warmup=0)
3331
self.assertEqual(fit2.get_inits()[0]['mu'], 400)
3432
self.assertFalse(all(fit1.extract()['mu'] == fit2.extract()['mu']))
3533

3634
def test_user_initfun(self):
37-
model_code = self.model_code
38-
data = self.data
35+
model, data = self.model, self.data
3936

4037
def make_inits(chain_id):
4138
return dict(mu=chain_id)
4239

43-
fit1 = pystan.stan(model_code=model_code, iter=10, chains=4, seed=2,
44-
data=data, init=make_inits, warmup=0)
40+
fit1 = model.sampling(iter=10, chains=4, seed=2, data=data, init=make_inits, warmup=0)
4541
for i, inits in enumerate(fit1.get_inits()):
4642
self.assertEqual(inits['mu'], i)
4743

4844
def test_user_initfun_chainid(self):
49-
model_code = self.model_code
50-
data = self.data
45+
model, data = self.model, self.data
5146

5247
def make_inits(chain_id):
5348
return dict(mu=chain_id)
5449

5550
chain_id = [9, 10, 11, 12]
56-
fit1 = pystan.stan(model_code=model_code, iter=10, chains=4, seed=2,
57-
data=data, init=make_inits, warmup=0, chain_id=chain_id)
51+
fit1 = model.sampling(iter=10, chains=4, seed=2, data=data,
52+
init=make_inits, warmup=0, chain_id=chain_id)
5853
for i, inits in zip(chain_id, fit1.get_inits()):
5954
self.assertEqual(inits['mu'], i)
6055

@@ -79,44 +74,44 @@ def test_user_init_unspecified(self):
7974

8075
class TestUserInitsMatrix(unittest.TestCase):
8176

82-
model_code = """
83-
data {
84-
int<lower=2> K;
85-
int<lower=1> D;
86-
}
87-
parameters {
88-
matrix[K,D] beta;
89-
}
90-
model {
91-
for (k in 1:K)
92-
for (d in 1:D)
93-
beta[k,d] ~ normal(if_else(d==2,100, 0),1);
94-
}"""
95-
model = pystan.StanModel(model_code=model_code)
96-
data = dict(K=3, D=4)
77+
@classmethod
78+
def setUpClass(cls):
79+
model_code = """
80+
data {
81+
int<lower=2> K;
82+
int<lower=1> D;
83+
}
84+
parameters {
85+
matrix[K,D] beta;
86+
}
87+
model {
88+
for (k in 1:K)
89+
for (d in 1:D)
90+
beta[k,d] ~ normal(if_else(d==2,100, 0),1);
91+
}
92+
"""
93+
cls.model = pystan.StanModel(model_code=model_code)
94+
cls.data = dict(K=3, D=4)
9795

9896
def test_user_init(self):
99-
model_code = self.model_code
100-
data = self.data
97+
model, data = self.model, self.data
10198
beta = np.ones((data['K'], data['D']))
102-
fit1 = pystan.stan(model_code=model_code, iter=10, chains=1, seed=2,
103-
data=data, init=[dict(beta=beta)], warmup=0)
99+
fit1 = model.sampling(iter=10, chains=1, seed=2,
100+
data=data, init=[dict(beta=beta)], warmup=0)
104101
np.testing.assert_equal(fit1.get_inits()[0]['beta'], beta)
105102
beta = 5 * np.ones((data['K'], data['D']))
106-
fit2 = pystan.stan(model_code=model_code, iter=10, chains=1, seed=2,
107-
data=data, init=[dict(beta=beta)], warmup=0)
103+
fit2 = model.sampling(iter=10, chains=1, seed=2,
104+
data=data, init=[dict(beta=beta)], warmup=0)
108105
np.testing.assert_equal(fit2.get_inits()[0]['beta'], beta)
109106

110107
def test_user_initfun(self):
111-
model_code = self.model_code
112-
data = self.data
108+
model, data = self.model, self.data
113109

114110
beta = np.ones((data['K'], data['D']))
115111

116112
def make_inits(chain_id):
117113
return dict(beta=beta * chain_id)
118114

119-
fit1 = pystan.stan(model_code=model_code, iter=10, chains=4, seed=2,
120-
data=data, init=make_inits, warmup=0)
115+
fit1 = model.sampling(iter=10, chains=4, seed=2, data=data, init=make_inits, warmup=0)
121116
for i, inits in enumerate(fit1.get_inits()):
122117
np.testing.assert_equal(beta * i, inits['beta'])

0 commit comments

Comments
 (0)