Skip to content

Commit 28844ab

Browse files
authored
Increase test coverage & drop python versions (#84)
* some more testing * new environment * make use of the toml defaults * update python versions in CI/CD
1 parent f202df1 commit 28844ab

File tree

10 files changed

+1141
-1019
lines changed

10 files changed

+1141
-1019
lines changed

.github/workflows/docs.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,6 @@ jobs:
1515
fetch-depth: 0
1616
- uses: actions/setup-python@v4
1717
with:
18-
python-version: 3.9
18+
python-version: 3.10
1919
- run: pip install mkdocs mkdocs-material mkdocs-markdownextradata-plugin mkdocs-git-revision-date-localized-plugin "mkdocstrings[python]"
2020
- run: mkdocs gh-deploy --force

.github/workflows/tests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ jobs:
99
strategy:
1010
matrix:
1111
os: [ubuntu-latest, windows-latest]
12-
python-version: ['3.8', '3.9', '3.10', '3.11', '3.12']
12+
python-version: ['3.10', '3.11', '3.12']
1313

1414
steps:
1515
- uses: actions/checkout@v2

Makefile

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,11 @@ test-generate-baseline:
44
poetry run pytest --mpl-generate-path=tests/example-plots tests/test_example_plots.py
55

66
test:
7-
poetry run pytest \
8-
--mpl --mpl-baseline-path=tests/example-plots \
9-
--cov=conjugate \
10-
--cov-report=xml --cov-report=term-missing \
11-
tests
7+
poetry run pytest tests
128

139
cov:
14-
poetry run pytest \
15-
--mpl --mpl-baseline-path=tests/example-plots \
16-
--cov=conjugate \
17-
--cov-report=html --cov-report=term-missing \
18-
tests
10+
poetry run pytest tests
11+
coverage html
1912
open htmlcov/index.html
2013

2114
format:

conjugate/_compound_gamma.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ def compound_gamma_pdf(x, a, b, q):
99

1010

1111
class compound_gamma:
12-
""" "Implementation to work like scipy distribution classes.
12+
"""Implementation to work like scipy distribution classes.
1313
1414
Reference:
1515
https://en.wikipedia.org/wiki/Beta_prime_distribution#Generalization

conjugate/distributions.py

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
11
"""These are the supported distributions based on the conjugate models.
22
3-
Many have the `dist` attribute which is a <a href=https://docs.scipy.org/doc/scipy/reference/stats.html>scipy.stats distribution</a> object. From there,
3+
Many have the `dist` attribute which is a <a href=https://docs.scipy.org/doc/scipy/reference/stats.html>scipy.stats distribution</a> object. From there,
44
you can use the methods from scipy.stats to get the pdf, cdf, etc.
55
66
Distributions can be plotted using the `plot_pmf` or `plot_pdf` methods of the distribution.
77
8-
```python
9-
from conjugate.distribution import Beta
8+
```python
9+
from conjugate.distribution import Beta
1010
1111
beta = Beta(1, 1)
12-
scipy_dist = beta.dist
12+
scipy_dist = beta.dist
1313
1414
print(scipy_dist.mean())
1515
# 0.5
@@ -21,9 +21,9 @@
2121
beta.plot_pmf(label="beta distribution")
2222
```
2323
24-
Distributions like Poisson can be added with other Poissons or multiplied by numerical values in order to scale rate. For instance,
24+
Distributions like Poisson can be added with other Poissons or multiplied by numerical values in order to scale rate. For instance,
2525
26-
```python
26+
```python
2727
daily_rate = 0.25
2828
daily_pois = Poisson(lam=daily_rate)
2929
@@ -34,6 +34,7 @@
3434
Below are the currently supported distributions
3535
3636
"""
37+
3738
from dataclasses import dataclass
3839
from typing import Any, Tuple, Union
3940

@@ -1004,6 +1005,27 @@ def sample_variance(self, size: int, random_state=None) -> NUMERIC:
10041005

10051006
return 1 / precision
10061007

1008+
def sample_mean(
1009+
self,
1010+
size: int,
1011+
return_variance: bool = False,
1012+
random_state=None,
1013+
) -> Union[NUMERIC, Tuple[NUMERIC, NUMERIC]]:
1014+
"""Sample mean from the normal distribution.
1015+
1016+
Args:
1017+
size: number of samples
1018+
return_variance: whether to return variance as well
1019+
random_state: random state
1020+
1021+
Returns:
1022+
samples from the normal distribution
1023+
1024+
"""
1025+
return self.sample_beta(
1026+
size=size, return_variance=return_variance, random_state=random_state
1027+
)
1028+
10071029
def sample_beta(
10081030
self, size: int, return_variance: bool = False, random_state=None
10091031
) -> Union[NUMERIC, Tuple[NUMERIC, NUMERIC]]:

conjugate/models.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,14 @@
66
77
- Bernoulli / Binomial
88
- Negative Binomial
9-
- Geometric
9+
- Geometric
1010
- Hypergeometric
1111
- Poisson
1212
- Categorical / Multinomial
1313
1414
## Continuous
1515
16-
- Normal
16+
- Normal
1717
- Multivariate Normal
1818
- Linear Regression (Normal)
1919
- Log Normal
@@ -29,6 +29,7 @@
2929
Below are the supported models
3030
3131
"""
32+
3233
from typing import Tuple
3334

3435
import numpy as np
@@ -250,7 +251,7 @@ def negative_binomial_beta_posterior_predictive(
250251
BetaNegativeBinomial posterior predictive distribution
251252
252253
"""
253-
return BetaNegativeBinomial(r=r, alpha=beta.alpha, beta=beta.beta)
254+
return BetaNegativeBinomial(n=r, alpha=beta.alpha, beta=beta.beta)
254255

255256

256257
def hypergeometric_beta_binomial(

poetry.lock

Lines changed: 863 additions & 954 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "conjugate-models"
3-
version = "0.7.1"
3+
version = "0.8.0"
44
description = "Bayesian Conjugate Models in Python"
55
authors = ["Will Dean <wd60622@gmail.com>"]
66
license = "MIT"
@@ -22,10 +22,15 @@ classifiers = [
2222
"Operating System :: POSIX",
2323
"Operating System :: Unix",
2424
"Operating System :: MacOS",
25+
"Programming Language :: Python",
26+
"Programming Language :: Python :: 3",
27+
"Programming Language :: Python :: 3.10",
28+
"Programming Language :: Python :: 3.11",
29+
"Programming Language :: Python :: 3.12",
2530
]
2631

2732
[tool.poetry.dependencies]
28-
python = ">=3.8,<4.0"
33+
python = ">=3.10,<4.0"
2934
matplotlib = "*"
3035
numpy = "*"
3136
scipy = "*"
@@ -50,7 +55,7 @@ mkdocs-material = "^9.1.17"
5055
mkdocstrings = {extras = ["python"], version = "^0.22.0"}
5156

5257
[tool.pytest.ini_options]
53-
addopts = "--mpl --mpl-baseline-path=tests/example-plots --cov=conjugate --cov-report=xml --cov-report=term-missing tests"
58+
addopts = "--mpl --mpl-baseline-path=tests/example-plots --cov=conjugate --cov-report=xml --cov-report=term-missing"
5459

5560
[tool.tox]
5661
legacy_tox_ini = """
@@ -61,16 +66,14 @@ legacy_tox_ini = """
6166
py312
6267
py311
6368
py310
64-
py39
65-
py38
6669
6770
[testenv]
6871
deps =
6972
pytest
7073
pytest-cov
7174
pytest-mpl
7275
pypika
73-
commands = pytest
76+
commands = pytest tests
7477
"""
7578

7679
[build-system]

tests/test_distributions.py

Lines changed: 144 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,23 +9,35 @@
99
from scipy import __version__ as scipy_version
1010

1111
from conjugate.distributions import (
12-
get_beta_param_from_mean_and_alpha,
1312
Beta,
13+
BetaBinomial,
1414
BetaNegativeBinomial,
15+
Geometric,
16+
Binomial,
17+
CompoundGamma,
1518
Dirichlet,
16-
Gamma,
1719
Exponential,
18-
NegativeBinomial,
19-
Poisson,
20-
Normal,
20+
Gamma,
21+
NormalGamma,
22+
InverseGamma,
23+
InverseWishart,
24+
Hypergeometric,
25+
LogNormal,
26+
Lomax,
2127
MultivariateNormal,
22-
StudentT,
2328
MultivariateStudentT,
24-
InverseWishart,
25-
NormalInverseWishart,
26-
InverseGamma,
29+
NegativeBinomial,
30+
Normal,
2731
NormalInverseGamma,
28-
CompoundGamma,
32+
NormalInverseWishart,
33+
Pareto,
34+
Poisson,
35+
ScaledInverseChiSquared,
36+
StudentT,
37+
Uniform,
38+
VectorizedDist,
39+
VonMises,
40+
get_beta_param_from_mean_and_alpha,
2941
)
3042

3143

@@ -47,6 +59,18 @@ def test_beta(alpha, beta) -> None:
4759
assert isinstance(ax, plt.Axes)
4860

4961

62+
def test_beta_uninformative() -> None:
63+
beta = Beta.uninformative()
64+
assert beta.alpha == 1.0
65+
assert beta.beta == 1.0
66+
67+
68+
def test_beta_from_success_and_failures() -> None:
69+
beta = Beta.from_successes_and_failures(successes=0, failures=0)
70+
assert beta.alpha == 1.0
71+
assert beta.beta == 1.0
72+
73+
5074
@pytest.mark.parametrize("mean", [0.025, 0.5, 0.75])
5175
@pytest.mark.parametrize("alpha", [1, 10, 100])
5276
def test_beta_mean_constructor(mean: float, alpha: float) -> None:
@@ -251,6 +275,9 @@ def test_normal_inverse_wishart() -> None:
251275
mean = distribution.sample_mean(size=1)
252276
assert mean.shape == (1, 2)
253277

278+
_, variance = distribution.sample_mean(size=1, return_variance=True)
279+
assert variance.shape == (1, 2, 2)
280+
254281

255282
@pytest.mark.parametrize("n_features", [1, 2, 3])
256283
@pytest.mark.parametrize("n_samples", [1, 2, 10])
@@ -296,3 +323,110 @@ def test_normal_inverse_gamma(n_features, n_samples) -> None:
296323
def test_compound_gamma(a, b, q, size) -> None:
297324
dist = CompoundGamma(alpha=1, beta=1, lam=1)
298325
assert dist.dist.rvs(size=size).shape == size
326+
327+
328+
def test_binomial_max_value() -> None:
329+
n = np.array([10, 20, 15])
330+
p = 0.5
331+
binomial = Binomial(n=n, p=p)
332+
333+
assert binomial.max_value == 20
334+
335+
336+
def test_dirichlet_rvs() -> None:
337+
dirichlet = Dirichlet(alpha=np.array([[1, 2, 3], [1, 1, 1]]))
338+
339+
dist = dirichlet.dist
340+
assert isinstance(dist, VectorizedDist)
341+
samples = dist.rvs(size=10)
342+
assert isinstance(samples, np.ndarray)
343+
344+
345+
@pytest.mark.parametrize(
346+
"dist",
347+
[
348+
Binomial(n=10, p=0.5),
349+
Geometric(p=0.5),
350+
BetaBinomial(n=10, alpha=1, beta=1),
351+
Poisson(lam=1),
352+
BetaNegativeBinomial(n=10, alpha=1, beta=1),
353+
NegativeBinomial(n=10, p=0.5),
354+
Hypergeometric(N=100, k=5, n=10),
355+
],
356+
)
357+
def test_plot_pmf(dist) -> None:
358+
dist.max_value = 10
359+
ax = dist.plot_pmf()
360+
assert isinstance(ax, plt.Axes)
361+
362+
363+
@pytest.mark.parametrize(
364+
"dist",
365+
[
366+
Beta(alpha=1, beta=1),
367+
Gamma(alpha=1, beta=1),
368+
CompoundGamma(alpha=1, beta=1, lam=1),
369+
LogNormal(mu=1, sigma=1),
370+
ScaledInverseChiSquared(nu=1, sigma2=1),
371+
VonMises(mu=0, kappa=1),
372+
Lomax(alpha=1, lam=1),
373+
StudentT(mu=0, sigma=1, nu=10),
374+
InverseGamma(alpha=1, beta=1),
375+
Pareto(x_m=10, alpha=1),
376+
Uniform(low=10, high=20),
377+
Normal(mu=0, sigma=1),
378+
Exponential(lam=1),
379+
],
380+
)
381+
def test_plot_pdf(dist) -> None:
382+
dist.max_value = 10
383+
dist.min_value = 0
384+
ax = dist.plot_pdf()
385+
assert isinstance(ax, plt.Axes)
386+
387+
388+
def test_normal_gamma() -> None:
389+
normal_gamma = NormalGamma(
390+
mu=0,
391+
lam=1,
392+
alpha=1,
393+
beta=1,
394+
)
395+
396+
assert normal_gamma.gamma == Gamma(alpha=1, beta=1)
397+
398+
mean = normal_gamma.sample_mean(size=10)
399+
assert mean.shape == (10,)
400+
401+
_, variance = normal_gamma.sample_mean(size=1, return_variance=True)
402+
assert variance.shape == (1,)
403+
404+
405+
def test_scaled_inverse_chi_squared_round_trip() -> None:
406+
inverse_gamma = InverseGamma(alpha=1, beta=1)
407+
scaled_inverse_gamma = ScaledInverseChiSquared.from_inverse_gamma(inverse_gamma)
408+
back_again = scaled_inverse_gamma.to_inverse_gamma()
409+
410+
assert inverse_gamma == back_again
411+
412+
413+
def test_combining_poisson() -> None:
414+
poisson_1 = Poisson(lam=1)
415+
poisson_2 = Poisson(lam=2)
416+
poisson_3 = poisson_1 + poisson_2
417+
assert poisson_3 == Poisson(lam=3)
418+
419+
420+
def test_scaling_of_normal() -> None:
421+
normal = Normal(mu=0, sigma=1)
422+
423+
scaled_normal = 4 * normal
424+
assert scaled_normal == Normal(mu=0, sigma=2)
425+
426+
427+
def test_normal_alternative_constructors() -> None:
428+
assert Normal.uninformative() == Normal(mu=0, sigma=1)
429+
assert Normal.from_mean_and_variance(mean=0, variance=4) == Normal(mu=0, sigma=2)
430+
assert Normal.from_mean_and_precision(mean=0, precision=1 / 4) == Normal(
431+
mu=0, sigma=2
432+
)

0 commit comments

Comments
 (0)