-
Notifications
You must be signed in to change notification settings - Fork 146
Expand file tree
/
Copy pathutils.py
More file actions
2188 lines (1899 loc) · 92 KB
/
utils.py
File metadata and controls
2188 lines (1899 loc) · 92 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
#
# Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved.
#
import json
import logging
import re
import traceback
from collections.abc import Hashable, Iterable, Sequence
from enum import Enum
from typing import TYPE_CHECKING, Any, Optional, Union
import modin.pandas as pd
import numpy as np
import pandas as native_pd
from pandas._typing import AnyArrayLike, Scalar
from pandas.core.dtypes.base import ExtensionDtype
from pandas.core.dtypes.common import (
is_bool_dtype,
is_integer_dtype,
is_object_dtype,
is_scalar,
)
from pandas.core.dtypes.inference import is_list_like
import snowflake.snowpark.modin.plugin._internal.statement_params_constants as STATEMENT_PARAMS
from snowflake.snowpark._internal.analyzer.analyzer_utils import (
DOUBLE_QUOTE,
EMPTY_STRING,
quote_name_without_upper_casing,
)
from snowflake.snowpark._internal.analyzer.expression import Literal
from snowflake.snowpark._internal.type_utils import LiteralType
from snowflake.snowpark._internal.utils import (
SNOWFLAKE_OBJECT_RE_PATTERN,
TempObjectType,
generate_random_alphanumeric,
get_temp_type_for_object,
random_name_for_temp_object,
)
from snowflake.snowpark.column import Column
from snowflake.snowpark.exceptions import SnowparkSQLException
from snowflake.snowpark.functions import (
col,
equal_nan,
floor,
iff,
to_char,
to_timestamp_ntz,
to_timestamp_tz,
typeof,
)
from snowflake.snowpark.modin.plugin._internal.ordered_dataframe import (
DataFrameReference,
OrderedDataFrame,
OrderingColumn,
)
from snowflake.snowpark.modin.plugin._internal.snowpark_pandas_types import (
SnowparkPandasType,
TimedeltaType,
ensure_snowpark_python_type,
)
from snowflake.snowpark.modin.plugin._typing import LabelTuple
from snowflake.snowpark.modin.plugin.utils.exceptions import (
SnowparkPandasErrorCode,
SnowparkPandasException,
)
from snowflake.snowpark.modin.plugin.utils.warning_message import (
ORDER_BY_IN_SQL_QUERY_NOT_GUARANTEED_WARNING,
WarningMessage,
)
from snowflake.snowpark.types import (
ArrayType,
DataType,
DecimalType,
DoubleType,
LongType,
MapType,
StringType,
StructField,
StructType,
TimestampTimeZone,
TimestampType,
VariantType,
_FractionalType,
)
if TYPE_CHECKING:
from snowflake.snowpark.modin.plugin._internal import frame
from snowflake.snowpark.modin.plugin.compiler.snowflake_query_compiler import (
SnowflakeQueryCompiler,
)
ROW_POSITION_COLUMN_LABEL = "row_position"
MAX_ROW_POSITION_COLUMN_LABEL = f"MAX_{ROW_POSITION_COLUMN_LABEL}"
SAMPLED_ROW_POSITION_COLUMN_LABEL = f"sampled_{ROW_POSITION_COLUMN_LABEL}"
INDEX_LABEL = "index"
# label used for data column to create the snowflake quoted identifier when the pandas
# label for the column is None
DEFAULT_DATA_COLUMN_LABEL = "data"
LEVEL_LABEL = "level"
ITEM_VALUE_LABEL = "item_value"
ORDERING_COLUMN_LABEL = "ordering"
READ_ONLY_TABLE_SUFFIX = "READONLY"
METADATA_ROW_POSITION_COLUMN = "METADATA$ROW_POSITION"
# read only table is only supported for base table or base temporary table
READ_ONLY_TABLE_SUPPORTED_TABLE_KINDS = ["LOCAL TEMPORARY", "BASE TABLE"]
UNDEFINED = "undefined"
# number of digits used to generate a random suffix
_NUM_SUFFIX_DIGITS = 4
ROW_COUNT_COLUMN_LABEL = "row_count"
# max number of retries used when generating conflict free quoted identifiers
_MAX_NUM_RETRIES = 3
_MAX_IDENTIFIER_LENGTH = 32
_logger = logging.getLogger(__name__)
# This is the default statement parameters for queries from Snowpark pandas API. It provides the fine grain metric for
# the server to track all pandas API usage.
def get_default_snowpark_pandas_statement_params() -> dict[str, str]:
return {STATEMENT_PARAMS.SNOWPARK_API: STATEMENT_PARAMS.PANDAS_API}
class FillNAMethod(Enum):
"""
Enum that defines the fillna methods - ffill for forward filling, and bfill for backfilling.
"""
FFILL_METHOD = "ffill"
BFILL_METHOD = "bfill"
@classmethod
def get_enum_for_string_method(cls, method_name: str) -> "FillNAMethod":
"""
Returns the appropriate Enum member for the given method.
Args:
method_name : str
The name of the method to use for fillna.
Returns:
FillNAMethod : The instance of the Enum corresponding to the specified method name,
or a ValueError if none match.
Notes:
This method is necessary since the two methods (ffill and bfill) have aliases - pad for ffill
and backfill for bfill. Rather than having four members of this Enum, we'd rather just map
`pad` to FillNAMethod.FFILL_METHOD and `backfill` to FillNAMethod.BFILL_METHOD so we don't have
to check if a method is one of two Enum values that are functionally the same.
"""
try:
return cls(method_name)
except ValueError:
if method_name == "pad":
return cls("ffill")
elif method_name == "backfill":
return cls("bfill")
else:
raise ValueError(
f"Invalid fillna method: {method_name}. Expected one of ['ffill', 'pad', 'bfill', 'backfill']"
)
def _is_table_name(table_name_or_query: str) -> bool:
"""
Checks whether the provided string is a table name or not.
Args:
table_name_or_query: the string to check
Returns:
True if it is a valid table name.
"""
# SNOWFLAKE_OBJECT_RE_PATTERN contains the pattern for identifiers in Snowflake.
# If we do not get a match for the SNOWFLAKE_OBJECT_RE_PATTERN, we know that
# the string passed in is not a valid identifier, so it cannot be a table name.
return SNOWFLAKE_OBJECT_RE_PATTERN.match(table_name_or_query) is not None
def _check_if_sql_query_contains_order_by_and_warn_user(
sql_query_text: str,
) -> bool:
"""
Checks whether the sql query passed in contains an order by clause
and warns the user that ORDER BY will be ignored currently.
Args:
sql_query_text: The SQL Query to check.
Returns:
Whether or not the query contains an order by.
"""
# We need to determine if the query contains an ORDER BY. We previously looked
# at the logical plan in order to determine if there was an ORDER BY; however,
# the output schema of the `EXPLAIN` SQL statement (which was used to generate
# the logical plan) seems to not be stable across releases, so instead, we
# check to see if the text of the SQL query contains "ORDER BY".
# Note: This method may sometimes raise false positives, e.g.:
# SELECT "ORDER BY COLUMN", * FROM TABLE;
# The above query shouldn't raise a warning, but since we are using
# string matching, we will get a false positive and raise a warning.
order_by_pattern = r"\s+order\s+by\s+"
contains_order_by = re.search(order_by_pattern, sql_query_text.lower()) is not None
if contains_order_by:
# If the query contains an ORDER BY, we need to warn the user that
# the ordering induced by the ORDER BY is not guaranteed to be preserved
# in the ordering of the returned DataFrame, and that they should use
# sort_values on the returned object to guarantee an ordering.
WarningMessage.single_warning(ORDER_BY_IN_SQL_QUERY_NOT_GUARANTEED_WARNING)
return contains_order_by
def _extract_base_table_from_simple_select_star_query(sql_query: str) -> str:
"""
Takes a SQL Query or table name as input, and attempts to reduce it to its base table name
if it is of the form SELECT * FROM table. Otherwise, returns the original query or table name.
Returns:
str
The base table name or SQL Query.
"""
base_table_name = None
if not _is_table_name(sql_query):
# We first need to find all occurences of `select * from`, since the query may be nested.
select_star_match = re.match(r"select \* from ", sql_query.lower())
if select_star_match is not None:
snowflake_object_match = re.fullmatch(
SNOWFLAKE_OBJECT_RE_PATTERN, sql_query[select_star_match.end() :] # type: ignore[union-attr]
)
# snowflake_object_match will only be None if whatever followed `select * from` in
# our original query did not match the regex for a Snowflake Object. This could be the
# case, e.g., when our query looks like `select * from (select * from OBJECT)`.
# If it is not None, then we should extract the object that was found as the
# base table name.
if snowflake_object_match:
base_table_name = snowflake_object_match.group()
return sql_query if base_table_name is None else base_table_name
def _create_read_only_table(
table_name: str,
materialize_into_temp_table: bool,
materialization_reason: Optional[str] = None,
) -> str:
"""
create read only table for the given table.
Args:
table_name: the table to create read only table on top of.
materialize_into_temp_table: whether to create a temp table of the given table. If true, the
read only table will be created on top of the temp table instead of
the original table.
Read only table creation is only supported for temporary table or regular
table at this moment. If the table is not those two types, you can set
materialize_into_temp_table to True to create a temporary table out of it for
read only table creation. Otherwise, creation will fail.
materialization_reason: why materialization into temp table is needed for creation of the read
only table. This is only needed when materialization_into_temp_table is true.
Returns:
The name of the read only table created.
"""
session = pd.session
# use random_name_for_temp_object, there is a check at server side for
# temp object used in snowpark stored proc, which needs to match the following regExpr
# "^SNOWPARK_TEMP_(TABLE|VIEW|STAGE|FUNCTION|TABLE_FUNCTION|FILE_FORMAT|PROCEDURE)_[0-9A-Z]+$")
readonly_table_name = (
f"{random_name_for_temp_object(TempObjectType.TABLE)}{READ_ONLY_TABLE_SUFFIX}"
)
use_scoped_temp_table = session._use_scoped_temp_read_only_table
# If we need to materialize into a temp table our create table expression
# needs to be SELECT * FROM (object).
if materialize_into_temp_table:
ctas_query = f"SELECT * FROM {table_name}"
temp_table_name = random_name_for_temp_object(TempObjectType.TABLE)
_logger.warning(
f"Snapshot source table/view '{table_name}' failed due to reason: `{materialization_reason}'. Data from "
f"source table/view '{table_name}' is being copied into a new "
f"temporary table '{temp_table_name}' for snapshotting. DataFrame creation might take some time."
)
statement_params = get_default_snowpark_pandas_statement_params()
# record 1) original table name (which may not be an actual table)
# 2) the name for the new temp table that has been created/materialized
# 3) the reason why materialization happens
new_params = {
STATEMENT_PARAMS.MATERIALIZATION_TABLE_NAME: temp_table_name,
STATEMENT_PARAMS.MATERIALIZATION_REASON: materialization_reason
if materialization_reason is not None
else STATEMENT_PARAMS.UNKNOWN,
}
statement_params.update(new_params)
session.sql(
f"CREATE OR REPLACE {get_temp_type_for_object(use_scoped_temp_objects=use_scoped_temp_table, is_generated=True)} TABLE {temp_table_name} AS {ctas_query}"
).collect(statement_params=statement_params)
table_name = temp_table_name
statement_params = get_default_snowpark_pandas_statement_params()
# record the actual table that the read only table is created on top of, and also the name of the
# read only table that is created.
statement_params.update(
{
STATEMENT_PARAMS.READONLY_SOURCE_TABLE_NAME: table_name,
STATEMENT_PARAMS.READONLY_TABLE_NAME: readonly_table_name,
}
)
# TODO (SNOW-1669224): pushing read only table creation down to snowpark for general usage
session.sql(
f"CREATE OR REPLACE {get_temp_type_for_object(use_scoped_temp_objects=use_scoped_temp_table, is_generated=True)} READ ONLY TABLE {readonly_table_name} CLONE {table_name}",
_emit_ast=False,
).collect(statement_params=statement_params, _emit_ast=False)
return readonly_table_name
def create_initial_ordered_dataframe(
table_name_or_query: Union[str, Iterable[str]],
enforce_ordering: bool,
) -> tuple[OrderedDataFrame, str]:
"""
create read only temp table on top of the existing table or Snowflake query if required, and create a OrderedDataFrame
with row position column using the read only temp table created or directly using the existing table.
Args:
table_name_or_query: A string or list of strings that specify the table name or
fully-qualified object identifier (database name, schema name, and table name) or SQL query.
enforce_ordering: If True, create a read only temp table on top of the existing table or Snowflake query,
and create the OrderedDataFrame using the read only temp table created.
Otherwise, directly using the existing table.
Returns:
OrderedDataFrame with row position column.
snowflake quoted identifier for the row position column.
"""
if not isinstance(table_name_or_query, str) and isinstance(
table_name_or_query, Iterable
):
table_name_or_query = ".".join(table_name_or_query)
session = pd.session
# `table_name_or_query` can be either a table name or a query. If it is a query of the form
# SELECT * FROM table, we can parse out the base table name, and treat it as though the user
# called `pd.read_snowflake("table")` instead of treating it as a SQL query, which will result
# in the materialization of an additional temporary table. Since that is the case, we first
# see if the coercion can happen, before determining if we are dealing with a query or not.
table_name_or_query = _extract_base_table_from_simple_select_star_query(
table_name_or_query
)
is_query = not _is_table_name(table_name_or_query)
if not is_query or not enforce_ordering:
if enforce_ordering:
try:
readonly_table_name = _create_read_only_table(
table_name=table_name_or_query,
materialize_into_temp_table=False,
)
except SnowparkSQLException as ex:
_logger.debug(
f"Failed to create read only table for {table_name_or_query}: {ex}"
)
# Creation of read only table fails for following cases which are not possible
# (or very difficult) to detect on client side in advance. We explicitly check
# for these errors and create a temporary table by copying the content of the
# original table and then create the read only table on the top of this
# temporary table.
# 1. Row access Policy:
# If the table has row access policy associated, read only table creation will
# fail. SNOW-850878 is created to support the query for row access policy on
# server side.
# 2. Table can not be cloned:
# Clone is not supported for tables that are imported from a share, views etc.
# 3. Table doesn't support read only table creation:
# Includes iceberg table, hybrid table etc.
known_errors = (
"Row access policy is not supported on read only table", # case 1
"Cannot clone", # case 2
"Unsupported feature", # case 3
"Clone Iceberg table should use CREATE ICEBERG TABLE CLONE command", # case 3
)
if any(error in ex.message for error in known_errors):
readonly_table_name = _create_read_only_table(
table_name=table_name_or_query,
materialize_into_temp_table=True,
materialization_reason=ex.message,
)
else:
raise SnowparkPandasException(
f"Failed to create Snowpark pandas DataFrame out of table {table_name_or_query} with error {ex}",
error_code=SnowparkPandasErrorCode.GENERAL_SQL_EXCEPTION.value,
) from ex
if is_query:
# If the string passed in to `pd.read_snowflake` is a SQL query, we can simply create
# a Snowpark DataFrame, and convert that to a Snowpark pandas DataFrame, and extract
# the OrderedDataFrame and row_position_snowflake_quoted_identifier from there.
# If there is an ORDER BY in the query, we should log it.
contains_order_by = _check_if_sql_query_contains_order_by_and_warn_user(
table_name_or_query
)
statement_params = get_default_snowpark_pandas_statement_params()
statement_params[STATEMENT_PARAMS.CONTAINS_ORDER_BY] = str(
contains_order_by
).upper()
initial_ordered_dataframe = OrderedDataFrame(
DataFrameReference(session.table(readonly_table_name, _emit_ast=False))
if enforce_ordering
else DataFrameReference(session.sql(table_name_or_query, _emit_ast=False))
if is_query
else DataFrameReference(session.table(table_name_or_query, _emit_ast=False))
)
# generate a snowflake quoted identifier for row position column that can be used for aliasing
snowflake_quoted_identifiers = (
initial_ordered_dataframe.projected_column_snowflake_quoted_identifiers
)
row_position_snowflake_quoted_identifier = (
initial_ordered_dataframe.generate_snowflake_quoted_identifiers(
pandas_labels=[ROW_POSITION_COLUMN_LABEL],
wrap_double_underscore=True,
)[0]
)
# create snowpark dataframe with columns: row_position_snowflake_quoted_identifier + snowflake_quoted_identifiers
# if no snowflake_quoted_identifiers is specified, all columns will be selected
if enforce_ordering:
row_position_column_str = f"{METADATA_ROW_POSITION_COLUMN} as {row_position_snowflake_quoted_identifier}"
else:
row_position_column_str = f"ROW_NUMBER() OVER (ORDER BY 1) - 1 as {row_position_snowflake_quoted_identifier}"
columns_to_select = ", ".join(
[row_position_column_str] + snowflake_quoted_identifiers
)
# Create or get the row position columns requires access to the metadata column of the table.
# However, snowpark_df = session().table(table_name) generates query (SELECT * from <table_name>),
# which creates a view without metadata column, we won't be able to access the metadata columns
# with the created snowpark dataframe. In order to get the metadata column access in the created
# dataframe, we create dataframe through sql which access the corresponding metadata column.
if enforce_ordering:
dataframe_sql = f"SELECT {columns_to_select} FROM {readonly_table_name}"
else:
dataframe_sql = f"SELECT {columns_to_select} FROM ({table_name_or_query})"
snowpark_df = session.sql(dataframe_sql, _emit_ast=False)
# assert dataframe_sql is None
result_columns_quoted_identifiers = [
row_position_snowflake_quoted_identifier
] + snowflake_quoted_identifiers
ordered_dataframe = OrderedDataFrame(
DataFrameReference(snowpark_df, result_columns_quoted_identifiers),
projected_column_snowflake_quoted_identifiers=result_columns_quoted_identifiers,
ordering_columns=[OrderingColumn(row_position_snowflake_quoted_identifier)],
row_position_snowflake_quoted_identifier=row_position_snowflake_quoted_identifier,
)
else:
assert is_query and enforce_ordering
# If the string passed in to `pd.read_snowflake` is a SQL query, we can simply create
# a Snowpark DataFrame, and convert that to a Snowpark pandas DataFrame, and extract
# the OrderedDataFrame and row_position_snowflake_quoted_identifier from there.
# If there is an ORDER BY in the query, we should log it.
contains_order_by = _check_if_sql_query_contains_order_by_and_warn_user(
table_name_or_query
)
statement_params = get_default_snowpark_pandas_statement_params()
statement_params[STATEMENT_PARAMS.CONTAINS_ORDER_BY] = str(
contains_order_by
).upper()
try:
# When we call `to_snowpark_pandas`, Snowpark will create a temporary table out of the
# Snowpark DataFrame, and then call `pd.read_snowflake` on it, which will create a
# Read only clone of that temporary table. We need to create the second table (instead
# of just using the temporary table Snowpark creates with a row position column as our
# backing table) because there are no guarantees that a temporary table cannot be modified
# so we lose the data isolation quality of pandas that we are attempting to replicate. By
# creating a read only clone, we ensure that the underlying data cannot be modified by anyone
# else.
snowpark_pandas_df = session.sql(table_name_or_query).to_snowpark_pandas(
enforce_ordering=enforce_ordering
)
except SnowparkSQLException as ex:
raise SnowparkPandasException(
f"Failed to create Snowpark pandas DataFrame out of query {table_name_or_query} with error {ex}",
error_code=SnowparkPandasErrorCode.GENERAL_SQL_EXCEPTION.value,
) from ex
ordered_dataframe = (
snowpark_pandas_df._query_compiler._modin_frame.ordered_dataframe
)
row_position_snowflake_quoted_identifier = (
ordered_dataframe.row_position_snowflake_quoted_identifier
)
# Set the materialized row count
materialized_row_count = ordered_dataframe._dataframe_ref.snowpark_dataframe.count(
_emit_ast=False
)
ordered_dataframe.row_count = materialized_row_count
ordered_dataframe.row_count_upper_bound = materialized_row_count
return ordered_dataframe, row_position_snowflake_quoted_identifier
def generate_snowflake_quoted_identifiers_helper(
*,
pandas_labels: list[Hashable],
excluded: Optional[list[str]] = None,
wrap_double_underscore: Optional[bool] = False,
) -> list[str]:
"""
Args:
pandas_labels: a list of pandas labels to generate snowflake quoted identifiers for.
For debug-ability the newly generated name will be generated by appending a number to resolve name
conflicts.
excluded: a list of snowflake quoted identifiers as strings. If not None, generated snowflake identifiers
can not be from this list and will not have conflicts among themselves.
When excluded is None, not conflict resolution happens, generated snowflake quoted snowflake identifiers
may have conflicts.
wrap_double_underscore: optional parameter to wrap the resolved prefix to produce a name '__<resolved prefix>__'
Generate a unique snowflake quoted identifier for each label in `pandas_labels`, that does not
conflict with identifiers submitted by `excluded`. If treating label as snowflake quoted identifier
leads to a conflict, attempt once to resolve conflict by appending a random suffix.
Fail when single attempt of appending a random suffix does not resolve conflict. The default
length of the random suffix is defined through _NUM_SUFFIX_DIGITS.
Examples:
generate_snowflake_quoted_identifiers(excluded=['"A"', '"B"', '"C"'], pandas_label='X')
returns ['"X"'] because it doesn't conflict with 'A', 'B', 'C'
generate_snowflake_quoted_identifiers(excluded=['"A"', '"B"', '"C"'], pandas_label='A')
returns ['"A_<random_suffix>"'] because '"A"' already exists
generate_snowflake_quoted_identifiers(excluded=['"__A__"', '"B"', '"C"'], pandas_label='A')
returns ['"A"']
generate_snowflake_quoted_identifiers(excluded=['"__A__"', '"B"', '"C"'], pandas_label='A', wrap_double_underscore=True)
returns ['"__A_<random_suffix>__"'] (e.g., <random_suffix> can be "a1b2") because
wrapping 'A' with __ would conflict with '"__A__"'
Raises:
ValueError if we fail to resolve conflict after appending the random suffix.
Returns:
A list of Snowflake quoted identifiers, that are conflict free.
Note:
There is a similar version of function inside OrderedDataFrame, which should be
used in general. This method should only be used when an OrderedDataFrame is not available,
e.g., in `from_pandas()`.
"""
resolve_conflicts = excluded is not None
if resolve_conflicts:
# verify that identifiers in 'excluded' are valid snowflake quoted identifiers.
for identifier in excluded: # type: ignore[union-attr]
if not is_valid_snowflake_quoted_identifier(identifier):
raise ValueError(
"'excluded' must have quoted identifiers."
f" Found unquoted identifier='{identifier}'"
)
excluded_set = set(excluded) # type: ignore[arg-type]
else:
excluded_set = set()
quoted_identifiers = []
for pandas_label in pandas_labels:
quoted_identifier = quote_name_without_upper_casing(
f"__{pandas_label}__" if wrap_double_underscore else f"{pandas_label}"
)
if resolve_conflicts:
num_retries = 0
while quoted_identifier in excluded_set and num_retries < _MAX_NUM_RETRIES:
if len(quoted_identifier) > _MAX_IDENTIFIER_LENGTH:
quoted_identifier = quote_name_without_upper_casing(
generate_column_identifier_random(_MAX_IDENTIFIER_LENGTH)
)
else:
suffix = generate_column_identifier_random()
quoted_identifier = quote_name_without_upper_casing(
f"__{pandas_label}_{suffix}__"
if wrap_double_underscore
else f"{pandas_label}_{suffix}"
)
num_retries += 1
if quoted_identifier in excluded_set:
raise ValueError(
f"Failed to generate quoted identifier for pandas label: '{pandas_label}' "
f"the generated identifier '{quoted_identifier}' conflicts with {excluded_set}"
)
quoted_identifiers.append(quoted_identifier)
excluded_set.add(quoted_identifier)
return quoted_identifiers
def generate_new_labels(
*, pandas_labels: list[Hashable], excluded: Optional[list[str]] = None
) -> list[str]:
"""
Helper function to generate new (string) pandas labels which do not conflict with the list of strings in excluded.
Args:
pandas_labels: a list of pandas labels to generate new string labels for.
excluded: a list of pandas string labels to exclude.
When excluded is None, not conflict resolution happens, generated string labels
may have conflicts.
"""
if not excluded:
return list(map(str, pandas_labels))
excluded_set = set(excluded)
new_labels = []
for pandas_label in pandas_labels:
new_label = f"{pandas_label}"
num_retries = 0
while new_label in excluded_set and num_retries < _MAX_NUM_RETRIES:
if len(new_label) > _MAX_IDENTIFIER_LENGTH:
new_label = generate_column_identifier_random(_MAX_IDENTIFIER_LENGTH)
else:
suffix = generate_column_identifier_random()
new_label = f"{pandas_label}_{suffix}"
num_retries += 1
if new_label in excluded_set:
raise ValueError( # pragma: no cover
f"Failed to generate string label {pandas_label} "
f"the generated label '{new_label}' conflicts with {excluded_set}"
)
new_labels.append(new_label)
excluded_set.add(new_label)
return new_labels
def serialize_pandas_labels(pandas_labels: list[Hashable]) -> list[str]:
"""
Serialize the hashable pandas labels into a string. If it is a tuple, then json serialization is used as a best
effort for better readability, however, if it fails, then the regular string representation is used.
"""
serialized_pandas_labels = []
for pandas_label in pandas_labels:
if isinstance(pandas_label, tuple):
try:
# We prefer a json compatible serialization of the pandas label for the column header so that we
# split the labels in SQL if needed. For example, in transpose operation a multi-level pandas label
# column name like (a,b) would become transposed into 2 new index columns (a) and (b) values. This
# currently does not handle cases where pandas label is not json serializable, so if there is a failure
# we will use the python string representation.
# TODO (SNOW-886400) Multi-level non-str pandas label not handled.
pandas_label = json.dumps(list(pandas_label))
except json.JSONDecodeError:
pass
except TypeError:
pass
pandas_label = str(pandas_label)
serialized_pandas_labels.append(pandas_label)
return serialized_pandas_labels
def is_json_serializable_pandas_labels(pandas_labels: list[Hashable]) -> bool:
"""
Returns True if all the pandas_labels can be json serialized.
"""
for pandas_label in pandas_labels:
try:
json.dumps(pandas_label)
except (TypeError, ValueError):
return False
return True
def unquote_name_if_quoted(name: str) -> str:
"""
For a given name unquote the name if the name is quoted, and also unescape the
quotes in the name.
"""
if name.startswith(DOUBLE_QUOTE) and name.endswith(DOUBLE_QUOTE):
name = name[1:-1]
return name.replace(DOUBLE_QUOTE + DOUBLE_QUOTE, DOUBLE_QUOTE)
def extract_pandas_label_from_snowflake_quoted_identifier(
snowflake_identifier: str,
) -> str:
"""
This function extracts pandas label from given snowflake identifier.
To extract pandas label from snowflake identifier we simply remove surrounding double quotes and unescape quotes.
Args:
snowflake_identifier: a quoted snowflake identifier, must be a quoted string.
Examples:
extract_pandas_label_from_snowflake_quoted_identifier('"abc"') -> 'abc'
extract_pandas_label_from_snowflake_quoted_identifier('"a""bc"') -> 'a"bc'
Returns:
pandas label.
"""
assert is_valid_snowflake_quoted_identifier(
snowflake_identifier
), f"invalid snowflake_identifier {snowflake_identifier}"
return snowflake_identifier[1:-1].replace(DOUBLE_QUOTE + DOUBLE_QUOTE, DOUBLE_QUOTE)
def get_snowflake_quoted_identifier_to_pandas_label_mapping(
original_snowflake_quoted_identifiers_list: list[str],
original_pandas_labels: list[str],
new_snowflake_quoted_identifiers_list: str,
) -> dict[str, str]:
"""
This function maps a list of snowflake_quoted_identifiers to the corresponding pandas label.
If the snowflake_quoted_identifier is found in the input query compiler's data column list, then,
the pandas label is returned from the query compiler. If it is not - we parse it and generate
the corresponding pandas label.
Args:
snowflake_quoted_identifiers: list of quoted snowflake identifiers, must be a list of quoted strings.
pandas_labels: list of pandas labels corresponding the quoted identifiers.
Returns:
map of snowflake_quoted_identifier to corresponding pandas label.
"""
ret_val = {}
qc_column_names_map = dict(
zip(original_snowflake_quoted_identifiers_list, original_pandas_labels)
)
for snowflake_quoted_identifier in new_snowflake_quoted_identifiers_list:
if snowflake_quoted_identifier in qc_column_names_map:
ret_val[snowflake_quoted_identifier] = qc_column_names_map[
snowflake_quoted_identifier
]
else:
ret_val[
snowflake_quoted_identifier
] = extract_pandas_label_from_snowflake_quoted_identifier(
snowflake_quoted_identifier
)
return ret_val
def parse_object_construct_snowflake_quoted_identifier_and_extract_pandas_label(
object_construct_snowflake_quoted_identifier: str,
num_levels: int,
) -> tuple[Hashable, dict[str, Any]]:
"""
This function parses the corresponding pandas label tuples as well as additional map keys from the json object
provided as a snowflake quoted identifier. This is done by parsing into a map and then extracting the pandas label
tuple using a 0-based integer index look up. Keys not related to the index look up (not in range(0,num_levels-1))
are returned as a separate key value map.
For example, '{"0":"abc","2":"ghi", "row": 10}' would return (("abc", None, "ghi"), {"row": 10})
For other examples of the indexed object construct extraction, see documentation for
extract_pandas_label_tuple_from_object_construct_snowflake_quoted_identifier
Arguments:
object_construct_snowflake_quoted_identifier: The snowflake quoted identifier.
num_levels: Number of levels in expected pandas labels
Returns:
Tuple containing the corresponding pandas labels and any other additional key, value pairs.
"""
obj_construct_map = parse_snowflake_object_construct_identifier_to_map(
object_construct_snowflake_quoted_identifier
)
pandas_label = extract_pandas_label_from_object_construct_map(
obj_construct_map, num_levels
)
other_kw_map = extract_non_pandas_label_from_object_construct_map(
obj_construct_map, num_levels
)
return (pandas_label, other_kw_map)
def parse_snowflake_object_construct_identifier_to_map(
object_construct_snowflake_quoted_identifier: str,
) -> dict[Any, Any]:
"""
Parse snowflake identifier based on object_construct (json object) into a map of the constituent parts.
Examples:
parse_snowflake_object_construct_identifier_to_map('"{""0"":""a"", ""2"":""b"", ""row"": 10}"')
returns {"0": "a", "2": "b", "row": 10}
Arguments:
object_construct_snowflake_quoted_identifier: snowflake quoted identifier
Returns:
dict of object_construct snowflake quoted identifier
"""
identifiers_by_level_json_obj = (
extract_pandas_label_from_snowflake_quoted_identifier(
object_construct_snowflake_quoted_identifier
)
)
identifiers_by_level_json_obj = convert_snowflake_string_constant_to_python_string(
identifiers_by_level_json_obj
)
# SNOW-853416: There are other encoding issues to handle, but this double slash decode is to undo
# escaping that happens as part of the pivot snowflake to display column name. For instance, a raw
# data value "shi\'ny" would pivot to column name "'shi\\''ny', we need to get back to the original.
identifiers_by_level_json_obj = identifiers_by_level_json_obj.replace("\\\\", "\\")
obj_construct_map = json.loads(identifiers_by_level_json_obj)
return obj_construct_map
def extract_pandas_label_from_object_construct_map(
obj_construct_map: dict[Any, Any], num_levels: int
) -> Hashable:
"""
Extract pandas label from object_construct map
Examples:
extract_pandas_label_tuple_from_object_construct_map({"0": "a"}, 1) returns ("a")
extract_pandas_label_tuple_from_object_construct_map({"0": "a", "1": "b", "row": 10}, 2) returns ("a", "b")
extract_pandas_label_tuple_from_object_construct_map({"0": "a", "1": "b", "2": "c"}, 3) returns ("a", "b", "c")
extract_pandas_label_tuple_from_object_construct_map({"0": "a", "2": "b", "row": 10}, 4)
returns ("a", None, "b", None)
Arguments:
obj_construct_map: Map of object_construct key, values including positional pandas label
num_levels: Number of pandas label levels
Returns:
pandas label extracted from object_construct map
"""
label_tuples = []
for level in range(num_levels):
level_str = str(level)
label_tuples.append(
obj_construct_map[level_str] if level_str in obj_construct_map else None
)
return to_pandas_label(tuple(label_tuples))
def extract_non_pandas_label_from_object_construct_map(
obj_construct_map: dict[Any, Any], num_levels: int
) -> dict[Any, Any]:
"""
Extract non-pandas label from object_construct map
Examples:
extract_non_pandas_label_from_object_construct_map({"0": "a", "1": "b"}, 2) returns {}
extract_non_pandas_label_from_object_construct_map({"0": "a", "foo": "bar"}, 2) returns {"foo": "bar"}
extract_non_pandas_label_from_object_construct_map({"0": "a", "1": "b", "foo": "bar"}, 2) returns {"foo": "bar"}
extract_non_pandas_label_from_object_construct_map({"2": "val"}, 2) returns {"2": "val"}
Arguments:
obj_construct_map: Map of object_construct key, values including positional pandas label as well as other
key, value pairs.
num_levels: Number of pandas label levels
Returns:
Key value information from object_construct map not related to pandas label
"""
non_pandas_label_map: dict[Any, Any] = {}
key_str_range = [str(num) for num in range(num_levels)]
for key in obj_construct_map.keys():
if key not in key_str_range:
non_pandas_label_map[key] = obj_construct_map[key]
return non_pandas_label_map
def extract_pandas_label_from_object_construct_snowflake_quoted_identifier(
object_construct_snowflake_quoted_identifier: str,
num_levels: int,
) -> Hashable:
"""
This function extracts the corresponding pandas labels from a snowflake quoted identifier which was constructed
via object_construct key:value mapping using a 0-based integer index value. The snowflake quoted identifier
is expected to be a valid json encoding. For example, '{"0":"abc","2":"ghi"}' would extract pandas labels
as ("abc", None, "ghi"), here are more examples:
Examples:
extract_pandas_label_from_object_construct_snowflake_quoted_identifier('{"0":"abc","1":"def"}', 2)
-> ("abc", "def")
extract_pandas_label_from_object_construct_snowflake_quoted_identifier('{"0":"ab\\"c","1":"def"}', 2)
-> ('ab"c', "def")
extract_pandas_label_from_object_construct_snowflake_quoted_identifier('{"0":"abc\\"\\"","1":"def"}', 2)
-> ('abc""', "def")
extract_pandas_label_from_object_construct_snowflake_quoted_identifier('{"0":"\\",\\"abc","1":"def"}', 2)
-> ('","abc', "def")
extract_pandas_label_from_object_construct_snowflake_quoted_identifier('{"0":"abc","2":"ghi"}', 3)
-> ("abc", None, "ghi")
extract_pandas_label_from_object_construct_snowflake_quoted_identifier('{"1":"def"}', 3) -> (None, "def", None)
extract_pandas_label_from_object_construct_snowflake_quoted_identifier("{}", 3) -> (None, None, None)
Arguments:
object_construct_snowflake_quoted_identifier: The snowflake quoted identifier.
num_levels: Number of levels in expected pandas labels
Returns:
Tuple containing the corresponding pandas labels.
"""
return parse_object_construct_snowflake_quoted_identifier_and_extract_pandas_label(
object_construct_snowflake_quoted_identifier, num_levels
)[0]
def is_valid_snowflake_quoted_identifier(identifier: str) -> bool:
"""
Check whether identifier is a properly quoted Snowflake identifier or not
Performs following checks:
1. Length must be > 2
2. Must have surrounding quotes.
2. Double quotes which are part of identifier must be properly escaped.
Args:
identifier: string representing a Snowflake identifier
Returns:
True if input string is properly quoted snowflake identifier, False otherwise
"""
return (
len(identifier) > 2
and identifier[0] == DOUBLE_QUOTE
and identifier[-1] == DOUBLE_QUOTE
and DOUBLE_QUOTE
not in identifier[1:-1].replace(DOUBLE_QUOTE + DOUBLE_QUOTE, EMPTY_STRING)
)
def traceback_as_str_from_exception(exception_object: Exception) -> str:
"""
return python traceback as string from exception object
Args:
exception_object: exception object
Returns: string containing description usually printed by interpreter to stderr
"""
exception_lines = traceback.format_exception(
None, exception_object, exception_object.__traceback__
)
formatted_traceback = "".join(exception_lines)
return formatted_traceback
def extract_all_duplicates(elements: Sequence[Hashable]) -> Sequence[Hashable]:
"""
Find duplicated elements for the given list of elements.
Args:
elements: the list of elements to check for duplications
Returns:
List[Hashable]
list of unique elements that contains duplication in the original list
"""
duplicated_elements = list(filter(lambda el: elements.count(el) > 1, elements))
unique_duplicated_elements = list(dict.fromkeys(duplicated_elements))
return unique_duplicated_elements
def is_duplicate_free(names: Sequence[Hashable]) -> bool:
"""
check whether names contains duplicates
Args:
names: sequence of hashable objects to check for duplicates based on __hash__ method. I.e., two
elements are considered duplicates if their hashes match.
Returns:
True if no duplicates, False else
"""
return len(extract_all_duplicates(names)) == 0
def assert_duplicate_free(names: Sequence[str], type: str) -> None:
"""
Checks going one-by-one through the sequence 'names', by comparing if the current given name in the sequence
has been seen before or not. An assertion error is produced containing information about all element in names
that have duplicates.
Args:
names: A sequence of strings to check for duplicates
type: a string to describe the elements when producing the assertion error.
Returns:
None
"""
duplicates = extract_all_duplicates(names)
if len(duplicates) == 1:
raise AssertionError(f"Found duplicate of type {type}: {duplicates[0]}")
elif len(duplicates) > 1:
raise AssertionError(f"Found duplicates of type {type}: {duplicates}")
def to_pandas_label(label: LabelTuple) -> Hashable:
"""
get the pandas label used for identify pandas column/rows in pandas dataframe
"""
assert (
len(label) >= 1
), "label in Snowpark pandas must have at least one label component"
if len(label) == 1:
return label[0]