22# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved.
33#
44import operator
5- import re
65from collections import defaultdict
76from enum import Enum
87from functools import reduce
1615from sqlalchemy .engine import URL , default , reflection
1716from sqlalchemy .schema import Table
1817from sqlalchemy .sql import text
19- from sqlalchemy .sql .elements import quoted_name
2018from sqlalchemy .sql .sqltypes import NullType
2119from sqlalchemy .types import FLOAT , Date , DateTime , Float , Time
2220
2321from snowflake .connector import errors as sf_errors
2422from snowflake .connector .connection import DEFAULT_CONFIGURATION
2523from snowflake .connector .constants import UTF8
2624from snowflake .sqlalchemy .compat import returns_unicode
25+ from snowflake .sqlalchemy .name_utils import _NameUtils
26+ from snowflake .sqlalchemy .structured_type_info_manager import _StructuredTypeInfoManager
2727
2828from ._constants import DIALECT_NAME
2929from .base import (
4242)
4343from .parser .custom_type_parser import * # noqa
4444from .parser .custom_type_parser import _CUSTOM_DECIMAL # noqa
45- from .parser .custom_type_parser import ischema_names , parse_index_columns , parse_type
45+ from .parser .custom_type_parser import ischema_names , parse_index_columns
4646from .sql .custom_schema .custom_table_prefix import CustomTablePrefix
4747from .util import (
4848 _update_connection_application_name ,
@@ -157,6 +157,7 @@ def __init__(
157157 super ().__init__ (isolation_level = isolation_level , ** kwargs )
158158 self .force_div_is_floordiv = force_div_is_floordiv
159159 self .div_is_floordiv = force_div_is_floordiv
160+ self .name_utils = _NameUtils (self .identifier_preparer )
160161
161162 def initialize (self , connection ):
162163 super ().initialize (connection )
@@ -282,29 +283,10 @@ def _has_object(self, connection, object_type, object_name, schema=None):
282283 raise
283284
284285 def normalize_name (self , name ):
285- if name is None :
286- return None
287- if name == "" :
288- return ""
289- if name .upper () == name and not self .identifier_preparer ._requires_quotes (
290- name .lower ()
291- ):
292- return name .lower ()
293- elif name .lower () == name :
294- return quoted_name (name , quote = True )
295- else :
296- return name
286+ return self .name_utils .normalize_name (name )
297287
298288 def denormalize_name (self , name ):
299- if name is None :
300- return None
301- if name == "" :
302- return ""
303- elif name .lower () == name and not self .identifier_preparer ._requires_quotes (
304- name .lower ()
305- ):
306- name = name .upper ()
307- return name
289+ return self .name_utils .denormalize_name (name )
308290
309291 def _denormalize_quote_join (self , * idents ):
310292 ip = self .identifier_preparer
@@ -491,53 +473,31 @@ def get_foreign_keys(self, connection, table_name, schema=None, **kw):
491473 )
492474 return foreign_key_map .get (table_name , [])
493475
494- def table_columns_as_dict (self , columns ):
495- result = {}
496- for column in columns :
497- result [column ["name" ]] = column
498- return result
499-
500476 @reflection .cache
501477 def _get_schema_columns (self , connection , schema , ** kw ):
502478 """Get all columns in the schema, if we hit 'Information schema query returned too much data' problem return
503479 None, as it is cacheable and is an unexpected return type for this function"""
504480 ans = {}
505- current_database , _ = self ._current_database_schema (connection , ** kw )
481+
482+ schema_name = self .denormalize_name (schema )
483+
484+ result = self ._query_all_columns_info (connection , schema_name , ** kw )
485+ if result is None :
486+ return None
487+
488+ current_database , default_schema = self ._current_database_schema (
489+ connection , ** kw
490+ )
506491 full_schema_name = self ._denormalize_quote_join (current_database , schema )
507- full_columns_descriptions = {}
508- try :
509- schema_primary_keys = self ._get_schema_primary_keys (
510- connection , full_schema_name , ** kw
511- )
512- schema_name = self .denormalize_name (schema )
513492
514- result = connection .execute (
515- text (
516- """
517- SELECT /* sqlalchemy:_get_schema_columns */
518- ic.table_name,
519- ic.column_name,
520- ic.data_type,
521- ic.character_maximum_length,
522- ic.numeric_precision,
523- ic.numeric_scale,
524- ic.is_nullable,
525- ic.column_default,
526- ic.is_identity,
527- ic.comment,
528- ic.identity_start,
529- ic.identity_increment
530- FROM information_schema.columns ic
531- WHERE ic.table_schema=:table_schema
532- ORDER BY ic.ordinal_position"""
533- ),
534- {"table_schema" : schema_name },
535- )
536- except sa_exc .ProgrammingError as pe :
537- if pe .orig .errno == 90030 :
538- # This means that there are too many tables in the schema, we need to go more granular
539- return None # None triggers _get_table_columns while staying cacheable
540- raise
493+ schema_primary_keys = self ._get_schema_primary_keys (
494+ connection , full_schema_name , ** kw
495+ )
496+
497+ structured_type_info_manager = _StructuredTypeInfoManager (
498+ connection , self .name_utils , default_schema
499+ )
500+
541501 for (
542502 table_name ,
543503 column_name ,
@@ -572,25 +532,11 @@ def _get_schema_columns(self, connection, schema, **kw):
572532 elif issubclass (col_type , (sqltypes .String , sqltypes .BINARY )):
573533 col_type_kw ["length" ] = character_maximum_length
574534 elif issubclass (col_type , StructuredType ):
575- if (schema_name , table_name ) not in full_columns_descriptions :
576- full_columns_descriptions [(schema_name , table_name )] = (
577- self .table_columns_as_dict (
578- self ._get_table_columns (
579- connection , table_name , schema_name
580- )
581- )
582- )
583-
584- if (
585- (schema_name , table_name ) in full_columns_descriptions
586- and column_name
587- in full_columns_descriptions [(schema_name , table_name )]
588- ):
589- ans [table_name ].append (
590- full_columns_descriptions [(schema_name , table_name )][
591- column_name
592- ]
593- )
535+ column_info = structured_type_info_manager .get_column_info (
536+ schema_name , table_name , column_name , ** kw
537+ )
538+ if column_info :
539+ ans [table_name ].append (column_info )
594540 continue
595541 else :
596542 col_type = NullType
@@ -628,72 +574,6 @@ def _get_schema_columns(self, connection, schema, **kw):
628574 }
629575 return ans
630576
631- @reflection .cache
632- def _get_table_columns (self , connection , table_name , schema = None , ** kw ):
633- """Get all columns in a table in a schema"""
634- ans = []
635- current_database , default_schema = self ._current_database_schema (
636- connection , ** kw
637- )
638- schema = schema if schema else default_schema
639- table_schema = self .denormalize_name (schema )
640- table_name = self .denormalize_name (table_name )
641- result = connection .execute (
642- text (
643- "DESC /* sqlalchemy:_get_schema_columns */"
644- f" TABLE { table_schema } .{ table_name } TYPE = COLUMNS"
645- )
646- )
647- for desc_data in result :
648- column_name = desc_data [0 ]
649- coltype = desc_data [1 ]
650- is_nullable = desc_data [3 ]
651- column_default = desc_data [4 ]
652- primary_key = desc_data [5 ]
653- comment = desc_data [9 ]
654-
655- column_name = self .normalize_name (column_name )
656- if column_name .startswith ("sys_clustering_column" ):
657- continue # ignoring clustering column
658- type_instance = parse_type (coltype )
659- if isinstance (type_instance , NullType ):
660- sa_util .warn (
661- f"Did not recognize type '{ coltype } ' of column '{ column_name } '"
662- )
663-
664- identity = None
665- match = re .match (
666- r"IDENTITY START (?P<start>\d+) INCREMENT (?P<increment>\d+) (?P<order_type>ORDER|NOORDER)" ,
667- column_default if column_default else "" ,
668- )
669- if match :
670- identity = {
671- "start" : int (match .group ("start" )),
672- "increment" : int (match .group ("increment" )),
673- "order_type" : match .group ("order_type" ),
674- }
675- is_identity = identity is not None
676-
677- ans .append (
678- {
679- "name" : column_name ,
680- "type" : type_instance ,
681- "nullable" : is_nullable == "Y" ,
682- "default" : None if is_identity else column_default ,
683- "autoincrement" : is_identity ,
684- "comment" : comment if comment != "" else None ,
685- "primary_key" : primary_key == "Y" ,
686- }
687- )
688-
689- if is_identity :
690- ans [- 1 ]["identity" ] = identity
691-
692- # If we didn't find any columns for the table, the table doesn't exist.
693- if len (ans ) == 0 :
694- raise sa_exc .NoSuchTableError ()
695- return ans
696-
697577 def get_columns (self , connection , table_name , schema = None , ** kw ):
698578 """
699579 Gets all column info given the table info
@@ -704,8 +584,11 @@ def get_columns(self, connection, table_name, schema=None, **kw):
704584
705585 schema_columns = self ._get_schema_columns (connection , schema , ** kw )
706586 if schema_columns is None :
587+ column_info_manager = _StructuredTypeInfoManager (
588+ connection , self .name_utils , self .default_schema_name
589+ )
707590 # Too many results, fall back to only query about single table
708- return self . _get_table_columns ( connection , table_name , schema , ** kw )
591+ return column_info_manager . get_table_columns ( table_name , schema )
709592 normalized_table_name = self .normalize_name (table_name )
710593 if normalized_table_name not in schema_columns :
711594 raise sa_exc .NoSuchTableError ()
@@ -719,6 +602,37 @@ def get_prefixes_from_data(self, name_to_index_map, row, **kw):
719602 prefixes_found .append (valid_prefix .name )
720603 return prefixes_found
721604
605+ @reflection .cache
606+ def _query_all_columns_info (self , connection , schema_name , ** kw ):
607+ try :
608+ return connection .execute (
609+ text (
610+ """
611+ SELECT /* sqlalchemy:_get_schema_columns */
612+ ic.table_name,
613+ ic.column_name,
614+ ic.data_type,
615+ ic.character_maximum_length,
616+ ic.numeric_precision,
617+ ic.numeric_scale,
618+ ic.is_nullable,
619+ ic.column_default,
620+ ic.is_identity,
621+ ic.comment,
622+ ic.identity_start,
623+ ic.identity_increment
624+ FROM information_schema.columns ic
625+ WHERE ic.table_schema=:table_schema
626+ ORDER BY ic.ordinal_position"""
627+ ),
628+ {"table_schema" : schema_name },
629+ )
630+ except sa_exc .ProgrammingError as pe :
631+ if pe .orig .errno == 90030 :
632+ # This means that there are too many tables in the schema, we need to go more granular
633+ return None # None triggers get_table_columns while staying cacheable
634+ raise
635+
722636 @reflection .cache
723637 def _get_schema_tables_info (self , connection , schema = None , ** kw ):
724638 """
0 commit comments