|
| 1 | +import random |
| 2 | +import string |
| 3 | +from sqloxide import parse_sql |
| 4 | +from substrait.builders.extended_expression import ( |
| 5 | + UnboundExtendedExpression, |
| 6 | + column, |
| 7 | + scalar_function, |
| 8 | + literal, |
| 9 | + aggregate_function, |
| 10 | + window_function, |
| 11 | +) |
| 12 | +from substrait.builders.plan import ( |
| 13 | + read_named_table, |
| 14 | + project, |
| 15 | + filter, |
| 16 | + sort, |
| 17 | + fetch, |
| 18 | + set, |
| 19 | + join, |
| 20 | + aggregate, |
| 21 | +) |
| 22 | +from substrait.gen.proto import type_pb2 as stt |
| 23 | +from substrait.gen.proto import algebra_pb2 as stalg |
| 24 | +from substrait.extension_registry import ExtensionRegistry |
| 25 | +from typing import Callable |
| 26 | +from deepdiff import DeepDiff |
| 27 | + |
| 28 | +SchemaResolver = Callable[[str], stt.NamedStruct] |
| 29 | + |
| 30 | +function_mapping = { |
| 31 | + "Plus": ("functions_arithmetic.yaml", "add"), |
| 32 | + "Minus": ("functions_arithmetic.yaml", "subtract"), |
| 33 | + "Gt": ("functions_comparison.yaml", "gt"), |
| 34 | + "GtEq": ("functions_comparison.yaml", "gte"), |
| 35 | + "Lt": ("functions_comparison.yaml", "lt"), |
| 36 | + "Eq": ("functions_comparison.yaml", "equal"), |
| 37 | +} |
| 38 | + |
| 39 | +aggregate_function_mapping = {"SUM": ("functions_arithmetic.yaml", "sum")} |
| 40 | + |
| 41 | +window_function_mapping = { |
| 42 | + "row_number": ("functions_arithmetic.yaml", "row_number"), |
| 43 | +} |
| 44 | + |
| 45 | + |
| 46 | +def compare_dicts(dict1, dict2): |
| 47 | + diff = DeepDiff(dict1, dict2, exclude_regex_paths=["span"]) |
| 48 | + return len(diff) == 0 |
| 49 | + |
| 50 | + |
| 51 | +def translate_expression( |
| 52 | + ast: dict, |
| 53 | + schema_resolver: SchemaResolver, |
| 54 | + registry: ExtensionRegistry, |
| 55 | + measures: list[UnboundExtendedExpression], |
| 56 | + groupings: list[dict], |
| 57 | + alias: str = None, |
| 58 | +) -> UnboundExtendedExpression: |
| 59 | + assert len(ast) == 1 |
| 60 | + op = list(ast.keys())[0] |
| 61 | + |
| 62 | + if groupings: |
| 63 | + # This means we are parsing a projection after a grouping |
| 64 | + # Loop through used groupings for an identical ast and return it rather than recalculate |
| 65 | + for i, f in enumerate(groupings): |
| 66 | + if compare_dicts(ast, f): |
| 67 | + return column(i, alias=alias) |
| 68 | + |
| 69 | + ast = ast[op] |
| 70 | + |
| 71 | + if op == "Identifier": |
| 72 | + return column(ast["value"], alias=alias) |
| 73 | + elif op == "UnnamedExpr" or op == "expr" or op == "Unnamed" or op == "Expr": |
| 74 | + return translate_expression( |
| 75 | + ast, |
| 76 | + schema_resolver=schema_resolver, |
| 77 | + registry=registry, |
| 78 | + measures=measures, |
| 79 | + groupings=groupings, |
| 80 | + ) |
| 81 | + elif op == "ExprWithAlias": |
| 82 | + return translate_expression( |
| 83 | + ast["expr"], |
| 84 | + schema_resolver=schema_resolver, |
| 85 | + registry=registry, |
| 86 | + measures=measures, |
| 87 | + groupings=groupings, |
| 88 | + alias=ast["alias"]["value"], |
| 89 | + ) |
| 90 | + elif op == "BinaryOp": |
| 91 | + expressions = [ |
| 92 | + translate_expression( |
| 93 | + ast["left"], |
| 94 | + schema_resolver=schema_resolver, |
| 95 | + registry=registry, |
| 96 | + measures=measures, |
| 97 | + groupings=groupings, |
| 98 | + ), |
| 99 | + translate_expression( |
| 100 | + ast["right"], |
| 101 | + schema_resolver=schema_resolver, |
| 102 | + registry=registry, |
| 103 | + measures=measures, |
| 104 | + groupings=groupings, |
| 105 | + ), |
| 106 | + ] |
| 107 | + func = function_mapping[ast["op"]] |
| 108 | + return scalar_function(func[0], func[1], expressions=expressions, alias=alias) |
| 109 | + elif op == "Value": |
| 110 | + return literal( |
| 111 | + int(ast["value"]["Number"][0]), stt.Type(i64=stt.Type.I64()), alias=alias |
| 112 | + ) # TODO infer type |
| 113 | + elif op == "Function": |
| 114 | + expressions = [ |
| 115 | + translate_expression( |
| 116 | + e, |
| 117 | + schema_resolver=schema_resolver, |
| 118 | + registry=registry, |
| 119 | + measures=measures, |
| 120 | + groupings=groupings, |
| 121 | + ) |
| 122 | + for e in ast["args"]["List"]["args"] |
| 123 | + ] |
| 124 | + name = ast["name"][0]["Identifier"]["value"] |
| 125 | + |
| 126 | + if name in function_mapping: |
| 127 | + func = function_mapping[name] |
| 128 | + return scalar_function(func[0], func[1], *expressions, alias=alias) |
| 129 | + elif name in aggregate_function_mapping: |
| 130 | + # All measures need to be extracted out because substrait calculates measures in a separate rel |
| 131 | + # We generate a random name for the measure and return a column with that name for the projection to work |
| 132 | + # Start by checking if multiple measures are identical and reuse previously generated name |
| 133 | + for m in measures: |
| 134 | + if compare_dicts(ast, m[1]): |
| 135 | + return column(m[2], alias=alias) |
| 136 | + |
| 137 | + func = aggregate_function_mapping[name] |
| 138 | + random_name = "".join( |
| 139 | + random.choices(string.ascii_uppercase + string.digits, k=5) |
| 140 | + ) # TODO make this deterministic |
| 141 | + aggr = aggregate_function(func[0], func[1], expressions, alias=random_name) |
| 142 | + measures.append((aggr, ast, random_name)) |
| 143 | + return column(random_name, alias=alias) |
| 144 | + elif name in window_function_mapping: |
| 145 | + func = window_function_mapping[name] |
| 146 | + |
| 147 | + partitions = [ |
| 148 | + translate_expression( |
| 149 | + e, |
| 150 | + schema_resolver=schema_resolver, |
| 151 | + registry=registry, |
| 152 | + measures=measures, |
| 153 | + groupings=groupings, |
| 154 | + ) |
| 155 | + for e in ast["over"]["WindowSpec"]["partition_by"] |
| 156 | + ] |
| 157 | + |
| 158 | + return window_function( |
| 159 | + func[0], func[1], expressions, partitions=partitions, alias=alias |
| 160 | + ) |
| 161 | + |
| 162 | + else: |
| 163 | + raise Exception(f"Unknown function {name}") |
| 164 | + # elif op == "Wildcard": |
| 165 | + # return wildcard() |
| 166 | + else: |
| 167 | + raise Exception(f"Unknown op {op}") |
| 168 | + |
| 169 | + |
| 170 | +def translate(ast: dict, schema_resolver: SchemaResolver, registry: ExtensionRegistry): |
| 171 | + assert len(ast) == 1 |
| 172 | + op = list(ast.keys())[0] |
| 173 | + ast = ast[op] |
| 174 | + |
| 175 | + if op == "Query": |
| 176 | + relation = translate( |
| 177 | + ast["body"], schema_resolver=schema_resolver, registry=registry |
| 178 | + ) |
| 179 | + |
| 180 | + if ast["order_by"]: |
| 181 | + expressions = [ |
| 182 | + translate_expression( |
| 183 | + e["expr"], |
| 184 | + schema_resolver=schema_resolver, |
| 185 | + registry=registry, |
| 186 | + measures=None, |
| 187 | + groupings=None, |
| 188 | + ) |
| 189 | + for e in ast["order_by"]["kind"]["Expressions"] |
| 190 | + ] |
| 191 | + relation = sort(relation, expressions)(registry) |
| 192 | + |
| 193 | + if ast["limit_clause"]: |
| 194 | + limit_expression = translate_expression( |
| 195 | + ast["limit_clause"]["LimitOffset"]["limit"], |
| 196 | + schema_resolver=schema_resolver, |
| 197 | + registry=registry, |
| 198 | + measures=None, |
| 199 | + groupings=None, |
| 200 | + ) |
| 201 | + |
| 202 | + if ast["limit_clause"]["LimitOffset"]["offset"]: |
| 203 | + offset_expression = translate_expression( |
| 204 | + ast["limit_clause"]["LimitOffset"]["offset"]["value"], |
| 205 | + schema_resolver=schema_resolver, |
| 206 | + registry=registry, |
| 207 | + measures=None, |
| 208 | + groupings=None, |
| 209 | + ) |
| 210 | + else: |
| 211 | + offset_expression = None |
| 212 | + |
| 213 | + relation = fetch(relation, offset_expression, limit_expression)(registry) |
| 214 | + |
| 215 | + return relation |
| 216 | + elif op == "Select": |
| 217 | + relation = translate( |
| 218 | + ast["from"][0]["relation"], |
| 219 | + schema_resolver=schema_resolver, |
| 220 | + registry=registry, |
| 221 | + ) |
| 222 | + |
| 223 | + if ast["from"][0]["joins"]: |
| 224 | + for _join in ast["from"][0]["joins"]: |
| 225 | + join_type_mapping = { |
| 226 | + "Inner": stalg.JoinRel.JOIN_TYPE_INNER, |
| 227 | + "Left": stalg.JoinRel.JOIN_TYPE_LEFT, |
| 228 | + "LeftOuter": stalg.JoinRel.JOIN_TYPE_LEFT, |
| 229 | + "RightOuter": stalg.JoinRel.JOIN_TYPE_RIGHT, |
| 230 | + "Right": stalg.JoinRel.JOIN_TYPE_RIGHT, |
| 231 | + } |
| 232 | + right = translate( |
| 233 | + _join["relation"], |
| 234 | + schema_resolver=schema_resolver, |
| 235 | + registry=registry, |
| 236 | + ) |
| 237 | + |
| 238 | + join_type = list(_join["join_operator"].keys())[0] |
| 239 | + |
| 240 | + expression = translate_expression( |
| 241 | + _join["join_operator"][join_type]["On"], |
| 242 | + schema_resolver=schema_resolver, |
| 243 | + registry=registry, |
| 244 | + measures=None, |
| 245 | + groupings=None, |
| 246 | + ) |
| 247 | + |
| 248 | + relation = join( |
| 249 | + relation, right, expression, join_type_mapping[join_type] |
| 250 | + )(registry) |
| 251 | + |
| 252 | + if "selection" in ast and ast["selection"]: |
| 253 | + where_expression = translate_expression( |
| 254 | + ast["selection"], |
| 255 | + schema_resolver=schema_resolver, |
| 256 | + registry=registry, |
| 257 | + measures=None, |
| 258 | + groupings=None, |
| 259 | + ) |
| 260 | + relation = filter(relation, where_expression)(registry) |
| 261 | + |
| 262 | + if ast["group_by"] and ast["group_by"]["Expressions"][0]: |
| 263 | + groupings = ast["group_by"]["Expressions"][0] |
| 264 | + grouping_expressions = [ |
| 265 | + translate_expression( |
| 266 | + e, |
| 267 | + schema_resolver=schema_resolver, |
| 268 | + registry=registry, |
| 269 | + measures=None, |
| 270 | + groupings=None, |
| 271 | + ) |
| 272 | + for e in groupings |
| 273 | + ] |
| 274 | + else: |
| 275 | + groupings = [] |
| 276 | + grouping_expressions = [] |
| 277 | + |
| 278 | + measures = [] |
| 279 | + |
| 280 | + projection = [ |
| 281 | + translate_expression( |
| 282 | + p, |
| 283 | + schema_resolver=schema_resolver, |
| 284 | + registry=registry, |
| 285 | + measures=measures, |
| 286 | + groupings=groupings, |
| 287 | + ) |
| 288 | + for p in ast["projection"] |
| 289 | + ] |
| 290 | + |
| 291 | + if ast["having"]: |
| 292 | + having_predicate = translate_expression( |
| 293 | + ast["having"], |
| 294 | + schema_resolver=schema_resolver, |
| 295 | + registry=registry, |
| 296 | + measures=measures, |
| 297 | + groupings=[], |
| 298 | + ) |
| 299 | + else: |
| 300 | + having_predicate = None |
| 301 | + |
| 302 | + if measures or groupings: |
| 303 | + relation = aggregate( |
| 304 | + relation, grouping_expressions, [e[0] for e in measures] |
| 305 | + )(registry) |
| 306 | + |
| 307 | + if having_predicate: |
| 308 | + relation = filter(relation, having_predicate)(registry) |
| 309 | + |
| 310 | + return project(relation, expressions=projection)(registry) |
| 311 | + elif op == "Table": |
| 312 | + name = ast["name"][0]["Identifier"]["value"] |
| 313 | + return read_named_table(name, schema_resolver(name)) |
| 314 | + elif op == "SetOperation": |
| 315 | + # TODO more than 2 inputs to a set operation |
| 316 | + left = translate( |
| 317 | + ast["left"], schema_resolver=schema_resolver, registry=registry |
| 318 | + ) |
| 319 | + right = translate( |
| 320 | + ast["right"], schema_resolver=schema_resolver, registry=registry |
| 321 | + ) |
| 322 | + if ast["op"] == "Union": |
| 323 | + set_op = ( |
| 324 | + stalg.SetRel.SET_OP_UNION_ALL |
| 325 | + if ast["set_quantifier"] == "All" |
| 326 | + else stalg.SetRel.SET_OP_UNION_DISTINCT |
| 327 | + ) |
| 328 | + else: |
| 329 | + raise Exception("") |
| 330 | + |
| 331 | + return set([left, right], set_op)(registry) |
| 332 | + else: |
| 333 | + raise Exception(f"Unknown op {op}") |
| 334 | + |
| 335 | + |
| 336 | +def convert(query: str, dialect: str, schema_resolver: SchemaResolver): |
| 337 | + ast = parse_sql(sql=query, dialect=dialect)[0] |
| 338 | + registry = ExtensionRegistry(load_default_extensions=True) |
| 339 | + return translate(ast, schema_resolver=schema_resolver, registry=registry) |
0 commit comments