Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 38 additions & 14 deletions v6-kaplan-meier-py/central.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,9 @@
"""
This file contains all central algorithm functions. It is important to note
that the central method is executed on a node, just like any other method.

The results in a return statement are sent to the vantage6 server (after
encryption if that is enabled).
"""
from typing import Union

import pandas as pd
import numpy as np

from typing import Dict, List, Union
from scipy import stats
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

scipy is not included in requirements.txt - maybe it is a dependency of one of the other ones? Not sure if you need to include it.
Same for numpy and pandas actually

from vantage6.algorithm.client import AlgorithmClient
from vantage6.algorithm.tools.util import info, error
from vantage6.algorithm.tools.decorators import algorithm_client
Expand All @@ -23,8 +18,8 @@ def kaplan_meier_central(
client: AlgorithmClient,
time_column_name: str,
censor_column_name: str,
organizations_to_include: List[int] | None = None,
) -> Dict[str, Union[str, List[str]]]:
organizations_to_include: list[int] | None = None,
) -> dict[str, Union[str, list[str]]]:
"""
Central part of the Federated Kaplan-Meier curve computation.

Expand Down Expand Up @@ -95,13 +90,42 @@ def kaplan_meier_central(
km["hazard"] = km["observed"] / km["at_risk"]
km["survival_cdf"] = (1 - km["hazard"]).cumprod()

info("Computing confidence intervals")
n_i = km["at_risk"]
S_t = km["survival_cdf"]
d_i = km["observed"]
cumulative_var = 0
ci_bounds = []

if n_i > d_i and n_i > 0:
cumulative_var += d_i / (n_i * (n_i - d_i))

# Calculate confidence interval using cumulative variance
if n_i > d_i and n_i > 0 and S_t > 0:
std_err = S_t * np.sqrt(cumulative_var)
z = stats.norm.ppf(1 - 0.05 / 2) # 95% CI
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had to think about this code. Isn't it more easily readable to put CI_95_PCT_TWOSIDED = 0.975 and then use that one?


# Use log-log transformation consistently for all cases
theta = np.log(-np.log(S_t))
se_theta = std_err / (S_t * np.abs(np.log(S_t)))
lower = np.exp(-np.exp(theta + z * se_theta))
upper = np.exp(-np.exp(theta - z * se_theta))

ci_bounds.append((lower, upper))
else:
# If we have no information, use the previous bounds or (0,1) for first point
if ci_bounds:
ci_bounds.append(ci_bounds[-1])
else:
ci_bounds.append((0, 1))

info("Kaplan-Meier curve computed")
return km.to_json()


def _start_partial_and_collect_results(
client: AlgorithmClient, method: str, organizations_to_include: List[int], **kwargs
) -> List[Dict[str, Union[str, List[str]]]]:
client: AlgorithmClient, method: str, organizations_to_include: list[int], **kwargs
) -> list[dict[str, Union[str, list[str]]]]:
"""
Launches a partial task to multiple organizations and collects their results when
ready.
Expand All @@ -112,14 +136,14 @@ def _start_partial_and_collect_results(
The vantage6 client used for communication with the server.
method : str
The method/function to be executed as a subtask by the organizations.
organization_ids : List[int]
organization_ids : list[int]
A list of organization IDs to which the subtask will be distributed.
**kwargs : dict
Additional keyword arguments to be passed to the method/function.

Returns
-------
List[Dict[str, Union[str, List[str]]]]
list[dict[str, Union[str, list[str]]]]
A list of dictionaries containing results obtained from the organizations.
"""
info(f"Including {len(organizations_to_include)} organizations in the analysis")
Expand Down
4 changes: 3 additions & 1 deletion v6-kaplan-meier-py/partial.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import re

from typing import List

import pandas as pd
import numpy as np

from typing import List
from vantage6.algorithm.tools.util import get_env_var, info, warn, error
from vantage6.algorithm.tools.decorators import data
from vantage6.algorithm.tools.exceptions import InputError, EnvironmentVariableError
Expand Down