|
9 | 9 | import shutil |
10 | 10 | import subprocess |
11 | 11 | from copy import copy |
| 12 | +from datetime import datetime |
12 | 13 | from pathlib import Path |
13 | 14 | from typing import Any, Dict, Iterable, List, Optional, Union |
14 | 15 |
|
15 | 16 | from cmdstanpy.utils import get_logger |
16 | | -from cmdstanpy.utils.cmdstan import EXTENSION, cmdstan_path |
| 17 | +from cmdstanpy.utils.cmdstan import ( |
| 18 | + EXTENSION, |
| 19 | + cmdstan_path, |
| 20 | + cmdstan_version, |
| 21 | + cmdstan_version_before, |
| 22 | +) |
17 | 23 | from cmdstanpy.utils.command import do_command |
18 | 24 | from cmdstanpy.utils.filesystem import SanitizedOrTmpFilePath |
19 | 25 |
|
@@ -476,3 +482,98 @@ def compile_stan_file( |
476 | 482 | f"Failed to compile Stan model '{src}'. " f"Console:\n{console}" |
477 | 483 | ) |
478 | 484 | return str(exe_target) |
| 485 | + |
| 486 | + |
| 487 | +def format_stan_file( |
| 488 | + stan_file: Union[str, os.PathLike], |
| 489 | + *, |
| 490 | + overwrite_file: bool = False, |
| 491 | + canonicalize: Union[bool, str, Iterable[str]] = False, |
| 492 | + max_line_length: int = 78, |
| 493 | + backup: bool = True, |
| 494 | + stanc_options: Optional[Dict[str, Any]] = None, |
| 495 | +) -> None: |
| 496 | + """ |
| 497 | + Run stanc's auto-formatter on the model code. Either saves directly |
| 498 | + back to the file or prints for inspection |
| 499 | +
|
| 500 | + :param stan_file: Path to Stan program file. |
| 501 | + :param overwrite_file: If True, save the updated code to disk, rather |
| 502 | + than printing it. By default False |
| 503 | + :param canonicalize: Whether or not the compiler should 'canonicalize' |
| 504 | + the Stan model, removing things like deprecated syntax. Default is |
| 505 | + False. If True, all canonicalizations are run. If it is a list of |
| 506 | + strings, those options are passed to stanc (new in Stan 2.29) |
| 507 | + :param max_line_length: Set the wrapping point for the formatter. The |
| 508 | + default value is 78, which wraps most lines by the 80th character. |
| 509 | + :param backup: If True, create a stanfile.bak backup before |
| 510 | + writing to the file. Only disable this if you're sure you have other |
| 511 | + copies of the file or are using a version control system like Git. |
| 512 | + :param stanc_options: Additional options to pass to the stanc compiler. |
| 513 | + """ |
| 514 | + stan_file = Path(stan_file).resolve() |
| 515 | + |
| 516 | + if not stan_file.exists(): |
| 517 | + raise ValueError(f'File does not exist: {stan_file}') |
| 518 | + |
| 519 | + try: |
| 520 | + cmd = ( |
| 521 | + [os.path.join(cmdstan_path(), 'bin', 'stanc' + EXTENSION)] |
| 522 | + # handle include-paths, allow-undefined etc |
| 523 | + + CompilerOptions(stanc_options=stanc_options).compose_stanc(None) |
| 524 | + + [str(stan_file)] |
| 525 | + ) |
| 526 | + |
| 527 | + if canonicalize: |
| 528 | + if cmdstan_version_before(2, 29): |
| 529 | + if isinstance(canonicalize, bool): |
| 530 | + cmd.append('--print-canonical') |
| 531 | + else: |
| 532 | + raise ValueError( |
| 533 | + "Invalid arguments passed for current CmdStan" |
| 534 | + + " version({})\n".format( |
| 535 | + cmdstan_version() or "Unknown" |
| 536 | + ) |
| 537 | + + "--canonicalize requires 2.29 or higher" |
| 538 | + ) |
| 539 | + else: |
| 540 | + if isinstance(canonicalize, str): |
| 541 | + cmd.append('--canonicalize=' + canonicalize) |
| 542 | + elif isinstance(canonicalize, Iterable): |
| 543 | + cmd.append('--canonicalize=' + ','.join(canonicalize)) |
| 544 | + else: |
| 545 | + cmd.append('--print-canonical') |
| 546 | + |
| 547 | + # before 2.29, having both --print-canonical |
| 548 | + # and --auto-format printed twice |
| 549 | + if not (cmdstan_version_before(2, 29) and canonicalize): |
| 550 | + cmd.append('--auto-format') |
| 551 | + |
| 552 | + if not cmdstan_version_before(2, 29): |
| 553 | + cmd.append(f'--max-line-length={max_line_length}') |
| 554 | + elif max_line_length != 78: |
| 555 | + raise ValueError( |
| 556 | + "Invalid arguments passed for current CmdStan version" |
| 557 | + + " ({})\n".format(cmdstan_version() or "Unknown") |
| 558 | + + "--max-line-length requires 2.29 or higher" |
| 559 | + ) |
| 560 | + |
| 561 | + out = subprocess.run(cmd, capture_output=True, text=True, check=True) |
| 562 | + if out.stderr: |
| 563 | + get_logger().warning(out.stderr) |
| 564 | + result = out.stdout |
| 565 | + if overwrite_file: |
| 566 | + if result: |
| 567 | + if backup: |
| 568 | + shutil.copyfile( |
| 569 | + stan_file, |
| 570 | + str(stan_file) |
| 571 | + + '.bak-' |
| 572 | + + datetime.now().strftime("%Y%m%d%H%M%S"), |
| 573 | + ) |
| 574 | + stan_file.write_text(result) |
| 575 | + else: |
| 576 | + print(result) |
| 577 | + |
| 578 | + except (ValueError, RuntimeError) as e: |
| 579 | + raise RuntimeError("Stanc formatting failed") from e |
0 commit comments