@@ -396,7 +396,7 @@ def __init__(
396396 history_size : Optional [int ] = None ,
397397 ) -> None :
398398
399- self .algorithm = algorithm
399+ self .algorithm = algorithm or ""
400400 self .init_alpha = init_alpha
401401 self .iter = iter
402402 self .save_iterations = save_iterations
@@ -414,20 +414,17 @@ def validate(
414414 """
415415 Check arguments correctness and consistency.
416416 """
417- if (
418- self .algorithm is not None
419- and self .algorithm not in self .OPTIMIZE_ALGOS
420- ):
417+ if self .algorithm and self .algorithm not in self .OPTIMIZE_ALGOS :
421418 raise ValueError (
422419 'Please specify optimizer algorithms as one of [{}]' .format (
423420 ', ' .join (self .OPTIMIZE_ALGOS )
424421 )
425422 )
426423
427424 if self .init_alpha is not None :
428- if self .algorithm == 'Newton' :
425+ if self .algorithm . lower () not in { 'lbfgs' , 'bfgs' } :
429426 raise ValueError (
430- 'init_alpha must not be set when algorithm is Newton '
427+ 'init_alpha requires that algorithm be set to bfgs or lbfgs '
431428 )
432429 if isinstance (self .init_alpha , float ):
433430 if self .init_alpha <= 0 :
@@ -443,9 +440,9 @@ def validate(
443440 raise ValueError ('iter must be type of int' )
444441
445442 if self .tol_obj is not None :
446- if self .algorithm == 'Newton' :
443+ if self .algorithm . lower () not in { 'lbfgs' , 'bfgs' } :
447444 raise ValueError (
448- 'tol_obj must not be set when algorithm is Newton '
445+ 'tol_obj requires that algorithm be set to bfgs or lbfgs '
449446 )
450447 if isinstance (self .tol_obj , float ):
451448 if self .tol_obj <= 0 :
@@ -454,9 +451,10 @@ def validate(
454451 raise ValueError ('tol_obj must be type of float' )
455452
456453 if self .tol_rel_obj is not None :
457- if self .algorithm == 'Newton' :
454+ if self .algorithm . lower () not in { 'lbfgs' , 'bfgs' } :
458455 raise ValueError (
459- 'tol_rel_obj must not be set when algorithm is Newton'
456+ 'tol_rel_obj requires that algorithm be set to bfgs'
457+ ' or lbfgs'
460458 )
461459 if isinstance (self .tol_rel_obj , float ):
462460 if self .tol_rel_obj <= 0 :
@@ -465,9 +463,9 @@ def validate(
465463 raise ValueError ('tol_rel_obj must be type of float' )
466464
467465 if self .tol_grad is not None :
468- if self .algorithm == 'Newton' :
466+ if self .algorithm . lower () not in { 'lbfgs' , 'bfgs' } :
469467 raise ValueError (
470- 'tol_grad must not be set when algorithm is Newton '
468+ 'tol_grad requires that algorithm be set to bfgs or lbfgs '
471469 )
472470 if isinstance (self .tol_grad , float ):
473471 if self .tol_grad <= 0 :
@@ -476,9 +474,10 @@ def validate(
476474 raise ValueError ('tol_grad must be type of float' )
477475
478476 if self .tol_rel_grad is not None :
479- if self .algorithm == 'Newton' :
477+ if self .algorithm . lower () not in { 'lbfgs' , 'bfgs' } :
480478 raise ValueError (
481- 'tol_rel_grad must not be set when algorithm is Newton'
479+ 'tol_rel_grad requires that algorithm be set to bfgs'
480+ ' or lbfgs'
482481 )
483482 if isinstance (self .tol_rel_grad , float ):
484483 if self .tol_rel_grad <= 0 :
@@ -487,9 +486,9 @@ def validate(
487486 raise ValueError ('tol_rel_grad must be type of float' )
488487
489488 if self .tol_param is not None :
490- if self .algorithm == 'Newton' :
489+ if self .algorithm . lower () not in { 'lbfgs' , 'bfgs' } :
491490 raise ValueError (
492- 'tol_param must not be set when algorithm is Newton '
491+ 'tol_param requires that algorithm be set to bfgs or lbfgs '
493492 )
494493 if isinstance (self .tol_param , float ):
495494 if self .tol_param <= 0 :
@@ -498,10 +497,9 @@ def validate(
498497 raise ValueError ('tol_param must be type of float' )
499498
500499 if self .history_size is not None :
501- if self .algorithm == 'Newton' or self . algorithm == 'BFGS ' :
500+ if self .algorithm . lower () != 'lbfgs ' :
502501 raise ValueError (
503- 'history_size must not be set when algorithm is '
504- 'Newton or BFGS'
502+ 'history_size requires that algorithm be set to lbfgs'
505503 )
506504 if isinstance (self .history_size , int ):
507505 if self .history_size < 0 :
0 commit comments