Skip to content

Commit 77e4b0f

Browse files
committed
performance improvements for spot location refinement code; pinning some packages in requirements.txt
1 parent 1a4c9df commit 77e4b0f

File tree

2 files changed

+27
-21
lines changed

2 files changed

+27
-21
lines changed

deepcell_spots/applications/polaris.py

Lines changed: 23 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -436,19 +436,25 @@ def _predict(self,
436436
print('Refining spot locations.')
437437
dec_prob_im = np.zeros((spots_image.shape[:3]))
438438

439-
for i in range(len(df_results)):
440-
gene = df_results.loc[i, 'predicted_name']
441-
if gene in ['Background', 'Unknown']:
442-
continue
443-
if 'Blank' in gene:
444-
continue
445-
446-
x = df_results.loc[i, 'x']
447-
y = df_results.loc[i, 'y']
448-
b = df_results.loc[i, 'batch_id']
449-
prob = max_proj_images[b, x, y]
450-
451-
dec_prob_im[b, x, y] = prob
439+
# Mask out unwanted genes
440+
mask_valid = ~df_results['predicted_name'].isin(['Background', 'Unknown']) & \
441+
~df_results['predicted_name'].str.contains('Blank', na=False)
442+
valid_df = df_results[mask_valid]
443+
444+
b = valid_df['batch_id'].to_numpy()
445+
x = valid_df['x'].to_numpy()
446+
y = valid_df['y'].to_numpy()
447+
448+
dec_prob_im[b, x, y] = max_proj_images[b, x, y]
449+
450+
451+
# Pre-create a dictionary for fast lookups
452+
lookup_dict = {}
453+
for idx, row in df_results.iterrows():
454+
key = (row['batch_id'], row['x'], row['y'])
455+
if key not in lookup_dict:
456+
lookup_dict[key] = []
457+
lookup_dict[key].append(idx)
452458

453459
mask = []
454460
for b in range(spots_image.shape[0]):
@@ -459,10 +465,10 @@ def _predict(self,
459465
x = decoded_spots_locations[0][i, 0]
460466
y = decoded_spots_locations[0][i, 1]
461467

462-
mask.append(df_results.loc[(df_results.x==x) &
463-
(df_results.y==y) &
464-
(df_results.batch_id==b)].index[0])
465-
468+
key = (b, x, y)
469+
if key in lookup_dict:
470+
mask.append(lookup_dict[key][0])
471+
466472
df_results = df_results.loc[mask]
467473

468474
return df_results, segmentation_result

requirements.txt

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,11 @@ jupyter>=1.0.0,<2
88
networkx>=2.1
99
opencv-python-headless<5
1010
deepcell>=0.12.7
11-
trackpy
11+
trackpy~=0.6.0
1212
tqdm
1313
plotly
14-
statsmodels
14+
statsmodels~=0.14.0
1515
--extra-index-url https://download.pytorch.org/whl/cpu # install the cpu only version of torch and torchvision
16-
torch
16+
torch~=2.3.0
1717
torchvision
18-
pyro-ppl
18+
pyro-ppl>=1.8.0

0 commit comments

Comments
 (0)