Skip to content

Commit ebd02ee

Browse files
authored
Merge pull request #567 from stan-dev/fix/566-overzealous-string-search
Emit warnings rather than errors when sampling has exceptions
2 parents 62ab5db + 34be12f commit ebd02ee

File tree

4 files changed

+241
-2
lines changed

4 files changed

+241
-2
lines changed

cmdstanpy/model.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1176,14 +1176,21 @@ def sample(
11761176
)
11771177

11781178
errors = runset.get_err_msgs()
1179-
if errors or not runset._check_retcodes():
1179+
if not runset._check_retcodes():
11801180
msg = (
11811181
f'Error during sampling:\n{errors}\n'
11821182
+ f'Command and output files:\n{repr(runset)}\n'
11831183
+ 'Consider re-running with show_console=True if the above'
11841184
+ ' output is unclear!'
11851185
)
11861186
raise RuntimeError(msg)
1187+
if errors:
1188+
msg = (
1189+
f'Non-fatal error during sampling:\n{errors}\n'
1190+
+ 'Consider re-running with show_console=True if the above'
1191+
+ ' output is unclear!'
1192+
)
1193+
get_logger().warning(msg)
11871194

11881195
mcmc = CmdStanMCMC(runset)
11891196
return mcmc
Lines changed: 207 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,207 @@
1+
{
2+
"x": [
3+
-0.7138385321976433,
4+
0.4935568322412949,
5+
-1.3814289386914758,
6+
-1.3083830490780055,
7+
0.8594237856855641,
8+
-0.5828201693491124,
9+
0.41472726513726604,
10+
-1.3096088915626156,
11+
1.7488214394544888,
12+
-0.05582078688738188,
13+
-1.3353224485998705,
14+
-1.7564315946342453,
15+
0.48308013370199504,
16+
0.5172253170304838,
17+
-0.6719992453398311,
18+
1.2602689891739454,
19+
1.412043997788473,
20+
0.20641523053812474,
21+
0.5201078975322044,
22+
1.5723251985671267,
23+
1.3189655696099178,
24+
-0.9111971580521261,
25+
0.6393467227482712,
26+
0.4650415534005901,
27+
-0.5452984948053263,
28+
-1.5130216352610142,
29+
-0.22862960142670605,
30+
0.6160324741955371,
31+
-0.6031016074719313,
32+
0.40594275421437453,
33+
0.38508886717718016,
34+
-1.0837724674547375,
35+
2.3915682063493477,
36+
-0.07839179990381143,
37+
0.35656973307438666,
38+
-0.2766314484210899,
39+
0.7302456204684461,
40+
0.08565219038050095,
41+
1.445060390202212,
42+
-0.21089429382395108,
43+
-0.004354007671109179,
44+
1.4964952578881543,
45+
-1.2377789502217904,
46+
1.1630189497108676,
47+
0.1868059860973441,
48+
1.0013767320556666,
49+
0.17712191472558358,
50+
2.2449838737047405,
51+
0.5723783157495013,
52+
-0.5866656693806949,
53+
0.5553141820829578,
54+
-1.0832025825922065,
55+
0.021397186668024173,
56+
1.0990490407588707,
57+
-0.47857707758054946,
58+
-1.0028128626128172,
59+
-0.6541883777415636,
60+
-2.157420733621822,
61+
1.6494513567711457,
62+
0.5326574665837476,
63+
-1.2488626570176349,
64+
1.8640634506194615,
65+
-2.0015397292660646,
66+
-1.165907209784441,
67+
1.243275871223251,
68+
1.2777634989564302,
69+
-0.2540735659288359,
70+
-0.7115336394378458,
71+
2.5123445378647937,
72+
-2.3620686383068192,
73+
0.704711924978527,
74+
0.20486730052438465,
75+
-1.1481828167175874,
76+
-0.5024671222986463,
77+
1.1909459096186445,
78+
-1.807532371463237,
79+
1.5452345378964871,
80+
0.9386793566797123,
81+
0.06108017364184471,
82+
0.44078590430668,
83+
1.5390990507903806,
84+
-0.14982494315528966,
85+
0.008041862641953225,
86+
-1.059362371600711,
87+
-0.6458557072831949,
88+
0.12174558659406327,
89+
1.6343750745549894,
90+
-0.006893657965902042,
91+
-1.980063711207591,
92+
-0.17769336879869058,
93+
-0.04671777679973899,
94+
1.4532456591905836,
95+
-1.4029866194860872,
96+
-1.5024900502317675,
97+
1.4402148888890611,
98+
-1.0727333689589642,
99+
-0.6228552056768534,
100+
0.3715735241216498,
101+
-0.0371718458010515,
102+
0.5704601721659295
103+
],
104+
"y": [
105+
0.6272150230712336,
106+
0.9805824577664779,
107+
0.3422690552470167,
108+
-0.05146489851488395,
109+
0.9046913839133488,
110+
0.6925433494574392,
111+
0.8699413738367437,
112+
0.640327008582535,
113+
0.39187605728779296,
114+
0.21904955777565455,
115+
0.3614380235604899,
116+
-0.25588465129027244,
117+
0.291608586376539,
118+
-0.12437104481728656,
119+
0.39699986298775203,
120+
1.091844069641923,
121+
1.0625362033977446,
122+
1.1487142324952053,
123+
0.6499370411907814,
124+
0.7847133816156049,
125+
0.949209681430482,
126+
0.15135972083512655,
127+
0.6742472860252764,
128+
0.5405321316929282,
129+
0.7549499219120746,
130+
0.16511085413710583,
131+
-0.02498036164663142,
132+
0.6153043951506284,
133+
0.6000720457920419,
134+
1.3480303063520085,
135+
0.3158006969256599,
136+
0.3637296025891455,
137+
1.1594053723296478,
138+
0.5970385913046845,
139+
0.5946450550774676,
140+
0.8203529738679207,
141+
1.3062075618596822,
142+
0.9328024321713398,
143+
0.7743457946449005,
144+
0.3251962426760894,
145+
0.6054087282349804,
146+
0.7934998350966271,
147+
0.8831208411206976,
148+
0.7894760893728714,
149+
0.8976865289354727,
150+
0.44217549478023965,
151+
0.38284855187218103,
152+
1.095780907185152,
153+
0.6549679342197391,
154+
0.23387184425362467,
155+
0.6049693148613544,
156+
0.4028784177001803,
157+
0.6250343428214425,
158+
0.567319927379435,
159+
0.5189952397494277,
160+
0.4671091755131367,
161+
0.748868640987746,
162+
0.29251932173366013,
163+
0.8245978901245103,
164+
0.673071240496385,
165+
0.775559253840947,
166+
0.9595878211742477,
167+
0.05153108078358959,
168+
0.6958343543502022,
169+
1.2920074084872448,
170+
0.936543215882786,
171+
-0.17826337795252567,
172+
0.7061739968427398,
173+
1.4127238083350044,
174+
0.4014500370341516,
175+
0.792652667869967,
176+
0.3352243339909373,
177+
0.4509809274388745,
178+
0.6568885296125428,
179+
0.6184394273698025,
180+
0.18864416763201924,
181+
1.0875801300709682,
182+
0.9177333084623992,
183+
0.5919873026862967,
184+
0.9297749171130603,
185+
0.814520099852734,
186+
0.8465674480525497,
187+
0.028620564426356743,
188+
0.2639034807770741,
189+
0.9495620232603935,
190+
0.4000550997372736,
191+
0.7506260736682799,
192+
0.8257633294514362,
193+
0.1294350823440999,
194+
0.9276942903866539,
195+
0.8068646299718158,
196+
0.7626459434600884,
197+
0.4057250786938288,
198+
0.6216165062482524,
199+
0.9989570513537478,
200+
0.12153467332478934,
201+
0.9175681261236525,
202+
0.5831590468381432,
203+
0.6800265152948471,
204+
1.0027971008115968
205+
],
206+
"N": 100
207+
}

