Skip to content

Commit da8a897

Browse files
author
Ubuntu
committed
Add initial SQL metrics implementation
1 parent 5c152c2 commit da8a897

File tree

2 files changed

+1657
-1344
lines changed

2 files changed

+1657
-1344
lines changed
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
from typing import List, Union
2+
3+
import sqlparse
4+
5+
from continuous_eval.metrics.base import Metric
6+
7+
8+
class SQLSyntaxMatch(Metric):
9+
"""
10+
This metric evaluates the syntactic similarity between the generated SQL query and a set of ground truth queries.
11+
It uses the sqlparse library to format and compare the SQL queries.
12+
"""
13+
14+
def __call__(self, answer: str, ground_truth_answers: Union[List[str], str]):
15+
if isinstance(ground_truth_answers, str):
16+
ground_truth_answers = [ground_truth_answers]
17+
18+
# Format the answer and ground truth answers using sqlparse for consistent comparison
19+
formatted_answer = sqlparse.format(answer, reindent=True, keyword_case="upper")
20+
formatted_ground_truths = [
21+
sqlparse.format(gt, reindent=True, keyword_case="upper")
22+
for gt in ground_truth_answers
23+
]
24+
25+
# Initialize the maximum match score
26+
max_match_score = 0
27+
28+
# Compare the formatted answer with each formatted ground truth answer
29+
for formatted_gt in formatted_ground_truths:
30+
# Simple string comparison for now, can be improved with more sophisticated methods
31+
match_score = float(formatted_answer == formatted_gt)
32+
if match_score > max_match_score:
33+
max_match_score = match_score
34+
35+
# Return the maximum match score
36+
return {"SQL_Syntax_Match_Score": max_match_score}

0 commit comments

Comments
 (0)