Skip to content

Commit 997a60b

Browse files
authored
Merge pull request #612 from stan-dev/fix/complex-output-ordering
Fix/complex output ordering
2 parents 1f88a4d + 94528b7 commit 997a60b

File tree

11 files changed

+79
-18
lines changed

11 files changed

+79
-18
lines changed

cmdstanpy/stanfit/gq.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -523,10 +523,15 @@ def stan_variable(
523523
col_idxs = self._metadata.stan_vars_cols[var]
524524
if len(col_idxs) > 0:
525525
dims.extend(self._metadata.stan_vars_dims[var])
526-
# pylint: disable=redundant-keyword-arg
527-
draws = self._draws[draw1:, :, col_idxs].reshape(dims, order='F')
526+
527+
draws = self._draws[draw1:, :, col_idxs]
528+
528529
if self._metadata.stan_vars_types[var] == BaseType.COMPLEX:
529-
draws = draws[..., 0] + 1j * draws[..., 1]
530+
draws = draws[..., ::2] + 1j * draws[..., 1::2]
531+
dims = dims[:-1]
532+
533+
draws = draws.reshape(dims, order='F')
534+
530535
return draws
531536

532537
def stan_variables(self, inc_warmup: bool = False) -> Dict[str, np.ndarray]:

cmdstanpy/stanfit/mcmc.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -747,9 +747,14 @@ def stan_variable(
747747
col_idxs = self._metadata.stan_vars_cols[var]
748748
if len(col_idxs) > 0:
749749
dims.extend(self._metadata.stan_vars_dims[var])
750-
draws = self._draws[draw1:, :, col_idxs].reshape(dims, order='F')
750+
draws = self._draws[draw1:, :, col_idxs]
751+
751752
if self._metadata.stan_vars_types[var] == BaseType.COMPLEX:
752-
draws = draws[..., 0] + 1j * draws[..., 1]
753+
draws = draws[..., ::2] + 1j * draws[..., 1::2]
754+
dims = dims[:-1]
755+
756+
draws = draws.reshape(dims, order='F')
757+
753758
return draws
754759

755760
def stan_variables(self) -> Dict[str, np.ndarray]:

cmdstanpy/stanfit/mle.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -218,12 +218,17 @@ def stan_variable(
218218
dims = (num_rows,) + self._metadata.stan_vars_dims[var]
219219
# pylint: disable=redundant-keyword-arg
220220
if num_rows > 1:
221-
result = self._all_iters[:, col_idxs].reshape(dims, order='F')
221+
result = self._all_iters[:, col_idxs]
222222
else:
223-
result = self._mle[col_idxs].reshape(dims[1:], order="F")
223+
result = self._mle[col_idxs]
224+
dims = dims[1:]
224225

225226
if self._metadata.stan_vars_types[var] == BaseType.COMPLEX:
226-
result = result[..., 0] + 1j * result[..., 1]
227+
result = result[..., ::2] + 1j * result[..., 1::2]
228+
dims = dims[:-1]
229+
230+
result = result.reshape(dims, order='F')
231+
227232
return result
228233

229234
else: # scalar var

cmdstanpy/stanfit/vb.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -143,12 +143,13 @@ def stan_variable(self, var: str) -> Union[np.ndarray, float]:
143143
shape: Tuple[int, ...] = ()
144144
if len(col_idxs) > 1:
145145
shape = self._metadata.stan_vars_dims[var]
146-
result: np.ndarray = np.asarray(self._variational_mean)[
147-
col_idxs
148-
].reshape(shape, order="F")
149-
146+
result: np.ndarray = np.asarray(self._variational_mean)[col_idxs]
150147
if self._metadata.stan_vars_types[var] == BaseType.COMPLEX:
151-
result = result[..., 0] + 1j * result[..., 1]
148+
result = result[..., ::2] + 1j * result[..., 1::2]
149+
shape = shape[:-1]
150+
151+
result = result.reshape(shape, order="F")
152+
152153
return result
153154
else:
154155
return float(self._variational_mean[col_idxs[0]])

cmdstanpy/utils/data_munging.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,12 +44,14 @@ def build_xarray_data(
4444
if dims:
4545
var_dims += tuple(f"{var_name}_dim_{i}" for i in range(len(dims)))
4646

47-
draws = drawset[start_row:, :, col_idxs].reshape(
48-
*drawset.shape[:2], *dims, order="F"
49-
)
47+
draws = drawset[start_row:, :, col_idxs]
48+
5049
if var_type == BaseType.COMPLEX:
51-
draws = draws[..., 0] + 1j * draws[..., 1]
50+
draws = draws[..., ::2] + 1j * draws[..., 1::2]
5251
var_dims = var_dims[:-1]
52+
dims = dims[:-1]
53+
54+
draws = draws.reshape(*drawset.shape[:2], *dims, order="F")
5355

5456
data[var_name] = (
5557
var_dims,

docsrc/changes.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,10 @@ What's New
77

88
For full changes, see the `Releases page <https://github.com/stan-dev/cmdstanpy/releases>`__ on GitHub.
99

10+
CmdStanPy 1.0.6
11+
---------------
12+
13+
- Fixed an issue where complex number containers in Stan program outputs were not being read in properly by CmdStanPy. The output would have the correct shape, but the values would be mixed up.
1014

1115
CmdStanPy 1.0.6
1216
---------------

test/data/complex_var.stan

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@ generated quantities {
1616
{{0, 1}, {0, 2}, {0, 3}}};
1717
array[2, 3] complex zs = {{3, 4i, 5}, {1i, 2i, 3i}};
1818
complex z = 3 + 4i;
19-
19+
2020
array[2] int imag = {3, 4};
21+
22+
complex_matrix[2,3] zs_mat = to_matrix(zs);
2123
}
2224

test/test_generate_quantities.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -440,6 +440,13 @@ def test_complex_output(self):
440440

441441
self.assertEqual(fit.stan_variable('zs').shape, (10, 2, 3))
442442
self.assertEqual(fit.stan_variable('z')[0], 3 + 4j)
443+
444+
self.assertTrue(
445+
np.allclose(
446+
fit.stan_variable('zs')[0], np.array([[3, 4j, 5], [1j, 2j, 3j]])
447+
)
448+
)
449+
443450
# make sure the name 'imag' isn't magic
444451
self.assertEqual(fit.stan_variable('imag').shape, (10, 2))
445452

test/test_optimize.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -606,6 +606,13 @@ def test_complex_output(self):
606606

607607
self.assertEqual(fit.stan_variable('zs').shape, (2, 3))
608608
self.assertEqual(fit.stan_variable('z'), 3 + 4j)
609+
610+
self.assertTrue(
611+
np.allclose(
612+
fit.stan_variable('zs'), np.array([[3, 4j, 5], [1j, 2j, 3j]])
613+
)
614+
)
615+
609616
# make sure the name 'imag' isn't magic
610617
self.assertEqual(fit.stan_variable('imag').shape, (2,))
611618

test/test_sample.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1810,11 +1810,27 @@ def test_complex_output(self):
18101810
# make sure the name 'imag' isn't magic
18111811
self.assertEqual(fit.stan_variable('imag').shape, (10, 2))
18121812

1813+
self.assertTrue(
1814+
np.allclose(
1815+
fit.stan_variable('zs')[0], np.array([[3, 4j, 5], [1j, 2j, 3j]])
1816+
)
1817+
)
1818+
self.assertTrue(
1819+
np.allclose(
1820+
fit.stan_variable('zs_mat')[0],
1821+
np.array([[3, 4j, 5], [1j, 2j, 3j]]),
1822+
)
1823+
)
1824+
18131825
self.assertNotIn("zs_dim_2", fit.draws_xr())
18141826
# getting a raw scalar out of xarray is heavy
18151827
self.assertEqual(
18161828
fit.draws_xr().z.isel(chain=0, draw=1).data[()], 3 + 4j
18171829
)
1830+
np.testing.assert_allclose(
1831+
fit.draws_xr().zs.isel(chain=0, draw=1).data,
1832+
np.array([[3, 4j, 5], [1j, 2j, 3j]]),
1833+
)
18181834

18191835
def test_attrs(self):
18201836
stan = os.path.join(DATAFILES_PATH, 'named_output.stan')

0 commit comments

Comments
 (0)