Skip to content

Commit af740e9

Browse files
committed
fix example
1 parent 05b077b commit af740e9

File tree

1 file changed

+3
-7
lines changed

1 file changed

+3
-7
lines changed

_doc/examples/plot_benchmark_rf.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,8 @@ def measure_inference(fct, X, repeat, max_time=5, quantile=1):
139139
depth = [6, 8, 10, 12, 14]
140140
Regressor = RandomForestRegressor
141141

142+
# avoid duplicates on machine with 1 or 2 cores.
143+
n_jobs = list(sorted(set(n_jobs), reverse=True))
142144

143145
##############################################
144146
# Benchmark parameters
@@ -273,13 +275,7 @@ def measure_inference(fct, X, repeat, max_time=5, quantile=1):
273275
subdf = df[(df.n_estimators == n_estimators) & (df.n_jobs == n_j)]
274276
if subdf.shape[0] == 0:
275277
continue
276-
try:
277-
piv = subdf.pivot(index="max_depth", columns="name", values=["avg", "med"])
278-
except Exception as e:
279-
from io import StringIO
280-
st = StringIO()
281-
subdf.to_csv(st, index=False)
282-
raise AssertionError(st.getvalue()) from e
278+
piv = subdf.pivot(index="max_depth", columns="name", values=["avg", "med"])
283279
piv.plot(ax=ax, title=f"jobs={n_j}, trees={n_estimators}")
284280
ax.set_ylabel(f"n_jobs={n_j}", fontsize="small")
285281
ax.set_xlabel("max_depth", fontsize="small")

0 commit comments

Comments
 (0)