Skip to content

Commit 4053bc1

Browse files
address comments left by @yisz on #59 (Add SQL Metrics Implementation);
1 parent da8a897 commit 4053bc1

File tree

1 file changed

+11
-3
lines changed

1 file changed

+11
-3
lines changed

continuous_eval/metrics/code/sql/sql_deterministic_metrics.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44

55
from continuous_eval.metrics.base import Metric
66

7-
87
class SQLSyntaxMatch(Metric):
98
"""
109
This metric evaluates the syntactic similarity between the generated SQL query and a set of ground truth queries.
@@ -27,10 +26,19 @@ def __call__(self, answer: str, ground_truth_answers: Union[List[str], str]):
2726

2827
# Compare the formatted answer with each formatted ground truth answer
2928
for formatted_gt in formatted_ground_truths:
30-
# Simple string comparison for now, can be improved with more sophisticated methods
31-
match_score = float(formatted_answer == formatted_gt)
29+
# Replace simple string comparison with AST comparison
30+
match_score = float(self.compare_ast(formatted_answer, formatted_gt))
3231
if match_score > max_match_score:
3332
max_match_score = match_score
3433

3534
# Return the maximum match score
3635
return {"SQL_Syntax_Match_Score": max_match_score}
36+
37+
def compare_ast(self, query1: str, query2: str) -> float:
38+
# Parse the queries into ASTs
39+
ast1 = sqlparse.parse(query1)
40+
ast2 = sqlparse.parse(query2)
41+
42+
# Compare the structure of the ASTs
43+
# This is a placeholder and would need to be replaced with a real implementation
44+
return float(ast1 == ast2)

0 commit comments

Comments
 (0)