Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions DESCRIPTION.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@ Source code is also available at:

# Release Notes

- v1.4.5(Unreleased)

- Improved performance of looking up columns in tables.

- v1.4.4(Nov 16, 2022)

- Fixed a bug that percent signs in a non-compiled statement should not be interpolated with emtpy sequence when executed.
Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ universal = 1

[metadata]
name = snowflake-sqlalchemy
version = 1.4.4
version = 1.4.5
description = Snowflake SQLAlchemy Dialect
long_description = file: DESCRIPTION.md
long_description_content_type = text/markdown
Expand Down
70 changes: 51 additions & 19 deletions src/snowflake/sqlalchemy/snowdialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,10 +318,16 @@ def get_check_constraints(self, connection, table_name, schema, **kw):
return []

@reflection.cache
def _get_schema_primary_keys(self, connection, schema, **kw):
def _get_schema_primary_keys(self, connection, schema, table_name=None, **kw):
fully_qualified_path = (
self._denormalize_quote_join(schema, self.denormalize_name(table_name))
if table_name is not None
else schema
)
result = connection.execute(
text(
f"SHOW /* sqlalchemy:_get_schema_primary_keys */PRIMARY KEYS IN SCHEMA {schema}"
f"SHOW /* sqlalchemy:_get_schema_primary_keys */ PRIMARY KEYS IN "
f"{'TABLE' if table_name is not None else 'SCHEMA'} {fully_qualified_path}"
)
)
ans = {}
Expand All @@ -346,14 +352,23 @@ def get_pk_constraint(self, connection, table_name, schema=None, **kw):
current_database, schema if schema else current_schema
)
return self._get_schema_primary_keys(
connection, self.denormalize_name(full_schema_name), **kw
connection,
self.denormalize_name(full_schema_name),
table_name=self.denormalize_name(table_name),
**kw,
).get(table_name, {"constrained_columns": [], "name": None})

@reflection.cache
def _get_schema_unique_constraints(self, connection, schema, **kw):
def _get_schema_unique_constraints(self, connection, schema, table_name=None, **kw):
fully_qualified_path = (
self._denormalize_quote_join(schema, self.denormalize_name(table_name))
if table_name is not None
else schema
)
result = connection.execute(
text(
f"SHOW /* sqlalchemy:_get_schema_unique_constraints */ UNIQUE KEYS IN SCHEMA {schema}"
f"SHOW /* sqlalchemy:_get_schema_unique_constraints */ UNIQUE KEYS IN "
f"{'TABLE' if table_name is not None else 'SCHEMA'} {fully_qualified_path}"
)
)
unique_constraints = {}
Expand Down Expand Up @@ -385,15 +400,24 @@ def get_unique_constraints(self, connection, table_name, schema, **kw):
current_database, schema if schema else current_schema
)
return self._get_schema_unique_constraints(
connection, self.denormalize_name(full_schema_name), **kw
connection,
self.denormalize_name(full_schema_name),
table_name=self.denormalize_name(table_name),
**kw,
).get(table_name, [])

@reflection.cache
def _get_schema_foreign_keys(self, connection, schema, **kw):
def _get_schema_foreign_keys(self, connection, schema, table_name=None, **kw):
_, current_schema = self._current_database_schema(connection, **kw)
fully_qualified_path = (
self._denormalize_quote_join(schema, self.denormalize_name(table_name))
if table_name is not None
else schema
)
result = connection.execute(
text(
f"SHOW /* sqlalchemy:_get_schema_foreign_keys */ IMPORTED KEYS IN SCHEMA {schema}"
f"SHOW /* sqlalchemy:_get_schema_foreign_keys */ IMPORTED KEYS IN "
f"{'TABLE' if table_name is not None else 'SCHEMA'} {fully_qualified_path}"
)
)
foreign_key_map = {}
Expand Down Expand Up @@ -463,7 +487,10 @@ def get_foreign_keys(self, connection, table_name, schema=None, **kw):
)

foreign_key_map = self._get_schema_foreign_keys(
connection, self.denormalize_name(full_schema_name), **kw
connection,
self.denormalize_name(full_schema_name),
table_name=self.denormalize_name(table_name),
**kw,
)
return foreign_key_map.get(table_name, [])

Expand Down Expand Up @@ -575,8 +602,8 @@ def _get_table_columns(self, connection, table_name, schema=None, **kw):
ans = []
current_database, _ = self._current_database_schema(connection, **kw)
full_schema_name = self._denormalize_quote_join(current_database, schema)
schema_primary_keys = self._get_schema_primary_keys(
connection, full_schema_name, **kw
table_primary_keys = self._get_schema_primary_keys(
connection, full_schema_name, table_name=table_name, **kw
)
result = connection.execute(
text(
Expand All @@ -591,7 +618,9 @@ def _get_table_columns(self, connection, table_name, schema=None, **kw):
ic.is_nullable,
ic.column_default,
ic.is_identity,
ic.comment
ic.comment,
ic.identity_start,
ic.identity_increment
FROM information_schema.columns ic
WHERE ic.table_schema=:table_schema
AND ic.table_name=:table_name
Expand All @@ -613,6 +642,8 @@ def _get_table_columns(self, connection, table_name, schema=None, **kw):
column_default,
is_identity,
comment,
identity_start,
identity_increment,
) in result:
table_name = self.normalize_name(table_name)
column_name = self.normalize_name(column_name)
Expand All @@ -637,7 +668,7 @@ def _get_table_columns(self, connection, table_name, schema=None, **kw):

type_instance = col_type(**col_type_kw)

current_table_pks = schema_primary_keys.get(table_name)
current_table_pks = table_primary_keys.get(table_name)

ans.append(
{
Expand All @@ -649,12 +680,17 @@ def _get_table_columns(self, connection, table_name, schema=None, **kw):
"comment": comment if comment != "" else None,
"primary_key": (
column_name
in schema_primary_keys[table_name]["constrained_columns"]
in table_primary_keys[table_name]["constrained_columns"]
)
if current_table_pks
else False,
}
)
if is_identity == "YES":
ans[-1]["identity"] = {
"start": identity_start,
"increment": identity_increment,
}
return ans

def get_columns(self, connection, table_name, schema=None, **kw):
Expand All @@ -665,11 +701,7 @@ def get_columns(self, connection, table_name, schema=None, **kw):
if not schema:
_, schema = self._current_database_schema(connection, **kw)

schema_columns = self._get_schema_columns(connection, schema, **kw)
if schema_columns is None:
# Too many results, fall back to only query about single table
return self._get_table_columns(connection, table_name, schema, **kw)
return schema_columns[self.normalize_name(table_name)]
return self._get_table_columns(connection, table_name, schema, **kw)

@reflection.cache
def get_table_names(self, connection, schema=None, **kw):
Expand Down