|
| 1 | +from typing import List, Union |
| 2 | + |
| 3 | +import sqlparse |
| 4 | + |
| 5 | +from continuous_eval.metrics.base import Metric |
| 6 | + |
| 7 | + |
| 8 | +class SQLSyntaxMatch(Metric): |
| 9 | + """ |
| 10 | + This metric evaluates the syntactic similarity between the generated SQL query and a set of ground truth queries. |
| 11 | + It uses the sqlparse library to format and compare the SQL queries. |
| 12 | + """ |
| 13 | + |
| 14 | + def __call__(self, answer: str, ground_truth_answers: Union[List[str], str]): |
| 15 | + if isinstance(ground_truth_answers, str): |
| 16 | + ground_truth_answers = [ground_truth_answers] |
| 17 | + |
| 18 | + # Format the answer and ground truth answers using sqlparse for consistent comparison |
| 19 | + formatted_answer = sqlparse.format(answer, reindent=True, keyword_case="upper") |
| 20 | + formatted_ground_truths = [ |
| 21 | + sqlparse.format(gt, reindent=True, keyword_case="upper") |
| 22 | + for gt in ground_truth_answers |
| 23 | + ] |
| 24 | + |
| 25 | + # Initialize the maximum match score |
| 26 | + max_match_score = 0 |
| 27 | + |
| 28 | + # Compare the formatted answer with each formatted ground truth answer |
| 29 | + for formatted_gt in formatted_ground_truths: |
| 30 | + # Simple string comparison for now, can be improved with more sophisticated methods |
| 31 | + match_score = float(formatted_answer == formatted_gt) |
| 32 | + if match_score > max_match_score: |
| 33 | + max_match_score = match_score |
| 34 | + |
| 35 | + # Return the maximum match score |
| 36 | + return {"SQL_Syntax_Match_Score": max_match_score} |
0 commit comments