|
8 | 8 | # import shelve |
9 | 9 | import re |
10 | 10 | from functools import cache |
11 | | -import warnings |
12 | 11 |
|
13 | 12 | import pickle |
14 | 13 | import numpy as np |
15 | 14 | import pandas as pd |
16 | | -with warnings.catch_warnings(): |
17 | | - warnings.filterwarnings("ignore", message="brainpy") |
18 | | - from brainpy import isotopic_variants |
| 15 | +from brainpy import isotopic_variants |
19 | 16 | import statsmodels.api as sm |
20 | 17 | from lineartree import LinearTreeClassifier |
21 | 18 | from sklearn.preprocessing import StandardScaler |
@@ -57,7 +54,6 @@ def fit(self, data): |
57 | 54 | data = data.copy() |
58 | 55 | data = data[np.isfinite(data['label'])] |
59 | 56 | data = self.preprocess(data.copy()) |
60 | | - # self.classifier = LogisticRegression(solver = 'liblinear') |
61 | 57 | self.classifier = CalibratedClassifierCV(GBC(), ensemble=False) |
62 | 58 | self.classifier.fit(data[self.features], data['label']) |
63 | 59 | self.logs.info(f'Fit {self.model}') |
@@ -172,25 +168,27 @@ def rt_error(self, subset): |
172 | 168 | frac = self.lowess_frac, |
173 | 169 | it = 3, |
174 | 170 | xvals = subset['Average Rt(min)']) |
175 | | - rt_error = subset['Reference RT'].to_numpy() - regression |
| 171 | + subset['rt_error'] = subset['Reference RT'].to_numpy() - regression |
176 | 172 | #record regression for QC purposes |
177 | 173 | file = next(f for f in subset['file']) |
178 | 174 | self.rt_observed[file] = subset['Average Rt(min)'] |
179 | 175 | self.rt_expected[file] = subset['Reference RT'] |
180 | 176 | self.rt_predictions[file] = regression |
181 | 177 | self.rt_calls[file] = subset['call'] |
182 | 178 | self.logs.debug(f'RT regression fit for {file} used {np.sum(subset["call"])} observations') |
183 | | - return rt_error |
| 179 | + return subset |
184 | 180 |
|
185 | 181 | def correct_data(self, data): |
186 | 182 | #identify high confidence subset for correction |
187 | 183 | scores = self._predict_prob(data) |
188 | 184 | data['call'] = self.predict(scores) |
189 | 185 |
|
190 | 186 | #build lowess regressions |
191 | | - rt_error = data.groupby('file')[data.columns].apply(self.rt_error) |
192 | | - data['rt_error'] = [val for file in rt_error for val in file] |
193 | | - data = data.drop(columns = ['call']) |
| 187 | + subsets = [] |
| 188 | + for file in set(data['file']): |
| 189 | + subsets.append(self.rt_error(data[data['file'] == file])) |
| 190 | + data = pd.concat(subsets) |
| 191 | + data.drop(columns = ['call']) |
194 | 192 |
|
195 | 193 | self.logs.info('RT correction has been applied.') |
196 | 194 | return data |
|
0 commit comments