diff --git a/curepy/__init__.py b/curepy/__init__.py index 5af0891..4d1cc52 100644 --- a/curepy/__init__.py +++ b/curepy/__init__.py @@ -35,7 +35,12 @@ # Utilities from curepy.utilities.plotting import plot_corner from curepy.utilities.maths import lnlike -from curepy.utilities.distributions import ln_uniform, ln_normal, ln_multi_normal, ln_trunc_normal +from curepy.utilities.distributions import ( + ln_uniform, + ln_normal, + ln_multi_normal, + ln_trunc_normal, +) from curepy.utilities.utilities import flatten_array, reshape_array, format_correlation from ._version import get_versions diff --git a/curepy/container/measurement.py b/curepy/container/measurement.py index ed1e8ba..3a9a12d 100644 --- a/curepy/container/measurement.py +++ b/curepy/container/measurement.py @@ -14,6 +14,7 @@ def __init__( u_y_rand: Optional[np.ndarray] = None, u_y_syst: Optional[np.ndarray] = None, corr_y: Optional[Union[str, np.ndarray]] = None, + skip_invcov: bool = False, ) -> None: """ Container class for measurement variable data. @@ -29,6 +30,7 @@ def __init__( Accepted values: ``None``, ``"rand"`` (random), ``"syst"`` (systematic), or a square matrix whose side length equals the length of ``y``. + :param skip_invcov: If ``True``, skip the computation of the inverse covariance matrix (which is only needed for certain retrieval methods like optimal estimation). """ u_y_total, corr_y = self._format_uncertainty( @@ -43,9 +45,13 @@ def __init__( self.corr_y = util.format_correlation(self.y_flat, corr_y) + self.corr_y, self.cholesky, self.W = self.return_corr_cholesky_whitening( + self.corr_y + ) + self._check_shapes(self.y_flat, self.u_y_flat, self.corr_y) - if corr_y is not None: + if corr_y is not None and not skip_invcov: self.invcov = self.calculate_inv_cov(self.u_y_flat, self.corr_y) else: self.invcov = None @@ -136,6 +142,30 @@ def _format_uncertainty(u_total, u_rand, u_syst, corr): tot_corr = cm.convert_cov_to_corr(tot_cov, tot) return tot, tot_corr + @staticmethod + def return_corr_cholesky_whitening(corr: Optional[np.ndarray]) -> tuple: + """ + Return the correlation matrix, its Cholesky decomposition, and the whitening matrix. + + :param corr: Correlation matrix, or ``None``. + :returns: Tuple of ``(corr, cholesky, W)`` where ``cholesky`` is + the Cholesky decomposition of the correlation matrix, or + ``None`` if ``corr`` is ``None``, and ``W`` is the whitening matrix. + """ + if corr is not None: + try: + cholesky = np.linalg.cholesky(corr) + W = np.linalg.solve(cholesky, np.eye(cholesky.shape[0])) + return corr, cholesky, W + except np.linalg.LinAlgError: + # If the correlation matrix is not positive definite, use the nearest positive definite matrix + corr_pd = cm.nearestPD_cholesky(corr, return_cholesky=False, corr=True) + cholesky = np.linalg.cholesky(corr_pd) + W = np.linalg.solve(cholesky, np.eye(cholesky.shape[0])) + return corr_pd, cholesky, W + else: + return None, None, None + @staticmethod def calculate_inv_cov(unc: np.ndarray, corr: np.ndarray) -> np.ndarray: """ @@ -151,5 +181,4 @@ def calculate_inv_cov(unc: np.ndarray, corr: np.ndarray) -> np.ndarray: if np.array_equal(cov, np.diag(np.diag(cov))): return np.diag(1 / np.diag(cov)) else: - # might need a check for PD here return np.linalg.inv(cov) diff --git a/curepy/container/tests/test_measurement.py b/curepy/container/tests/test_measurement.py index 9373553..70abba9 100644 --- a/curepy/container/tests/test_measurement.py +++ b/curepy/container/tests/test_measurement.py @@ -40,6 +40,8 @@ def test_calculate_inv_cov(self, mock_convert_corr_to_cov, mock_inv): def test_init_format_correlation_called( self, mock_format, mock_check, mock_convert ): + # Configure mock to return a valid correlation matrix + mock_format.return_value = np.eye(len(y)) meas = Measurement(y, u_y, corr_y="rand") diff --git a/curepy/retrieval_methods/base.py b/curepy/retrieval_methods/base.py index 090bbff..7a30a75 100644 --- a/curepy/retrieval_methods/base.py +++ b/curepy/retrieval_methods/base.py @@ -127,35 +127,44 @@ def find_chisum( ).flatten() ) diff = modelled_data - self.retrieval_input.measurement_obj.y_flat + + # Only normalize by u_y_flat if it's available + if self.retrieval_input.measurement_obj.u_y_flat is not None: + diff_norm = diff / self.retrieval_input.measurement_obj.u_y_flat + else: + diff_norm = diff + if np.isfinite(np.sum(diff)): - if self.retrieval_input.measurement_obj.invcov is None: - return np.sum( - (diff) ** 2 / self.retrieval_input.measurement_obj.u_y_flat**2 - ) + if self.retrieval_input.measurement_obj.cholesky is None: + chisq = np.sum( + (diff_norm) ** 2 + ) # this is equivalent to using an identity matrix for the inverse covariance, which is appropriate when only uncorrelated uncertainties are available + else: if len(repeat_dims) == 0: - return np.dot( - np.dot(diff.T, self.retrieval_input.measurement_obj.invcov), - diff, - ) + y = self.retrieval_input.measurement_obj.W @ diff_norm + chisq = y.T @ y elif len(repeat_dims) == 1: sum = 0 for i in range(diff.shape[repeat_dims[0]]): - diffi = np.take(diff, i, repeat_dims[0]) - sum += np.dot( - np.dot( - diffi.T, self.retrieval_input.measurement_obj.invcov - ), - diffi, - ) - return sum + diff_norm_i = np.take(diff_norm, i, repeat_dims[0]) + y = self.retrieval_input.measurement_obj.W @ diff_norm_i + sum += y.T @ y + chisq = sum else: raise ValueError( "Methods for multiple repeat dimensions are not yet implemented," ) else: print("The difference between model and observations is infinite") - return np.inf + chisq = np.inf + + if chisq < 0: + raise ValueError( + "The chi-squared cost is negative, which should not be possible. Check the inputs and the measurement function for errors." + ) + + return chisq def lnprob(self, theta: np.ndarray) -> float: """ diff --git a/curepy/retrieval_methods/optimal_estimation.py b/curepy/retrieval_methods/optimal_estimation.py index 2c54f25..a1d98f0 100644 --- a/curepy/retrieval_methods/optimal_estimation.py +++ b/curepy/retrieval_methods/optimal_estimation.py @@ -56,7 +56,7 @@ def _run_retrieval( self.retrieval_input.measurement_function_obj.initial_guess ) - res = minimize(-self.lnprob, theta_0) + res = minimize(lambda theta: -self.lnprob(theta), theta_0) if self.Jx is None: Jx = self.calculate_Jx(res.x) diff --git a/curepy/retrieval_methods/tests/test_base.py b/curepy/retrieval_methods/tests/test_base.py index 40fcaef..f62514b 100644 --- a/curepy/retrieval_methods/tests/test_base.py +++ b/curepy/retrieval_methods/tests/test_base.py @@ -14,6 +14,8 @@ def make_mock_retrieval_input_for_chisum( invcov=None, u_y=None, b=None, + L=None, + W=None, ): retrieval_input = RetrievalInput() # measurement function object @@ -27,6 +29,9 @@ def make_mock_retrieval_input_for_chisum( retrieval_input.measurement_obj.y_flat = np.array(y_flat) retrieval_input.measurement_obj.invcov = invcov retrieval_input.measurement_obj.u_y_flat = u_y + retrieval_input.measurement_obj.cholesky = L + retrieval_input.measurement_obj.W = W + # ancillary retrieval_input.ancillary_obj = MagicMock() retrieval_input.ancillary_obj.b = b @@ -137,7 +142,7 @@ def test_multiple_repeat_dims_raises_error(self): retrieval_input = make_mock_retrieval_input_for_chisum( measurement_function_output=np.array([1.0, 2.0]), y_flat=np.array([1.0, 1.0]), - invcov=np.eye(2), + L=np.eye(2), u_y=None, b=None, ) @@ -181,6 +186,8 @@ def test_chisum_with_invcov_no_repeat(self): invcov=invcov, u_y=None, b=None, + L=np.eye(2), + W=np.eye(2), ) dr = DummyRetrieval() diff --git a/examples/multidimensional_MCMC_example.py b/examples/multidimensional_MCMC_example.py index 8e6a03f..a368e77 100644 --- a/examples/multidimensional_MCMC_example.py +++ b/examples/multidimensional_MCMC_example.py @@ -25,7 +25,7 @@ def quadratic(a, b, c, x, d): y = data + noise meas_func = MeasurementFunction(quadratic, [0.5, 0.2, -10]) -meas = Measurement(y, noise, "rand") +meas = Measurement(y, noise, corr_y="rand") ancill = AncillaryParameter([x, d], [None, 1], [None, None], b_MC_steps=3) inputs = RetrievalInput(meas_func, meas, ancill) diff --git a/examples/multidimensional_LPU_example.py b/examples/multidimensional_OE_example.py similarity index 93% rename from examples/multidimensional_LPU_example.py rename to examples/multidimensional_OE_example.py index 3de870f..445f399 100644 --- a/examples/multidimensional_LPU_example.py +++ b/examples/multidimensional_OE_example.py @@ -2,7 +2,7 @@ MeasurementFunction, Measurement, AncillaryParameter, - LPU, + OE, RetrievalInput, ) @@ -24,7 +24,7 @@ def quadratic(a, b, c, x, d): y = data + noise meas_func = MeasurementFunction(quadratic, [0.5, 0.2, -10]) -meas = Measurement(y, noise, "rand") +meas = Measurement(y, noise, corr_y="rand") ancill = AncillaryParameter( [x, d], [0.01 * np.ones_like(x), 1], @@ -41,7 +41,7 @@ def quadratic(a, b, c, x, d): inputs = RetrievalInput(meas_func, meas, ancill) -ret = LPU() +ret = OE() results = ret.run_retrieval(inputs) diff --git a/examples/simple_quadratic_MCMC_example.py b/examples/simple_quadratic_MCMC_example.py index c81f8ac..ddeb422 100644 --- a/examples/simple_quadratic_MCMC_example.py +++ b/examples/simple_quadratic_MCMC_example.py @@ -28,7 +28,7 @@ def quadratic(a, b, c, x, d): y = data + noise meas_func = MeasurementFunction(quadratic, [0.5, 0.2, -10]) -meas = Measurement(y, noise, np.eye(len(x))) +meas = Measurement(y, noise, corr_y=np.eye(len(x))) ancill = AncillaryParameter([x, d], [None, 1], [np.eye(len(x)), None], b_MC_steps=3) prior = Prior( ["normal"] * 3, diff --git a/examples/simple_quadratic_LPU_example.py b/examples/simple_quadratic_OE_example.py similarity index 88% rename from examples/simple_quadratic_LPU_example.py rename to examples/simple_quadratic_OE_example.py index 4e53b8d..47bdcb7 100644 --- a/examples/simple_quadratic_LPU_example.py +++ b/examples/simple_quadratic_OE_example.py @@ -2,7 +2,7 @@ MeasurementFunction, Measurement, AncillaryParameter, - LPU, + OE, RetrievalInput, Prior, ) @@ -26,7 +26,7 @@ def quadratic(a, b, c, x, d): y = data + noise meas_func = MeasurementFunction(quadratic, [0.5, 0.2, -10]) -meas = Measurement(y, noise, np.eye(len(x))) +meas = Measurement(y, noise, corr_y=np.eye(len(x))) ancill = AncillaryParameter([x, d], [None, 0.05], [None, None]) prior = Prior( ["normal"] * 3, @@ -36,7 +36,7 @@ def quadratic(a, b, c, x, d): inputs = RetrievalInput(meas_func, meas, ancill, prior) -ret = LPU() +ret = OE() results = ret.run_retrieval(inputs) @@ -44,4 +44,4 @@ def quadratic(a, b, c, x, d): print(results.uncertainties) plt.plot(x, quadratic(*results.values, x, d)) plt.scatter(x, y, alpha=0.5, c="orange") -plt.savefig(os.path.join(example_dir, "LPU_test.png")) +plt.savefig(os.path.join(example_dir, "OE_test.png"))