Skip to content

Commit 8382877

Browse files
committed
improve tree plotting
1 parent 5ce9398 commit 8382877

File tree

2 files changed

+15
-3
lines changed

2 files changed

+15
-3
lines changed

_unittests/ut_plotting/test_text_plot.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,18 @@ def test_onnx_text_plot_tree_cls(self):
6565
self.assertIn(" T y=", res)
6666
self.assertIn("n_classes=3", res)
6767

68+
def test_onnx_text_plot_tree_cls_2(self):
69+
iris = load_iris()
70+
X_train, y_train = iris.data.astype(numpy.float32), iris.target
71+
clr = DecisionTreeClassifier()
72+
clr.fit(X_train, y_train)
73+
model_def = to_onnx(
74+
clr, X_train.astype(numpy.float32), options={"zipmap": False}
75+
)
76+
res = onnx_text_plot_tree(model_def.graph.node[0])
77+
self.assertIn("n_classes=3", res)
78+
print(res)
79+
6880
@ignore_warnings((UserWarning, FutureWarning))
6981
def test_onnx_simple_text_plot_kmeans(self):
7082
x = numpy.random.randn(10, 3)

onnx_array_api/plotting/text_plot.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -114,9 +114,9 @@ def process_tree(atts, treeid):
114114
for i in range(len(short[f"{prefix}_treeids"])):
115115
idn = short[f"{prefix}_nodeids"][i]
116116
node = nodes[idn]
117-
node.target_nodeids = idn
118-
node.target_ids = short[f"{prefix}_ids"][i]
119-
node.target_weights = short[f"{prefix}_weights"][i]
117+
node.append_target(
118+
id=short[f"{prefix}_ids"][i], weight=short[f"{prefix}_weights"][i]
119+
)
120120

121121
def iterate(nodes, node, depth=0, true_false=""):
122122
node.depth = depth

0 commit comments

Comments
 (0)