Skip to content

Commit 544c4f5

Browse files
Update snowdialect.py
Implement optimization changes suggested by sfc-gh-aling. Thank you!
1 parent 4ffa3ae commit 544c4f5

File tree

1 file changed

+66
-161
lines changed

1 file changed

+66
-161
lines changed

src/snowflake/sqlalchemy/snowdialect.py

Lines changed: 66 additions & 161 deletions
Original file line numberDiff line numberDiff line change
@@ -318,35 +318,22 @@ def get_check_constraints(self, connection, table_name, schema, **kw):
318318
return []
319319

320320
@reflection.cache
321-
def _get_table_primary_keys(self, connection, schema, table_name, **kw):
322-
fully_qualified_path = self._denormalize_quote_join(
323-
schema, self.denormalize_name(table_name)
324-
)
325-
result = connection.execute(
326-
text(
327-
f"SHOW /* sqlalchemy:_get_table_primary_keys */ PRIMARY KEYS IN TABLE {fully_qualified_path}"
321+
def _get_schema_primary_keys(self, connection, schema, table_name=None, **kw):
322+
if table_name is not None:
323+
fully_qualified_path = self._denormalize_quote_join(
324+
schema, self.denormalize_name(table_name)
328325
)
329-
)
330-
ans = {}
331-
for row in result:
332-
table_name = self.normalize_name(row._mapping["table_name"])
333-
if table_name not in ans:
334-
ans[table_name] = {
335-
"constrained_columns": [],
336-
"name": self.normalize_name(row._mapping["constraint_name"]),
337-
}
338-
ans[table_name]["constrained_columns"].append(
339-
self.normalize_name(row._mapping["column_name"])
326+
result = connection.execute(
327+
text(
328+
f"SHOW /* sqlalchemy:_get_schema_primary_keys */ PRIMARY KEYS IN TABLE {fully_qualified_path}"
329+
)
340330
)
341-
return ans
342-
343-
@reflection.cache
344-
def _get_schema_primary_keys(self, connection, schema, **kw):
345-
result = connection.execute(
346-
text(
347-
f"SHOW /* sqlalchemy:_get_schema_primary_keys */PRIMARY KEYS IN SCHEMA {schema}"
331+
else:
332+
result = connection.execute(
333+
text(
334+
f"SHOW /* sqlalchemy:_get_schema_primary_keys */PRIMARY KEYS IN SCHEMA {schema}"
335+
)
348336
)
349-
)
350337
ans = {}
351338
for row in result:
352339
table_name = self.normalize_name(row._mapping["table_name"])
@@ -368,54 +355,28 @@ def get_pk_constraint(self, connection, table_name, schema=None, **kw):
368355
full_schema_name = self._denormalize_quote_join(
369356
current_database, schema if schema else current_schema
370357
)
371-
372-
if table_name is not None:
373-
return self._get_table_primary_keys(
374-
connection,
375-
self.denormalize_name(full_schema_name),
376-
self.denormalize_name(table_name),
377-
**kw,
378-
).get(table_name, {"constrained_columns": [], "name": None})
379-
else:
380-
return self._get_schema_primary_keys(
381-
connection, self.denormalize_name(full_schema_name), **kw
382-
).get(table_name, {"constrained_columns": [], "name": None})
358+
return self._get_schema_primary_keys(
359+
connection,
360+
self.denormalize_name(full_schema_name),
361+
table_name=self.denormalize_name(table_name),
362+
**kw,
363+
).get(table_name, {"constrained_columns": [], "name": None})
383364

384365
@reflection.cache
385-
def _get_table_unique_constraints(self, connection, schema, table_name, **kw):
386-
fully_qualified_path = self._denormalize_quote_join(schema, table_name)
387-
result = connection.execute(
388-
text(
389-
f"SHOW /* sqlalchemy:_get_table_unique_constraints */ UNIQUE KEYS IN TABLE {fully_qualified_path}"
366+
def _get_schema_unique_constraints(self, connection, schema, table_name=None, **kw):
367+
if table_name is not None:
368+
fully_qualified_path = self._denormalize_quote_join(schema, table_name)
369+
result = connection.execute(
370+
text(
371+
f"SHOW /* sqlalchemy:_get_schema_unique_constraints */ UNIQUE KEYS IN TABLE {fully_qualified_path}"
372+
)
390373
)
391-
)
392-
unique_constraints = {}
393-
for row in result:
394-
name = self.normalize_name(row._mapping["constraint_name"])
395-
if name not in unique_constraints:
396-
unique_constraints[name] = {
397-
"column_names": [self.normalize_name(row._mapping["column_name"])],
398-
"name": name,
399-
"table_name": self.normalize_name(row._mapping["table_name"]),
400-
}
401-
else:
402-
unique_constraints[name]["column_names"].append(
403-
self.normalize_name(row._mapping["column_name"])
374+
else:
375+
result = connection.execute(
376+
text(
377+
f"SHOW /* sqlalchemy:_get_schema_unique_constraints */ UNIQUE KEYS IN SCHEMA {schema}"
404378
)
405-
406-
ans = defaultdict(list)
407-
for constraint in unique_constraints.values():
408-
table_name = constraint.pop("table_name")
409-
ans[table_name].append(constraint)
410-
return ans
411-
412-
@reflection.cache
413-
def _get_schema_unique_constraints(self, connection, schema, **kw):
414-
result = connection.execute(
415-
text(
416-
f"SHOW /* sqlalchemy:_get_schema_unique_constraints */ UNIQUE KEYS IN SCHEMA {schema}"
417379
)
418-
)
419380
unique_constraints = {}
420381
for row in result:
421382
name = self.normalize_name(row._mapping["constraint_name"])
@@ -444,91 +405,31 @@ def get_unique_constraints(self, connection, table_name, schema, **kw):
444405
full_schema_name = self._denormalize_quote_join(
445406
current_database, schema if schema else current_schema
446407
)
447-
if table_name is not None:
448-
return self._get_table_unique_constraints(
449-
connection,
450-
self.denormalize_name(full_schema_name),
451-
self.denormalize_name(table_name),
452-
**kw,
453-
).get(table_name, [])
454-
else:
455-
return self._get_schema_unique_constraints(
456-
connection, self.denormalize_name(full_schema_name), **kw
457-
).get(table_name, [])
408+
return self._get_schema_unique_constraints(
409+
connection,
410+
self.denormalize_name(full_schema_name),
411+
table_name=self.denormalize_name(table_name),
412+
**kw,
413+
).get(table_name, [])
458414

459415
@reflection.cache
460-
def _get_table_foreign_keys(self, connection, schema, table_name, **kw):
416+
def _get_schema_foreign_keys(self, connection, schema, table_name=None, **kw):
461417
_, current_schema = self._current_database_schema(connection, **kw)
462-
fully_qualified_path = self._denormalize_quote_join(
463-
schema, self.denormalize_name(table_name)
464-
)
465-
result = connection.execute(
466-
text(
467-
f"SHOW /* sqlalchemy:_get_table_foreign_keys */ IMPORTED KEYS IN TABLE {fully_qualified_path}"
418+
if table_name is not None:
419+
fully_qualified_path = self._denormalize_quote_join(
420+
schema, self.denormalize_name(table_name)
468421
)
469-
)
470-
foreign_key_map = {}
471-
for row in result:
472-
name = self.normalize_name(row._mapping["fk_name"])
473-
if name not in foreign_key_map:
474-
referred_schema = self.normalize_name(row._mapping["pk_schema_name"])
475-
foreign_key_map[name] = {
476-
"constrained_columns": [
477-
self.normalize_name(row._mapping["fk_column_name"])
478-
],
479-
# referred schema should be None in context where it doesn't need to be specified
480-
# https://docs.sqlalchemy.org/en/14/core/reflection.html#reflection-schema-qualified-interaction
481-
"referred_schema": (
482-
referred_schema
483-
if referred_schema
484-
not in (self.default_schema_name, current_schema)
485-
else None
486-
),
487-
"referred_table": self.normalize_name(
488-
row._mapping["pk_table_name"]
489-
),
490-
"referred_columns": [
491-
self.normalize_name(row._mapping["pk_column_name"])
492-
],
493-
"name": name,
494-
"table_name": self.normalize_name(row._mapping["fk_table_name"]),
495-
}
496-
options = {}
497-
if self.normalize_name(row._mapping["delete_rule"]) != "NO ACTION":
498-
options["ondelete"] = self.normalize_name(
499-
row._mapping["delete_rule"]
500-
)
501-
if self.normalize_name(row._mapping["update_rule"]) != "NO ACTION":
502-
options["onupdate"] = self.normalize_name(
503-
row._mapping["update_rule"]
504-
)
505-
foreign_key_map[name]["options"] = options
506-
else:
507-
foreign_key_map[name]["constrained_columns"].append(
508-
self.normalize_name(row._mapping["fk_column_name"])
509-
)
510-
foreign_key_map[name]["referred_columns"].append(
511-
self.normalize_name(row._mapping["pk_column_name"])
422+
result = connection.execute(
423+
text(
424+
f"SHOW /* sqlalchemy:_get_schema_foreign_keys */ IMPORTED KEYS IN TABLE {fully_qualified_path}"
512425
)
513-
514-
ans = {}
515-
516-
for _, v in foreign_key_map.items():
517-
if v["table_name"] not in ans:
518-
ans[v["table_name"]] = []
519-
ans[v["table_name"]].append(
520-
{k2: v2 for k2, v2 in v.items() if k2 != "table_name"}
521426
)
522-
return ans
523-
524-
@reflection.cache
525-
def _get_schema_foreign_keys(self, connection, schema, **kw):
526-
_, current_schema = self._current_database_schema(connection, **kw)
527-
result = connection.execute(
528-
text(
529-
f"SHOW /* sqlalchemy:_get_schema_foreign_keys */ IMPORTED KEYS IN SCHEMA {schema}"
427+
else:
428+
result = connection.execute(
429+
text(
430+
f"SHOW /* sqlalchemy:_get_schema_foreign_keys */ IMPORTED KEYS IN SCHEMA {schema}"
431+
)
530432
)
531-
)
532433
foreign_key_map = {}
533434
for row in result:
534435
name = self.normalize_name(row._mapping["fk_name"])
@@ -595,17 +496,12 @@ def get_foreign_keys(self, connection, table_name, schema=None, **kw):
595496
current_database, schema if schema else current_schema
596497
)
597498

598-
if table_name is not None:
599-
foreign_key_map = self._get_table_foreign_keys(
600-
connection,
601-
self.denormalize_name(full_schema_name),
602-
self.denormalize_name(table_name),
603-
**kw,
604-
)
605-
else:
606-
foreign_key_map = self._get_schema_foreign_keys(
607-
connection, self.denormalize_name(full_schema_name), **kw
608-
)
499+
foreign_key_map = self._get_schema_foreign_keys(
500+
connection,
501+
self.denormalize_name(full_schema_name),
502+
table_name=self.denormalize_name(table_name),
503+
**kw,
504+
)
609505
return foreign_key_map.get(table_name, [])
610506

611507
@reflection.cache
@@ -716,8 +612,8 @@ def _get_table_columns(self, connection, table_name, schema=None, **kw):
716612
ans = []
717613
current_database, _ = self._current_database_schema(connection, **kw)
718614
full_schema_name = self._denormalize_quote_join(current_database, schema)
719-
table_primary_keys = self._get_table_primary_keys(
720-
connection, full_schema_name, table_name, **kw
615+
table_primary_keys = self._get_schema_primary_keys(
616+
connection, full_schema_name, table_name=table_name, **kw
721617
)
722618
result = connection.execute(
723619
text(
@@ -732,7 +628,9 @@ def _get_table_columns(self, connection, table_name, schema=None, **kw):
732628
ic.is_nullable,
733629
ic.column_default,
734630
ic.is_identity,
735-
ic.comment
631+
ic.comment,
632+
ic.identity_start,
633+
ic.identity_increment
736634
FROM information_schema.columns ic
737635
WHERE ic.table_schema=:table_schema
738636
AND ic.table_name=:table_name
@@ -754,6 +652,8 @@ def _get_table_columns(self, connection, table_name, schema=None, **kw):
754652
column_default,
755653
is_identity,
756654
comment,
655+
identity_start,
656+
identity_increment,
757657
) in result:
758658
table_name = self.normalize_name(table_name)
759659
column_name = self.normalize_name(column_name)
@@ -796,6 +696,11 @@ def _get_table_columns(self, connection, table_name, schema=None, **kw):
796696
else False,
797697
}
798698
)
699+
if is_identity == "YES":
700+
ans[-1]["identity"] = {
701+
"start": identity_start,
702+
"increment": identity_increment,
703+
}
799704
return ans
800705

801706
def get_columns(self, connection, table_name, schema=None, **kw):

0 commit comments

Comments
 (0)