Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 59 additions & 13 deletions ehrapy/preprocessing/_imputation.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,6 @@ def simple_impute(

@_check_feature_types
@use_ehrdata(deprecated_after="1.0.0")
@function_2D_only()
@spinner("Performing KNN impute")
def knn_impute(
edata: EHRData | AnnData,
Expand All @@ -285,6 +284,8 @@ def knn_impute(
"""Imputes missing values in the input data object using K-nearest neighbor imputation.

If required, the data needs to be properly encoded as this imputation requires numerical data only.
For 2D data, if layer is `None`, `edata.X` is used directly.
For 3D data, the layer is flattened along axis 0 before imputation and reshaped back to 3D afterwards.

.. warning::
Currently, both `n_neighbours` and `n_neighbors` are accepted as parameters for the number of neighbors.
Expand All @@ -296,7 +297,7 @@ def knn_impute(
var_names: A list of variable names indicating which columns to impute.
If `None`, all columns are imputed. Default is `None`.
n_neighbors: Number of neighbors to use when performing the imputation.
layer: The layer to impute.
layer: The layer to impute. Required when the input data is 3D.
copy: Whether to perform the imputation on a copy of the original data object.
If `True`, the original object remains unmodified.
backend: The implementation to use for the KNN imputation.
Expand All @@ -317,13 +318,29 @@ def knn_impute(
Examples:
>>> import ehrdata as ed
>>> import ehrapy as ep
>>> edata = ed.dt.mimic_2()
>>> ed.infer_feature_types(edata)
>>> ep.pp.knn_impute(edata)
>>> edata_3d = ed.dt.ehrdata_blobs(n_variables=3, n_observations=3, base_timepoints=2, missing_values=0.3)
>>> edata_imputed = ep.pp.knn_impute(edata_3d, layer="tem_data", copy=True)

Example Output:

>>> edata_3d.layers["tem_data"][0, :, :]
[[-12.12732884, -18.37304373],
[ nan, -0.91339411],
[ nan, -7.88514984]]
>>> edata_imputed.layers["tem_data"][0, :, :]
[[-12.12732884, -18.37304373],
[ -0.07689509, -0.91339411],
[ -2.75584421, -7.88514984]]

"""
if copy:
edata = edata.copy()

if edata.X is None and layer is None: # if edata is 3D
Comment thread
eroell marked this conversation as resolved.
raise ValueError(
"3D imputation requires a layer to be specified. Pass the layer containing the full temporal data."
)

_warn_imputation_threshold(edata, var_names, threshold=warning_threshold, layer=layer)

if backend not in {"scikit-learn", "faiss"}:
Expand Down Expand Up @@ -387,15 +404,44 @@ def _knn_impute(
"Can only impute numerical data. Try to restrict imputation to certain columns using "
"var_names parameter or perform an encoding of your data."
)
X = edata.X if layer is None else edata.layers[layer]
complete_numerical_columns = np.array(numerical_indices)[~np.isnan(X[:, numerical_indices]).any(axis=0)].tolist()
imputer_data_indices = var_indices + [i for i in complete_numerical_columns if i not in var_indices]
imputer_x = X[::, imputer_data_indices].astype("float64")

if layer is None:
edata.X[::, imputer_data_indices] = imputer.fit_transform(imputer_x)
mtx = edata.X if layer is None else edata.layers[layer]
var_indices_original = var_indices
is_3d = False
Comment thread
sueoglu marked this conversation as resolved.
input_dtype = mtx.dtype if np.issubdtype(mtx.dtype, np.floating) else np.float64

# if input data is 3D, flatten along axis 0 before passing it to the imputer: each timepoint becomes a row
if mtx.ndim == 3:
is_3d = True
n_obs, n_vars, n_t = mtx.shape
mtx = (
mtx[:, var_indices, :]
.astype(input_dtype, copy=True)
.transpose(0, 2, 1)
.reshape(n_obs * n_t, len(var_indices))
)
numerical_indices = list(range(len(var_indices)))
var_indices = numerical_indices

# complete columns to be used as anchors
complete_numerical_columns = np.array(numerical_indices)[~np.isnan(mtx[:, numerical_indices]).any(axis=0)].tolist()

imputer_data_indices = var_indices + [
column for column in complete_numerical_columns if column not in var_indices
] # columns to impute
imputer_x = mtx[:, imputer_data_indices].astype(input_dtype, copy=True)
X_imputed = imputer.fit_transform(imputer_x)

if is_3d:
# slice back to only requested columns and transpose back to n_obs, n_var, n_t
X_imputed = (
X_imputed[:, : len(var_indices_original)].reshape(n_obs, n_t, len(var_indices_original)).transpose(0, 2, 1)
)
edata.layers[layer][:, var_indices_original, :] = X_imputed
else:
edata.layers[layer][::, imputer_data_indices] = imputer.fit_transform(imputer_x)
if layer is None:
edata.X[:, imputer_data_indices] = X_imputed
else:
edata.layers[layer][:, imputer_data_indices] = X_imputed


@use_ehrdata(deprecated_after="1.0.0")
Expand Down
25 changes: 22 additions & 3 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,9 @@ def edata_mini():


@pytest.fixture
def edata_mini_3D_missing_values():
def edata_mini_3D_missing_values(request):
only_numerical = getattr(request, "param", False)

tiny_mixed_array = np.array(
[
[[138, 139], [78, np.nan], [77, 76], [1, 2], ["A", "B"], ["Yes", np.nan]],
Expand All @@ -195,8 +197,25 @@ def edata_mini_3D_missing_values():
],
dtype=object,
)
n_obs, n_vars, _ = tiny_mixed_array.shape
return ed.EHRData(shape=(n_obs, n_vars), layers={DEFAULT_TEM_LAYER_NAME: tiny_mixed_array})

if only_numerical:
layer = tiny_mixed_array[:, :4, :]
feature_types = [NUMERIC_TAG, NUMERIC_TAG, NUMERIC_TAG, NUMERIC_TAG]
else:
layer = tiny_mixed_array
feature_types = [
NUMERIC_TAG,
NUMERIC_TAG,
NUMERIC_TAG,
NUMERIC_TAG,
CATEGORICAL_TAG,
CATEGORICAL_TAG,
]

n_obs, n_vars, _ = layer.shape
edata = ed.EHRData(shape=(n_obs, n_vars), layers={DEFAULT_TEM_LAYER_NAME: layer})
edata.var[FEATURE_TYPE_KEY] = feature_types
return edata


@pytest.fixture
Expand Down
43 changes: 39 additions & 4 deletions tests/preprocessing/test_imputation.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,10 +278,45 @@ def test_simple_impute_invalid_strategy(impute_edata):
simple_impute(impute_edata, strategy="invalid_strategy", copy=True) # type: ignore


def test_knn_impute_3D_edata(edata_blob_small):
knn_impute(edata_blob_small, layer="layer_2")
with pytest.raises(ValueError, match=r"only supports 2D data"):
knn_impute(edata_blob_small, layer=DEFAULT_TEM_LAYER_NAME)
@pytest.mark.parametrize("edata_mini_3D_missing_values", [True], indirect=True)
def test_knn_impute_3d_numerical(edata_mini_3D_missing_values):
edata = edata_mini_3D_missing_values.copy()
edata_imputed = knn_impute(edata, layer=DEFAULT_TEM_LAYER_NAME, copy=True)
_base_check_imputation(
edata_mini_3D_missing_values,
edata_imputed,
before_imputation_layer=DEFAULT_TEM_LAYER_NAME,
after_imputation_layer=DEFAULT_TEM_LAYER_NAME,
)


@pytest.mark.parametrize("edata_mini_3D_missing_values", [True], indirect=True)
def test_knn_impute_3d_scikit_backend(edata_mini_3D_missing_values):
edata = edata_mini_3D_missing_values.copy()
edata_imputed = knn_impute(edata, layer=DEFAULT_TEM_LAYER_NAME, copy=True, backend="scikit-learn")
_base_check_imputation(
edata_mini_3D_missing_values,
edata_imputed,
before_imputation_layer=DEFAULT_TEM_LAYER_NAME,
after_imputation_layer=DEFAULT_TEM_LAYER_NAME,
)


def test_knn_impute_3d_var_names_subset(edata_mini_3D_missing_values):
edata = edata_mini_3D_missing_values.copy()
Comment thread
sueoglu marked this conversation as resolved.
imputed = knn_impute(edata, layer=DEFAULT_TEM_LAYER_NAME, var_names=["1", "2"], copy=True)
edata_imputed = imputed[:, :2].copy()
_base_check_imputation(
edata_mini_3D_missing_values[:, :2],
edata_imputed,
before_imputation_layer=DEFAULT_TEM_LAYER_NAME,
after_imputation_layer=DEFAULT_TEM_LAYER_NAME,
)


def test_knn_impute_3d_layer_none(edata_mini_3D_missing_values):
with pytest.raises(ValueError, match="requires a layer"):
knn_impute(edata_mini_3D_missing_values, copy=True)


def test_knn_impute_check_backend(impute_num_edata):
Expand Down
Loading