Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 24 additions & 14 deletions api/python/quilt3/data_transfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,16 +86,26 @@ class S3ClientProvider:
We assume that public buckets are read-only: write operations should always use S3ClientProvider.standard_client
"""

_client_map = {}

@classmethod
def set_s3_client(cls, bucket: str, client):
assert bucket is not None
cls._client_map[bucket] = client

def __init__(self):
self._use_unsigned_client = {} # f'{action}/{bucket}' -> use_unsigned_client_bool
self._standard_client = None
self._unsigned_client = None

@property
def standard_client(self):
if self._standard_client is None:
self._build_standard_client()
return self._standard_client
def get_standard_client(self, bucket):
mapped_client = self.__class__._client_map.get(bucket)
if mapped_client is not None:
return mapped_client
else:
if self._standard_client is None:
self._build_standard_client()
return self._standard_client

@property
def unsigned_client(self):
Expand All @@ -111,7 +121,7 @@ def get_correct_client(self, action: S3Api, bucket: str):
if self.should_use_unsigned_client(action, bucket):
return self.unsigned_client
else:
return self.standard_client
return self.get_standard_client(bucket)

def key(self, action: S3Api, bucket: str):
return f"{action}/{bucket}"
Expand Down Expand Up @@ -140,9 +150,9 @@ def find_correct_client(self, api_type, bucket, param_dict):
f"API '{api_type}' is not current supported. You may want to use S3ClientProvider.standard_client " \
f"instead "
check_fn = check_fn_mapper[api_type]
if check_fn(self.standard_client, param_dict):
if check_fn(self.get_standard_client(bucket), param_dict):
self.set_cache(api_type, bucket, use_unsigned=False)
return self.standard_client
return self.get_standard_client(bucket)
else:
if check_fn(self.unsigned_client, param_dict):
self.set_cache(api_type, bucket, use_unsigned=True)
Expand Down Expand Up @@ -313,7 +323,7 @@ def _copy_local_file(ctx: WorkerContext, size: int, src_path: str, dest_path: st


def _upload_file(ctx: WorkerContext, size: int, src_path: str, dest_bucket: str, dest_key: str):
s3_client = ctx.s3_client_provider.standard_client
s3_client = ctx.s3_client_provider.get_standard_client(dest_bucket)

if not is_mpu(size):
with ReadFileChunk.from_filename(src_path, 0, size, [ctx.progress]) as fd:
Expand Down Expand Up @@ -470,7 +480,7 @@ def _copy_remote_file(ctx: WorkerContext, size: int, src_bucket: str, src_key: s
VersionId=src_version
)

s3_client = ctx.s3_client_provider.standard_client
s3_client = ctx.s3_client_provider.get_standard_client(dest_bucket)

if not is_mpu(size):
params: Dict[str, Any] = dict(
Expand Down Expand Up @@ -746,7 +756,7 @@ def _calculate_etag(file_path):


def delete_object(bucket, key):
s3_client = S3ClientProvider().standard_client
s3_client = S3ClientProvider().get_standard_client(bucke)

s3_client.head_object(Bucket=bucket, Key=key) # Make sure it exists
s3_client.delete_object(Bucket=bucket, Key=key) # Actually delete it
Expand Down Expand Up @@ -862,7 +872,7 @@ def delete_url(src: PhysicalKey):
except FileNotFoundError:
pass
else:
s3_client = S3ClientProvider().standard_client
s3_client = S3ClientProvider().get_standard_client(src.bucket)
s3_client.delete_object(Bucket=src.bucket, Key=src.path)


Expand Down Expand Up @@ -925,7 +935,7 @@ def put_bytes(data: bytes, dest: PhysicalKey):
else:
if dest.version_id is not None:
raise ValueError("Cannot set VersionId on destination")
s3_client = S3ClientProvider().standard_client
s3_client = S3ClientProvider().get_standard_client(dest.bucket)
s3_client.put_object(
Bucket=dest.bucket,
Key=dest.path,
Expand Down Expand Up @@ -1419,7 +1429,7 @@ def select(src, query, meta=None, raw=False, **kwargs):

# S3 Select does not support anonymous access (as of Jan 2019)
# https://docs.aws.amazon.com/AmazonS3/latest/API/API_SelectObjectContent.html
s3_client = S3ClientProvider().standard_client
s3_client = S3ClientProvider().get_standard_client(src.bucket)
response = s3_client.select_object_content(**select_kwargs)

# we don't want multiple copies of large chunks of data hanging around.
Expand Down
Loading