Skip to content

Commit e0e50b4

Browse files
authored
Merge pull request #6 from sdpython/doc
Updates one documentation example
2 parents a751cae + 31ab6ca commit e0e50b4

File tree

3 files changed

+12
-10
lines changed

3 files changed

+12
-10
lines changed

_doc/examples/plot_benchmark_rf.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ def skl2onnx_convert_lightgbm(scope, operator, container):
8181
###############################################
8282
# Or with the following command.
8383
out, err = run_cmd("cat /proc/cpuinfo")
84+
print(out)
8485

8586
###############################################
8687
# Fonction to measure inference time
@@ -124,8 +125,6 @@ def measure_inference(fct, X, repeat, max_time=5, quantile=1):
124125
# model for a random forest and onnxruntime after it was converted
125126
# into ONNX and for the following configurations.
126127

127-
legend = "parallel-batch-4096-block"
128-
129128
small = cpu_count() < 12
130129
if small:
131130
N = 1000
@@ -142,6 +141,8 @@ def measure_inference(fct, X, repeat, max_time=5, quantile=1):
142141
depth = [6, 8, 10, 12, 14]
143142
Regressor = RandomForestRegressor
144143

144+
legend = f"parallel-nf-{n_features}-"
145+
145146
# avoid duplicates on machine with 1 or 2 cores.
146147
n_jobs = list(sorted(set(n_jobs), reverse=True))
147148

@@ -175,7 +176,7 @@ def measure_inference(fct, X, repeat, max_time=5, quantile=1):
175176

176177
# parallelization
177178
cache_name = os.path.join(
178-
cache_dir, f"rf-J-{n_j}-E-{n_estimators}-D-{max_depth}.pkl"
179+
cache_dir, f"nf-{X.shape[1]}-rf-J-{n_j}-E-{n_estimators}-D-{max_depth}.pkl"
179180
)
180181
if os.path.exists(cache_name):
181182
with open(cache_name, "rb") as f:
@@ -196,7 +197,7 @@ def measure_inference(fct, X, repeat, max_time=5, quantile=1):
196197
so = SessionOptions()
197198
so.intra_op_num_threads = n_j
198199
cache_name = os.path.join(
199-
cache_dir, f"rf-J-{n_j}-E-{n_estimators}-D-{max_depth}.onnx"
200+
cache_dir, f"nf-{X.shape[1]}-rf-J-{n_j}-E-{n_estimators}-D-{max_depth}.onnx"
200201
)
201202
if os.path.exists(cache_name):
202203
sess = InferenceSession(cache_name, so)
@@ -268,7 +269,7 @@ def measure_inference(fct, X, repeat, max_time=5, quantile=1):
268269

269270

270271
fig, axes = plt.subplots(n_rows, n_cols, figsize=(4 * n_cols, 4 * n_rows))
271-
fig.suptitle(f"{rf.__class__.__name__}")
272+
fig.suptitle(f"{rf.__class__.__name__}\nX.shape={X.shape}")
272273

273274
for n_j, n_estimators in tqdm(product(n_jobs, n_ests)):
274275
i = n_jobs.index(n_j)

setup.cfg

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
[options]
2+
packages = find:
3+
4+
[options.packages.find]
5+
include = onnx_array_api*

setup.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# -*- coding: utf-8 -*-
22
import os
33

4-
from setuptools import find_packages, setup
4+
from setuptools import setup
55

66
######################
77
# beginning of setup
@@ -11,8 +11,6 @@
1111
here = os.path.dirname(__file__)
1212
if here == "":
1313
here = "."
14-
packages = find_packages(where=here)
15-
package_dir = {k: os.path.join(here, k.replace(".", "/")) for k in packages}
1614
package_data = {}
1715

1816
try:
@@ -48,8 +46,6 @@
4846
author="Xavier Dupré",
4947
author_email="[email protected]",
5048
url="https://github.com/sdpython/onnx-array-api",
51-
packages=packages,
52-
package_dir=package_dir,
5349
package_data=package_data,
5450
setup_requires=["numpy", "scipy"],
5551
install_requires=requirements,

0 commit comments

Comments
 (0)