Skip to content

Commit 018515c

Browse files
author
Francis Motta
committed
implemented parallelization in PersistenceImages().transform(), updated Classificaiton with persistence images notebook to use new PersistenceImages() class, added paralellization example to Persistence Image notebook, made all unit tests use numpt.testing assert methods
1 parent 92e80b8 commit 018515c

File tree

5 files changed

+223
-131
lines changed

5 files changed

+223
-131
lines changed

docs/notebooks/Classification with persistence images.ipynb

Lines changed: 58 additions & 34 deletions
Large diffs are not rendered by default.

docs/notebooks/PersImage_update.ipynb

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
},
1818
{
1919
"cell_type": "code",
20-
"execution_count": 20,
20+
"execution_count": 1,
2121
"metadata": {},
2222
"outputs": [],
2323
"source": [
@@ -168,16 +168,17 @@
168168
},
169169
{
170170
"cell_type": "code",
171-
"execution_count": 1,
171+
"execution_count": 9,
172172
"metadata": {},
173173
"outputs": [],
174174
"source": [
175+
"# Verify unit tests pass\n",
175176
"%run ../../test/test_persim_update.py"
176177
]
177178
},
178179
{
179180
"cell_type": "code",
180-
"execution_count": 2,
181+
"execution_count": 10,
181182
"metadata": {},
182183
"outputs": [],
183184
"source": [
@@ -200,8 +201,7 @@
200201
"\n",
201202
"TestTransformOutput().test_lists_of_lists()\n",
202203
"TestTransformOutput().test_n_pixels()\n",
203-
"TestTransformOutput().test_multiple_diagrams()\n",
204-
"\n"
204+
"TestTransformOutput().test_multiple_diagrams()"
205205
]
206206
},
207207
{

docs/notebooks/Persistence Images.ipynb

Lines changed: 80 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,13 @@
1616
},
1717
{
1818
"cell_type": "code",
19-
"execution_count": 2,
19+
"execution_count": 3,
2020
"metadata": {},
2121
"outputs": [],
2222
"source": [
2323
"from itertools import product\n",
2424
"\n",
25+
"import time\n",
2526
"import numpy as np\n",
2627
"from sklearn import datasets\n",
2728
"from scipy.stats import multivariate_normal as mvn\n",
@@ -428,6 +429,84 @@
428429
"pimgr.plot_image(pimgr.transform(H1_dgm))"
429430
]
430431
},
432+
{
433+
"cell_type": "markdown",
434+
"metadata": {},
435+
"source": [
436+
"## Parallelization"
437+
]
438+
},
439+
{
440+
"cell_type": "code",
441+
"execution_count": 7,
442+
"metadata": {},
443+
"outputs": [
444+
{
445+
"name": "stdout",
446+
"output_type": "stream",
447+
"text": [
448+
"Execution time in serial: 0.172504 sec.\n",
449+
"Execution time in parallel: 0.145611 sec.\n"
450+
]
451+
}
452+
],
453+
"source": [
454+
"# For diagrams with small numbers of persistence pairs, overhead costs may not justify parallelization\n",
455+
"# Also, initial run of job in parallel is very costly. Run twice to see speed gains.\n",
456+
"import time\n",
457+
"num_diagrams = 100\n",
458+
"min_pairs = 50\n",
459+
"max_pairs = 100\n",
460+
"\n",
461+
"pimgr = PersistenceImager()\n",
462+
"dgms = [np.random.rand(np.random.randint(min_pairs, max_pairs), 2) for _ in range(num_diagrams)]\n",
463+
"\n",
464+
"pimgr.fit(dgms)\n",
465+
"\n",
466+
"start_time = time.time()\n",
467+
"pimgr.transform(dgms)\n",
468+
"print(\"Execution time in serial: %g sec.\" % (time.time() - start_time))\n",
469+
"\n",
470+
"start_time = time.time()\n",
471+
"pimgr.transform(dgms, n_jobs=-1)\n",
472+
"print(\"Execution time in parallel: %g sec.\" % (time.time() - start_time))"
473+
]
474+
},
475+
{
476+
"cell_type": "code",
477+
"execution_count": 8,
478+
"metadata": {},
479+
"outputs": [
480+
{
481+
"name": "stdout",
482+
"output_type": "stream",
483+
"text": [
484+
"Execution time in serial: 1.59773 sec.\n",
485+
"Execution time in parallel: 0.386466 sec.\n"
486+
]
487+
}
488+
],
489+
"source": [
490+
"# For larger diagrams, speed up can be significant\n",
491+
"import time\n",
492+
"num_diagrams = 100\n",
493+
"min_pairs = 500\n",
494+
"max_pairs = 1000\n",
495+
"\n",
496+
"pimgr = PersistenceImager()\n",
497+
"dgms = [np.random.rand(np.random.randint(min_pairs, max_pairs), 2) for _ in range(num_diagrams)]\n",
498+
"\n",
499+
"pimgr.fit(dgms)\n",
500+
"\n",
501+
"start_time = time.time()\n",
502+
"pimgr.transform(dgms)\n",
503+
"print(\"Execution time in serial: %g sec.\" % (time.time() - start_time))\n",
504+
"\n",
505+
"start_time = time.time()\n",
506+
"pimgr.transform(dgms, n_jobs=-1)\n",
507+
"print(\"Execution time in parallel: %g sec.\" % (time.time() - start_time))"
508+
]
509+
},
431510
{
432511
"cell_type": "code",
433512
"execution_count": null,

persim/images.py

Lines changed: 53 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import collections
44

55
from joblib import Parallel, delayed
6-
from multiprocessing import Pool
76

87
import copy
98
import numpy as np
@@ -476,8 +475,6 @@ def transform(self, pers_dgms, skew=True, n_jobs=None):
476475
"""
477476
if n_jobs is not None:
478477
parallelize = True
479-
if n_jobs == -1:
480-
n_jobs = None
481478
else:
482479
parallelize = False
483480

@@ -488,15 +485,10 @@ def transform(self, pers_dgms, skew=True, n_jobs=None):
488485
# convert to a list of diagrams if necessary
489486
pers_dgms, singular = self._ensure_iterable(pers_dgms)
490487

491-
# TODO: Parallellize over collection of diagrams
492488
if parallelize:
493-
fxn = lambda pers_dgm: self._transform(pers_dgm, skew=skew)
494-
pool = Pool(n_jobs)
495-
pers_imgs = pool.map(fxn, pers_dgms)
496-
#pers_imgs = Parallel(n_jobs=n_jobs)(delayed(self._transform)(pers_dgm, skew=skew) for pers_dgm in pers_dgms)
497-
pool.close()
489+
pers_imgs = Parallel(n_jobs=n_jobs)(delayed(_transform)(pers_dgm, skew, self.resolution, self.weight, self.weight_params, self.kernel, self.kernel_params, self._bpnts, self._ppnts) for pers_dgm in pers_dgms)
498490
else:
499-
pers_imgs = [self._transform(pers_dgm, skew=skew) for pers_dgm in pers_dgms]
491+
pers_imgs = [_transform(pers_dgm, skew=skew, resolution=self.resolution, weight=self.weight, weight_params=self.weight_params, kernel=self.kernel, kernel_params=self.kernel_params, _bpnts=self._bpnts, _ppnts=self._ppnts) for pers_dgm in pers_dgms]
500492

501493
if singular:
502494
pers_imgs = pers_imgs[0]
@@ -521,58 +513,7 @@ def fit_transform(self, pers_dgms, skew=True):
521513
pers_imgs = self.transform(pers_dgms, skew=skew)
522514

523515
return pers_imgs
524-
525-
def _transform(self, pers_dgm, skew=True):
526-
"""
527-
Transform a persistence diagram to a persistence image using the parameters specified in the PersistenceImager
528-
object instance
529-
:param pers_dgm: (N,2) numpy array of persistence pairs encoding a persistence diagram
530-
:param skew: boolean flag indicating if diagram needs to be converted to birth-persistence coordinates
531-
(default: True)
532-
:return: numpy array encoding the persistence image
533-
"""
534-
pers_dgm = np.copy(pers_dgm)
535-
pers_img = np.zeros(self.resolution)
536-
n = pers_dgm.shape[0]
537-
general_flag = True
538-
539-
# if necessary convert from birth-death coordinates to birth-persistence coordinates
540-
if skew:
541-
pers_dgm[:, 1] = pers_dgm[:, 1] - pers_dgm[:, 0]
542-
543-
# compute weights at each persistence pair
544-
wts = self.weight(pers_dgm[:, 0], pers_dgm[:, 1], **self.weight_params)
545-
546-
# handle the special case of a standard, isotropic Gaussian kernel
547-
if self.kernel == bvncdf:
548-
general_flag = False
549-
sigma = self.kernel_params['sigma']
550-
551-
# sigma is specified by a single variance
552-
if isinstance(sigma, (int, float)):
553-
sigma = np.array([[sigma, 0.0], [0.0, sigma]], dtype=np.float64)
554-
555-
if (sigma[0, 0] == sigma[1, 1] and sigma[0, 1] == 0.0):
556-
sigma = np.sqrt(sigma[0, 0])
557-
for i in range(n):
558-
ncdf_b = _norm_cdf((self._bpnts - pers_dgm[i, 0]) / sigma)
559-
ncdf_p = _norm_cdf((self._ppnts - pers_dgm[i, 1]) / sigma)
560-
curr_img = ncdf_p[None, :] * ncdf_b[:, None]
561-
pers_img += wts[i]*(curr_img[1:, 1:] - curr_img[:-1, 1:] - curr_img[1:, :-1] + curr_img[:-1, :-1])
562-
else:
563-
general_flag = True
564-
565-
# handle the general case
566-
if general_flag:
567-
bb, pp = np.meshgrid(self._bpnts, self._ppnts, indexing='ij')
568-
bb = bb.flatten(order='C')
569-
pp = pp.flatten(order='C')
570-
for i in range(n):
571-
curr_img = np.reshape(self.kernel(bb, pp, mu=pers_dgm[i, :], **self.kernel_params),
572-
(self.resolution[0]+1, self.resolution[1]+1), order='C')
573-
pers_img += wts[i]*(curr_img[1:, 1:] - curr_img[:-1, 1:] - curr_img[1:, :-1] + curr_img[:-1, :-1])
574-
575-
return pers_img
516+
576517

577518
def _ensure_iterable(self, pers_dgms):
578519
# if first entry of first entry is not iterable, then diagrams is singular and we need to make it a list of diagrams
@@ -669,7 +610,6 @@ def plot_image(self, pers_img, ax=None, out_file=None):
669610
if out_file:
670611
plt.savefig(out_file, bbox_inches='tight')
671612

672-
673613
def dict_print(dict_in):
674614
# print dictionary contents in human-readable format
675615
if dict_in is None:
@@ -682,7 +622,57 @@ def dict_print(dict_in):
682622

683623
return str_out
684624

625+
def _transform(pers_dgm, skew=True, resolution=None, weight=None, weight_params=None, kernel=None, kernel_params=None, _bpnts=None, _ppnts=None):
626+
"""
627+
Transform a persistence diagram to a persistence image using the parameters specified in the PersistenceImager
628+
object instance
629+
:param pers_dgm: (N,2) numpy array of persistence pairs encoding a persistence diagram
630+
:param skew: boolean flag indicating if diagram needs to be converted to birth-persistence coordinates
631+
(default: True)
632+
:return: numpy array encoding the persistence image
633+
"""
634+
pers_dgm = np.copy(pers_dgm)
635+
pers_img = np.zeros(resolution)
636+
n = pers_dgm.shape[0]
637+
general_flag = True
638+
639+
# if necessary convert from birth-death coordinates to birth-persistence coordinates
640+
if skew:
641+
pers_dgm[:, 1] = pers_dgm[:, 1] - pers_dgm[:, 0]
642+
643+
# compute weights at each persistence pair
644+
wts = weight(pers_dgm[:, 0], pers_dgm[:, 1], **weight_params)
645+
646+
# handle the special case of a standard, isotropic Gaussian kernel
647+
if kernel == bvncdf:
648+
general_flag = False
649+
sigma = kernel_params['sigma']
685650

651+
# sigma is specified by a single variance
652+
if isinstance(sigma, (int, float)):
653+
sigma = np.array([[sigma, 0.0], [0.0, sigma]], dtype=np.float64)
654+
655+
if (sigma[0, 0] == sigma[1, 1] and sigma[0, 1] == 0.0):
656+
sigma = np.sqrt(sigma[0, 0])
657+
for i in range(n):
658+
ncdf_b = _norm_cdf((_bpnts - pers_dgm[i, 0]) / sigma)
659+
ncdf_p = _norm_cdf((_ppnts - pers_dgm[i, 1]) / sigma)
660+
curr_img = ncdf_p[None, :] * ncdf_b[:, None]
661+
pers_img += wts[i]*(curr_img[1:, 1:] - curr_img[:-1, 1:] - curr_img[1:, :-1] + curr_img[:-1, :-1])
662+
else:
663+
general_flag = True
664+
665+
# handle the general case
666+
if general_flag:
667+
bb, pp = np.meshgrid(_bpnts, _ppnts, indexing='ij')
668+
bb = bb.flatten(order='C')
669+
pp = pp.flatten(order='C')
670+
for i in range(n):
671+
curr_img = np.reshape(kernel(bb, pp, mu=pers_dgm[i, :], **kernel_params),
672+
(resolution[0]+1, resolution[1]+1), order='C')
673+
pers_img += wts[i]*(curr_img[1:, 1:] - curr_img[:-1, 1:] - curr_img[1:, :-1] + curr_img[:-1, :-1])
674+
675+
return pers_img
686676
"""
687677
Kernel functions:
688678

0 commit comments

Comments
 (0)