Skip to content

Commit 9523786

Browse files
authored
Merge pull request #76 from grahamWroberts/main
added tree to pipeline
2 parents d6ddbb5 + d1f52ab commit 9523786

File tree

7 files changed

+1394
-0
lines changed

7 files changed

+1394
-0
lines changed

AFL/double_agent/TreePipeline.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
from AFL.double_agent import *
2+
import numpy as np
3+
import pandas as pd
4+
from matplotlib import pyplot as plt
5+
import matplotlib
6+
#import tune_all_decisions as tad
7+
import itertools
8+
import joblib
9+
from io import BytesIO
10+
import xarray as xr
11+
import json
12+
import TreeHierarchy as te
13+
from sklearn.metrics import classification_report as cr
14+
from sklearn.metrics import root_mean_squared_error as RMSE
15+
from sklearn.metrics import mean_absolute_error as MAE
16+
from sklearn.metrics import mean_absolute_percentage_error as MAPE
17+
18+
#PipelineOp constructor for classification tree
19+
#The tree itself is defined in TreeHierarchy
20+
#This constructor follows the expected PipelineOp syntax
21+
# input_variable: the name of the input feature in the xarray
22+
# output_variable: the name of the variable to add/modify in the xarray dataset
23+
# model_definition: A dictionary containing an encoding of a TreeHierarchy object. The encoder is contained in treeHierarchy.
24+
class ClassificationPipeline(PipelineOp):
25+
def __init__(self, input_variable, output_variable, model_definition, name="Classifier"):
26+
super().__init__(
27+
input_variable=input_variable,
28+
output_variable=output_variable,
29+
name=name
30+
)
31+
self.classifier = te.json_decoder(model_definition)
32+
33+
def set_classifier(self, classifier_instance):
34+
self.classifier = classifier_instance
35+
36+
def calculate(self, dataset):
37+
data = self._get_variable(dataset)
38+
predicted_classes = self.classifier.predict(np.log10(data))
39+
dataset[self.output_variable] = ('sample', predicted_classes)
40+
return(self)
41+
42+
#PipelineOp constructor for a regressor
43+
#This constructor follows the expected PipelineOp syntax, with some important considerations
44+
# input_variable: the name of the input feature in the xarray
45+
# output_variable: the name of the variable to add/modify in the xarray dataset
46+
# key_variable: the name of the variable that contains morphology information in the xarray, could be ground_truth_labels, predicted_labels, etc.
47+
# morphology: the morphology that this model is trained on
48+
# model_Efinition: a dictionary containing a complete definition of a trained classification model, the encoder in TreeHierarchy also works for this
49+
#NOTE: Each regressor only works for one parameter for one morphology, if multiple morphologies share a parameter i.e., radius is common to many morphologies, then they shuold each operate on the SAME output_variable.
50+
#Each RegressionPipeline will only modify output_variable where key_variable==morphology, place mulptiple PipelineOps in the same pipeline to perform regression over all parameters and morphologies
51+
class RegressionPipeline(PipelineOp):
52+
def __init__(self, input_variable, output_variable, key_variable, morphology, model_definition, name="Classifier"):
53+
super().__init__(
54+
input_variable=input_variable,
55+
output_variable=output_variable,
56+
name=name
57+
)
58+
self.key_variable = key_variable
59+
self.morphology = morphology
60+
self.regression = te.json_decoder(model_definition)
61+
62+
def calculate(self, dataset):
63+
data = self._get_variable(dataset)
64+
key = dataset[self.key_variable].data
65+
print(np.unique(key))
66+
print(self.morphology)
67+
inds = np.where(np.equal(key, self.morphology))
68+
predictions = self.regression.predict(np.log10(data[inds]))
69+
if self.output_variable in dataset.data_vars:
70+
output = dataset[self.output_variable].data
71+
else:
72+
output = np.nan * np.ones(data.shape[0])
73+
output[inds] = predictions
74+
dataset[self.output_variable] = ('sample', output)
75+
return(self)
76+

AFL/double_agent/data/classification_pipeline.json

Lines changed: 827 additions & 0 deletions
Large diffs are not rendered by default.
22 MB
Binary file not shown.

AFL/double_agent/data/example_tree_structure.json

Lines changed: 401 additions & 0 deletions
Large diffs are not rendered by default.
307 KB
Binary file not shown.

pyproject.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,10 @@ automation = [
9696
"requests",
9797
]
9898

