|
1 | 1 | """CmdStanModel""" |
2 | 2 |
|
| 3 | +import io |
3 | 4 | import os |
4 | 5 | import platform |
5 | 6 | import re |
@@ -2171,3 +2172,118 @@ def progress_hook(line: str, idx: int) -> None: |
2171 | 2172 | pbars[idx].postfix[0]["value"] = mline |
2172 | 2173 |
|
2173 | 2174 | return progress_hook |
| 2175 | + |
| 2176 | + def diagnose( |
| 2177 | + self, |
| 2178 | + inits: Union[Dict[str, Any], str, os.PathLike, None] = None, |
| 2179 | + data: Union[Mapping[str, Any], str, os.PathLike, None] = None, |
| 2180 | + *, |
| 2181 | + epsilon: Optional[float] = None, |
| 2182 | + error: Optional[float] = None, |
| 2183 | + require_gradients_ok: bool = True, |
| 2184 | + sig_figs: Optional[int] = None, |
| 2185 | + ) -> pd.DataFrame: |
| 2186 | + """ |
| 2187 | + Run diagnostics to calculate the gradients at the specified parameter |
| 2188 | + values and compare them with gradients calculated by finite differences. |
| 2189 | +
|
| 2190 | + :param inits: Specifies how the sampler initializes parameter values. |
| 2191 | + Initialization is either uniform random on a range centered on 0, |
| 2192 | + exactly 0, or a dictionary or file of initial values for some or |
| 2193 | + all parameters in the model. The default initialization behavior |
| 2194 | + will initialize all parameter values on range [-2, 2] on the |
| 2195 | + *unconstrained* support. The following value types are allowed: |
| 2196 | + * Single number, n > 0 - initialization range is [-n, n]. |
| 2197 | + * 0 - all parameters are initialized to 0. |
| 2198 | + * dictionary - pairs parameter name : initial value. |
| 2199 | + * string - pathname to a JSON or Rdump data file. |
| 2200 | +
|
| 2201 | + :param data: Values for all data variables in the model, specified |
| 2202 | + either as a dictionary with entries matching the data variables, |
| 2203 | + or as the path of a data file in JSON or Rdump format. |
| 2204 | +
|
| 2205 | + :param sig_figs: Numerical precision used for output CSV and text files. |
| 2206 | + Must be an integer between 1 and 18. If unspecified, the default |
| 2207 | + precision for the system file I/O is used; the usual value is 6. |
| 2208 | +
|
| 2209 | + :param epsilon: Step size for finite difference gradients. |
| 2210 | +
|
| 2211 | + :param error: Absolute error threshold for comparing autodiff and finite |
| 2212 | + difference gradients. |
| 2213 | +
|
| 2214 | + :param require_gradients_ok: Whether or not to raise an error if Stan |
| 2215 | + reports that the difference between autodiff gradients and finite |
| 2216 | + difference gradients exceed the error threshold. |
| 2217 | +
|
| 2218 | + :return: A pandas.DataFrame containing columns |
| 2219 | + * "param_idx": increasing parameter index. |
| 2220 | + * "value": Parameter value. |
| 2221 | + * "model": Gradients evaluated using autodiff. |
| 2222 | + * "finite_diff": Gradients evaluated using finite differences. |
| 2223 | + * "error": Delta between autodiff and finite difference gradients. |
| 2224 | +
|
| 2225 | + Gradients are evaluated in the unconstrained space. |
| 2226 | + """ |
| 2227 | + |
| 2228 | + with temp_single_json(data) as _data, \ |
| 2229 | + temp_single_json(inits) as _inits: |
| 2230 | + cmd = [ |
| 2231 | + str(self.exe_file), |
| 2232 | + "diagnose", |
| 2233 | + "test=gradient", |
| 2234 | + ] |
| 2235 | + if epsilon is not None: |
| 2236 | + cmd.append(f"epsilon={epsilon}") |
| 2237 | + if error is not None: |
| 2238 | + cmd.append(f"epsilon={error}") |
| 2239 | + if _data is not None: |
| 2240 | + cmd += ["data", f"file={_data}"] |
| 2241 | + if _inits is not None: |
| 2242 | + cmd.append(f"init={_inits}") |
| 2243 | + |
| 2244 | + output_dir = tempfile.mkdtemp(prefix=self.name, dir=_TMPDIR) |
| 2245 | + |
| 2246 | + output = os.path.join(output_dir, "output.csv") |
| 2247 | + cmd += ["output", f"file={output}"] |
| 2248 | + if sig_figs is not None: |
| 2249 | + cmd.append(f"sig_figs={sig_figs}") |
| 2250 | + |
| 2251 | + get_logger().debug("Cmd: %s", str(cmd)) |
| 2252 | + |
| 2253 | + proc = subprocess.run( |
| 2254 | + cmd, capture_output=True, check=False, text=True |
| 2255 | + ) |
| 2256 | + if proc.returncode: |
| 2257 | + get_logger().error( |
| 2258 | + "'diagnose' command failed!\nstdout:%s\nstderr:%s", |
| 2259 | + proc.stdout, |
| 2260 | + proc.stderr, |
| 2261 | + ) |
| 2262 | + if require_gradients_ok: |
| 2263 | + raise RuntimeError( |
| 2264 | + "The difference between autodiff and finite difference " |
| 2265 | + "gradients may exceed the error threshold. If you " |
| 2266 | + "would like to inspect the output, re-call with " |
| 2267 | + "`require_gradients_ok=False`." |
| 2268 | + ) |
| 2269 | + get_logger().warning( |
| 2270 | + "The difference between autodiff and finite difference " |
| 2271 | + "gradients may exceed the error threshold. Proceeding " |
| 2272 | + "because `require_gradients_ok` is set to `False`." |
| 2273 | + ) |
| 2274 | + |
| 2275 | + # Read the text and get the last chunk separated by a single # char. |
| 2276 | + try: |
| 2277 | + with open(output) as handle: |
| 2278 | + text = handle.read() |
| 2279 | + except FileNotFoundError as exc: |
| 2280 | + raise RuntimeError( |
| 2281 | + "Output of 'diagnose' command does not exist." |
| 2282 | + ) from exc |
| 2283 | + *_, table = re.split(r"#\s*\n", text) |
| 2284 | + table = ( |
| 2285 | + re.sub(r"^#\s*", "", table, flags=re.M) |
| 2286 | + .replace("param idx", "param_idx") |
| 2287 | + .replace("finite diff", "finite_diff") |
| 2288 | + ) |
| 2289 | + return pd.read_csv(io.StringIO(table), sep=r"\s+") |
0 commit comments