2
2
# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved.
3
3
#
4
4
import operator
5
- import re
6
5
from collections import defaultdict
7
6
from enum import Enum
8
7
from functools import reduce
16
15
from sqlalchemy .engine import URL , default , reflection
17
16
from sqlalchemy .schema import Table
18
17
from sqlalchemy .sql import text
19
- from sqlalchemy .sql .elements import quoted_name
20
18
from sqlalchemy .sql .sqltypes import NullType
21
19
from sqlalchemy .types import FLOAT , Date , DateTime , Float , Time
22
20
23
21
from snowflake .connector import errors as sf_errors
24
22
from snowflake .connector .connection import DEFAULT_CONFIGURATION
25
23
from snowflake .connector .constants import UTF8
26
24
from snowflake .sqlalchemy .compat import returns_unicode
25
+ from snowflake .sqlalchemy .name_utils import _NameUtils
26
+ from snowflake .sqlalchemy .structured_type_info_manager import _StructuredTypeInfoManager
27
27
28
28
from ._constants import DIALECT_NAME
29
29
from .base import (
42
42
)
43
43
from .parser .custom_type_parser import * # noqa
44
44
from .parser .custom_type_parser import _CUSTOM_DECIMAL # noqa
45
- from .parser .custom_type_parser import ischema_names , parse_index_columns , parse_type
45
+ from .parser .custom_type_parser import ischema_names , parse_index_columns
46
46
from .sql .custom_schema .custom_table_prefix import CustomTablePrefix
47
47
from .util import (
48
48
_update_connection_application_name ,
@@ -157,6 +157,7 @@ def __init__(
157
157
super ().__init__ (isolation_level = isolation_level , ** kwargs )
158
158
self .force_div_is_floordiv = force_div_is_floordiv
159
159
self .div_is_floordiv = force_div_is_floordiv
160
+ self .name_utils = _NameUtils (self .identifier_preparer )
160
161
161
162
def initialize (self , connection ):
162
163
super ().initialize (connection )
@@ -282,29 +283,10 @@ def _has_object(self, connection, object_type, object_name, schema=None):
282
283
raise
283
284
284
285
def normalize_name (self , name ):
285
- if name is None :
286
- return None
287
- if name == "" :
288
- return ""
289
- if name .upper () == name and not self .identifier_preparer ._requires_quotes (
290
- name .lower ()
291
- ):
292
- return name .lower ()
293
- elif name .lower () == name :
294
- return quoted_name (name , quote = True )
295
- else :
296
- return name
286
+ return self .name_utils .normalize_name (name )
297
287
298
288
def denormalize_name (self , name ):
299
- if name is None :
300
- return None
301
- if name == "" :
302
- return ""
303
- elif name .lower () == name and not self .identifier_preparer ._requires_quotes (
304
- name .lower ()
305
- ):
306
- name = name .upper ()
307
- return name
289
+ return self .name_utils .denormalize_name (name )
308
290
309
291
def _denormalize_quote_join (self , * idents ):
310
292
ip = self .identifier_preparer
@@ -491,53 +473,31 @@ def get_foreign_keys(self, connection, table_name, schema=None, **kw):
491
473
)
492
474
return foreign_key_map .get (table_name , [])
493
475
494
- def table_columns_as_dict (self , columns ):
495
- result = {}
496
- for column in columns :
497
- result [column ["name" ]] = column
498
- return result
499
-
500
476
@reflection .cache
501
477
def _get_schema_columns (self , connection , schema , ** kw ):
502
478
"""Get all columns in the schema, if we hit 'Information schema query returned too much data' problem return
503
479
None, as it is cacheable and is an unexpected return type for this function"""
504
480
ans = {}
505
- current_database , _ = self ._current_database_schema (connection , ** kw )
481
+
482
+ schema_name = self .denormalize_name (schema )
483
+
484
+ result = self ._query_all_columns_info (connection , schema_name , ** kw )
485
+ if result is None :
486
+ return None
487
+
488
+ current_database , default_schema = self ._current_database_schema (
489
+ connection , ** kw
490
+ )
506
491
full_schema_name = self ._denormalize_quote_join (current_database , schema )
507
- full_columns_descriptions = {}
508
- try :
509
- schema_primary_keys = self ._get_schema_primary_keys (
510
- connection , full_schema_name , ** kw
511
- )
512
- schema_name = self .denormalize_name (schema )
513
492
514
- result = connection .execute (
515
- text (
516
- """
517
- SELECT /* sqlalchemy:_get_schema_columns */
518
- ic.table_name,
519
- ic.column_name,
520
- ic.data_type,
521
- ic.character_maximum_length,
522
- ic.numeric_precision,
523
- ic.numeric_scale,
524
- ic.is_nullable,
525
- ic.column_default,
526
- ic.is_identity,
527
- ic.comment,
528
- ic.identity_start,
529
- ic.identity_increment
530
- FROM information_schema.columns ic
531
- WHERE ic.table_schema=:table_schema
532
- ORDER BY ic.ordinal_position"""
533
- ),
534
- {"table_schema" : schema_name },
535
- )
536
- except sa_exc .ProgrammingError as pe :
537
- if pe .orig .errno == 90030 :
538
- # This means that there are too many tables in the schema, we need to go more granular
539
- return None # None triggers _get_table_columns while staying cacheable
540
- raise
493
+ schema_primary_keys = self ._get_schema_primary_keys (
494
+ connection , full_schema_name , ** kw
495
+ )
496
+
497
+ structured_type_info_manager = _StructuredTypeInfoManager (
498
+ connection , self .name_utils , default_schema
499
+ )
500
+
541
501
for (
542
502
table_name ,
543
503
column_name ,
@@ -572,25 +532,11 @@ def _get_schema_columns(self, connection, schema, **kw):
572
532
elif issubclass (col_type , (sqltypes .String , sqltypes .BINARY )):
573
533
col_type_kw ["length" ] = character_maximum_length
574
534
elif issubclass (col_type , StructuredType ):
575
- if (schema_name , table_name ) not in full_columns_descriptions :
576
- full_columns_descriptions [(schema_name , table_name )] = (
577
- self .table_columns_as_dict (
578
- self ._get_table_columns (
579
- connection , table_name , schema_name
580
- )
581
- )
582
- )
583
-
584
- if (
585
- (schema_name , table_name ) in full_columns_descriptions
586
- and column_name
587
- in full_columns_descriptions [(schema_name , table_name )]
588
- ):
589
- ans [table_name ].append (
590
- full_columns_descriptions [(schema_name , table_name )][
591
- column_name
592
- ]
593
- )
535
+ column_info = structured_type_info_manager .get_column_info (
536
+ schema_name , table_name , column_name , ** kw
537
+ )
538
+ if column_info :
539
+ ans [table_name ].append (column_info )
594
540
continue
595
541
else :
596
542
col_type = NullType
@@ -628,72 +574,6 @@ def _get_schema_columns(self, connection, schema, **kw):
628
574
}
629
575
return ans
630
576
631
- @reflection .cache
632
- def _get_table_columns (self , connection , table_name , schema = None , ** kw ):
633
- """Get all columns in a table in a schema"""
634
- ans = []
635
- current_database , default_schema = self ._current_database_schema (
636
- connection , ** kw
637
- )
638
- schema = schema if schema else default_schema
639
- table_schema = self .denormalize_name (schema )
640
- table_name = self .denormalize_name (table_name )
641
- result = connection .execute (
642
- text (
643
- "DESC /* sqlalchemy:_get_schema_columns */"
644
- f" TABLE { table_schema } .{ table_name } TYPE = COLUMNS"
645
- )
646
- )
647
- for desc_data in result :
648
- column_name = desc_data [0 ]
649
- coltype = desc_data [1 ]
650
- is_nullable = desc_data [3 ]
651
- column_default = desc_data [4 ]
652
- primary_key = desc_data [5 ]
653
- comment = desc_data [9 ]
654
-
655
- column_name = self .normalize_name (column_name )
656
- if column_name .startswith ("sys_clustering_column" ):
657
- continue # ignoring clustering column
658
- type_instance = parse_type (coltype )
659
- if isinstance (type_instance , NullType ):
660
- sa_util .warn (
661
- f"Did not recognize type '{ coltype } ' of column '{ column_name } '"
662
- )
663
-
664
- identity = None
665
- match = re .match (
666
- r"IDENTITY START (?P<start>\d+) INCREMENT (?P<increment>\d+) (?P<order_type>ORDER|NOORDER)" ,
667
- column_default if column_default else "" ,
668
- )
669
- if match :
670
- identity = {
671
- "start" : int (match .group ("start" )),
672
- "increment" : int (match .group ("increment" )),
673
- "order_type" : match .group ("order_type" ),
674
- }
675
- is_identity = identity is not None
676
-
677
- ans .append (
678
- {
679
- "name" : column_name ,
680
- "type" : type_instance ,
681
- "nullable" : is_nullable == "Y" ,
682
- "default" : None if is_identity else column_default ,
683
- "autoincrement" : is_identity ,
684
- "comment" : comment if comment != "" else None ,
685
- "primary_key" : primary_key == "Y" ,
686
- }
687
- )
688
-
689
- if is_identity :
690
- ans [- 1 ]["identity" ] = identity
691
-
692
- # If we didn't find any columns for the table, the table doesn't exist.
693
- if len (ans ) == 0 :
694
- raise sa_exc .NoSuchTableError ()
695
- return ans
696
-
697
577
def get_columns (self , connection , table_name , schema = None , ** kw ):
698
578
"""
699
579
Gets all column info given the table info
@@ -704,8 +584,11 @@ def get_columns(self, connection, table_name, schema=None, **kw):
704
584
705
585
schema_columns = self ._get_schema_columns (connection , schema , ** kw )
706
586
if schema_columns is None :
587
+ column_info_manager = _StructuredTypeInfoManager (
588
+ connection , self .name_utils , self .default_schema_name
589
+ )
707
590
# Too many results, fall back to only query about single table
708
- return self . _get_table_columns ( connection , table_name , schema , ** kw )
591
+ return column_info_manager . get_table_columns ( table_name , schema )
709
592
normalized_table_name = self .normalize_name (table_name )
710
593
if normalized_table_name not in schema_columns :
711
594
raise sa_exc .NoSuchTableError ()
@@ -719,6 +602,37 @@ def get_prefixes_from_data(self, name_to_index_map, row, **kw):
719
602
prefixes_found .append (valid_prefix .name )
720
603
return prefixes_found
721
604
605
+ @reflection .cache
606
+ def _query_all_columns_info (self , connection , schema_name , ** kw ):
607
+ try :
608
+ return connection .execute (
609
+ text (
610
+ """
611
+ SELECT /* sqlalchemy:_get_schema_columns */
612
+ ic.table_name,
613
+ ic.column_name,
614
+ ic.data_type,
615
+ ic.character_maximum_length,
616
+ ic.numeric_precision,
617
+ ic.numeric_scale,
618
+ ic.is_nullable,
619
+ ic.column_default,
620
+ ic.is_identity,
621
+ ic.comment,
622
+ ic.identity_start,
623
+ ic.identity_increment
624
+ FROM information_schema.columns ic
625
+ WHERE ic.table_schema=:table_schema
626
+ ORDER BY ic.ordinal_position"""
627
+ ),
628
+ {"table_schema" : schema_name },
629
+ )
630
+ except sa_exc .ProgrammingError as pe :
631
+ if pe .orig .errno == 90030 :
632
+ # This means that there are too many tables in the schema, we need to go more granular
633
+ return None # None triggers get_table_columns while staying cacheable
634
+ raise
635
+
722
636
@reflection .cache
723
637
def _get_schema_tables_info (self , connection , schema = None , ** kw ):
724
638
"""
0 commit comments