@@ -158,6 +158,7 @@ def callback(res):
158158 nonlocal self
159159 assert self .CONSTRAINTS (res .x )
160160 assert self .CONSTRAINTS (res .xv [0 ])
161+ return True
161162
162163 for method in BUILTIN_METHODS :
163164 with self .subTest (method = method ):
@@ -255,7 +256,7 @@ def constraints(x):
255256 was_called = True
256257 return np .all (x > 0 )
257258
258- _minimize = partial (minimize , x0 = [0 ] * 2 , constraints = constraints , max_iter = 10 )
259+ _minimize = partial (minimize , x0 = [0 ] * 2 , constraints = constraints , max_iter = 5 )
259260 for method in BUILTIN_METHODS :
260261 with self .subTest (method = method ):
261262 was_called = False
@@ -376,16 +377,17 @@ def test_make_doc_plots(self):
376377 KWARGS = {
377378 plot_regret : dict (true_minimum = 0 ),
378379 plot_convergence : dict (yscale = 'log' , true_minimum = 0 ),
380+ plot_objective : dict (resolution = 12 )
379381 }
380- with chdir (( p := Path (__file__ ).parent ) .parent / f'html/ { p . stem } ' ):
382+ with chdir (Path (__file__ ).parent .parent / 'www ' ):
381383 for plot_func in PLOT_FUNCS :
382384 name = plot_func .__name__ .removeprefix ("plot_" )
383385 with self .subTest (plot = name ):
384386 try :
385387 fig = plot_func (* zip (BUILTIN_METHODS , results ),
386388 ** KWARGS .get (plot_func , {}))
387389 except TypeError :
388- fig = plot_func (results [0 ]) # FIXME: plot (1, 3) subplots
390+ fig = plot_func (results [0 ], ** KWARGS . get ( plot_func , {}) ) # FIXME: plot (1, 3) subplots
389391 fig .savefig (f'{ name } .svg' )
390392 plt .show ()
391393
@@ -406,17 +408,18 @@ def evaluate(x):
406408
407409 results = []
408410 for estimator in BUILTIN_ESTIMATORS :
409- optimizer = Optimizer (fun = None , bounds = [(- 2 , 2 )]* 2 , estimator = estimator , rng = 0 )
411+ optimizer = Optimizer (fun = None , bounds = [(- 2 , 2 )]* 4 , estimator = estimator , rng = 0 )
410412
411- for i in range (100 ):
412- suggested_x = optimizer .ask (n_candidates = 1 )
413+ for i in range (30 ):
414+ suggested_x = optimizer .ask (n_candidates = 2 )
413415 y = [evaluate (x ) for x in suggested_x ]
414416 optimizer .tell (y )
415417
416418 result = optimizer .run ()
417419 results .append (result )
418- fig = plot_convergence (* [(f'estimator={ e !r} ' , r ) for e , r in zip (BUILTIN_ESTIMATORS , results )], yscale = 'log' , true_minimum = 0 )
419- with chdir (Path (__file__ ).parent .parent / 'html' ):
420+ named_results = [(f'estimator={ e !r} ' , r ) for e , r in zip (BUILTIN_ESTIMATORS , results )]
421+ fig = plot_convergence (* named_results , true_minimum = 0 )
422+ with chdir (Path (__file__ ).parent .parent / 'www' ):
420423 fig .savefig ('convergence2.svg' )
421424 plt .show ()
422425
@@ -428,22 +431,22 @@ def test_website_example3(self):
428431 X , y = load_breast_cancer (return_X_y = True )
429432 clf = DecisionTreeClassifier (random_state = 0 )
430433 param_grid = {
431- 'max_depth' : list (range (1 , 30 )),
432- 'min_samples_split' : [2 , 5 , 10 , 20 , 50 , 100 ],
433- 'min_samples_leaf' : list (range (1 , 20 )),
434+ 'max_depth' : list (range (1 , 20 )),
435+ 'min_samples_split' : [2 , 5 , 10 , 50 , 100 ],
436+ 'min_samples_leaf' : list (range (1 , 10 )),
434437 'criterion' : ['gini' , 'entropy' ],
435438 'max_features' : [None , 'sqrt' , 'log2' ],
436439 }
437- search = GridSearchCV (clf , param_grid , cv = 2 , n_jobs = - 1 )
438- # Trying all ~20k combinations may take a long time ...
440+ search = GridSearchCV (clf , param_grid , cv = 2 , n_jobs = 1 )
441+ # Trying all ~6k combinations may take a long time ...
439442 search .fit (X , y )
440443 pprint (dict (sorted (search .best_params_ .items ())))
441444 print (search .best_score_ )
442445
443446 # Alternatively ...
444447 from sambo import SamboSearchCV
445448
446- search = SamboSearchCV (clf , param_grid , max_iter = 100 , cv = 2 , n_jobs = - 1 , rng = 0 )
449+ search = SamboSearchCV (clf , param_grid , max_iter = 100 , cv = 2 , n_jobs = 1 , rng = 0 )
447450 search .fit (X , y ) # Fast, good enough
448451 pprint (dict (sorted (search .best_params_ .items ())))
449452 print (search .best_score_ )
0 commit comments