Skip to content

Commit 5c6c642

Browse files
committed
add: tests
1 parent 6567e77 commit 5c6c642

File tree

2 files changed

+58
-0
lines changed

2 files changed

+58
-0
lines changed

tests/test_mod.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
import numpy as np
2+
3+
import cripser
4+
import tcripser
5+
6+
7+
def test_cripser_module_on_3d_hole():
8+
arr = np.load('sample/3d_hole.npy').astype(np.float64)
9+
ph = cripser.computePH(arr, maxdim=2)
10+
assert ph.ndim == 2 and ph.shape[1] == 9
11+
dims = set(ph[:, 0].astype(int))
12+
# Expect H0 and H2 features on a 3D hole dataset
13+
assert 0 in dims
14+
assert 2 in dims
15+
16+
17+
def test_tcripser_module_on_3d_hole():
18+
arr = np.load('sample/3d_hole.npy').astype(np.float64)
19+
ph = tcripser.computePH(arr, maxdim=2)
20+
assert ph.ndim == 2 and ph.shape[1] == 9
21+
dims = set(ph[:, 0].astype(int))
22+
# Expect non-empty result and some higher-dimensional features
23+
assert len(dims) >= 1
24+
assert any(d in dims for d in (1, 2))

tests/test_utils.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
import numpy as np
2+
3+
from cripser import (
4+
compute_ph,
5+
to_gudhi_diagrams,
6+
to_gudhi_persistence,
7+
group_by_dim,
8+
)
9+
10+
11+
def test_utils_converters_basic():
12+
# Simple 2D constant image ensures an infinite H0 bar
13+
arr = np.zeros((3, 3), dtype=np.float64)
14+
ph = compute_ph(arr, maxdim=1)
15+
16+
# Base shape checks
17+
assert ph.ndim == 2 and ph.shape[1] == 9
18+
19+
# Diagrams
20+
dgms = to_gudhi_diagrams(ph, maxdim=1)
21+
assert isinstance(dgms, list) and len(dgms) == 2
22+
assert dgms[0].ndim == 2 and dgms[0].shape[1] == 2
23+
# Expect at least one infinite death in H0
24+
assert np.isinf(dgms[0][:, 1]).any()
25+
26+
# Persistence list
27+
pers = to_gudhi_persistence(ph)
28+
assert any(d == 0 and np.isinf(bd[1]) for d, bd in pers)
29+
30+
# Group by dimension
31+
groups = group_by_dim(ph)
32+
assert len(groups) >= 1
33+
if len(groups[0]) > 0:
34+
assert np.all(groups[0][:, 0] == 0)

0 commit comments

Comments
 (0)