99+
mlmodels = [
100+
"TreeHierarchy @ git+https://github.com/grahamRobertsW/TreeHierarchy"
101+
]
102+
99103
dev = [
100104
"black",
101105
"mypy",

tests/test_classifier_pipeline.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
"""
2+
Unit tests for the AFL.double_agent.PipelineOp module.
3+
"""
4+
5+
import pytest
6+
import numpy as np
7+
import xarray as xr
8+
import json
9+
import os
10+
11+
from tests.utils import MockPipelineOp
12+
from AFL.double_agent import TreePipeline as tp
13+
from AFL.double_agent import (Pipeline, LogLogTransform)
14+
from sklearn.svm import SVC
15+
from AFL.double_agent.data import (
16+
get_data_dir,
17+
list_datasets,
18+
load_dataset,
19+
example_dataset1,
20+
)
21+
from TreeHierarchy import (
22+
TreeHierarchy,
23+
json_decoder
24+
)
25+
26+
27+
@pytest.mark.unit
28+
class TestClassificationPipeline:
29+
"""Tests for the PipelineOp class."""
30+
def test_classifier_creation(self):
31+
data = load_dataset("example_classification_data")
32+
classification_def = json.loads(open(os.path.join(get_data_dir(), "example_tree_structure.json"), 'r').read())
33+
with Pipeline() as P:
34+
LogLogTransform("SAS_curves", "log_sas_curves")
35+
pipe = tp.ClassificationPipeline("SAS_curves", "predicted_labels", classification_def)
36+
assert isinstance(pipe, tp.ClassificationPipeline)
37+
assert isinstance(pipe.classifier, TreeHierarchy)
38+
assert isinstance(pipe.classifier.left, TreeHierarchy)
39+
assert isinstance(pipe.classifier.right, TreeHierarchy)
40+
assert isinstance(pipe.classifier.left.left, TreeHierarchy)
41+
assert isinstance(pipe.classifier.left.right, TreeHierarchy)
42+
assert isinstance(pipe.classifier.right.left, TreeHierarchy)
43+
assert isinstance(pipe.classifier.right.right, TreeHierarchy)
44+
assert isinstance(pipe.classifier.entity, SVC)
45+
assert isinstance(pipe.classifier.left.entity, SVC)
46+
assert isinstance(pipe.classifier.right.entity, SVC)
47+
48+
@pytest.mark.unit
49+
class TestClassificationPipelineLoaded:
50+
"""Tests for the PipelineOp class."""
51+
def test_classifier_load(self):
52+
### data = load_dataset("classification_data")
53+
### classification_def = json.loads(open(os.path.join(get_data_dir(), "classification_tree.json"), 'r').read())
54+
### pipe = tp.ClassificationPipeline("log_sas_curves", "predicted_labels", classification_def)
55+
save_path = os.path.join(get_data_dir(), "classification_pipeline.json")
56+
with Pipeline.read_json(str(save_path)) as P:
57+
assert isinstance(P[1], tp.ClassificationPipeline)
58+
assert isinstance(P[1].classifier, TreeHierarchy)
59+
assert isinstance(P[1].classifier.left, TreeHierarchy)
60+
assert isinstance(P[1].classifier.right, TreeHierarchy)
61+
assert isinstance(P[1].classifier.left.left, TreeHierarchy)
62+
assert isinstance(P[1].classifier.left.right, TreeHierarchy)
63+
assert isinstance(P[1].classifier.right.left, TreeHierarchy)
64+
assert isinstance(P[1].classifier.right.right, TreeHierarchy)
65+
assert isinstance(P[1].classifier.entity, SVC)
66+
assert isinstance(P[1].classifier.left.entity, SVC)
67+
assert isinstance(P[1].classifier.right.entity, SVC)
68+
69+
@pytest.mark.unit
70+
class TestClassificationPipelinePerformance:
71+
"""Tests for the PipelineOp class."""
72+
def test_classifier_load(self):
73+
### data = load_dataset("classification_data")
74+
### classification_def = json.loads(open(os.path.join(get_data_dir(), "classification_tree.json"), 'r').read())
75+
### pipe = tp.ClassificationPipeline("log_sas_curves", "predicted_labels", classification_def)
76+
save_path = os.path.join(get_data_dir(), "classification_pipeline.json")
77+
data = load_dataset("example_classification_data")
78+
ref = load_dataset("reference_predictions")
79+
with Pipeline.read_json(str(save_path)) as P:
80+
out = P.calculate(data)
81+
print(P[0].output_variable)
82+
np.testing.assert_array_equal(out["predicted_test_labels"].data, ref["reference_predictions"].data)
83+
84+
85+
86+

0 commit comments

Comments
 (0)