Skip to content

Commit fa568b5

Browse files
committed
Refactor p_norm throughout
The former approximate p_norm method used vector norm from numpy, giving incorrect results. Refactor exact p_norm and move to auxiliary as _p_norm. Now both exact and approximate p_norm methods call _p_norm
1 parent 01cd6a5 commit fa568b5

File tree

4 files changed

+42
-36
lines changed

4 files changed

+42
-36
lines changed

persim/landscapes/approximate.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import numpy as np
66
from operator import itemgetter
77
from .base import PersLandscape
8-
from .auxiliary import union_vals, ndsnap_regular
8+
from .auxiliary import union_vals, ndsnap_regular, _p_norm
99

1010
__all__ = ["PersLandscapeApprox"]
1111

@@ -364,7 +364,8 @@ def p_norm(self, p: int = 2) -> float:
364364
p: float, default 2
365365
value p of the L_{`p`} norm
366366
"""
367-
return np.sum([np.linalg.norm(depth, p) for depth in self.values])
367+
super().p_norm(p=p)
368+
return _p_norm(p=p, critical_pairs=self.values_to_pairs())
368369

369370
def sup_norm(self) -> float:
370371
"""

persim/landscapes/auxiliary.py

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -91,10 +91,7 @@ def union_crit_pairs(A, B):
9191
else:
9292
result_pairs.append(
9393
slope_to_pos_interp(
94-
sum_slopes(
95-
pos_to_slope_interp(a),
96-
pos_to_slope_interp(b),
97-
)
94+
sum_slopes(pos_to_slope_interp(a), pos_to_slope_interp(b),)
9895
)
9996
)
10097
return result_pairs
@@ -187,3 +184,33 @@ def ndsnap_regular(points, *grid_axes):
187184
best = np.argmin(np.abs(diff), axis=0)
188185
snapped.append(ax[best])
189186
return np.array(snapped).T
187+
188+
189+
def _p_norm(p: float, critical_pairs: list = []):
190+
"""
191+
Compute `p` norm of interpolated piecewise linear function defined from list of
192+
critical pairs.
193+
"""
194+
result = 0.0
195+
for l in critical_pairs:
196+
for [[x0, y0], [x1, y1]] in zip(l, l[1:]):
197+
if y0 == y1:
198+
# horizontal line segment
199+
result += (np.abs(y0) ** p) * (x1 - x0)
200+
continue
201+
# slope is well-defined
202+
slope = (y1 - y0) / (x1 - x0)
203+
b = y0 - slope * x0
204+
# segment crosses the x-axis
205+
if (y0 < 0 and y1 > 0) or (y0 > 0 and y1 < 0):
206+
z = -b / slope
207+
ev_x1 = (slope * x1 + b) ** (p + 1) / (slope * (p + 1))
208+
ev_x0 = (slope * x0 + b) ** (p + 1) / (slope * (p + 1))
209+
ev_z = (slope * z + +b) ** (p + 1) / (slope * (p + 1))
210+
result += np.abs(ev_x1 + ev_x0 - 2 * ev_z)
211+
# segment does not cross the x-axis
212+
else:
213+
ev_x1 = (slope * x1 + b) ** (p + 1) / (slope * (p + 1))
214+
ev_x0 = (slope * x0 + b) ** (p + 1) / (slope * (p + 1))
215+
result += np.abs(ev_x1 - ev_x0)
216+
return (result) ** (1.0 / p)

persim/landscapes/base.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,11 @@ def __init__(self, dgms: list = [], hom_deg: int = 0) -> None:
3838

3939
@abstractmethod
4040
def p_norm(self, p: int = 2) -> float:
41-
pass
41+
if p < -1 or -1 < p < 0:
42+
raise ValueError(f"p can't be negative, but {p} was passed")
43+
self.compute_landscape()
44+
if p == -1:
45+
return self.sup_norm()
4246

4347
@abstractmethod
4448
def sup_norm(self) -> float:

persim/landscapes/exact.py

Lines changed: 3 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from operator import itemgetter
88

99
from .approximate import PersLandscapeApprox
10-
from .auxiliary import union_crit_pairs
10+
from .auxiliary import union_crit_pairs, _p_norm
1111
from .base import PersLandscape
1212

1313
__all__ = ["PersLandscapeExact"]
@@ -402,34 +402,8 @@ def p_norm(self, p: int = 2) -> float:
402402
p: float, default 2
403403
value p of the L_{`p`} norm
404404
"""
405-
if p == -1:
406-
return self.infinity_norm()
407-
if p < -1 or -1 < p < 0:
408-
raise ValueError(f"p can't be negative, but {p} was passed")
409-
self.compute_landscape()
410-
result = 0.0
411-
for l in self.critical_pairs:
412-
for [[x0, y0], [x1, y1]] in zip(l, l[1:]):
413-
if y0 == y1:
414-
# horizontal line segment
415-
result += (np.abs(y0) ** p) * (x1 - x0)
416-
continue
417-
# slope is well-defined
418-
slope = (y1 - y0) / (x1 - x0)
419-
b = y0 - slope * x0
420-
# segment crosses the x-axis
421-
if (y0 < 0 and y1 > 0) or (y0 > 0 and y1 < 0):
422-
z = -b / slope
423-
ev_x1 = (slope * x1 + b) ** (p + 1) / (slope * (p + 1))
424-
ev_x0 = (slope * x0 + b) ** (p + 1) / (slope * (p + 1))
425-
ev_z = (slope * z + +b) ** (p + 1) / (slope * (p + 1))
426-
result += np.abs(ev_x1 + ev_x0 - 2 * ev_z)
427-
# segment does not cross the x-axis
428-
else:
429-
ev_x1 = (slope * x1 + b) ** (p + 1) / (slope * (p + 1))
430-
ev_x0 = (slope * x0 + b) ** (p + 1) / (slope * (p + 1))
431-
result += np.abs(ev_x1 - ev_x0)
432-
return (result) ** (1.0 / p)
405+
super().p_norm(p=p)
406+
return _p_norm(p=p, critical_pairs=self.critical_pairs)
433407

434408
def sup_norm(self) -> float:
435409
"""

0 commit comments

Comments
 (0)