Skip to content

Commit e4f16fa

Browse files
authored
Merge pull request #682 from neutrinoceros/dep/ditch-munch
DEP/RFC: refactor out dependency on `munch`
2 parents 5f3f93d + 89fc4b2 commit e4f16fa

File tree

2 files changed

+48
-53
lines changed

2 files changed

+48
-53
lines changed

pyproject.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@ dependencies = [
3232
"emcee>=3.1.0",
3333
"hciplot>=0.2.4",
3434
"matplotlib>=3.7.0",
35-
"munch>=3.0.0",
3635
"nestle>=0.2.0",
3736
"numpy>=1.21.2",
3837
"pandas>=1.3.3",

src/vip_hci/metrics/roc.py

Lines changed: 48 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,11 @@
1010
from hciplot import plot_frames
1111
from scipy import stats
1212
from photutils.segmentation import detect_sources
13-
from munch import Munch
1413
from ..config import time_ini, timing, Progressbar
1514
from ..fm import cube_inject_companions
1615
from ..psfsub.svd import SVDecomposer
1716
from ..var import frame_center, get_annulus_segments, get_circle
1817

19-
# TODO: remove the munch dependency
20-
2118

2219
class EvalRoc(object):
2320
"""
@@ -68,7 +65,7 @@ def add_algo(self, name, algo, color, symbol, thresholds):
6865
thresholds : list of lists
6966
7067
"""
71-
self.methods.append(Munch(algo=algo, name=name, color=color,
68+
self.methods.append(dict(algo=algo, name=name, color=color,
7269
symbol=symbol, thresholds=thresholds))
7370

7471
def inject_and_postprocess(self, patch_size, cevr=0.9,
@@ -97,11 +94,11 @@ def inject_and_postprocess(self, patch_size, cevr=0.9,
9794
print("{}% of CEVR with {} PCs".format(cevr, self.optpcs))
9895

9996
# for m in methods:
100-
# if hasattr(m, "ncomp") and m.ncomp is None: # PCA
101-
# m.ncomp = self.optpcs
97+
# if m.get("ncomp", object()) is None: # PCA
98+
# m["ncomp"] = self.optpcs
10299
#
103-
# if hasattr(m, "rank") and m.rank is None: # LLSG
104-
# m.rank = self.optpcs
100+
# if m.get("rank", object()) is None: # LLSG
101+
# m["rank"] = self.optpcs
105102

106103
#
107104
# ------> this should be moved inside the HCIPostProcAlgo classes!
@@ -135,8 +132,8 @@ def inject_and_postprocess(self, patch_size, cevr=0.9,
135132
self.thetas.append(theta)
136133

137134
for m in self.methods:
138-
m.frames = []
139-
m.probmaps = []
135+
m["frames"] = []
136+
m["probmaps"] = []
140137

141138
self.list_xy = []
142139

@@ -157,7 +154,7 @@ def inject_and_postprocess(self, patch_size, cevr=0.9,
157154
# TODO: this is not elegant at all.
158155
# shallow copy. Should not copy e.g. the cube in memory,
159156
# just reference it.
160-
algo = copy.copy(m.algo)
157+
algo = copy.copy(m["algo"])
161158
_dataset = copy.copy(self.dataset)
162159
_dataset.cube = cufc
163160

@@ -169,8 +166,8 @@ def inject_and_postprocess(self, patch_size, cevr=0.9,
169166
algo.run(dataset=_dataset, verbose=False)
170167
algo.make_snrmap(approximated=True, nproc=nproc, verbose=False)
171168

172-
m.frames.append(algo.frame_final)
173-
m.probmaps.append(algo.snr_map)
169+
m["frames"].append(algo.frame_final)
170+
m["probmaps"].append(algo.snr_map)
174171

175172
timing(starttime)
176173

@@ -192,22 +189,22 @@ def compute_tpr_fps(self, **kwargs):
192189
starttime = time_ini()
193190

194191
for m in self.methods:
195-
m.detections = []
196-
m.fps = []
197-
m.bmaps = []
192+
m["detections"] = []
193+
m["fps"] = []
194+
m["bmaps"] = []
198195

199196
print('Evaluating injections:')
200197
for i in Progressbar(range(self.n_injections)):
201198
x, y = self.list_xy[i]
202199

203200
for m in self.methods:
204201
dets, fps, bmaps = compute_binary_map(
205-
m.probmaps[i], m.thresholds, fwhm=self.dataset.fwhm,
202+
m["probmaps"][i], m["thresholds"], fwhm=self.dataset.fwhm,
206203
injections=(x, y), **kwargs
207204
)
208-
m.detections.append(dets)
209-
m.fps.append(fps)
210-
m.bmaps.append(bmaps)
205+
m["detections"].append(dets)
206+
m["fps"].append(fps)
207+
m["bmaps"].append(bmaps)
211208

212209
timing(starttime)
213210

@@ -245,9 +242,9 @@ def plot_detmaps(self, i=None, thr=9, dpi=100,
245242

246243
if vmax == 'max':
247244
# TODO: document this feature.
248-
vmax = np.concatenate([m.frames[i] for m in self.methods if
249-
hasattr(m, "frames") and
250-
len(m.frames) >= i]).max()/2
245+
vmax = np.concatenate([m["frames"][i] for m in self.methods if
246+
"frames" in m and
247+
len(m["frames"]) >= i]).max()/2
251248

252249
# print information
253250
print('X,Y: {}'.format(self.list_xy[i]))
@@ -258,33 +255,32 @@ def plot_detmaps(self, i=None, thr=9, dpi=100,
258255
if plot_type in [1, "horiz"]:
259256
for m in self.methods:
260257
print('detection state: {} | false postives: {}'.format(
261-
m.detections[i][thr], m.fps[i][thr]))
262-
labels = ('{} frame'.format(m.name), '{} S/Nmap'.format(m.name),
263-
'Thresholded at {:.1f}'.format(m.thresholds[thr]))
264-
plot_frames((m.frames[i] if len(m.frames) >= i else
265-
np.zeros((2, 2)), m.probmaps[i], m.bmaps[i][thr]),
258+
m["detections"][i][thr], m["fps"][i][thr]))
259+
labels = (f"{m['name']} frame", f"{m['name']} S/Nmap",
260+
f"Thresholded at {m['thresholds'][thr]:.1f}")
261+
plot_frames((m["frames"][i] if len(m["frames"]) >= i else
262+
np.zeros((2, 2)), m["probmaps"][i], m["bmaps"][i][thr]),
266263
label=labels, dpi=dpi, horsp=0.2, axis=axis,
267264
grid=grid, cmap=['viridis', 'viridis', 'gray'])
268265

269266
elif plot_type in [2, "vert"]:
270-
labels = tuple('{} frame'.format(m.name) for m in self.methods if
271-
hasattr(m, "frames") and len(m.frames) >= i)
272-
plot_frames(tuple(m.frames[i] for m in self.methods if
273-
hasattr(m, "frames") and len(m.frames) >= i),
267+
labels = tuple(f"{m['name']} frame" for m in self.methods if
268+
"frames" in m and len(m["frames"]) >= i)
269+
plot_frames(tuple(m["frames"][i] for m in self.methods if
270+
"frames" in m and len(m["frames"]) >= i),
274271
dpi=dpi, label=labels, vmax=vmax, vmin=vmin, axis=axis,
275272
grid=grid)
276273

277-
plot_frames(tuple(m.probmaps[i] for m in self.methods), dpi=dpi,
278-
label=tuple(['{} S/Nmap'.format(m.name) for m in
274+
plot_frames(tuple(m["probmaps"][i] for m in self.methods), dpi=dpi,
275+
label=tuple([f"{m['name']} S/Nmap" for m in
279276
self.methods]), axis=axis, grid=grid)
280277

281278
for m in self.methods:
282-
msg = '{} detection: {}, FPs: {}'
283-
print(msg.format(m.name, m.detections[i][thr], m.fps[i][thr]))
279+
print(f"{m['name']} detection: {m['detections'][i][thr]}, FPs: {m['fps'][i][thr]}")
284280

285-
labels = tuple('Thresholded at {:.1f}'.format(m.thresholds[thr])
281+
labels = tuple(f"Thresholded at {m['thresholds'][thr]:.1f}"
286282
for m in self.methods)
287-
plot_frames(tuple(m.bmaps[i][thr] for m in self.methods),
283+
plot_frames(tuple(m["bmaps"][i][thr] for m in self.methods),
288284
dpi=dpi, label=labels, axis=axis, grid=grid,
289285
colorbar=False, cmap='bone')
290286
else:
@@ -342,40 +338,40 @@ def plot_roc_curves(self, dpi=100, figsize=(5, 5), xmin=None, xmax=None,
342338
# "SODIRF": dict(color="#9467bd", symbol="s"),
343339
# "SODINN": dict(color="#1f77b4", symbol="p"),
344340
# "SODINN-pw": dict(color="#1f77b4", symbol="p")
345-
# } # maps m.name to plot style
341+
# } # maps m["name"] to plot style
346342

347343
for i, m in enumerate(self.methods):
348344

349345
if not hasattr(m, "detections") or not hasattr(m, "fps"):
350346
raise AttributeError("method #{} has no detections/fps. Run"
351347
"`compute_tpr_fps` first.".format(i))
352348

353-
m.tpr = np.zeros((n_thresholds))
354-
m.mean_fps = np.zeros((n_thresholds))
349+
m["tpr"] = np.zeros(n_thresholds)
350+
m["mean_fps"] = np.zeros(n_thresholds)
355351

356352
for j in range(n_thresholds):
357-
m.tpr[j] = np.asarray(m.detections)[:, j].tolist().count(1) / \
353+
m["tpr"][j] = np.asarray(m["detections"])[:, j].tolist().count(1) / \
358354
self.n_injections
359-
m.mean_fps[j] = np.asarray(m.fps)[:, j].mean()
355+
m["mean_fps"][j] = np.asarray(m["fps"])[:, j].mean()
360356

361-
plt.plot(m.mean_fps, m.tpr, '--', color=m.color, **linekw)
362-
plt.plot(m.mean_fps, m.tpr, m.symbol, label=m.name, color=m.color,
357+
plt.plot(m["mean_fps"], m["tpr"], '--', color=m["color"], **linekw)
358+
plt.plot(m["mean_fps"], m["tpr"], m["symbol"], label=m["name"], color=m["color"],
363359
**markerkw)
364360

365361
if show_data_labels:
366362
if label_skip_one[i]:
367-
lab_x = m.mean_fps[1::2]
368-
lab_y = m.tpr[1::2]
369-
thr = m.thresholds[1::2]
363+
lab_x = m["mean_fps"][1::2]
364+
lab_y = m["tpr"][1::2]
365+
thr = m["thresholds"][1::2]
370366
else:
371-
lab_x = m.mean_fps
372-
lab_y = m.tpr
373-
thr = m.thresholds
367+
lab_x = m["mean_fps"]
368+
lab_y = m["tpr"]
369+
thr = m["thresholds"]
374370

375371
for i, xy in enumerate(zip(lab_x + label_gap[0],
376372
lab_y + label_gap[1])):
377373
labels.append(ax.annotate('{:.2f}'.format(thr[i]),
378-
xy=xy, xycoords='data', color=m.color,
374+
xy=xy, xycoords='data', color=m["color"],
379375
**labelskw))
380376
# TODO: reverse order of `self.methods` for better annot.
381377
# z-index?

0 commit comments

Comments
 (0)