diff --git a/api/python/quilt3/data_transfer.py b/api/python/quilt3/data_transfer.py index 2900f1c34ee..193bdb6debb 100644 --- a/api/python/quilt3/data_transfer.py +++ b/api/python/quilt3/data_transfer.py @@ -86,16 +86,26 @@ class S3ClientProvider: We assume that public buckets are read-only: write operations should always use S3ClientProvider.standard_client """ + _client_map = {} + + @classmethod + def set_s3_client(cls, bucket: str, client): + assert bucket is not None + cls._client_map[bucket] = client + def __init__(self): self._use_unsigned_client = {} # f'{action}/{bucket}' -> use_unsigned_client_bool self._standard_client = None self._unsigned_client = None - @property - def standard_client(self): - if self._standard_client is None: - self._build_standard_client() - return self._standard_client + def get_standard_client(self, bucket): + mapped_client = self.__class__._client_map.get(bucket) + if mapped_client is not None: + return mapped_client + else: + if self._standard_client is None: + self._build_standard_client() + return self._standard_client @property def unsigned_client(self): @@ -111,7 +121,7 @@ def get_correct_client(self, action: S3Api, bucket: str): if self.should_use_unsigned_client(action, bucket): return self.unsigned_client else: - return self.standard_client + return self.get_standard_client(bucket) def key(self, action: S3Api, bucket: str): return f"{action}/{bucket}" @@ -140,9 +150,9 @@ def find_correct_client(self, api_type, bucket, param_dict): f"API '{api_type}' is not current supported. You may want to use S3ClientProvider.standard_client " \ f"instead " check_fn = check_fn_mapper[api_type] - if check_fn(self.standard_client, param_dict): + if check_fn(self.get_standard_client(bucket), param_dict): self.set_cache(api_type, bucket, use_unsigned=False) - return self.standard_client + return self.get_standard_client(bucket) else: if check_fn(self.unsigned_client, param_dict): self.set_cache(api_type, bucket, use_unsigned=True) @@ -313,7 +323,7 @@ def _copy_local_file(ctx: WorkerContext, size: int, src_path: str, dest_path: st def _upload_file(ctx: WorkerContext, size: int, src_path: str, dest_bucket: str, dest_key: str): - s3_client = ctx.s3_client_provider.standard_client + s3_client = ctx.s3_client_provider.get_standard_client(dest_bucket) if not is_mpu(size): with ReadFileChunk.from_filename(src_path, 0, size, [ctx.progress]) as fd: @@ -470,7 +480,7 @@ def _copy_remote_file(ctx: WorkerContext, size: int, src_bucket: str, src_key: s VersionId=src_version ) - s3_client = ctx.s3_client_provider.standard_client + s3_client = ctx.s3_client_provider.get_standard_client(dest_bucket) if not is_mpu(size): params: Dict[str, Any] = dict( @@ -746,7 +756,7 @@ def _calculate_etag(file_path): def delete_object(bucket, key): - s3_client = S3ClientProvider().standard_client + s3_client = S3ClientProvider().get_standard_client(bucke) s3_client.head_object(Bucket=bucket, Key=key) # Make sure it exists s3_client.delete_object(Bucket=bucket, Key=key) # Actually delete it @@ -862,7 +872,7 @@ def delete_url(src: PhysicalKey): except FileNotFoundError: pass else: - s3_client = S3ClientProvider().standard_client + s3_client = S3ClientProvider().get_standard_client(src.bucket) s3_client.delete_object(Bucket=src.bucket, Key=src.path) @@ -925,7 +935,7 @@ def put_bytes(data: bytes, dest: PhysicalKey): else: if dest.version_id is not None: raise ValueError("Cannot set VersionId on destination") - s3_client = S3ClientProvider().standard_client + s3_client = S3ClientProvider().get_standard_client(dest.bucket) s3_client.put_object( Bucket=dest.bucket, Key=dest.path, @@ -1419,7 +1429,7 @@ def select(src, query, meta=None, raw=False, **kwargs): # S3 Select does not support anonymous access (as of Jan 2019) # https://docs.aws.amazon.com/AmazonS3/latest/API/API_SelectObjectContent.html - s3_client = S3ClientProvider().standard_client + s3_client = S3ClientProvider().get_standard_client(src.bucket) response = s3_client.select_object_content(**select_kwargs) # we don't want multiple copies of large chunks of data hanging around.