|
| 1 | +import logging |
| 2 | +from dataclasses import dataclass |
| 3 | +from typing import Dict, List, Optional, Union |
| 4 | + |
| 5 | +from sqlglot import diff, parse_one, transpile |
| 6 | +from sqlglot.diff import Insert, Keep, Move, Remove, Update |
| 7 | +from sqlglot.optimizer import optimize |
| 8 | + |
| 9 | +from continuous_eval.metrics.base import Metric |
| 10 | + |
| 11 | +logger = logging.getLogger("metrics") |
| 12 | + |
| 13 | + |
| 14 | +@dataclass(frozen=True) |
| 15 | +class ASTDiffWeightConfig: |
| 16 | + """ |
| 17 | + Configuration for assigning weights to different types of changes in the AST diff. |
| 18 | + Higher weights indicate more significant changes, which are expected to have a greater impact on query semantics. |
| 19 | + """ |
| 20 | + |
| 21 | + keep: float = 0.0 |
| 22 | + # Updates are significant as they imply a modification in function or value. |
| 23 | + update: float = 1.5 |
| 24 | + # Inserts affect the structure and content but are simpler than updates. |
| 25 | + insert: float = 1.0 |
| 26 | + # Removes affect the structure and content but are simpler than updates. |
| 27 | + remove: float = 1.0 |
| 28 | + # Moves are generally less impactful as they simply change the order. |
| 29 | + move: float = 0.5 |
| 30 | + # Default weight for other types of changes |
| 31 | + default: float = 1.0 |
| 32 | + |
| 33 | + |
| 34 | +class _SQLMetric: |
| 35 | + def __init__(self, optimize: bool = False, schema: Optional[Dict] = None): |
| 36 | + self._optimize = optimize |
| 37 | + self._schema = schema |
| 38 | + |
| 39 | + def _prepare_query(self, sql: str): |
| 40 | + """ |
| 41 | + Parse, transpile, and optionally optimize a SQL query. |
| 42 | + """ |
| 43 | + formatted_sql = transpile(sql, pretty=True, comments=False)[0] |
| 44 | + if self._optimize: |
| 45 | + try: |
| 46 | + optimized_sql = optimize(parse_one(formatted_sql), schema=self._schema).sql(pretty=True) |
| 47 | + return optimized_sql |
| 48 | + except Exception as e: |
| 49 | + logger.warning(f"Failed to optimize SQL query given schema: {e}. Using unoptimized query.") |
| 50 | + return formatted_sql |
| 51 | + return formatted_sql |
| 52 | + |
| 53 | + |
| 54 | +class SQLSyntaxMatch(Metric, _SQLMetric): |
| 55 | + """ |
| 56 | + This metric evaluates the syntactic similarity between the generated SQL query and a set of ground truth queries. |
| 57 | + It uses the sqlglot library to format and compare the SQL queries. |
| 58 | + """ |
| 59 | + |
| 60 | + def __init__(self, optimize: bool = False, schema: Optional[Dict] = None): |
| 61 | + super(SQLSyntaxMatch, self).__init__() |
| 62 | + _SQLMetric.__init__(self, optimize=optimize, schema=schema) |
| 63 | + |
| 64 | + def __call__(self, answer: str, ground_truth_answers: Union[List[str], str]): |
| 65 | + |
| 66 | + transformed_answer = self._prepare_query(answer) |
| 67 | + transformed_ground_truths = [self._prepare_query(gt) for gt in ground_truth_answers] |
| 68 | + |
| 69 | + max_match_score = 0.0 |
| 70 | + |
| 71 | + for transformed_gt in transformed_ground_truths: |
| 72 | + match_score = float(transformed_answer == transformed_gt) |
| 73 | + if match_score > max_match_score: |
| 74 | + max_match_score = match_score |
| 75 | + |
| 76 | + return {"SQL_Syntax_Match": max_match_score} |
| 77 | + |
| 78 | + |
| 79 | +class SQLASTSimilarity(Metric, _SQLMetric): |
| 80 | + """ |
| 81 | + Compare SQL queries using AST similarity, considering different types of changes differently and improving normalization. |
| 82 | + """ |
| 83 | + |
| 84 | + def __init__( |
| 85 | + self, |
| 86 | + optimize: bool = False, |
| 87 | + schema: Optional[Dict] = None, |
| 88 | + diff_weights: ASTDiffWeightConfig = ASTDiffWeightConfig(), |
| 89 | + ): |
| 90 | + super(SQLASTSimilarity, self).__init__() |
| 91 | + _SQLMetric.__init__(self, optimize=optimize, schema=schema) |
| 92 | + self._diff_weights = diff_weights |
| 93 | + |
| 94 | + def __call__(self, answer: str, ground_truth_answers: Union[List[str], str], **kwargs): |
| 95 | + |
| 96 | + transformed_answer = self._prepare_query(answer) |
| 97 | + transformed_ground_truths = [self._prepare_query(gt) for gt in ground_truth_answers] |
| 98 | + |
| 99 | + try: |
| 100 | + answer_tree = parse_one(transformed_answer) |
| 101 | + ground_truth_trees = [parse_one(gt) for gt in transformed_ground_truths] |
| 102 | + except Exception: |
| 103 | + return {"SQL_AST_Similarity": -1.0} |
| 104 | + |
| 105 | + similarity_scores = [ |
| 106 | + self._calculate_similarity(answer_tree, ground_truth_tree) for ground_truth_tree in ground_truth_trees |
| 107 | + ] |
| 108 | + |
| 109 | + return { |
| 110 | + "SQL_AST_Similarity": max(similarity_scores) if similarity_scores else -1.0, |
| 111 | + } |
| 112 | + |
| 113 | + def _calculate_similarity(self, tree1, tree2): |
| 114 | + diff_result = diff(tree1, tree2) |
| 115 | + total_changes = sum(self._apply_weights(change) for change in diff_result) |
| 116 | + max_nodes = max(len(list(tree1.walk())), len(list(tree2.walk()))) |
| 117 | + similarity_score = 1 - (total_changes / max_nodes) if max_nodes > 0 else 1 |
| 118 | + return similarity_score |
| 119 | + |
| 120 | + def _apply_weights(self, change): |
| 121 | + """ |
| 122 | + Assign weights to different types of changes based on their expected impact on query semantics. |
| 123 | + """ |
| 124 | + if isinstance(change, Keep): |
| 125 | + return self._diff_weights.keep |
| 126 | + elif isinstance(change, Update): |
| 127 | + return self._diff_weights.update |
| 128 | + elif isinstance(change, Insert): |
| 129 | + return self._diff_weights.insert |
| 130 | + elif isinstance(change, Remove): |
| 131 | + return self._diff_weights.remove |
| 132 | + elif isinstance(change, Move): |
| 133 | + return self._diff_weights.move |
| 134 | + else: |
| 135 | + return self._diff_weights.default |
0 commit comments