Skip to content

Commit 844e97a

Browse files
committed
MLFlow Eval
1 parent 9a6d5ba commit 844e97a

File tree

1 file changed

+209
-0
lines changed

1 file changed

+209
-0
lines changed
Lines changed: 209 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,209 @@
1+
"""
2+
MLflow evaluation functions for RCA-Annotator.
3+
4+
Provides functions to download annotation files from jumpbox and log
5+
evaluation results to MLflow.
6+
"""
7+
8+
import argparse
9+
import json
10+
import os
11+
import sys
12+
from pathlib import Path
13+
from typing import Any
14+
import sys
15+
import mlflow # type: ignore
16+
17+
from jumpbox_io import download_from_jumpbox
18+
19+
20+
def load_annotation(job_id: str) -> dict[str, Any] | None:
21+
"""
22+
Load annotation_draft.json for a given job_id.
23+
24+
Args:
25+
job_id: Job ID to load annotation for
26+
27+
Returns:
28+
Annotation dict if successful, None on failure
29+
"""
30+
annotation_file = Path(".analysis") / job_id / "annotation_draft.json"
31+
32+
if not annotation_file.exists():
33+
print(f" Error: Annotation file not found: {annotation_file}")
34+
return None
35+
36+
try:
37+
with open(annotation_file, "r") as f:
38+
annotation = json.load(f)
39+
return annotation
40+
except json.JSONDecodeError as e:
41+
print(f" Error: Invalid JSON in {annotation_file}: {e}")
42+
return None
43+
except Exception as e:
44+
print(f" Error: Failed to read {annotation_file}: {e}")
45+
return None
46+
47+
def download_annotations_for_eval(
48+
job_ids: list[str], jumpbox_uri: str | None = None
49+
) -> dict[str, dict[str, Any]]:
50+
"""
51+
Download and load annotations for multiple jobs.
52+
53+
Args:
54+
job_ids: List of job IDs to download
55+
jumpbox_uri: JUMPBOX_URI connection string (defaults to env var)
56+
57+
Returns:
58+
Dict mapping job_id to annotation data
59+
"""
60+
annotations = {}
61+
62+
print(f"Downloading annotations for {len(job_ids)} jobs...")
63+
64+
for i, job_id in enumerate(job_ids, 1):
65+
print(f"\n[{i}/{len(job_ids)}] Job {job_id}")
66+
67+
if download_from_jumpbox(job_id, jumpbox_uri):
68+
annotation = load_annotation(job_id)
69+
if annotation:
70+
annotations[job_id] = annotation
71+
print(f" ✓ Loaded annotation")
72+
else:
73+
print(f" ✗ Failed to load annotation")
74+
else:
75+
print(f" ✗ Download failed")
76+
77+
print(f"\n{'='*60}")
78+
print(f"Downloaded {len(annotations)}/{len(job_ids)} annotations")
79+
print(f"{'='*60}")
80+
81+
return annotations
82+
83+
84+
def log_annotation_feedback(trace_id: str, annotations: dict[str, Any]) -> None:
85+
"""Log annotation details as MLflow feedback."""
86+
root_cause = annotations.get("root_cause", {})
87+
if root_cause:
88+
mlflow.log_feedback(
89+
trace_id=trace_id,
90+
name="Root Cause",
91+
value=f"Category: {root_cause.get('category')} \n Confidence: {root_cause.get('confidence')}",
92+
rationale=root_cause.get("summary"),
93+
)
94+
95+
for evidence_item in annotations.get("evidence", []):
96+
mlflow.log_feedback(
97+
trace_id=trace_id,
98+
name="Evidence",
99+
value=f"{evidence_item.get('source')}: {evidence_item.get('message')} \n Confidence {evidence_item.get('confidence')}"
100+
)
101+
102+
for recommendation in annotations.get("recommendations", []):
103+
mlflow.log_feedback(
104+
trace_id=trace_id,
105+
name="Recommendation",
106+
value=f"Priority: {recommendation.get('priority')} \n Action: {recommendation.get('action')}",
107+
rationale=f"File: {recommendation.get('file')}"
108+
)
109+
110+
for alt_diagnosis in annotations.get("alternative_diagnoses", []):
111+
mlflow.log_feedback(
112+
trace_id=trace_id,
113+
name="Alternative Diagnosis",
114+
value=f"Category: {alt_diagnosis.get('category')} \n Summary: {alt_diagnosis.get('summary')}",
115+
rationale=alt_diagnosis.get("why_wrong")
116+
)
117+
118+
for factor in annotations.get("contributing_factors", []):
119+
mlflow.log_feedback(trace_id=trace_id, name="Contributing Factor", value=factor)
120+
121+
for key, value in annotations.get("consistency_check", {}).items():
122+
mlflow.log_feedback(
123+
trace_id=trace_id,
124+
name=f"Consistency Check: {key}",
125+
value=f"{key}: {value}"
126+
)
127+
128+
def log_expectation(trace_id: str, annotations: dict[str, Any]) -> None:
129+
"""Log expectations for the given annotations."""
130+
human_review = annotations.get("human_review", {})
131+
if human_review:
132+
# Category review
133+
mlflow.log_expectation(
134+
trace_id=trace_id,
135+
name="Human Review",
136+
value=f"Summary accurate: {human_review.get('summary_accurate')} \n Summary comment: {human_review.get('summary_comment')}"
137+
)
138+
139+
# Summary review
140+
mlflow.log_expectation(
141+
trace_id=trace_id,
142+
name="Summary (Human Review)",
143+
value=f"Summary accurate: {human_review.get('summary_accurate')}"
144+
)
145+
146+
# Evidence review
147+
mlflow.log_expectation(
148+
trace_id=trace_id,
149+
name="Evidence ",
150+
value=human_review.get("evidence_feedback")
151+
)
152+
153+
# Difficulty review
154+
mlflow.log_expectation(
155+
trace_id=trace_id,
156+
name="Difficulty",
157+
value=f"Difficulty appropriate: {human_review.get('difficulty_appropriate')}"
158+
)
159+
160+
# Alternative diagnoses added by human reviewer
161+
for alt_diagnosis in human_review.get("alternative_diagnoses_added", []):
162+
mlflow.log_expectation(
163+
trace_id=trace_id,
164+
name="Alternative Diagnosis Added",
165+
value=f"Category: {alt_diagnosis.get('category')} | Plausibility: {alt_diagnosis.get('plausibility')} \nSummary: {alt_diagnosis.get('summary')}"
166+
)
167+
168+
def evaluate_jobs(job_ids: list[str]) -> None:
169+
"""Run evaluation for the given job IDs."""
170+
tracking_uri = os.environ.get("MLFLOW_TRACKING_URI", "http://localhost:5000")
171+
mlflow.set_tracking_uri(tracking_uri)
172+
experiment_name = os.environ.get("MLFLOW_EXPERIMENT_NAME", "Default")
173+
mlflow.set_experiment(experiment_name)
174+
175+
with mlflow.start_run(run_name="ANNOTATOR_EVALUATION") as run:
176+
# Create a traced span within the run to ensure linkage
177+
with mlflow.start_span(name="download_annotations") as span:
178+
data = download_annotations_for_eval(job_ids)
179+
trace_id = span.request_id
180+
181+
for job_id in job_ids:
182+
if job_id not in data:
183+
print(f"Warning: No data for job {job_id}")
184+
continue
185+
186+
annotations = data[job_id]
187+
188+
# Log feedback for annotation quality metrics
189+
log_annotation_feedback(trace_id, annotations)
190+
191+
# Log ground truth annotation as expectation
192+
log_expectation(trace_id, annotations)
193+
194+
# Log run params
195+
mlflow.log_param("job_ids", job_ids)
196+
if job_ids and job_ids[0] in data:
197+
mlflow.log_param("annotator", data[job_ids[0]].get('annotator'))
198+
199+
print(f"Traces are {trace_id}")
200+
201+
def main():
202+
parser = argparse.ArgumentParser(description="Evaluate RCA annotations using MLflow.")
203+
parser.add_argument("job_ids", nargs="+", help="List of job IDs to evaluate")
204+
args = parser.parse_args()
205+
evaluate_jobs(args.job_ids)
206+
207+
208+
if __name__ == "__main__":
209+
sys.exit(main())

0 commit comments

Comments
 (0)