Skip to content

Commit 71444c3

Browse files
committed
Update test
1 parent c3905f2 commit 71444c3

File tree

1 file changed

+26
-24
lines changed

1 file changed

+26
-24
lines changed

test/test_log_prob.py

Lines changed: 26 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,10 @@
22

33
import logging
44
import os
5-
from test import CustomTestCase
5+
import re
6+
from test import check_present
67

7-
from testfixtures import LogCapture, StringComparison
8+
import pytest
89

910
from cmdstanpy.model import CmdStanModel
1011
from cmdstanpy.utils import EXTENSION
@@ -18,25 +19,26 @@
1819
BERN_BASENAME = 'bernoulli'
1920

2021

21-
class CmdStanLogProb(CustomTestCase):
22-
def test_lp_good(self):
23-
model = CmdStanModel(stan_file=BERN_STAN)
24-
x = model.log_prob({"theta": 0.1}, data=BERN_DATA)
25-
assert "lp_" in x.columns
26-
27-
def test_lp_bad(self):
28-
model = CmdStanModel(stan_file=BERN_STAN)
29-
30-
with LogCapture(level=logging.ERROR) as log:
31-
with self.assertRaisesRegex(
32-
RuntimeError, "failed with return code"
33-
):
34-
model.log_prob({"not_here": 0.1}, data=BERN_DATA)
35-
36-
log.check_present(
37-
(
38-
'cmdstanpy',
39-
'ERROR',
40-
StringComparison(r"(?s).*parameter theta not found.*"),
41-
)
42-
)
22+
def test_lp_good() -> None:
23+
model = CmdStanModel(stan_file=BERN_STAN)
24+
x = model.log_prob({"theta": 0.1}, data=BERN_DATA)
25+
assert "lp_" in x.columns
26+
27+
28+
def test_lp_bad(
29+
caplog: pytest.LogCaptureFixture,
30+
) -> None:
31+
model = CmdStanModel(stan_file=BERN_STAN)
32+
33+
with caplog.at_level(logging.ERROR):
34+
with pytest.raises(RuntimeError, match="failed with return code"):
35+
model.log_prob({"not_here": 0.1}, data=BERN_DATA)
36+
37+
check_present(
38+
caplog,
39+
(
40+
'cmdstanpy',
41+
'ERROR',
42+
re.compile(r"(?s).*parameter theta not found.*"),
43+
),
44+
)

0 commit comments

Comments
 (0)