Skip to content

Commit 3f266f8

Browse files
authored
Merge pull request #82 from grahamWroberts/main
updated TreePipeline operations
2 parents 9523786 + cf8f04b commit 3f266f8

File tree

1 file changed

+58
-4
lines changed

1 file changed

+58
-4
lines changed

AFL/double_agent/TreePipeline.py

Lines changed: 58 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def set_classifier(self, classifier_instance):
3535

3636
def calculate(self, dataset):
3737
data = self._get_variable(dataset)
38-
predicted_classes = self.classifier.predict(np.log10(data))
38+
predicted_classes = self.classifier.predict(data)
3939
dataset[self.output_variable] = ('sample', predicted_classes)
4040
return(self)
4141

@@ -64,13 +64,67 @@ def calculate(self, dataset):
6464
key = dataset[self.key_variable].data
6565
print(np.unique(key))
6666
print(self.morphology)
67-
inds = np.where(np.equal(key, self.morphology))
68-
predictions = self.regression.predict(np.log10(data[inds]))
67+
inds = np.where(np.equal(key, self.morphology))[0]
68+
predictions = self.regression.predict(data[inds])
6969
if self.output_variable in dataset.data_vars:
7070
output = dataset[self.output_variable].data
7171
else:
7272
output = np.nan * np.ones(data.shape[0])
73-
output[inds] = predictions
73+
print("INDS")
74+
print(inds.shape)
75+
print("PREDS")
76+
print(predictions.shape)
77+
output[inds] = predictions.reshape(-1)
7478
dataset[self.output_variable] = ('sample', output)
7579
return(self)
7680

81+
class ThresholdClassificationPipeline(PipelineOp):
82+
def __init__(self, input_variable, output_variable, components, threshold, name = "mixture_separation"):
83+
super().__init__(input_variable = input_variable,
84+
output_variable = output_variable,
85+
name = name)
86+
self.components = components
87+
self.threshold = threshold
88+
89+
def calculate(self, dataset):
90+
data = self._get_variable(dataset)
91+
labs = []
92+
for i in range(data.shape[0]):
93+
d = data.data[i]
94+
print(d)
95+
print(type(d))
96+
comps = self.components[d]
97+
measures = np.array([dataset[c].data[i] for c in comps])
98+
portions = measures/np.sum(measures)
99+
print(np.where(portions > self.threshold)[0])
100+
if any(portions >= self.threshold):
101+
labs += [comps[np.where(portions >= self.threshold)[0][0]]]
102+
else:
103+
labs += [d]
104+
dataset[self.output_variable] = ('sample', labs)
105+
return(self)
106+
107+
class FlatAddition(PipelineOp):
108+
def __init__(self, input_variable, output_variable, value, name = "flat_addition"):
109+
super().__init__(input_variable = input_variable,
110+
output_variable = output_variable,
111+
name = name)
112+
self.value = value
113+
114+
def calculate(self, dataset):
115+
data = self._get_variable(dataset)
116+
dataset[self.output_variable] = data+self.value
117+
return(self)
118+
119+
class IntEncoding(PipelineOp):
120+
def __init__(self, input_variable, output_variable, classes, name = "label_encoder"):
121+
super().__init__(input_variable = input_variable,
122+
output_variable = output_variable,
123+
name = name)
124+
self.classes = classes
125+
self.encoding = {c:i for i,c in enumerate(classes)}
126+
127+
def calculate(self, dataset):
128+
data = self._get_variable(dataset)
129+
dataset[self.output_variable] = ('sample', [self.encoding[l] for l in data.data])
130+
return(self)

0 commit comments

Comments
 (0)