3
3
import collections
4
4
5
5
from joblib import Parallel , delayed
6
- from multiprocessing import Pool
7
6
8
7
import copy
9
8
import numpy as np
@@ -476,8 +475,6 @@ def transform(self, pers_dgms, skew=True, n_jobs=None):
476
475
"""
477
476
if n_jobs is not None :
478
477
parallelize = True
479
- if n_jobs == - 1 :
480
- n_jobs = None
481
478
else :
482
479
parallelize = False
483
480
@@ -488,15 +485,10 @@ def transform(self, pers_dgms, skew=True, n_jobs=None):
488
485
# convert to a list of diagrams if necessary
489
486
pers_dgms , singular = self ._ensure_iterable (pers_dgms )
490
487
491
- # TODO: Parallellize over collection of diagrams
492
488
if parallelize :
493
- fxn = lambda pers_dgm : self ._transform (pers_dgm , skew = skew )
494
- pool = Pool (n_jobs )
495
- pers_imgs = pool .map (fxn , pers_dgms )
496
- #pers_imgs = Parallel(n_jobs=n_jobs)(delayed(self._transform)(pers_dgm, skew=skew) for pers_dgm in pers_dgms)
497
- pool .close ()
489
+ pers_imgs = Parallel (n_jobs = n_jobs )(delayed (_transform )(pers_dgm , skew , self .resolution , self .weight , self .weight_params , self .kernel , self .kernel_params , self ._bpnts , self ._ppnts ) for pers_dgm in pers_dgms )
498
490
else :
499
- pers_imgs = [self . _transform (pers_dgm , skew = skew ) for pers_dgm in pers_dgms ]
491
+ pers_imgs = [_transform (pers_dgm , skew = skew , resolution = self . resolution , weight = self . weight , weight_params = self . weight_params , kernel = self . kernel , kernel_params = self . kernel_params , _bpnts = self . _bpnts , _ppnts = self . _ppnts ) for pers_dgm in pers_dgms ]
500
492
501
493
if singular :
502
494
pers_imgs = pers_imgs [0 ]
@@ -521,58 +513,7 @@ def fit_transform(self, pers_dgms, skew=True):
521
513
pers_imgs = self .transform (pers_dgms , skew = skew )
522
514
523
515
return pers_imgs
524
-
525
- def _transform (self , pers_dgm , skew = True ):
526
- """
527
- Transform a persistence diagram to a persistence image using the parameters specified in the PersistenceImager
528
- object instance
529
- :param pers_dgm: (N,2) numpy array of persistence pairs encoding a persistence diagram
530
- :param skew: boolean flag indicating if diagram needs to be converted to birth-persistence coordinates
531
- (default: True)
532
- :return: numpy array encoding the persistence image
533
- """
534
- pers_dgm = np .copy (pers_dgm )
535
- pers_img = np .zeros (self .resolution )
536
- n = pers_dgm .shape [0 ]
537
- general_flag = True
538
-
539
- # if necessary convert from birth-death coordinates to birth-persistence coordinates
540
- if skew :
541
- pers_dgm [:, 1 ] = pers_dgm [:, 1 ] - pers_dgm [:, 0 ]
542
-
543
- # compute weights at each persistence pair
544
- wts = self .weight (pers_dgm [:, 0 ], pers_dgm [:, 1 ], ** self .weight_params )
545
-
546
- # handle the special case of a standard, isotropic Gaussian kernel
547
- if self .kernel == bvncdf :
548
- general_flag = False
549
- sigma = self .kernel_params ['sigma' ]
550
-
551
- # sigma is specified by a single variance
552
- if isinstance (sigma , (int , float )):
553
- sigma = np .array ([[sigma , 0.0 ], [0.0 , sigma ]], dtype = np .float64 )
554
-
555
- if (sigma [0 , 0 ] == sigma [1 , 1 ] and sigma [0 , 1 ] == 0.0 ):
556
- sigma = np .sqrt (sigma [0 , 0 ])
557
- for i in range (n ):
558
- ncdf_b = _norm_cdf ((self ._bpnts - pers_dgm [i , 0 ]) / sigma )
559
- ncdf_p = _norm_cdf ((self ._ppnts - pers_dgm [i , 1 ]) / sigma )
560
- curr_img = ncdf_p [None , :] * ncdf_b [:, None ]
561
- pers_img += wts [i ]* (curr_img [1 :, 1 :] - curr_img [:- 1 , 1 :] - curr_img [1 :, :- 1 ] + curr_img [:- 1 , :- 1 ])
562
- else :
563
- general_flag = True
564
-
565
- # handle the general case
566
- if general_flag :
567
- bb , pp = np .meshgrid (self ._bpnts , self ._ppnts , indexing = 'ij' )
568
- bb = bb .flatten (order = 'C' )
569
- pp = pp .flatten (order = 'C' )
570
- for i in range (n ):
571
- curr_img = np .reshape (self .kernel (bb , pp , mu = pers_dgm [i , :], ** self .kernel_params ),
572
- (self .resolution [0 ]+ 1 , self .resolution [1 ]+ 1 ), order = 'C' )
573
- pers_img += wts [i ]* (curr_img [1 :, 1 :] - curr_img [:- 1 , 1 :] - curr_img [1 :, :- 1 ] + curr_img [:- 1 , :- 1 ])
574
-
575
- return pers_img
516
+
576
517
577
518
def _ensure_iterable (self , pers_dgms ):
578
519
# if first entry of first entry is not iterable, then diagrams is singular and we need to make it a list of diagrams
@@ -669,7 +610,6 @@ def plot_image(self, pers_img, ax=None, out_file=None):
669
610
if out_file :
670
611
plt .savefig (out_file , bbox_inches = 'tight' )
671
612
672
-
673
613
def dict_print (dict_in ):
674
614
# print dictionary contents in human-readable format
675
615
if dict_in is None :
@@ -682,7 +622,57 @@ def dict_print(dict_in):
682
622
683
623
return str_out
684
624
625
+ def _transform (pers_dgm , skew = True , resolution = None , weight = None , weight_params = None , kernel = None , kernel_params = None , _bpnts = None , _ppnts = None ):
626
+ """
627
+ Transform a persistence diagram to a persistence image using the parameters specified in the PersistenceImager
628
+ object instance
629
+ :param pers_dgm: (N,2) numpy array of persistence pairs encoding a persistence diagram
630
+ :param skew: boolean flag indicating if diagram needs to be converted to birth-persistence coordinates
631
+ (default: True)
632
+ :return: numpy array encoding the persistence image
633
+ """
634
+ pers_dgm = np .copy (pers_dgm )
635
+ pers_img = np .zeros (resolution )
636
+ n = pers_dgm .shape [0 ]
637
+ general_flag = True
638
+
639
+ # if necessary convert from birth-death coordinates to birth-persistence coordinates
640
+ if skew :
641
+ pers_dgm [:, 1 ] = pers_dgm [:, 1 ] - pers_dgm [:, 0 ]
642
+
643
+ # compute weights at each persistence pair
644
+ wts = weight (pers_dgm [:, 0 ], pers_dgm [:, 1 ], ** weight_params )
645
+
646
+ # handle the special case of a standard, isotropic Gaussian kernel
647
+ if kernel == bvncdf :
648
+ general_flag = False
649
+ sigma = kernel_params ['sigma' ]
685
650
651
+ # sigma is specified by a single variance
652
+ if isinstance (sigma , (int , float )):
653
+ sigma = np .array ([[sigma , 0.0 ], [0.0 , sigma ]], dtype = np .float64 )
654
+
655
+ if (sigma [0 , 0 ] == sigma [1 , 1 ] and sigma [0 , 1 ] == 0.0 ):
656
+ sigma = np .sqrt (sigma [0 , 0 ])
657
+ for i in range (n ):
658
+ ncdf_b = _norm_cdf ((_bpnts - pers_dgm [i , 0 ]) / sigma )
659
+ ncdf_p = _norm_cdf ((_ppnts - pers_dgm [i , 1 ]) / sigma )
660
+ curr_img = ncdf_p [None , :] * ncdf_b [:, None ]
661
+ pers_img += wts [i ]* (curr_img [1 :, 1 :] - curr_img [:- 1 , 1 :] - curr_img [1 :, :- 1 ] + curr_img [:- 1 , :- 1 ])
662
+ else :
663
+ general_flag = True
664
+
665
+ # handle the general case
666
+ if general_flag :
667
+ bb , pp = np .meshgrid (_bpnts , _ppnts , indexing = 'ij' )
668
+ bb = bb .flatten (order = 'C' )
669
+ pp = pp .flatten (order = 'C' )
670
+ for i in range (n ):
671
+ curr_img = np .reshape (kernel (bb , pp , mu = pers_dgm [i , :], ** kernel_params ),
672
+ (resolution [0 ]+ 1 , resolution [1 ]+ 1 ), order = 'C' )
673
+ pers_img += wts [i ]* (curr_img [1 :, 1 :] - curr_img [:- 1 , 1 :] - curr_img [1 :, :- 1 ] + curr_img [:- 1 , :- 1 ])
674
+
675
+ return pers_img
686
676
"""
687
677
Kernel functions:
688
678
0 commit comments