Skip to content

Commit 128c5f8

Browse files
author
Alexander Hillsley
committed
scripts to evaluate models
1 parent 8d130d3 commit 128c5f8

File tree

3 files changed

+611
-0
lines changed

3 files changed

+611
-0
lines changed
Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
import yaml
2+
3+
import pandas as pd
4+
import torch
5+
6+
from ops_model.data import data_loader
7+
8+
NONFEATURE_COLUMNS = [
9+
"label_str",
10+
"label_int",
11+
"sgRNA",
12+
"well",
13+
"experiment",
14+
"x_position",
15+
"y_position",
16+
]
17+
18+
19+
def eval_classification_accuracy(
20+
scores_df: pd.DataFrame, labels_df: pd.DataFrame, label_column: str = None
21+
) -> dict:
22+
"""
23+
Evaluate classification accuracy given prediction scores and true labels.
24+
25+
Args:
26+
scores_df: DataFrame with classification scores, shape (n_samples, n_classes).
27+
labels_df: DataFrame containing the true integer labels.
28+
label_column: Name of the column in labels_df containing the labels.
29+
30+
Returns:
31+
dict: Dictionary containing accuracy metrics.
32+
"""
33+
# Convert to tensors
34+
scores = torch.from_numpy(scores_df.values)
35+
36+
# Extract true labels
37+
if label_column is not None:
38+
labels = torch.from_numpy(labels_df[label_column].values)
39+
elif len(labels_df.columns) == 1:
40+
labels = torch.from_numpy(labels_df.iloc[:, 0].values)
41+
else:
42+
labels = torch.from_numpy(labels_df.values.flatten())
43+
44+
# Get predictions and calculate accuracy
45+
predictions = torch.argmax(scores, dim=1)
46+
accuracy = (predictions == labels).float().mean().item()
47+
48+
return {
49+
"accuracy": accuracy,
50+
"correct": (predictions == labels).sum().item(),
51+
"total": len(labels),
52+
"predictions": predictions.numpy(),
53+
}
54+
55+
return
56+
57+
58+
def cnn_inference(config_path: str):
59+
with open(config_path, "r") as f:
60+
config = yaml.safe_load(f)
61+
62+
experiment_dict = {
63+
"ops0031_20250424": ["A/1/0", "A/2/0", "A/3/0"],
64+
# "ops0053_20250709": ["A/1/0", "A/2/0", "A/3/0"],
65+
# "ops0079_20250916": ["A/1/0", "A/2/0", "A/3/0"],
66+
# "ops0064_20250811": ["A/1/0", "A/2/0", "A/3/0"],
67+
# "ops0065_20250812": ["A/1/0", "A/2/0", "A/3/0"],
68+
}
69+
run_name = config["run_name"]
70+
71+
dm = data_loader.OpsDataManager(
72+
experiments=experiment_dict,
73+
batch_size=config["data_manager"]["batch_size"],
74+
data_split=tuple(config["data_manager"]["data_split"]),
75+
out_channels=config["data_manager"]["out_channels"],
76+
initial_yx_patch_size=tuple(config["data_manager"]["initial_yx_patch_size"]),
77+
final_yx_patch_size=tuple(config["data_manager"]["final_yx_patch_size"]),
78+
verbose=False,
79+
)
80+
81+
# Construct dataloaders first (without sampler) to get the train indices
82+
dm.construct_dataloaders(
83+
num_workers=config["data_manager"]["num_workers"],
84+
dataset_type=config["dataset_type"],
85+
basic_kwargs=config["data_manager"].get("basic_kwargs"),
86+
balanced_sampling=config["data_manager"].get("balanced_sampling", False),
87+
)
88+
return
89+
90+
91+
if __name__ == "__main__":
92+
save_path = "/hpc/projects/intracellular_dashboard/ops/ops0031_20250424/3-assembly/cytoself_features/cytoself_classification_scores.csv"
93+
df = pd.read_csv(save_path)
94+
labels_df = df[["label_int"]]
95+
classification_scores = df.drop(columns=NONFEATURE_COLUMNS)
96+
97+
a = eval_classification_accuracy(
98+
scores_df=classification_scores, labels_df=labels_df, label_column="label_int"
99+
)
Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
import yaml
2+
from pathlib import Path
3+
4+
import pandas as pd
5+
import torch
6+
import lightning as L
7+
8+
from ops_model.data import data_loader
9+
from ops_model.models import cytoself_model
10+
11+
torch.multiprocessing.set_sharing_strategy("file_system")
12+
13+
14+
def run_inference(
15+
config_path: str,
16+
checkpoint_path: str,
17+
output_path: str,
18+
):
19+
output_path = Path(output_path)
20+
21+
with open(config_path, "r") as f:
22+
config = yaml.safe_load(f)
23+
24+
model_type = config["model_type"]
25+
dataset_type = config["dataset_type"]
26+
27+
data_manager = data_loader.OpsDataManager(
28+
experiments=config["data_manager"]["experiments"],
29+
batch_size=config["data_manager"]["batch_size"],
30+
data_split=(0, 0, 1),
31+
out_channels=config["data_manager"]["out_channels"],
32+
initial_yx_patch_size=tuple(config["data_manager"]["initial_yx_patch_size"]),
33+
final_yx_patch_size=tuple(config["data_manager"]["final_yx_patch_size"]),
34+
)
35+
data_manager.construct_dataloaders(
36+
num_workers=config["data_manager"]["num_workers"],
37+
dataset_type=dataset_type,
38+
basic_kwargs=config["data_manager"].get("basic_kwargs"),
39+
triplet_kwargs=config["data_manager"].get("triplet_kwargs"),
40+
)
41+
42+
test_loader = data_manager.test_loader
43+
44+
torch.set_float32_matmul_precision("medium") # huge boost in speed
45+
46+
mc = config["model"]
47+
lit_model = cytoself_model.LitCytoSelf.load_from_checkpoint(
48+
checkpoint_path,
49+
emb_shapes=(
50+
tuple(mc["embedding_shapes"][0]),
51+
tuple(mc["embedding_shapes"][1]),
52+
),
53+
vq_args=mc["vq_args"],
54+
num_class=mc["num_classes"],
55+
input_shape=tuple(mc["input_shape"]),
56+
output_shape=tuple(mc["input_shape"]),
57+
fc_input_type=mc["fc_input_type"],
58+
fc_output_idx=[mc["fc_output_index"]],
59+
)
60+
pred_writer = cytoself_model.CytoselfPredictionWriter
61+
62+
trainer = L.Trainer(
63+
devices=1,
64+
accelerator="gpu",
65+
callbacks=[
66+
pred_writer(
67+
output_dir=output_path,
68+
write_interval="batch",
69+
int_label_lut=data_manager.int_label_lut,
70+
)
71+
],
72+
# limit_predict_batches=2
73+
)
74+
predictions = trainer.predict(lit_model, dataloaders=test_loader)
75+
76+
aggregate_csvs(
77+
chunk_subdir=output_path / "emb_2_chunks",
78+
final_csv_name="cytoself_local_features.csv",
79+
)
80+
aggregate_csvs(
81+
chunk_subdir=output_path / "classification_scores",
82+
final_csv_name="cytoself_classification_scores.csv",
83+
)
84+
aggregate_csvs(
85+
chunk_subdir=output_path / "global_emb_metadata",
86+
final_csv_name="cytoself_global_metadata.csv",
87+
)
88+
89+
return
90+
91+
92+
def aggregate_csvs(
93+
chunk_subdir: Path,
94+
final_csv_name: str,
95+
):
96+
print(f"\nLoading and concatenating chunks from {chunk_subdir.name}...")
97+
csv_files = sorted(chunk_subdir.glob("*.csv"))
98+
99+
if not csv_files:
100+
print("No feature files found!")
101+
return None
102+
103+
df_list = [pd.read_csv(csv_file) for csv_file in csv_files]
104+
final_df = pd.concat(df_list, ignore_index=True)
105+
106+
# Save the final concatenated dataframe
107+
final_path = chunk_subdir.parent / final_csv_name
108+
final_df.to_csv(final_path, index=False)
109+
print(f"Saved final concatenated features to {final_path}")
110+
print(f"Final dataframe shape: {final_df.shape}")
111+
112+
return
113+
114+
115+
if __name__ == "__main__":
116+
checkpoint_path = "/hpc/projects/intracellular_dashboard/ops/models/logs/cytoself/cytoself_20251202_2/cytoself_20251202_2-2025-12-04-global_step=0.000000-val/total_loss=330.02.ckpt"
117+
config_path = "/hpc/mydata/alexander.hillsley/ops/ops_model/configs/cytoself/cytoself_20251204.yml"
118+
output_path = "/hpc/projects/intracellular_dashboard/ops/ops0031_20250424/3-assembly/cytoself_features"
119+
run_inference(
120+
config_path=config_path,
121+
checkpoint_path=checkpoint_path,
122+
output_path=output_path,
123+
)

0 commit comments

Comments
 (0)