test/data/linear_regression.stan

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
data {
2+
int<lower=0> N;
3+
vector[N] x;
4+
vector[N] y;
5+
}
6+
parameters {
7+
real alpha;
8+
real beta;
9+
real<lower=0> sigma;
10+
}
11+
model {
12+
y ~ normal(alpha + beta * x, sigma);
13+
}

test/test_sample.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from time import time
1515

1616
import numpy as np
17-
from testfixtures import LogCapture
17+
from testfixtures import LogCapture, StringComparison
1818

1919
try:
2020
import ujson as json
@@ -1238,6 +1238,18 @@ def test_validate_bad_run(self):
12381238
):
12391239
CmdStanMCMC(runset)
12401240

1241+
def test_sample_sporadic_exception(self):
1242+
stan = os.path.join(DATAFILES_PATH, 'linear_regression.stan')
1243+
jdata = os.path.join(DATAFILES_PATH, 'linear_regression.data.json')
1244+
linear_model = CmdStanModel(stan_file=stan)
1245+
# will produce a failure due to calling normal_lpdf with 0 for scale
1246+
# but then continue sampling normally
1247+
with LogCapture() as log:
1248+
linear_model.sample(data=jdata, inits=0)
1249+
log.check_present(
1250+
('cmdstanpy', 'WARNING', StringComparison(r"Non-fatal error.*"))
1251+
)
1252+
12411253
def test_save_warmup(self):
12421254
stan = os.path.join(DATAFILES_PATH, 'bernoulli.stan')
12431255
jdata = os.path.join(DATAFILES_PATH, 'bernoulli.data.json')

0 commit comments

Comments
 (0)