-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathrun_local_benchmark.py
More file actions
executable file
·119 lines (92 loc) · 3.65 KB
/
run_local_benchmark.py
File metadata and controls
executable file
·119 lines (92 loc) · 3.65 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
#!/usr/bin/env python3
"""
Run evaluation benchmark locally with Databricks connection.
Usage:
python3 run_local_benchmark.py
python3 run_local_benchmark.py --artifact-type tables --batch-size 5
"""
import os
import sys
import argparse
import logging
from pathlib import Path
from dotenv import load_dotenv
# Load .env from project root
env_path = Path(__file__).parent / ".env"
if env_path.exists():
load_dotenv(env_path)
print(f"✅ Loaded .env from: {env_path}")
# Add src to path
sys.path.insert(0, str(Path(__file__).parent / "src"))
import mlflow
from databricks_langchain import ChatDatabricks
from artifact_translation_package.evaluation import run_benchmark, ModelConfig
from artifact_translation_package.evaluation.model_benchmark import create_default_model_configs
def get_experiment_name(custom_name: str = None) -> str:
"""Get experiment name, auto-detecting user if not provided."""
if custom_name:
return custom_name
try:
from databricks.sdk import WorkspaceClient
username = WorkspaceClient().current_user.me().user_name
return f"/Users/{username}/sql-translation-benchmark"
except Exception:
return "/Shared/sql-translation-benchmark"
def setup_mlflow(experiment_name: str) -> str:
"""Configure MLflow with Databricks tracking."""
mlflow.set_tracking_uri("databricks")
try:
mlflow.set_experiment(experiment_name)
except Exception:
mlflow.create_experiment(experiment_name)
mlflow.set_experiment(experiment_name)
print(f"✅ MLflow experiment: {experiment_name}")
return experiment_name
def parse_args():
"""Parse command line arguments."""
parser = argparse.ArgumentParser(description="SQL translation benchmark")
parser.add_argument("--artifact-type", default="tables", help="Artifact type")
parser.add_argument("--dataset-source", help="Dataset JSON path")
parser.add_argument("--experiment-name", help="MLflow experiment name")
parser.add_argument("--batch-size", type=int, default=5, help="Batch size")
parser.add_argument("--judge-endpoint", default="databricks-llama-4-maverick")
parser.add_argument("--models", nargs="+", help="Model endpoints to test")
return parser.parse_args()
def main():
"""Run benchmark."""
args = parse_args()
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
print("\n" + "="*70)
print("🚀 SQL Translation Benchmark")
print("="*70)
experiment_name = setup_mlflow(get_experiment_name(args.experiment_name))
# Configure models
if args.models:
model_configs = [
ModelConfig(name=ep, endpoint=ep, temperature=0.1, max_tokens=4000)
for ep in args.models
]
else:
model_configs = create_default_model_configs()
print(f"Models: {[c.name for c in model_configs]}")
print(f"Artifact: {args.artifact_type}, Batch: {args.batch_size}")
try:
results_df = run_benchmark(
artifact_type=args.artifact_type,
dataset_source=args.dataset_source,
experiment_name=experiment_name,
model_configs=model_configs,
batch_size=args.batch_size,
judge_endpoint=args.judge_endpoint
)
print("\n" + "="*70)
print("✅ Benchmark Complete!")
print("="*70)
print(results_df.to_string())
print(f"\n🔗 View results: {experiment_name}")
except Exception as e:
print(f"\n❌ Failed: {e}")
logging.error("Benchmark error", exc_info=True)
sys.exit(1)
if __name__ == "__main__":
main()