diff --git a/ehrapy/preprocessing/_imputation.py b/ehrapy/preprocessing/_imputation.py index 1a4db5b2..04e890bf 100644 --- a/ehrapy/preprocessing/_imputation.py +++ b/ehrapy/preprocessing/_imputation.py @@ -268,7 +268,6 @@ def simple_impute( @_check_feature_types @use_ehrdata(deprecated_after="1.0.0") -@function_2D_only() @spinner("Performing KNN impute") def knn_impute( edata: EHRData | AnnData, @@ -285,6 +284,8 @@ def knn_impute( """Imputes missing values in the input data object using K-nearest neighbor imputation. If required, the data needs to be properly encoded as this imputation requires numerical data only. + For 2D data, if layer is `None`, `edata.X` is used directly. + For 3D data, the layer is flattened along axis 0 before imputation and reshaped back to 3D afterwards. .. warning:: Currently, both `n_neighbours` and `n_neighbors` are accepted as parameters for the number of neighbors. @@ -296,7 +297,7 @@ def knn_impute( var_names: A list of variable names indicating which columns to impute. If `None`, all columns are imputed. Default is `None`. n_neighbors: Number of neighbors to use when performing the imputation. - layer: The layer to impute. + layer: The layer to impute. Required when the input data is 3D. copy: Whether to perform the imputation on a copy of the original data object. If `True`, the original object remains unmodified. backend: The implementation to use for the KNN imputation. @@ -317,13 +318,29 @@ def knn_impute( Examples: >>> import ehrdata as ed >>> import ehrapy as ep - >>> edata = ed.dt.mimic_2() - >>> ed.infer_feature_types(edata) - >>> ep.pp.knn_impute(edata) + >>> edata_3d = ed.dt.ehrdata_blobs(n_variables=3, n_observations=3, base_timepoints=2, missing_values=0.3) + >>> edata_imputed = ep.pp.knn_impute(edata_3d, layer="tem_data", copy=True) + + Example Output: + + >>> edata_3d.layers["tem_data"][0, :, :] + [[-12.12732884, -18.37304373], + [ nan, -0.91339411], + [ nan, -7.88514984]] + >>> edata_imputed.layers["tem_data"][0, :, :] + [[-12.12732884, -18.37304373], + [ -0.07689509, -0.91339411], + [ -2.75584421, -7.88514984]] + """ if copy: edata = edata.copy() + if edata.X is None and layer is None: # if edata is 3D + raise ValueError( + "3D imputation requires a layer to be specified. Pass the layer containing the full temporal data." + ) + _warn_imputation_threshold(edata, var_names, threshold=warning_threshold, layer=layer) if backend not in {"scikit-learn", "faiss"}: @@ -387,15 +404,44 @@ def _knn_impute( "Can only impute numerical data. Try to restrict imputation to certain columns using " "var_names parameter or perform an encoding of your data." ) - X = edata.X if layer is None else edata.layers[layer] - complete_numerical_columns = np.array(numerical_indices)[~np.isnan(X[:, numerical_indices]).any(axis=0)].tolist() - imputer_data_indices = var_indices + [i for i in complete_numerical_columns if i not in var_indices] - imputer_x = X[::, imputer_data_indices].astype("float64") - - if layer is None: - edata.X[::, imputer_data_indices] = imputer.fit_transform(imputer_x) + mtx = edata.X if layer is None else edata.layers[layer] + var_indices_original = var_indices + is_3d = False + input_dtype = mtx.dtype if np.issubdtype(mtx.dtype, np.floating) else np.float64 + + # if input data is 3D, flatten along axis 0 before passing it to the imputer: each timepoint becomes a row + if mtx.ndim == 3: + is_3d = True + n_obs, n_vars, n_t = mtx.shape + mtx = ( + mtx[:, var_indices, :] + .astype(input_dtype, copy=True) + .transpose(0, 2, 1) + .reshape(n_obs * n_t, len(var_indices)) + ) + numerical_indices = list(range(len(var_indices))) + var_indices = numerical_indices + + # complete columns to be used as anchors + complete_numerical_columns = np.array(numerical_indices)[~np.isnan(mtx[:, numerical_indices]).any(axis=0)].tolist() + + imputer_data_indices = var_indices + [ + column for column in complete_numerical_columns if column not in var_indices + ] # columns to impute + imputer_x = mtx[:, imputer_data_indices].astype(input_dtype, copy=True) + X_imputed = imputer.fit_transform(imputer_x) + + if is_3d: + # slice back to only requested columns and transpose back to n_obs, n_var, n_t + X_imputed = ( + X_imputed[:, : len(var_indices_original)].reshape(n_obs, n_t, len(var_indices_original)).transpose(0, 2, 1) + ) + edata.layers[layer][:, var_indices_original, :] = X_imputed else: - edata.layers[layer][::, imputer_data_indices] = imputer.fit_transform(imputer_x) + if layer is None: + edata.X[:, imputer_data_indices] = X_imputed + else: + edata.layers[layer][:, imputer_data_indices] = X_imputed @use_ehrdata(deprecated_after="1.0.0") diff --git a/tests/conftest.py b/tests/conftest.py index 0c29527b..0fb87c4c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -185,7 +185,9 @@ def edata_mini(): @pytest.fixture -def edata_mini_3D_missing_values(): +def edata_mini_3D_missing_values(request): + only_numerical = getattr(request, "param", False) + tiny_mixed_array = np.array( [ [[138, 139], [78, np.nan], [77, 76], [1, 2], ["A", "B"], ["Yes", np.nan]], @@ -195,8 +197,25 @@ def edata_mini_3D_missing_values(): ], dtype=object, ) - n_obs, n_vars, _ = tiny_mixed_array.shape - return ed.EHRData(shape=(n_obs, n_vars), layers={DEFAULT_TEM_LAYER_NAME: tiny_mixed_array}) + + if only_numerical: + layer = tiny_mixed_array[:, :4, :] + feature_types = [NUMERIC_TAG, NUMERIC_TAG, NUMERIC_TAG, NUMERIC_TAG] + else: + layer = tiny_mixed_array + feature_types = [ + NUMERIC_TAG, + NUMERIC_TAG, + NUMERIC_TAG, + NUMERIC_TAG, + CATEGORICAL_TAG, + CATEGORICAL_TAG, + ] + + n_obs, n_vars, _ = layer.shape + edata = ed.EHRData(shape=(n_obs, n_vars), layers={DEFAULT_TEM_LAYER_NAME: layer}) + edata.var[FEATURE_TYPE_KEY] = feature_types + return edata @pytest.fixture diff --git a/tests/preprocessing/test_imputation.py b/tests/preprocessing/test_imputation.py index a0144c6f..11f64474 100644 --- a/tests/preprocessing/test_imputation.py +++ b/tests/preprocessing/test_imputation.py @@ -278,10 +278,45 @@ def test_simple_impute_invalid_strategy(impute_edata): simple_impute(impute_edata, strategy="invalid_strategy", copy=True) # type: ignore -def test_knn_impute_3D_edata(edata_blob_small): - knn_impute(edata_blob_small, layer="layer_2") - with pytest.raises(ValueError, match=r"only supports 2D data"): - knn_impute(edata_blob_small, layer=DEFAULT_TEM_LAYER_NAME) +@pytest.mark.parametrize("edata_mini_3D_missing_values", [True], indirect=True) +def test_knn_impute_3d_numerical(edata_mini_3D_missing_values): + edata = edata_mini_3D_missing_values.copy() + edata_imputed = knn_impute(edata, layer=DEFAULT_TEM_LAYER_NAME, copy=True) + _base_check_imputation( + edata_mini_3D_missing_values, + edata_imputed, + before_imputation_layer=DEFAULT_TEM_LAYER_NAME, + after_imputation_layer=DEFAULT_TEM_LAYER_NAME, + ) + + +@pytest.mark.parametrize("edata_mini_3D_missing_values", [True], indirect=True) +def test_knn_impute_3d_scikit_backend(edata_mini_3D_missing_values): + edata = edata_mini_3D_missing_values.copy() + edata_imputed = knn_impute(edata, layer=DEFAULT_TEM_LAYER_NAME, copy=True, backend="scikit-learn") + _base_check_imputation( + edata_mini_3D_missing_values, + edata_imputed, + before_imputation_layer=DEFAULT_TEM_LAYER_NAME, + after_imputation_layer=DEFAULT_TEM_LAYER_NAME, + ) + + +def test_knn_impute_3d_var_names_subset(edata_mini_3D_missing_values): + edata = edata_mini_3D_missing_values.copy() + imputed = knn_impute(edata, layer=DEFAULT_TEM_LAYER_NAME, var_names=["1", "2"], copy=True) + edata_imputed = imputed[:, :2].copy() + _base_check_imputation( + edata_mini_3D_missing_values[:, :2], + edata_imputed, + before_imputation_layer=DEFAULT_TEM_LAYER_NAME, + after_imputation_layer=DEFAULT_TEM_LAYER_NAME, + ) + + +def test_knn_impute_3d_layer_none(edata_mini_3D_missing_values): + with pytest.raises(ValueError, match="requires a layer"): + knn_impute(edata_mini_3D_missing_values, copy=True) def test_knn_impute_check_backend(impute_num_edata):