3636GCS_FILE_HEADER_DIGEST = "gcs-file-header-digest"
3737GCS_FILE_HEADER_CONTENT_LENGTH = "gcs-file-header-content-length"
3838GCS_FILE_HEADER_ENCRYPTION_METADATA = "gcs-file-header-encryption-metadata"
39+ GCS_REGION_ME_CENTRAL_2 = "me-central2"
3940CONTENT_CHUNK_SIZE = 10 * kilobyte
4041ACCESS_TOKEN = "GCS_ACCESS_TOKEN"
4142
4243
4344class GcsLocation (NamedTuple ):
4445 bucket_name : str
4546 path : str
47+ endpoint : str = "https://storage.googleapis.com"
4648
4749
4850class SnowflakeGCSRestClient (SnowflakeStorageClient ):
@@ -53,7 +55,6 @@ def __init__(
5355 stage_info : dict [str , Any ],
5456 cnx : SnowflakeConnection ,
5557 command : str ,
56- use_s3_regional_url : bool = False ,
5758 unsafe_file_write : bool = False ,
5859 ) -> None :
5960 """Creates a client object with given stage credentials.
@@ -79,6 +80,15 @@ def __init__(
7980 # presigned_url in meta is for downloading
8081 self .presigned_url : str = meta .presigned_url or stage_info .get ("presignedUrl" )
8182 self .security_token = credentials .creds .get ("GCS_ACCESS_TOKEN" )
83+ self .use_regional_url = (
84+ "region" in stage_info
85+ and stage_info ["region" ].lower () == GCS_REGION_ME_CENTRAL_2
86+ or "useRegionalUrl" in stage_info
87+ and stage_info ["useRegionalUrl" ]
88+ )
89+ self .endpoint : str | None = (
90+ None if "endPoint" not in stage_info else stage_info ["endPoint" ]
91+ )
8292
8393 if self .security_token :
8494 logger .debug (f"len(GCS_ACCESS_TOKEN): { len (self .security_token )} " )
@@ -91,7 +101,7 @@ def _has_expired_token(self, response: requests.Response) -> bool:
91101
92102 def _has_expired_presigned_url (self , response : requests .Response ) -> bool :
93103 # Presigned urls can be generated for any xml-api operation
94- # offered by GCS. Hence the error codes expected are similar
104+ # offered by GCS. Hence, the error codes expected are similar
95105 # to xml api.
96106 # https://cloud.google.com/storage/docs/xml-api/reference-status
97107
@@ -152,7 +162,14 @@ def generate_url_and_rest_args() -> (
152162 ):
153163 if not self .presigned_url :
154164 upload_url = self .generate_file_url (
155- self .stage_info ["location" ], meta .dst_file_name .lstrip ("/" )
165+ self .stage_info ["location" ],
166+ meta .dst_file_name .lstrip ("/" ),
167+ self .use_regional_url ,
168+ (
169+ None
170+ if "region" not in self .stage_info
171+ else self .stage_info ["region" ]
172+ ),
156173 )
157174 access_token = self .security_token
158175 else :
@@ -182,7 +199,15 @@ def generate_url_and_rest_args() -> (
182199 gcs_headers = {}
183200 if not self .presigned_url :
184201 download_url = self .generate_file_url (
185- self .stage_info ["location" ], meta .src_file_name .lstrip ("/" )
202+ self .stage_info ["location" ],
203+ meta .src_file_name .lstrip ("/" ),
204+ self .use_regional_url ,
205+ (
206+ None
207+ if "region" not in self .stage_info
208+ else self .stage_info ["region" ]
209+ ),
210+ self .endpoint ,
186211 )
187212 access_token = self .security_token
188213 gcs_headers ["Authorization" ] = f"Bearer { access_token } "
@@ -339,7 +364,14 @@ def get_file_header(self, filename: str) -> FileHeader | None:
339364
340365 def generate_url_and_authenticated_headers ():
341366 url = self .generate_file_url (
342- self .stage_info ["location" ], filename .lstrip ("/" )
367+ self .stage_info ["location" ],
368+ filename .lstrip ("/" ),
369+ self .use_regional_url ,
370+ (
371+ None
372+ if "region" not in self .stage_info
373+ else self .stage_info ["region" ]
374+ ),
343375 )
344376 gcs_headers = {"Authorization" : f"Bearer { self .security_token } " }
345377 rest_args = {"headers" : gcs_headers }
@@ -383,7 +415,13 @@ def generate_url_and_authenticated_headers():
383415 return None
384416
385417 @staticmethod
386- def extract_bucket_name_and_path (stage_location : str ) -> GcsLocation :
418+ def get_location (
419+ stage_location : str ,
420+ use_regional_url : str = False ,
421+ region : str = None ,
422+ endpoint : str = None ,
423+ use_virtual_endpoints : bool = False ,
424+ ) -> GcsLocation :
387425 container_name = stage_location
388426 path = ""
389427
@@ -393,13 +431,40 @@ def extract_bucket_name_and_path(stage_location: str) -> GcsLocation:
393431 path = stage_location [stage_location .index ("/" ) + 1 :]
394432 if path and not path .endswith ("/" ):
395433 path += "/"
396-
397- return GcsLocation (bucket_name = container_name , path = path )
434+ if endpoint :
435+ if endpoint .endswith ("/" ):
436+ endpoint = endpoint [:- 1 ]
437+ return GcsLocation (bucket_name = container_name , path = path , endpoint = endpoint )
438+ elif use_virtual_endpoints :
439+ return GcsLocation (
440+ bucket_name = container_name ,
441+ path = path ,
442+ endpoint = f"https://{ container_name } .storage.googleapis.com" ,
443+ )
444+ elif use_regional_url :
445+ return GcsLocation (
446+ bucket_name = container_name ,
447+ path = path ,
448+ endpoint = f"https://storage.{ region .lower ()} .rep.googleapis.com" ,
449+ )
450+ else :
451+ return GcsLocation (bucket_name = container_name , path = path )
398452
399453 @staticmethod
400- def generate_file_url (stage_location : str , filename : str ) -> str :
401- gcs_location = SnowflakeGCSRestClient .extract_bucket_name_and_path (
402- stage_location
454+ def generate_file_url (
455+ stage_location : str ,
456+ filename : str ,
457+ use_regional_url : str = False ,
458+ region : str = None ,
459+ endpoint : str = None ,
460+ use_virtual_endpoints : bool = False ,
461+ ) -> str :
462+ gcs_location = SnowflakeGCSRestClient .get_location (
463+ stage_location , use_regional_url , region , endpoint
403464 )
404465 full_file_path = f"{ gcs_location .path } { filename } "
405- return f"https://storage.googleapis.com/{ gcs_location .bucket_name } /{ quote (full_file_path )} "
466+
467+ if use_virtual_endpoints :
468+ return f"{ gcs_location .endpoint } /{ quote (full_file_path )} "
469+ else :
470+ return f"{ gcs_location .endpoint } /{ gcs_location .bucket_name } /{ quote (full_file_path )} "
0 commit comments