Skip to content

Add 3D support for knn_impute#1041

Merged
sueoglu merged 14 commits intomainfrom
longitudinal/issue-947
Apr 14, 2026
Merged

Add 3D support for knn_impute#1041
sueoglu merged 14 commits intomainfrom
longitudinal/issue-947

Conversation

@sueoglu
Copy link
Copy Markdown
Collaborator

@sueoglu sueoglu commented Mar 27, 2026

fixes #947
ep.pp.knn_impute() is extended to support 3D EHRdata.

Description of changes

  • Updated _knn_impute to handle 3D array by first flattening it along the axis 0 and then passing it to the imputer
  • Added test for 3D numerical imputation using the existing edata_mini_3D_missing_values fixture

Technical details
For 3D array with shape n_obs, n_var, n_t:

  1. The array is sliced to the requested variable and reshaped to (n_obs * n_t, n_var), each patient-timepoint becomes a row
  2. The existing complete column logic runs on the flattened 2D array
  3. After imputation the result is reshaped back to (n_obs, n_t, n_var) and transposed to (n_obs, n_var, n_t), then written back to only the requested variable positions in the layer

@sueoglu
Copy link
Copy Markdown
Collaborator Author

sueoglu commented Apr 9, 2026

Small note:

After our discussion, I used vertical stacking as reshaping strategy. I decided not to use the ehrapy decorator _apply_over_time_axis and to implement the reshaping logic inside _knn_impute because:

  • its designed for functions that accept an array as input where _knn_impute accepts an EHRData object.

  • for 3D input, var_indices and numerical_indices need to be remapped to the flattened array and the decorator has no mechanism to handle this.

  • anchor column logic appends extra complete columns to imputer_data_indices, so X_imputed has more columns than the input X. With the decorator the reshape step would break since the decorator assumes that the shape of input is preserved.

Also here is an evaluation script comparing the two 3D stacking strategies (horizontal vs vertical), where RMSE and MAE of imputed positions in a dataset with 40% missingness are computed against the corresponding positions in the complete dataset to assess imputation quality.

import numpy as np
import ehrdata as ed
import ehrapy as ep
from typing import Iterable
from ehrdata.core.constants import FEATURE_TYPE_KEY, NUMERIC_TAG


def _knn_impute_with_mode(edata, var_names, n_neighbors, layer, temporal_mode):
    edata = edata.copy()

    from fknni import FastKNNImputer

    imputer = FastKNNImputer(n_neighbors=n_neighbors)

    if var_names is None:
        var_names = edata.var_names
    var_indices = edata.var_names.get_indexer(var_names).tolist()

    numerical_var_names = edata.var_names[edata.var[FEATURE_TYPE_KEY] == NUMERIC_TAG]
    numerical_indices = edata.var_names.get_indexer(numerical_var_names).tolist()

    X = edata.X if layer is None else edata.layers[layer]
    var_indices_original = var_indices
    is_3d = False

    if X.ndim == 3:
        is_3d = True
        n_obs, n_vars, n_t = X.shape
        if temporal_mode == "vertical":
            X = X[:, var_indices, :].astype("float64").transpose(0, 2, 1).reshape(n_obs * n_t, len(var_indices))
        else:  # horizontal
            X = X[:, var_indices, :].astype("float64").reshape(n_obs, len(var_indices) * n_t)
        numerical_indices = list(range(X.shape[1]))
        var_indices = numerical_indices

    complete_numerical_columns = np.array(numerical_indices)[
        ~np.isnan(X[:, numerical_indices]).any(axis=0)
    ].tolist()
    imputer_data_indices = var_indices + [i for i in complete_numerical_columns if i not in var_indices]
    imputer_x = X[:, imputer_data_indices].astype("float64")
    X_imputed = imputer.fit_transform(imputer_x)

    if is_3d:
        if temporal_mode == "vertical":
            X_imputed = X_imputed[:, :len(var_indices_original)].reshape(n_obs, n_t, len(var_indices_original)).transpose(0, 2, 1)
        else:
            X_imputed = X_imputed[:, :len(var_indices_original) * n_t].reshape(n_obs, len(var_indices_original), n_t)
        edata.layers[layer][:, var_indices_original, :] = X_imputed
    else:
        if layer is None:
            edata.X[:, imputer_data_indices] = X_imputed
        else:
            edata.layers[layer][:, imputer_data_indices] = X_imputed

    return edata


