diff --git a/v6-kaplan-meier-py/central.py b/v6-kaplan-meier-py/central.py index 0b38ee3..9f9a953 100644 --- a/v6-kaplan-meier-py/central.py +++ b/v6-kaplan-meier-py/central.py @@ -1,14 +1,9 @@ -""" -This file contains all central algorithm functions. It is important to note -that the central method is executed on a node, just like any other method. - -The results in a return statement are sent to the vantage6 server (after -encryption if that is enabled). -""" +from typing import Union import pandas as pd +import numpy as np -from typing import Dict, List, Union +from scipy import stats from vantage6.algorithm.client import AlgorithmClient from vantage6.algorithm.tools.util import info, error from vantage6.algorithm.tools.decorators import algorithm_client @@ -23,8 +18,8 @@ def kaplan_meier_central( client: AlgorithmClient, time_column_name: str, censor_column_name: str, - organizations_to_include: List[int] | None = None, -) -> Dict[str, Union[str, List[str]]]: + organizations_to_include: list[int] | None = None, +) -> dict[str, Union[str, list[str]]]: """ Central part of the Federated Kaplan-Meier curve computation. @@ -95,13 +90,42 @@ def kaplan_meier_central( km["hazard"] = km["observed"] / km["at_risk"] km["survival_cdf"] = (1 - km["hazard"]).cumprod() + info("Computing confidence intervals") + n_i = km["at_risk"] + S_t = km["survival_cdf"] + d_i = km["observed"] + cumulative_var = 0 + ci_bounds = [] + + if n_i > d_i and n_i > 0: + cumulative_var += d_i / (n_i * (n_i - d_i)) + + # Calculate confidence interval using cumulative variance + if n_i > d_i and n_i > 0 and S_t > 0: + std_err = S_t * np.sqrt(cumulative_var) + z = stats.norm.ppf(1 - 0.05 / 2) # 95% CI + + # Use log-log transformation consistently for all cases + theta = np.log(-np.log(S_t)) + se_theta = std_err / (S_t * np.abs(np.log(S_t))) + lower = np.exp(-np.exp(theta + z * se_theta)) + upper = np.exp(-np.exp(theta - z * se_theta)) + + ci_bounds.append((lower, upper)) + else: + # If we have no information, use the previous bounds or (0,1) for first point + if ci_bounds: + ci_bounds.append(ci_bounds[-1]) + else: + ci_bounds.append((0, 1)) + info("Kaplan-Meier curve computed") return km.to_json() def _start_partial_and_collect_results( - client: AlgorithmClient, method: str, organizations_to_include: List[int], **kwargs -) -> List[Dict[str, Union[str, List[str]]]]: + client: AlgorithmClient, method: str, organizations_to_include: list[int], **kwargs +) -> list[dict[str, Union[str, list[str]]]]: """ Launches a partial task to multiple organizations and collects their results when ready. @@ -112,14 +136,14 @@ def _start_partial_and_collect_results( The vantage6 client used for communication with the server. method : str The method/function to be executed as a subtask by the organizations. - organization_ids : List[int] + organization_ids : list[int] A list of organization IDs to which the subtask will be distributed. **kwargs : dict Additional keyword arguments to be passed to the method/function. Returns ------- - List[Dict[str, Union[str, List[str]]]] + list[dict[str, Union[str, list[str]]]] A list of dictionaries containing results obtained from the organizations. """ info(f"Including {len(organizations_to_include)} organizations in the analysis") diff --git a/v6-kaplan-meier-py/partial.py b/v6-kaplan-meier-py/partial.py index 0998a00..0622cb9 100644 --- a/v6-kaplan-meier-py/partial.py +++ b/v6-kaplan-meier-py/partial.py @@ -1,8 +1,10 @@ import re + +from typing import List + import pandas as pd import numpy as np -from typing import List from vantage6.algorithm.tools.util import get_env_var, info, warn, error from vantage6.algorithm.tools.decorators import data from vantage6.algorithm.tools.exceptions import InputError, EnvironmentVariableError