Skip to content

Commit 8124a22

Browse files
authored
Add deterministic SQL metrics (#63)
1 parent 344f7e9 commit 8124a22

File tree

13 files changed

+597
-268
lines changed

13 files changed

+597
-268
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ print(metric(**datum))
122122
<tr>
123123
<td rowspan="2">Code Generation</td>
124124
<td>Deterministic</td>
125-
<td>CodeStringMatch, PythonASTSimilarity</td>
125+
<td>CodeStringMatch, PythonASTSimilarity, SQLSyntaxMatch, SQLASTSimilarity</td>
126126
</tr>
127127
<tr>
128128
<td>LLM-based</td>

continuous_eval/metrics/code/__init__.py

Whitespace-only changes.

continuous_eval/metrics/code/python/__init__.py

Lines changed: 0 additions & 1 deletion
This file was deleted.

continuous_eval/metrics/code/python/code_deterministic_metrics.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
from typing import List, Union
33

44
from munkres import Munkres
5+
from sqlglot import diff, parse_one
6+
from sqlglot.diff import Keep
57
from thefuzz import fuzz
68

79
from continuous_eval.metrics.base import Metric
Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
import logging
2+
from dataclasses import dataclass
3+
from typing import Dict, List, Optional, Union
4+
5+
from sqlglot import diff, parse_one, transpile
6+
from sqlglot.diff import Insert, Keep, Move, Remove, Update
7+
from sqlglot.optimizer import optimize
8+
9+
from continuous_eval.metrics.base import Metric
10+
11+
logger = logging.getLogger("metrics")
12+
13+
14+
@dataclass(frozen=True)
15+
class ASTDiffWeightConfig:
16+
"""
17+
Configuration for assigning weights to different types of changes in the AST diff.
18+
Higher weights indicate more significant changes, which are expected to have a greater impact on query semantics.
19+
"""
20+
21+
keep: float = 0.0
22+
# Updates are significant as they imply a modification in function or value.
23+
update: float = 1.5
24+
# Inserts affect the structure and content but are simpler than updates.
25+
insert: float = 1.0
26+
# Removes affect the structure and content but are simpler than updates.
27+
remove: float = 1.0
28+
# Moves are generally less impactful as they simply change the order.
29+
move: float = 0.5
30+
# Default weight for other types of changes
31+
default: float = 1.0
32+
33+
34+
class _SQLMetric:
35+
def __init__(self, optimize: bool = False, schema: Optional[Dict] = None):
36+
self._optimize = optimize
37+
self._schema = schema
38+
39+
def _prepare_query(self, sql: str):
40+
"""
41+
Parse, transpile, and optionally optimize a SQL query.
42+
"""
43+
formatted_sql = transpile(sql, pretty=True, comments=False)[0]
44+
if self._optimize:
45+
try:
46+
optimized_sql = optimize(parse_one(formatted_sql), schema=self._schema).sql(pretty=True)
47+
return optimized_sql
48+
except Exception as e:
49+
logger.warning(f"Failed to optimize SQL query given schema: {e}. Using unoptimized query.")
50+
return formatted_sql
51+
return formatted_sql
52+
53+
54+
class SQLSyntaxMatch(Metric, _SQLMetric):
55+
"""
56+
This metric evaluates the syntactic similarity between the generated SQL query and a set of ground truth queries.
57+
It uses the sqlglot library to format and compare the SQL queries.
58+
"""
59+
60+
def __init__(self, optimize: bool = False, schema: Optional[Dict] = None):
61+
super(SQLSyntaxMatch, self).__init__()
62+
_SQLMetric.__init__(self, optimize=optimize, schema=schema)
63+
64+
def __call__(self, answer: str, ground_truth_answers: Union[List[str], str]):
65+
66+
transformed_answer = self._prepare_query(answer)
67+
transformed_ground_truths = [self._prepare_query(gt) for gt in ground_truth_answers]
68+
69+
max_match_score = 0.0
70+
71+
for transformed_gt in transformed_ground_truths:
72+
match_score = float(transformed_answer == transformed_gt)
73+
if match_score > max_match_score:
74+
max_match_score = match_score
75+
76+
return {"SQL_Syntax_Match": max_match_score}
77+
78+
79+
class SQLASTSimilarity(Metric, _SQLMetric):
80+
"""
81+
Compare SQL queries using AST similarity, considering different types of changes differently and improving normalization.
82+
"""
83+
84+
def __init__(
85+
self,
86+
optimize: bool = False,
87+
schema: Optional[Dict] = None,
88+
diff_weights: ASTDiffWeightConfig = ASTDiffWeightConfig(),
89+
):
90+
super(SQLASTSimilarity, self).__init__()
91+
_SQLMetric.__init__(self, optimize=optimize, schema=schema)
92+
self._diff_weights = diff_weights
93+
94+
def __call__(self, answer: str, ground_truth_answers: Union[List[str], str], **kwargs):
95+
96+
transformed_answer = self._prepare_query(answer)
97+
transformed_ground_truths = [self._prepare_query(gt) for gt in ground_truth_answers]
98+
99+
try:
100+
answer_tree = parse_one(transformed_answer)
101+
ground_truth_trees = [parse_one(gt) for gt in transformed_ground_truths]
102+
except Exception:
103+
return {"SQL_AST_Similarity": -1.0}
104+
105+
similarity_scores = [
106+
self._calculate_similarity(answer_tree, ground_truth_tree) for ground_truth_tree in ground_truth_trees
107+
]
108+
109+
return {
110+
"SQL_AST_Similarity": max(similarity_scores) if similarity_scores else -1.0,
111+
}
112+
113+
def _calculate_similarity(self, tree1, tree2):
114+
diff_result = diff(tree1, tree2)
115+
total_changes = sum(self._apply_weights(change) for change in diff_result)
116+
max_nodes = max(len(list(tree1.walk())), len(list(tree2.walk())))
117+
similarity_score = 1 - (total_changes / max_nodes) if max_nodes > 0 else 1
118+
return similarity_score
119+
120+
def _apply_weights(self, change):
121+
"""
122+
Assign weights to different types of changes based on their expected impact on query semantics.
123+
"""
124+
if isinstance(change, Keep):
125+
return self._diff_weights.keep
126+
elif isinstance(change, Update):
127+
return self._diff_weights.update
128+
elif isinstance(change, Insert):
129+
return self._diff_weights.insert
130+
elif isinstance(change, Remove):
131+
return self._diff_weights.remove
132+
elif isinstance(change, Move):
133+
return self._diff_weights.move
134+
else:
135+
return self._diff_weights.default

docs/src/content/docs/metrics/Code/Deterministic/string_match.md renamed to docs/src/content/docs/metrics/Code/Deterministic/code_string_match.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
---
2-
title: StringMatch
2+
title: Code String Match
33
sidebar:
44
order: 1
55
---
@@ -18,7 +18,7 @@ It outputs both the binary exact match score and the fuzzy match score in the ra
1818
Required data items: `answer`, `ground_truth_answers`
1919

2020
```python
21-
from continuous_eval.metrics.code.python import CodeStringMatch
21+
from continuous_eval.metrics.code import CodeStringMatch
2222

2323
datum = {
2424
"answer": "def function(x, y):\n return x + y",

docs/src/content/docs/metrics/Code/Deterministic/python_ast_similarity.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ The metric depends on syntactically correct Python scripts to produce the Abstra
2121
Required data items: `answer`, `ground_truth_answers`
2222

2323
```python
24-
from continuous_eval.metrics.code.python import PythonASTSimilarity
24+
from continuous_eval.metrics.code import PythonASTSimilarity
2525

2626
datum = {
2727
"answer": "def function(x, y):\n return x + y",
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
---
2+
title: SQL AST Similarity
3+
sidebar:
4+
order: 1
5+
---
6+
7+
### Definitions
8+
9+
**SQL AST Similarity** compares the structure of two SQL queries by analyzing their Abstract Syntax Trees (ASTs). This metric assesses similarity by matching the nodes within these trees, taking into account the statement types and their arrangement. Different types of tree differences (such as insert, remove, update, move, etc.) are weighted differently to calculate the final similarity score.
10+
11+
<br>
12+
13+
$$
14+
\text{SQL AST Similarity} = 1 - \frac{\text{Total Weight Changes}}{\text{Maximum Possible Nodes}}
15+
$$
16+
17+
<br>
18+
19+
:::note
20+
The metric depends on syntactically correct SQL queries to produce the Abstract Syntax Trees (ASTs). If the scripts contain syntax errors and cannot be parsed, the metric will yield a score of -1.0.
21+
:::
22+
23+
<br>
24+
25+
### Example Usage
26+
27+
Required data items: `answer`, `ground_truth_answers`
28+
29+
```python
30+
from continuous_eval.metrics.code import SQLASTSimilarity
31+
32+
datum = {
33+
"answer": "SELECT name, age FROM customers",
34+
"ground_truth_answers": ["SELECT age, name FROM customers"],
35+
},
36+
37+
metric = SQLASTSimilarity()
38+
print(metric(**datum))
39+
```
40+
41+
You can optionally initialize the metric to use optimized SQL queries using the [sqlglot optimizer](https://github.com/tobymao/sqlglot?tab=readme-ov-file#sql-optimizer) and optionally pass in the schema. For example:
42+
```python
43+
schema={"x": {"A": "INT", "B": "INT", "C": "INT", "D": "INT", "Z": "STRING"}}
44+
sql_syntax_match_optimized = SQLASTSimilarity(optimized=True, schema=schema)
45+
```
46+
47+
You can also customize weights to different types of nodes in the AST diff.
48+
Higher weights indicate more significant changes, which are expected to have a greater impact on query semantics.
49+
50+
```python
51+
from continuous_eval.metrics.code.sql.deterministic import ASTDiffWeightConfig
52+
53+
weights = ASTDiffWeightConfig(
54+
keep_weight=0.0,
55+
update_weight=2,
56+
insert_weight=1.0,
57+
remove_weight=1.5,
58+
move_weight=0,
59+
default_weight=0,
60+
)
61+
ASTSimilarity = SQLASTSimilarity(diff_weights=weights)
62+
```
63+
64+
### Example Output
65+
66+
```JSON
67+
{
68+
"SQL_AST_Similarity": 0.9375
69+
}
70+
```
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
---
2+
title: SQL Syntax Match
3+
sidebar:
4+
order: 2
5+
---
6+
7+
## Definitions
8+
9+
**SQL Syntax Match** evaluates the syntactic equivalence between generated SQL queries and a set of ground truth queries. The strict comparison can tolerate formatting changes.
10+
11+
## Example Usage
12+
13+
Required data items: `answer`, `ground_truth_answers`
14+
15+
```python
16+
from continuous_eval.metrics.code import SQLSyntaxMatch
17+
18+
sql_syntax_match = SQLSyntaxMatch()
19+
20+
datum = {
21+
"answer": "SELECT * FROM users;"",
22+
"ground_truth_answers": [
23+
"SELECT * from users;"
24+
],
25+
},
26+
27+
metric = SQLSyntaxMatch()
28+
print(metric(**datum))
29+
```
30+
31+
You can optionally initialize the metric to use optimized SQL queries using the [sqlglot optimizer](https://github.com/tobymao/sqlglot?tab=readme-ov-file#sql-optimizer) and optionally pass in the schema. For example:
32+
```python
33+
schema={"x": {"A": "INT", "B": "INT", "C": "INT", "D": "INT", "Z": "STRING"}}
34+
sql_syntax_match_optimized = SQLSyntaxSimilarity(optimized=True, schema=schema)
35+
```
36+
37+
## Example Output
38+
39+
```JSON
40+
{
41+
"SQL_Syntax_Match": 1.0
42+
}
43+
```

0 commit comments

Comments
 (0)