Skip to content

Commit 1219274

Browse files
authored
Merge pull request #619 from stan-dev/fix/617-tol-params
Tighten optimization argument requirements, print command
2 parents 949bbe5 + 017b52e commit 1219274

File tree

4 files changed

+51
-39
lines changed

4 files changed

+51
-39
lines changed

cmdstanpy/cmdstan_args.py

Lines changed: 18 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -396,7 +396,7 @@ def __init__(
396396
history_size: Optional[int] = None,
397397
) -> None:
398398

399-
self.algorithm = algorithm
399+
self.algorithm = algorithm or ""
400400
self.init_alpha = init_alpha
401401
self.iter = iter
402402
self.save_iterations = save_iterations
@@ -414,20 +414,17 @@ def validate(
414414
"""
415415
Check arguments correctness and consistency.
416416
"""
417-
if (
418-
self.algorithm is not None
419-
and self.algorithm not in self.OPTIMIZE_ALGOS
420-
):
417+
if self.algorithm and self.algorithm not in self.OPTIMIZE_ALGOS:
421418
raise ValueError(
422419
'Please specify optimizer algorithms as one of [{}]'.format(
423420
', '.join(self.OPTIMIZE_ALGOS)
424421
)
425422
)
426423

427424
if self.init_alpha is not None:
428-
if self.algorithm == 'Newton':
425+
if self.algorithm.lower() not in {'lbfgs', 'bfgs'}:
429426
raise ValueError(
430-
'init_alpha must not be set when algorithm is Newton'
427+
'init_alpha requires that algorithm be set to bfgs or lbfgs'
431428
)
432429
if isinstance(self.init_alpha, float):
433430
if self.init_alpha <= 0:
@@ -443,9 +440,9 @@ def validate(
443440
raise ValueError('iter must be type of int')
444441

445442
if self.tol_obj is not None:
446-
if self.algorithm == 'Newton':
443+
if self.algorithm.lower() not in {'lbfgs', 'bfgs'}:
447444
raise ValueError(
448-
'tol_obj must not be set when algorithm is Newton'
445+
'tol_obj requires that algorithm be set to bfgs or lbfgs'
449446
)
450447
if isinstance(self.tol_obj, float):
451448
if self.tol_obj <= 0:
@@ -454,9 +451,10 @@ def validate(
454451
raise ValueError('tol_obj must be type of float')
455452

456453
if self.tol_rel_obj is not None:
457-
if self.algorithm == 'Newton':
454+
if self.algorithm.lower() not in {'lbfgs', 'bfgs'}:
458455
raise ValueError(
459-
'tol_rel_obj must not be set when algorithm is Newton'
456+
'tol_rel_obj requires that algorithm be set to bfgs'
457+
' or lbfgs'
460458
)
461459
if isinstance(self.tol_rel_obj, float):
462460
if self.tol_rel_obj <= 0:
@@ -465,9 +463,9 @@ def validate(
465463
raise ValueError('tol_rel_obj must be type of float')
466464

467465
if self.tol_grad is not None:
468-
if self.algorithm == 'Newton':
466+
if self.algorithm.lower() not in {'lbfgs', 'bfgs'}:
469467
raise ValueError(
470-
'tol_grad must not be set when algorithm is Newton'
468+
'tol_grad requires that algorithm be set to bfgs or lbfgs'
471469
)
472470
if isinstance(self.tol_grad, float):
473471
if self.tol_grad <= 0:
@@ -476,9 +474,10 @@ def validate(
476474
raise ValueError('tol_grad must be type of float')
477475

478476
if self.tol_rel_grad is not None:
479-
if self.algorithm == 'Newton':
477+
if self.algorithm.lower() not in {'lbfgs', 'bfgs'}:
480478
raise ValueError(
481-
'tol_rel_grad must not be set when algorithm is Newton'
479+
'tol_rel_grad requires that algorithm be set to bfgs'
480+
' or lbfgs'
482481
)
483482
if isinstance(self.tol_rel_grad, float):
484483
if self.tol_rel_grad <= 0:
@@ -487,9 +486,9 @@ def validate(
487486
raise ValueError('tol_rel_grad must be type of float')
488487

489488
if self.tol_param is not None:
490-
if self.algorithm == 'Newton':
489+
if self.algorithm.lower() not in {'lbfgs', 'bfgs'}:
491490
raise ValueError(
492-
'tol_param must not be set when algorithm is Newton'
491+
'tol_param requires that algorithm be set to bfgs or lbfgs'
493492
)
494493
if isinstance(self.tol_param, float):
495494
if self.tol_param <= 0:
@@ -498,10 +497,9 @@ def validate(
498497
raise ValueError('tol_param must be type of float')
499498

500499
if self.history_size is not None:
501-
if self.algorithm == 'Newton' or self.algorithm == 'BFGS':
500+
if self.algorithm.lower() != 'lbfgs':
502501
raise ValueError(
503-
'history_size must not be set when algorithm is '
504-
'Newton or BFGS'
502+
'history_size requires that algorithm be set to lbfgs'
505503
)
506504
if isinstance(self.history_size, int):
507505
if self.history_size < 0:

cmdstanpy/model.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -701,7 +701,9 @@ def optimize(
701701
self._run_cmdstan(runset, dummy_chain_id, show_console=show_console)
702702

703703
if not runset._check_retcodes():
704-
msg = 'Error during optimization: {}'.format(runset.get_err_msgs())
704+
msg = "Error during optimization! Command '{}' failed: {}".format(
705+
' '.join(runset.cmd(0)), runset.get_err_msgs()
706+
)
705707
if 'Line search failed' in msg and not require_converged:
706708
get_logger().warning(msg)
707709
else:

test/test_cmdstan_args.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,11 +34,14 @@ def test_args_algorithm(self):
3434
self.assertIn('algorithm=newton', ' '.join(cmd))
3535

3636
def test_args_algorithm_init_alpha(self):
37-
args = OptimizeArgs(init_alpha=2e-4)
37+
args = OptimizeArgs(algorithm='bfgs', init_alpha=2e-4)
3838
args.validate()
3939
cmd = args.compose(None, cmd=['output'])
4040

4141
self.assertIn('init_alpha=0.0002', ' '.join(cmd))
42+
args = OptimizeArgs(init_alpha=2e-4)
43+
with self.assertRaises(ValueError):
44+
args.validate()
4245
args = OptimizeArgs(init_alpha=-1.0)
4346
with self.assertRaises(ValueError):
4447
args.validate()

test/test_optimize.py

Lines changed: 26 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -432,9 +432,7 @@ def test_parameters_and_optimizer_compatible(self):
432432
jdata = os.path.join(DATAFILES_PATH, 'bernoulli.data.json')
433433
jinit = os.path.join(DATAFILES_PATH, 'bernoulli.init.json')
434434

435-
with self.assertRaisesRegex(
436-
ValueError, 'must not be set when algorithm is Newton'
437-
):
435+
with self.assertRaisesRegex(ValueError, 'bfgs or lbfgs'):
438436
model.optimize(
439437
data=jdata,
440438
seed=1239812093,
@@ -443,9 +441,7 @@ def test_parameters_and_optimizer_compatible(self):
443441
tol_obj=1,
444442
)
445443

446-
with self.assertRaisesRegex(
447-
ValueError, 'must not be set when algorithm is Newton'
448-
):
444+
with self.assertRaisesRegex(ValueError, 'bfgs or lbfgs'):
449445
model.optimize(
450446
data=jdata,
451447
seed=1239812093,
@@ -454,9 +450,7 @@ def test_parameters_and_optimizer_compatible(self):
454450
tol_rel_obj=1,
455451
)
456452

457-
with self.assertRaisesRegex(
458-
ValueError, 'must not be set when algorithm is Newton'
459-
):
453+
with self.assertRaisesRegex(ValueError, 'bfgs or lbfgs'):
460454
model.optimize(
461455
data=jdata,
462456
seed=1239812093,
@@ -465,9 +459,7 @@ def test_parameters_and_optimizer_compatible(self):
465459
tol_grad=1,
466460
)
467461

468-
with self.assertRaisesRegex(
469-
ValueError, 'must not be set when algorithm is Newton'
470-
):
462+
with self.assertRaisesRegex(ValueError, 'bfgs or lbfgs'):
471463
model.optimize(
472464
data=jdata,
473465
seed=1239812093,
@@ -476,9 +468,15 @@ def test_parameters_and_optimizer_compatible(self):
476468
tol_rel_grad=1,
477469
)
478470

479-
with self.assertRaisesRegex(
480-
ValueError, 'must not be set when algorithm is Newton'
481-
):
471+
with self.assertRaisesRegex(ValueError, 'bfgs or lbfgs'):
472+
model.optimize(
473+
data=jdata,
474+
seed=1239812093,
475+
inits=jinit,
476+
tol_rel_grad=1,
477+
)
478+
479+
with self.assertRaisesRegex(ValueError, 'bfgs or lbfgs'):
482480
model.optimize(
483481
data=jdata,
484482
seed=1239812093,
@@ -489,7 +487,7 @@ def test_parameters_and_optimizer_compatible(self):
489487

490488
with self.assertRaisesRegex(
491489
ValueError,
492-
'history_size must not be set when algorithm is Newton or BFGS',
490+
'lbfgs',
493491
):
494492
model.optimize(
495493
data=jdata,
@@ -501,7 +499,7 @@ def test_parameters_and_optimizer_compatible(self):
501499

502500
with self.assertRaisesRegex(
503501
ValueError,
504-
'history_size must not be set when algorithm is Newton or BFGS',
502+
'lbfgs',
505503
):
506504
model.optimize(
507505
data=jdata,
@@ -511,6 +509,17 @@ def test_parameters_and_optimizer_compatible(self):
511509
history_size=1,
512510
)
513511

512+
with self.assertRaisesRegex(
513+
ValueError,
514+
'lbfgs',
515+
):
516+
model.optimize(
517+
data=jdata,
518+
seed=1239812093,
519+
inits=jinit,
520+
history_size=1,
521+
)
522+
514523
def test_optimize_good_dict(self):
515524
exe = os.path.join(DATAFILES_PATH, 'bernoulli' + EXTENSION)
516525
stan = os.path.join(DATAFILES_PATH, 'bernoulli.stan')

0 commit comments

Comments
 (0)