Skip to content

Commit 4cbc654

Browse files
committed
bugfix: information was inappropreately shared between files during RT correction
1 parent b8367a1 commit 4cbc654

File tree

1 file changed

+8
-10
lines changed

1 file changed

+8
-10
lines changed

src/MSDpostprocess/models.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,11 @@
88
# import shelve
99
import re
1010
from functools import cache
11-
import warnings
1211

1312
import pickle
1413
import numpy as np
1514
import pandas as pd
16-
with warnings.catch_warnings():
17-
warnings.filterwarnings("ignore", message="brainpy")
18-
from brainpy import isotopic_variants
15+
from brainpy import isotopic_variants
1916
import statsmodels.api as sm
2017
from lineartree import LinearTreeClassifier
2118
from sklearn.preprocessing import StandardScaler
@@ -57,7 +54,6 @@ def fit(self, data):
5754
data = data.copy()
5855
data = data[np.isfinite(data['label'])]
5956
data = self.preprocess(data.copy())
60-
# self.classifier = LogisticRegression(solver = 'liblinear')
6157
self.classifier = CalibratedClassifierCV(GBC(), ensemble=False)
6258
self.classifier.fit(data[self.features], data['label'])
6359
self.logs.info(f'Fit {self.model}')
@@ -172,25 +168,27 @@ def rt_error(self, subset):
172168
frac = self.lowess_frac,
173169
it = 3,
174170
xvals = subset['Average Rt(min)'])
175-
rt_error = subset['Reference RT'].to_numpy() - regression
171+
subset['rt_error'] = subset['Reference RT'].to_numpy() - regression
176172
#record regression for QC purposes
177173
file = next(f for f in subset['file'])
178174
self.rt_observed[file] = subset['Average Rt(min)']
179175
self.rt_expected[file] = subset['Reference RT']
180176
self.rt_predictions[file] = regression
181177
self.rt_calls[file] = subset['call']
182178
self.logs.debug(f'RT regression fit for {file} used {np.sum(subset["call"])} observations')
183-
return rt_error
179+
return subset
184180

185181
def correct_data(self, data):
186182
#identify high confidence subset for correction
187183
scores = self._predict_prob(data)
188184
data['call'] = self.predict(scores)
189185

190186
#build lowess regressions
191-
rt_error = data.groupby('file')[data.columns].apply(self.rt_error)
192-
data['rt_error'] = [val for file in rt_error for val in file]
193-
data = data.drop(columns = ['call'])
187+
subsets = []
188+
for file in set(data['file']):
189+
subsets.append(self.rt_error(data[data['file'] == file]))
190+
data = pd.concat(subsets)
191+
data.drop(columns = ['call'])
194192

195193
self.logs.info('RT correction has been applied.')
196194
return data

0 commit comments

Comments
 (0)