You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
#V0.1.0 -- A very early version of the code which contains the basic structure of the tree hierarchy, as well as some functionality for reading and writing from various json formats.
5
+
6
+
importnumpyasnp
7
+
fromsklearn.svmimportSVC
8
+
fromsklearn.kernel_ridgeimportKernelRidgeasKRR
9
+
importjson
10
+
fromjoblibimportdump, load
11
+
importjoblib
12
+
fromioimportBytesIO
13
+
importbase64
14
+
15
+
#TreeHierarchy
16
+
# allows construction of custom hierarchical classifiers
17
+
# these allow for multiclass classification with indepenently tuned classifiers at each branch
18
+
# this is a recursie structure
19
+
classTreeHierarchy:
20
+
21
+
def_init_(self):
22
+
self.left=None
23
+
self.right=None
24
+
self.content=None
25
+
self.terminal=None
26
+
self.classA=None
27
+
self.classB=None
28
+
#self.__dict__ = self.vars()
29
+
30
+
31
+
defvars(self):
32
+
outdict= {}
33
+
forkinself.dir():
34
+
v=getattr(self, k, None)
35
+
ifvisnotNone:
36
+
ifisinstance(v, TreeHierarchy):
37
+
outdict[k] =vars(v)
38
+
elifisinsstance(v, np.ndarray):
39
+
outdict[k] =list(v)
40
+
elifisinstance(v, SVC):
41
+
tempdict= {}
42
+
fork2, v2inv.__dict__:
43
+
ifisinstance(v2, np.ndarray):
44
+
tempdict[k2] =list(v2)
45
+
else:
46
+
tempdict[k2] =v2
47
+
outdict[k] =tempdict
48
+
return(outdict)
49
+
# return(None)
50
+
51
+
defadd_content(self, contant, terminal):
52
+
self.content=content
53
+
self.terminal=terminal
54
+
55
+
defadd_left(self, entity):
56
+
self.left=entity
57
+
58
+
defadd_right(self, entity):
59
+
self.right=entity
60
+
61
+
#Fit follows the sklearn fit structure and recursively calls fit on each component tree
62
+
deffit(self, X, y):
63
+
ifnotgetattr(self, 'terminal', False):
64
+
templabels=np.zeros(X.shape[0])
65
+
forlinnp.unique(self.classB):
66
+
foriinrange(len(y)):
67
+
ify[i] ==l:
68
+
templabels[i] =1
69
+
self.entity.fit(X, templabels)
70
+
ia=np.where(templabels==0)[0]
71
+
ib=np.where(templabels==1)[0]
72
+
self.left.fit(X[ia], y[ia])
73
+
self.right.fit(X[ib], y[ib])
74
+
return
75
+
76
+
77
+
defpredict(self, X):
78
+
ifX.shape[0] ==0:
79
+
y=np.zeros(0)
80
+
elifnotgetattr(self, 'terminal', False):
81
+
inds=np.arange(X.shape[0])
82
+
temp_y=self.entity.predict(X)
83
+
ia=np.where(np.logical_not(temp_y))[0]
84
+
ib=np.where(temp_y)[0]
85
+
return_y=np.empty(X.shape[0], dtype=object)
86
+
y_a=self.left.predict(X[ia])
87
+
y_b=self.right.predict(X[ib])
88
+
forvinrange(ia.shape[0]):
89
+
vi=ia[v]
90
+
return_y[vi] =y_a[v]
91
+
forvinrange(ib.shape[0]):
92
+
vi=ib[v]
93
+
return_y[vi] =y_b[v]
94
+
y=return_y
95
+
else:
96
+
y=np.array([self.terminal] *X.shape[0])
97
+
return(y)
98
+
99
+
#Structure from json takes as input a json structure and constructs the tree based on that structure
100
+
defstructure_from_json(self, J):
101
+
if'class'inJ.keys():
102
+
self.terminal=J['class']
103
+
else:
104
+
if'jobfile'inJ.keys():
105
+
self.entity=load(J['jobfile'])
106
+
print('Loading %s'%(J['jobfile']))
107
+
elif'classifier'inJ.keys():
108
+
print(J['classifier'])
109
+
self.entity=classifier_from_json(J['classifier'])
110
+
print(self.entity)
111
+
else:
112
+
self.entity=None
113
+
self.left=TreeHierarchy()
114
+
self.left.structure_from_json(J['left'])
115
+
self.right=TreeHierarchy()
116
+
self.right.structure_from_json(J['right'])
117
+
self.classA=J['classA']
118
+
self.classB=J['classB']
119
+
120
+
#Pass to json.dumps as an encoder class
121
+
#decontructs the tree, component SVCs KRRs and npArrays into primitives as well.
122
+
classTreeEncoder(json.JSONEncoder):
123
+
124
+
125
+
defdefault(self, obj):
126
+
"""If input object is an ndarray it will be converted into a dict
127
+
holding dtype, shape and the data, base64 encoded.
128
+
"""
129
+
ifisinstance(obj, np.ndarray):
130
+
ifobj.flags['C_CONTIGUOUS']:
131
+
obj_data=obj.data
132
+
else:
133
+
cont_obj=np.ascontiguousarray(obj)
134
+
assert(cont_obj.flags['C_CONTIGUOUS'])
135
+
obj_data=cont_obj.data
136
+
data_b64=base64.b64encode(obj_data)
137
+
returndict(__ndarray__=data_b64.decode('utf-8'),
138
+
dtype=str(obj.dtype),
139
+
shape=obj.shape)
140
+
elifisinstance(obj, SVC):
141
+
return(obj.__dict__)
142
+
elifisinstance(obj, KRR):
143
+
return(obj.__dict__)
144
+
elifisinstance(obj, TreeHierarchy):
145
+
return(obj.__dict__)
146
+
else:
147
+
# Let the base class default method raise the TypeError
148
+
super().default(obj)
149
+
150
+
defjson_decoder(dct):
151
+
"""Decodes a previously encoded TreeHierarchy, numpy ndarray with proper shape and dtype, SVC, or KRR.
152
+
153
+
:param dct: (dict) json encoded ndarray
154
+
:return: (ndarray) if input was an encoded ndarray
0 commit comments