edata_complete = ed.dt.ehrdata_blobs(
    n_variables=10, missing_values=0.0, n_observations=100,
    base_timepoints=10, random_state=42, seasonality=True
)
truth = edata_complete.layers["tem_data"]

edata = ed.dt.ehrdata_blobs(
    n_variables=10, missing_values=0.4, n_observations=100,
    base_timepoints=10, random_state=42, seasonality=True
)
ed.infer_feature_types(edata) 
mask = np.isnan(edata.layers["tem_data"])

def evaluate(imputed, truth, mask, label):
    diff  = imputed[mask] - truth[mask]
    rmse  = np.sqrt(np.mean(diff**2))
    mae   = np.mean(np.abs(diff))
    scale = truth[mask].max() - truth[mask].min()
    nrmse = rmse / scale
    nmae  = mae  / scale
    print(f"{label:12s} RMSE: {rmse:.4f}, MAE: {mae:.4f}, NRMSE: {nrmse:.4f}, NMAE: {nmae:.4f}")


for mode in ["vertical", "horizontal"]:
    edata_imputed = _knn_impute_with_mode(edata, var_names=None, n_neighbors=5, layer="tem_data", temporal_mode=mode)
    evaluate(edata_imputed.layers["tem_data"], truth, mask, mode)

Output:


vertical     RMSE: 3.7324, MAE: 2.6651, NRMSE: 0.0949, NMAE: 0.0677
horizontal   RMSE: 2.1849, MAE: 1.5606, NRMSE: 0.0555, NMAE: 0.0397

@sueoglu sueoglu requested review from Zethson and eroell and removed request for eroell April 10, 2026 11:56
Copy link
Copy Markdown
Member

@Zethson Zethson left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you!
I think it looks good already but I have some nitpicky comments.

Comment thread tests/preprocessing/test_imputation.py
Comment thread ehrapy/preprocessing/_imputation.py
Comment thread ehrapy/preprocessing/_imputation.py Outdated
Comment thread ehrapy/preprocessing/_imputation.py Outdated
Comment thread ehrapy/preprocessing/_imputation.py Outdated
Comment thread ehrapy/preprocessing/_imputation.py Outdated
Comment thread tests/preprocessing/test_imputation.py Outdated
Comment thread tests/preprocessing/test_imputation.py Outdated
@Zethson Zethson marked this pull request as ready for review April 10, 2026 12:03
Copy link
Copy Markdown
Collaborator

@eroell eroell left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have not much more add, all the earlier discussion points are well described in your PR comments and I think this is a nice cleanup of this imputation function

Comment thread ehrapy/preprocessing/_imputation.py
Comment thread ehrapy/preprocessing/_imputation.py Outdated
@sueoglu sueoglu requested a review from eroell April 14, 2026 07:23
Copy link
Copy Markdown
Collaborator

@eroell eroell left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool!

Can you in the documentation

  • A quick preview where dataset with missing values, then the imputation, and then the missing values gone are shown? Either ehrdata blobs or one of the pre-loaded datasets ideally not requiring any other import than ehrdata and ehrapy
  • Add 1-2 sentences on how this works in the 2D vs 3D case?

No need to re-request review afterwards, when this is done you can merge!

@sueoglu sueoglu merged commit 41e2dfd into main Apr 14, 2026
16 of 19 checks passed
@sueoglu sueoglu deleted the longitudinal/issue-947 branch April 14, 2026 10:45
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Longitudinal knn_impute

3 participants