Skip to content

Commit 01aa932

Browse files
committed
[WIP] MLFlow Eval
1 parent 9a6d5ba commit 01aa932

File tree

1 file changed

+182
-0
lines changed

1 file changed

+182
-0
lines changed
Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,182 @@
1+
"""
2+
MLflow evaluation functions for RCA-Annotator.
3+
4+
Provides functions to download annotation files from jumpbox and log
5+
evaluation results to MLflow.
6+
"""
7+
8+
import json
9+
import os
10+
from pathlib import Path
11+
from typing import Any
12+
13+
import mlflow
14+
# from mlflow.entities import Feedback, AssessmentSource, AssessmentSourceType
15+
16+
from jumpbox_io import download_from_jumpbox
17+
18+
19+
def load_annotation(job_id: str) -> dict[str, Any] | None:
20+
"""
21+
Load annotation_draft.json for a given job_id.
22+
23+
Args:
24+
job_id: Job ID to load annotation for
25+
26+
Returns:
27+
Annotation dict if successful, None on failure
28+
"""
29+
annotation_file = Path(".analysis") / job_id / "annotation_draft.json"
30+
31+
if not annotation_file.exists():
32+
print(f" Error: Annotation file not found: {annotation_file}")
33+
return None
34+
35+
try:
36+
with open(annotation_file, "r") as f:
37+
annotation = json.load(f)
38+
return annotation
39+
except json.JSONDecodeError as e:
40+
print(f" Error: Invalid JSON in {annotation_file}: {e}")
41+
return None
42+
except Exception as e:
43+
print(f" Error: Failed to read {annotation_file}: {e}")
44+
return None
45+
46+
def download_annotations_for_eval(
47+
job_ids: list[str], jumpbox_uri: str | None = None
48+
) -> dict[str, dict[str, Any]]:
49+
"""
50+
Download and load annotations for multiple jobs.
51+
52+
Args:
53+
job_ids: List of job IDs to download
54+
jumpbox_uri: JUMPBOX_URI connection string (defaults to env var)
55+
56+
Returns:
57+
Dict mapping job_id to annotation data
58+
"""
59+
annotations = {}
60+
61+
print(f"Downloading annotations for {len(job_ids)} jobs...")
62+
63+
for i, job_id in enumerate(job_ids, 1):
64+
print(f"\n[{i}/{len(job_ids)}] Job {job_id}")
65+
66+
if download_from_jumpbox(job_id, jumpbox_uri):
67+
annotation = load_annotation(job_id)
68+
if annotation:
69+
annotations[job_id] = annotation
70+
print(f" ✓ Loaded annotation")
71+
else:
72+
print(f" ✗ Failed to load annotation")
73+
else:
74+
print(f" ✗ Download failed")
75+
76+
print(f"\n{'='*60}")
77+
print(f"Downloaded {len(annotations)}/{len(job_ids)} annotations")
78+
print(f"{'='*60}")
79+
80+
return annotations
81+
82+
83+
def log_annotation_feedback(trace_id: str, annotations: dict[str, Any]) -> None:
84+
"""Log annotation details as MLflow feedback."""
85+
root_cause = annotations.get("root_cause", {})
86+
if root_cause:
87+
mlflow.log_feedback(
88+
trace_id=trace_id,
89+
name="Root Cause",
90+
value=f"Category: {root_cause.get('category')} \n Confidence: {root_cause.get('confidence')}",
91+
rationale=root_cause.get("summary"),
92+
)
93+
94+
for evidence_item in annotations.get("evidence", []):
95+
mlflow.log_feedback(
96+
trace_id=trace_id,
97+
name="Evidence",
98+
value=f"{evidence_item.get('source')}: {evidence_item.get('message')} \n Confidence {evidence_item.get('confidence')}"
99+
)
100+
101+
for recommendation in annotations.get("recommendations", []):
102+
mlflow.log_feedback(
103+
trace_id=trace_id,
104+
name="Recommendation",
105+
value=f"Priority: {recommendation.get('priority')} \n Action: {recommendation.get('action')}",
106+
rationale=f"File: {recommendation.get('file')}"
107+
)
108+
109+
for alt_diagnosis in annotations.get("alternative_diagnoses", []):
110+
mlflow.log_feedback(
111+
trace_id=trace_id,
112+
name="Alternative Diagnosis",
113+
value=f"Category: {alt_diagnosis.get('category')} \n Summary: {alt_diagnosis.get('summary')}",
114+
rationale=alt_diagnosis.get("why_wrong")
115+
)
116+
117+
for factor in annotations.get("contributing_factors", []):
118+
mlflow.log_feedback(trace_id=trace_id, name="Contributing Factor", value=factor)
119+
120+
for key, value in annotations.get("consistency_check", {}).items():
121+
mlflow.log_feedback(
122+
trace_id=trace_id,
123+
name=f"Consistency Check: {key}",
124+
value=f"{key}: {value}"
125+
)
126+
127+
def evaluate_jobs(job_ids: list[str]) -> None:
128+
"""Run evaluation for the given job IDs."""
129+
tracking_uri = os.environ.get("MLFLOW_TRACKING_URI", "http://localhost:5000")
130+
mlflow.set_tracking_uri(tracking_uri)
131+
experiment_name = os.environ.get("MLFLOW_EXPERIMENT_NAME", "Default")
132+
mlflow.set_experiment(experiment_name)
133+
134+
135+
### TO BE DELETED ###
136+
expected_responses = {
137+
"1234567": {
138+
"root_cause": {
139+
"category": "Example",
140+
"summary": "Example summary",
141+
"confidence": "example",
142+
},
143+
}
144+
}
145+
146+
with mlflow.start_run(run_name="ANNOTATOR_EVALUATION") as run:
147+
# Create a traced span within the run to ensure linkage
148+
with mlflow.start_span(name="download_annotations") as span:
149+
data = download_annotations_for_eval(job_ids)
150+
trace_id = span.request_id
151+
152+
for job_id in job_ids:
153+
if job_id not in data:
154+
print(f"Warning: No data for job {job_id}")
155+
continue
156+
157+
annotations = data[job_id]
158+
159+
# Log feedback for annotation quality metrics
160+
log_annotation_feedback(trace_id, annotations)
161+
162+
# Log ground truth annotation as expectation
163+
expected_response = expected_responses.get(job_id)
164+
if expected_response:
165+
mlflow.log_expectation(
166+
trace_id=trace_id,
167+
name=f"expected_response_{job_id}",
168+
value=expected_response
169+
)
170+
171+
# Log run params
172+
mlflow.log_param("job_ids", job_ids)
173+
if job_ids and job_ids[0] in data:
174+
mlflow.log_param("annotator", data[job_ids[0]].get('annotator'))
175+
176+
print(f"Traces are {trace_id}")
177+
178+
if __name__ == "__main__":
179+
test_job_ids = ["2035512"]
180+
181+
182+
evaluate_jobs(test_job_ids)

0 commit comments

Comments
 (0)