Skip to content

Commit 56fb50d

Browse files
authored
Merge pull request #840 from amas0/add-compilation-kwargs
Add multithreading and stanc_optimization arguments to CmdStanModel
2 parents 709da77 + 64f0f1b commit 56fb50d

File tree

5 files changed

+91
-35
lines changed

5 files changed

+91
-35
lines changed

cmdstanpy/compilation.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -486,3 +486,23 @@ def format_stan_file(
486486

487487
except (ValueError, RuntimeError) as e:
488488
raise RuntimeError("Stanc formatting failed") from e
489+
490+
491+
def resolve_cpp_options(
492+
cpp_options: dict[str, Any] | None, multithreading: bool
493+
) -> dict[str, Any]:
494+
out = cpp_options or {}
495+
out = out.copy()
496+
if multithreading and "STAN_THREADS" not in out:
497+
out["STAN_THREADS"] = "TRUE"
498+
return out
499+
500+
501+
def resolve_stanc_options(
502+
stanc_options: dict[str, Any] | None, stanc_optimizations: bool
503+
) -> dict[str, Any]:
504+
out = stanc_options or {}
505+
out = out.copy()
506+
if stanc_optimizations and "O" not in out:
507+
out["O"] = 1
508+
return out

cmdstanpy/model.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,9 @@ def __init__(
9393
stanc_options: dict[str, Any] | None = None,
9494
cpp_options: dict[str, Any] | None = None,
9595
user_header: OptionalPath = None,
96+
*,
97+
multithreading: bool = False,
98+
stanc_optimizations: bool = False,
9699
) -> None:
97100
"""
98101
Initialize object given constructor args.
@@ -101,14 +104,27 @@ def __init__(
101104
:param exe_file: Path to compiled executable file.
102105
:param force_compile: Whether or not to force recompilation if
103106
executable file already exists.
104-
:param stanc_options: Options for stanc compiler.
105-
:param cpp_options: Options for C++ compiler.
107+
:param multithreading: Enables multithreading in a Stan model.
108+
Equivalent to `cpp_options = {"STAN_THREADS": "TRUE"}`.
109+
Defaults to False.
110+
:param stanc_optimizations: Enables O1 optimizations in the
111+
stanc compiler. Equivalent to `stanc_options = {"O": 1}`.
112+
Defaults to False.
113+
:param stanc_options: Options for stanc compiler. Note, this
114+
will override the `stanc_optimizations` if in conflict.
115+
:param cpp_options: Options for C++ compiler. Note, this will
116+
override the `multithreading` option if in conflict.
106117
:param user_header: A path to a header file to include during C++
107118
compilation.
108119
"""
109120
self._name = ''
110121
self._stan_file = None
111-
self._stanc_options: dict[str, Any] = stanc_options or {}
122+
self._stanc_options = compilation.resolve_stanc_options(
123+
stanc_options, stanc_optimizations
124+
)
125+
cpp_options = compilation.resolve_cpp_options(
126+
cpp_options, multithreading
127+
)
112128

113129
self._fixed_param = False
114130

docsrc/users-guide/examples/Pathfinder.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444
"outputs": [],
4545
"source": [
4646
"import os\n",
47-
"from cmdstanpy.model import CmdStanModel, cmdstan_path"
47+
"from cmdstanpy import CmdStanModel, cmdstan_path"
4848
]
4949
},
5050
{

docsrc/users-guide/workflow.rst

Lines changed: 25 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -39,21 +39,14 @@ managing the resulting inference for a single model and set of inputs.
3939
Compile the Stan model
4040
^^^^^^^^^^^^^^^^^^^^^^
4141

42-
The: :class:`CmdStanModel` class provides methods
43-
to compile and run the Stan program.
44-
A CmdStanModel object can be instantiated by specifying
45-
either a Stan file or the executable file, or both.
46-
If only the Stan file path is specified, the constructor will
47-
check for the existence of a correspondingly named exe file in
48-
the same directory. If found, it will use this as the exe file path.
49-
50-
By default, when a CmdStanModel object is instantiated from a Stan file,
51-
the constructor will compile the model as needed.
52-
The constructor argument `compile` controls this behavior.
53-
54-
* ``compile=False``: never compile the Stan file.
55-
* ``compile="Force"``: always compile the Stan file.
56-
* ``compile=True``: (default) compile the Stan file as needed, i.e., if no exe file exists or if the Stan file is newer than the exe file.
42+
The :class:`CmdStanModel` class provides methods to compile and run the Stan
43+
program. A CmdStanModel object can be instantiated by specifying a Stan file,
44+
the executable file, or both. If only the Stan file path is specified, the
45+
constructor will check for the existence of a correspondingly named executable in
46+
the same directory. If found, it will use this as the exe file path.
47+
48+
When a CmdStanModel object is instantiated from a Stan file, the constructor
49+
will compile the model if the executable is non-existent or out-of-date.
5750

5851
.. code-block:: python
5952
@@ -67,8 +60,8 @@ The constructor argument `compile` controls this behavior.
6760
my_model.exe_file
6861
my_model.code()
6962
70-
The CmdStanModel class also provides the :meth:`~CmdStanModel.compile` method,
71-
which can be called at any point to (re)compile the model as needed.
63+
The ``force_compile=True`` argument can be passed to the CmdStanModel
64+
constructor, which will force (re)compilation of the model.
7265

7366
Model compilation is carried out via the GNU Make build tool.
7467
The CmdStan ``makefile`` contains a set of general rules which
@@ -83,28 +76,30 @@ Model compilation is done in two steps:
8376
* The C++ compiler compiles the generated code and links in
8477
the necessary supporting libraries.
8578

86-
Therefore, both the constructor and the ``compile`` method
87-
allow optional arguments ``stanc_options`` and ``cpp_options`` which
88-
specify options for each compilation step.
89-
Options are specified as a Python dictionary mapping
90-
compiler option names to appropriate values.
79+
The constructor accepts arguments to specify both ``stanc`` and C++ compilation
80+
options, if desired. Passing `multithreading=True` enables the **STAN_THREADS**
81+
C++ flag, which is needed to parallelize within-chain computations, such as
82+
with ``reduce_sum``, or to parallelize the NUTS-HMC sampler across chains.
83+
Passing ``stanc_optimizations=True`` will enable ``O1`` optimizations in the
84+
``stanc`` compiler.
9185

92-
In order parallelize within-chain computations using the
93-
Stan language ``reduce_sum`` function, or to parallelize
94-
running the NUTS-HMC sampler across chains,
95-
the Stan model must be compiled with
96-
C++ compiler flag **STAN_THREADS**.
97-
While any value can be used,
98-
we recommend the value ``True``, e.g.:
86+
Outside of these common options, the constructor accepts the optional arguments
87+
``stanc_options`` and ``cpp_options``, which allow specifying arbitrary
88+
compilation options. Some more advanced Stan features, like MPI or OpenCL
89+
support, require using these. Note that if the lower-level compilation options
90+
conflict with an argument like ``multithreading=True``, the option in
91+
``stanc_options`` or ``cpp_options`` takes precedence.
9992

93+
An example model compilation that enables multithreading and
94+
basic optimization can be done like so:
10095

10196
.. code-block:: python
10297
10398
import os
10499
from cmdstanpy import CmdStanModel
105100
106101
my_stanfile = os.path.join('.', 'my_model.stan')
107-
my_model = CmdStanModel(stan_file=my_stanfile, cpp_options={'STAN_THREADS':'true'})
102+
my_model = CmdStanModel(stan_file=my_stanfile, multithreading=True, stanc_optimizations=True)
108103
109104
110105
Assemble input and initialization data

test/test_compilation.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,12 @@
99

1010
import pytest
1111

12-
from cmdstanpy.compilation import CompilerOptions, format_stan_file
12+
from cmdstanpy.compilation import (
13+
CompilerOptions,
14+
format_stan_file,
15+
resolve_cpp_options,
16+
resolve_stanc_options,
17+
)
1318

1419
HERE = os.path.dirname(os.path.abspath(__file__))
1520
DATAFILES_PATH = os.path.join(HERE, 'data')
@@ -225,3 +230,23 @@ def test_model_format_options() -> None:
225230
formatted = sys_stdout.getvalue()
226231
assert formatted.count('{') == 3
227232
assert formatted.count('(') == 1
233+
234+
235+
def test_compilation_options_resolution() -> None:
236+
out = resolve_cpp_options(None, multithreading=False)
237+
assert not out
238+
out = resolve_cpp_options(None, multithreading=True)
239+
assert out == {"STAN_THREADS": "TRUE"}
240+
out = resolve_cpp_options({"STAN_THREADS": ""}, multithreading=True)
241+
assert out == {"STAN_THREADS": ""}
242+
out = resolve_cpp_options({"STAN_OPENCL": "TRUE"}, multithreading=True)
243+
assert out == {"STAN_THREADS": "TRUE", "STAN_OPENCL": "TRUE"}
244+
245+
out = resolve_stanc_options(None, stanc_optimizations=False)
246+
assert not out
247+
out = resolve_stanc_options(None, stanc_optimizations=True)
248+
assert out == {"O": 1}
249+
out = resolve_stanc_options({"O": 0}, stanc_optimizations=True)
250+
assert out == {"O": 0}
251+
out = resolve_stanc_options({"O": "experimental"}, stanc_optimizations=True)
252+
assert out == {"O": "experimental"}

0 commit comments

Comments
 (0)