1010from hciplot import plot_frames
1111from scipy import stats
1212from photutils .segmentation import detect_sources
13- from munch import Munch
1413from ..config import time_ini , timing , Progressbar
1514from ..fm import cube_inject_companions
1615from ..psfsub .svd import SVDecomposer
1716from ..var import frame_center , get_annulus_segments , get_circle
1817
19- # TODO: remove the munch dependency
20-
2118
2219class EvalRoc (object ):
2320 """
@@ -68,7 +65,7 @@ def add_algo(self, name, algo, color, symbol, thresholds):
6865 thresholds : list of lists
6966
7067 """
71- self .methods .append (Munch (algo = algo , name = name , color = color ,
68+ self .methods .append (dict (algo = algo , name = name , color = color ,
7269 symbol = symbol , thresholds = thresholds ))
7370
7471 def inject_and_postprocess (self , patch_size , cevr = 0.9 ,
@@ -97,11 +94,11 @@ def inject_and_postprocess(self, patch_size, cevr=0.9,
9794 print ("{}% of CEVR with {} PCs" .format (cevr , self .optpcs ))
9895
9996 # for m in methods:
100- # if hasattr(m, "ncomp") and m.ncomp is None: # PCA
101- # m. ncomp = self.optpcs
97+ # if m.get( "ncomp", object()) is None: # PCA
98+ # m[" ncomp"] = self.optpcs
10299 #
103- # if hasattr(m, "rank") and m.rank is None: # LLSG
104- # m. rank = self.optpcs
100+ # if m.get( "rank", object()) is None: # LLSG
101+ # m[" rank"] = self.optpcs
105102
106103 #
107104 # ------> this should be moved inside the HCIPostProcAlgo classes!
@@ -135,8 +132,8 @@ def inject_and_postprocess(self, patch_size, cevr=0.9,
135132 self .thetas .append (theta )
136133
137134 for m in self .methods :
138- m . frames = []
139- m . probmaps = []
135+ m [ " frames" ] = []
136+ m [ " probmaps" ] = []
140137
141138 self .list_xy = []
142139
@@ -157,7 +154,7 @@ def inject_and_postprocess(self, patch_size, cevr=0.9,
157154 # TODO: this is not elegant at all.
158155 # shallow copy. Should not copy e.g. the cube in memory,
159156 # just reference it.
160- algo = copy .copy (m . algo )
157+ algo = copy .copy (m [ " algo" ] )
161158 _dataset = copy .copy (self .dataset )
162159 _dataset .cube = cufc
163160
@@ -169,8 +166,8 @@ def inject_and_postprocess(self, patch_size, cevr=0.9,
169166 algo .run (dataset = _dataset , verbose = False )
170167 algo .make_snrmap (approximated = True , nproc = nproc , verbose = False )
171168
172- m . frames .append (algo .frame_final )
173- m . probmaps .append (algo .snr_map )
169+ m [ " frames" ] .append (algo .frame_final )
170+ m [ " probmaps" ] .append (algo .snr_map )
174171
175172 timing (starttime )
176173
@@ -192,22 +189,22 @@ def compute_tpr_fps(self, **kwargs):
192189 starttime = time_ini ()
193190
194191 for m in self .methods :
195- m . detections = []
196- m . fps = []
197- m . bmaps = []
192+ m [ " detections" ] = []
193+ m [ " fps" ] = []
194+ m [ " bmaps" ] = []
198195
199196 print ('Evaluating injections:' )
200197 for i in Progressbar (range (self .n_injections )):
201198 x , y = self .list_xy [i ]
202199
203200 for m in self .methods :
204201 dets , fps , bmaps = compute_binary_map (
205- m . probmaps [i ], m . thresholds , fwhm = self .dataset .fwhm ,
202+ m [ " probmaps" ] [i ], m [ " thresholds" ] , fwhm = self .dataset .fwhm ,
206203 injections = (x , y ), ** kwargs
207204 )
208- m . detections .append (dets )
209- m . fps .append (fps )
210- m . bmaps .append (bmaps )
205+ m [ " detections" ] .append (dets )
206+ m [ " fps" ] .append (fps )
207+ m [ " bmaps" ] .append (bmaps )
211208
212209 timing (starttime )
213210
@@ -245,9 +242,9 @@ def plot_detmaps(self, i=None, thr=9, dpi=100,
245242
246243 if vmax == 'max' :
247244 # TODO: document this feature.
248- vmax = np .concatenate ([m . frames [i ] for m in self .methods if
249- hasattr ( m , "frames" ) and
250- len (m . frames ) >= i ]).max ()/ 2
245+ vmax = np .concatenate ([m [ " frames" ] [i ] for m in self .methods if
246+ "frames" in m and
247+ len (m [ " frames" ] ) >= i ]).max ()/ 2
251248
252249 # print information
253250 print ('X,Y: {}' .format (self .list_xy [i ]))
@@ -258,33 +255,32 @@ def plot_detmaps(self, i=None, thr=9, dpi=100,
258255 if plot_type in [1 , "horiz" ]:
259256 for m in self .methods :
260257 print ('detection state: {} | false postives: {}' .format (
261- m . detections [i ][thr ], m . fps [i ][thr ]))
262- labels = ('{ } frame' . format ( m . name ), '{ } S/Nmap' . format ( m . name ) ,
263- ' Thresholded at {:.1f}' . format ( m . thresholds [thr ]) )
264- plot_frames ((m . frames [i ] if len (m . frames ) >= i else
265- np .zeros ((2 , 2 )), m . probmaps [i ], m . bmaps [i ][thr ]),
258+ m [ " detections" ] [i ][thr ], m [ " fps" ] [i ][thr ]))
259+ labels = (f" { m [ 'name' ] } frame" , f" { m [ 'name' ] } S/Nmap" ,
260+ f" Thresholded at { m [ ' thresholds' ] [thr ]:.1f } " )
261+ plot_frames ((m [ " frames" ] [i ] if len (m [ " frames" ] ) >= i else
262+ np .zeros ((2 , 2 )), m [ " probmaps" ] [i ], m [ " bmaps" ] [i ][thr ]),
266263 label = labels , dpi = dpi , horsp = 0.2 , axis = axis ,
267264 grid = grid , cmap = ['viridis' , 'viridis' , 'gray' ])
268265
269266 elif plot_type in [2 , "vert" ]:
270- labels = tuple ('{ } frame' . format ( m . name ) for m in self .methods if
271- hasattr ( m , "frames" ) and len (m . frames ) >= i )
272- plot_frames (tuple (m . frames [i ] for m in self .methods if
273- hasattr ( m , "frames" ) and len (m . frames ) >= i ),
267+ labels = tuple (f" { m [ 'name' ] } frame" for m in self .methods if
268+ "frames" in m and len (m [ " frames" ] ) >= i )
269+ plot_frames (tuple (m [ " frames" ] [i ] for m in self .methods if
270+ "frames" in m and len (m [ " frames" ] ) >= i ),
274271 dpi = dpi , label = labels , vmax = vmax , vmin = vmin , axis = axis ,
275272 grid = grid )
276273
277- plot_frames (tuple (m . probmaps [i ] for m in self .methods ), dpi = dpi ,
278- label = tuple (['{ } S/Nmap' . format ( m . name ) for m in
274+ plot_frames (tuple (m [ " probmaps" ] [i ] for m in self .methods ), dpi = dpi ,
275+ label = tuple ([f" { m [ 'name' ] } S/Nmap" for m in
279276 self .methods ]), axis = axis , grid = grid )
280277
281278 for m in self .methods :
282- msg = '{} detection: {}, FPs: {}'
283- print (msg .format (m .name , m .detections [i ][thr ], m .fps [i ][thr ]))
279+ print (f"{ m ['name' ]} detection: { m ['detections' ][i ][thr ]} , FPs: { m ['fps' ][i ][thr ]} " )
284280
285- labels = tuple (' Thresholded at {:.1f}' . format ( m . thresholds [thr ])
281+ labels = tuple (f" Thresholded at { m [ ' thresholds' ] [thr ]:.1f } "
286282 for m in self .methods )
287- plot_frames (tuple (m . bmaps [i ][thr ] for m in self .methods ),
283+ plot_frames (tuple (m [ " bmaps" ] [i ][thr ] for m in self .methods ),
288284 dpi = dpi , label = labels , axis = axis , grid = grid ,
289285 colorbar = False , cmap = 'bone' )
290286 else :
@@ -342,40 +338,40 @@ def plot_roc_curves(self, dpi=100, figsize=(5, 5), xmin=None, xmax=None,
342338 # "SODIRF": dict(color="#9467bd", symbol="s"),
343339 # "SODINN": dict(color="#1f77b4", symbol="p"),
344340 # "SODINN-pw": dict(color="#1f77b4", symbol="p")
345- # } # maps m. name to plot style
341+ # } # maps m[" name"] to plot style
346342
347343 for i , m in enumerate (self .methods ):
348344
349345 if not hasattr (m , "detections" ) or not hasattr (m , "fps" ):
350346 raise AttributeError ("method #{} has no detections/fps. Run"
351347 "`compute_tpr_fps` first." .format (i ))
352348
353- m . tpr = np .zeros (( n_thresholds ) )
354- m . mean_fps = np .zeros (( n_thresholds ) )
349+ m [ " tpr" ] = np .zeros (n_thresholds )
350+ m [ " mean_fps" ] = np .zeros (n_thresholds )
355351
356352 for j in range (n_thresholds ):
357- m . tpr [j ] = np .asarray (m . detections )[:, j ].tolist ().count (1 ) / \
353+ m [ " tpr" ] [j ] = np .asarray (m [ " detections" ] )[:, j ].tolist ().count (1 ) / \
358354 self .n_injections
359- m . mean_fps [j ] = np .asarray (m . fps )[:, j ].mean ()
355+ m [ " mean_fps" ] [j ] = np .asarray (m [ " fps" ] )[:, j ].mean ()
360356
361- plt .plot (m . mean_fps , m . tpr , '--' , color = m . color , ** linekw )
362- plt .plot (m . mean_fps , m . tpr , m . symbol , label = m . name , color = m . color ,
357+ plt .plot (m [ " mean_fps" ] , m [ " tpr" ] , '--' , color = m [ " color" ] , ** linekw )
358+ plt .plot (m [ " mean_fps" ] , m [ " tpr" ] , m [ " symbol" ] , label = m [ " name" ] , color = m [ " color" ] ,
363359 ** markerkw )
364360
365361 if show_data_labels :
366362 if label_skip_one [i ]:
367- lab_x = m . mean_fps [1 ::2 ]
368- lab_y = m . tpr [1 ::2 ]
369- thr = m . thresholds [1 ::2 ]
363+ lab_x = m [ " mean_fps" ] [1 ::2 ]
364+ lab_y = m [ " tpr" ] [1 ::2 ]
365+ thr = m [ " thresholds" ] [1 ::2 ]
370366 else :
371- lab_x = m . mean_fps
372- lab_y = m . tpr
373- thr = m . thresholds
367+ lab_x = m [ " mean_fps" ]
368+ lab_y = m [ " tpr" ]
369+ thr = m [ " thresholds" ]
374370
375371 for i , xy in enumerate (zip (lab_x + label_gap [0 ],
376372 lab_y + label_gap [1 ])):
377373 labels .append (ax .annotate ('{:.2f}' .format (thr [i ]),
378- xy = xy , xycoords = 'data' , color = m . color ,
374+ xy = xy , xycoords = 'data' , color = m [ " color" ] ,
379375 ** labelskw ))
380376 # TODO: reverse order of `self.methods` for better annot.
381377 # z-index?
0 commit comments