36
36
GCS_FILE_HEADER_DIGEST = "gcs-file-header-digest"
37
37
GCS_FILE_HEADER_CONTENT_LENGTH = "gcs-file-header-content-length"
38
38
GCS_FILE_HEADER_ENCRYPTION_METADATA = "gcs-file-header-encryption-metadata"
39
+ GCS_REGION_ME_CENTRAL_2 = "me-central2"
39
40
CONTENT_CHUNK_SIZE = 10 * kilobyte
40
41
ACCESS_TOKEN = "GCS_ACCESS_TOKEN"
41
42
42
43
43
44
class GcsLocation (NamedTuple ):
44
45
bucket_name : str
45
46
path : str
47
+ endpoint : str = "https://storage.googleapis.com"
46
48
47
49
48
50
class SnowflakeGCSRestClient (SnowflakeStorageClient ):
@@ -53,7 +55,6 @@ def __init__(
53
55
stage_info : dict [str , Any ],
54
56
cnx : SnowflakeConnection ,
55
57
command : str ,
56
- use_s3_regional_url : bool = False ,
57
58
unsafe_file_write : bool = False ,
58
59
) -> None :
59
60
"""Creates a client object with given stage credentials.
@@ -79,6 +80,15 @@ def __init__(
79
80
# presigned_url in meta is for downloading
80
81
self .presigned_url : str = meta .presigned_url or stage_info .get ("presignedUrl" )
81
82
self .security_token = credentials .creds .get ("GCS_ACCESS_TOKEN" )
83
+ self .use_regional_url = (
84
+ "region" in stage_info
85
+ and stage_info ["region" ].lower () == GCS_REGION_ME_CENTRAL_2
86
+ or "useRegionalUrl" in stage_info
87
+ and stage_info ["useRegionalUrl" ]
88
+ )
89
+ self .endpoint : str | None = (
90
+ None if "endPoint" not in stage_info else stage_info ["endPoint" ]
91
+ )
82
92
83
93
if self .security_token :
84
94
logger .debug (f"len(GCS_ACCESS_TOKEN): { len (self .security_token )} " )
@@ -91,7 +101,7 @@ def _has_expired_token(self, response: requests.Response) -> bool:
91
101
92
102
def _has_expired_presigned_url (self , response : requests .Response ) -> bool :
93
103
# Presigned urls can be generated for any xml-api operation
94
- # offered by GCS. Hence the error codes expected are similar
104
+ # offered by GCS. Hence, the error codes expected are similar
95
105
# to xml api.
96
106
# https://cloud.google.com/storage/docs/xml-api/reference-status
97
107
@@ -152,7 +162,14 @@ def generate_url_and_rest_args() -> (
152
162
):
153
163
if not self .presigned_url :
154
164
upload_url = self .generate_file_url (
155
- self .stage_info ["location" ], meta .dst_file_name .lstrip ("/" )
165
+ self .stage_info ["location" ],
166
+ meta .dst_file_name .lstrip ("/" ),
167
+ self .use_regional_url ,
168
+ (
169
+ None
170
+ if "region" not in self .stage_info
171
+ else self .stage_info ["region" ]
172
+ ),
156
173
)
157
174
access_token = self .security_token
158
175
else :
@@ -182,7 +199,15 @@ def generate_url_and_rest_args() -> (
182
199
gcs_headers = {}
183
200
if not self .presigned_url :
184
201
download_url = self .generate_file_url (
185
- self .stage_info ["location" ], meta .src_file_name .lstrip ("/" )
202
+ self .stage_info ["location" ],
203
+ meta .src_file_name .lstrip ("/" ),
204
+ self .use_regional_url ,
205
+ (
206
+ None
207
+ if "region" not in self .stage_info
208
+ else self .stage_info ["region" ]
209
+ ),
210
+ self .endpoint ,
186
211
)
187
212
access_token = self .security_token
188
213
gcs_headers ["Authorization" ] = f"Bearer { access_token } "
@@ -339,7 +364,14 @@ def get_file_header(self, filename: str) -> FileHeader | None:
339
364
340
365
def generate_url_and_authenticated_headers ():
341
366
url = self .generate_file_url (
342
- self .stage_info ["location" ], filename .lstrip ("/" )
367
+ self .stage_info ["location" ],
368
+ filename .lstrip ("/" ),
369
+ self .use_regional_url ,
370
+ (
371
+ None
372
+ if "region" not in self .stage_info
373
+ else self .stage_info ["region" ]
374
+ ),
343
375
)
344
376
gcs_headers = {"Authorization" : f"Bearer { self .security_token } " }
345
377
rest_args = {"headers" : gcs_headers }
@@ -383,7 +415,13 @@ def generate_url_and_authenticated_headers():
383
415
return None
384
416
385
417
@staticmethod
386
- def extract_bucket_name_and_path (stage_location : str ) -> GcsLocation :
418
+ def get_location (
419
+ stage_location : str ,
420
+ use_regional_url : str = False ,
421
+ region : str = None ,
422
+ endpoint : str = None ,
423
+ use_virtual_endpoints : bool = False ,
424
+ ) -> GcsLocation :
387
425
container_name = stage_location
388
426
path = ""
389
427
@@ -393,13 +431,40 @@ def extract_bucket_name_and_path(stage_location: str) -> GcsLocation:
393
431
path = stage_location [stage_location .index ("/" ) + 1 :]
394
432
if path and not path .endswith ("/" ):
395
433
path += "/"
396
-
397
- return GcsLocation (bucket_name = container_name , path = path )
434
+ if endpoint :
435
+ if endpoint .endswith ("/" ):
436
+ endpoint = endpoint [:- 1 ]
437
+ return GcsLocation (bucket_name = container_name , path = path , endpoint = endpoint )
438
+ elif use_virtual_endpoints :
439
+ return GcsLocation (
440
+ bucket_name = container_name ,
441
+ path = path ,
442
+ endpoint = f"https://{ container_name } .storage.googleapis.com" ,
443
+ )
444
+ elif use_regional_url :
445
+ return GcsLocation (
446
+ bucket_name = container_name ,
447
+ path = path ,
448
+ endpoint = f"https://storage.{ region .lower ()} .rep.googleapis.com" ,
449
+ )
450
+ else :
451
+ return GcsLocation (bucket_name = container_name , path = path )
398
452
399
453
@staticmethod
400
- def generate_file_url (stage_location : str , filename : str ) -> str :
401
- gcs_location = SnowflakeGCSRestClient .extract_bucket_name_and_path (
402
- stage_location
454
+ def generate_file_url (
455
+ stage_location : str ,
456
+ filename : str ,
457
+ use_regional_url : str = False ,
458
+ region : str = None ,
459
+ endpoint : str = None ,
460
+ use_virtual_endpoints : bool = False ,
461
+ ) -> str :
462
+ gcs_location = SnowflakeGCSRestClient .get_location (
463
+ stage_location , use_regional_url , region , endpoint
403
464
)
404
465
full_file_path = f"{ gcs_location .path } { filename } "
405
- return f"https://storage.googleapis.com/{ gcs_location .bucket_name } /{ quote (full_file_path )} "
466
+
467
+ if use_virtual_endpoints :
468
+ return f"{ gcs_location .endpoint } /{ quote (full_file_path )} "
469
+ else :
470
+ return f"{ gcs_location .endpoint } /{ gcs_location .bucket_name } /{ quote (full_file_path )} "
0 commit comments