Skip to content

Commit d5e3f81

Browse files
committed
implement core multipart upload + test
1 parent 8745e78 commit d5e3f81

File tree

5 files changed

+894
-9
lines changed

5 files changed

+894
-9
lines changed

src/together/constants.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,20 @@
1515
DOWNLOAD_BLOCK_SIZE = 10 * 1024 * 1024 # 10 MB
1616
DISABLE_TQDM = False
1717

18+
# Upload defaults
19+
MAX_CONCURRENT_PARTS = 4 # Maximum concurrent parts for multipart upload
20+
21+
# Multipart upload constants
22+
MIN_PART_SIZE_MB = 5 # Minimum part size (S3 requirement)
23+
TARGET_PART_SIZE_MB = 100 # Target part size for optimal performance
24+
MAX_MULTIPART_PARTS = 250 # Maximum parts per upload (S3 limit)
25+
MULTIPART_UPLOAD_TIMEOUT = 300 # Timeout in seconds for uploading each part
26+
MULTIPART_THRESHOLD_GB = 5.0 # threshold for switching to multipart upload
27+
28+
# maximum number of GB sized files we support finetuning for
29+
MAX_FILE_SIZE_GB = 25.0
30+
31+
1832
# Messages
1933
MISSING_API_KEY_MESSAGE = """TOGETHER_API_KEY not found.
2034
Please set it as an environment variable or set it as together.api_key
@@ -26,8 +40,6 @@
2640
# the number of bytes in a gigabyte, used to convert bytes to GB for readable comparison
2741
NUM_BYTES_IN_GB = 2**30
2842

29-
# maximum number of GB sized files we support finetuning for
30-
MAX_FILE_SIZE_GB = 4.9
3143

3244
# expected columns for Parquet files
3345
PARQUET_EXPECTED_COLUMNS = ["input_ids", "attention_mask", "labels"]

src/together/filemanager.py

Lines changed: 227 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,34 @@
11
from __future__ import annotations
22

3+
import math
34
import os
45
import shutil
56
import stat
67
import tempfile
78
import uuid
9+
from concurrent.futures import ThreadPoolExecutor, as_completed
810
from functools import partial
911
from pathlib import Path
10-
from typing import Tuple
12+
from typing import Any, Dict, List, Tuple, cast
1113

1214
import requests
1315
from filelock import FileLock
1416
from requests.structures import CaseInsensitiveDict
1517
from tqdm import tqdm
16-
from tqdm.utils import CallbackIOWrapper
1718

18-
import together.utils
1919
from together.abstract import api_requestor
20-
from together.constants import DISABLE_TQDM, DOWNLOAD_BLOCK_SIZE, MAX_RETRIES
20+
from together.constants import (
21+
DISABLE_TQDM,
22+
DOWNLOAD_BLOCK_SIZE,
23+
MAX_CONCURRENT_PARTS,
24+
MAX_FILE_SIZE_GB,
25+
MAX_RETRIES,
26+
MIN_PART_SIZE_MB,
27+
NUM_BYTES_IN_GB,
28+
TARGET_PART_SIZE_MB,
29+
MAX_MULTIPART_PARTS,
30+
MULTIPART_UPLOAD_TIMEOUT,
31+
)
2132
from together.error import (
2233
APIError,
2334
AuthenticationError,
@@ -32,6 +43,8 @@
3243
TogetherClient,
3344
TogetherRequest,
3445
)
46+
from tqdm.utils import CallbackIOWrapper
47+
import together.utils
3548

3649

3750
def chmod_and_replace(src: Path, dst: Path) -> None:
@@ -385,3 +398,213 @@ def upload(
385398
assert isinstance(response, TogetherResponse)
386399

387400
return FileResponse(**response.data)
401+
402+
403+
class MultipartUploadManager:
404+
"""Handles multipart uploads for large files"""
405+
406+
def __init__(self, client: TogetherClient) -> None:
407+
self._client = client
408+
self.max_concurrent_parts = MAX_CONCURRENT_PARTS
409+
410+
def upload(
411+
self,
412+
url: str,
413+
file: Path,
414+
purpose: FilePurpose,
415+
) -> FileResponse:
416+
"""Upload large file using multipart upload"""
417+
418+
file_size = os.stat(file.as_posix()).st_size
419+
420+
# Validate file size limits
421+
file_size_gb = file_size / NUM_BYTES_IN_GB
422+
if file_size_gb > MAX_FILE_SIZE_GB:
423+
raise FileTypeError(
424+
f"File size {file_size_gb:.1f}GB exceeds maximum supported size of {MAX_FILE_SIZE_GB}GB"
425+
)
426+
427+
part_size, num_parts = self._calculate_parts(file_size)
428+
429+
file_type = self._get_file_type(file)
430+
431+
try:
432+
# Phase 1: Initiate multipart upload
433+
upload_info = self._initiate_upload(
434+
url, file, file_size, num_parts, purpose, file_type
435+
)
436+
437+
# Phase 2: Upload parts concurrently
438+
completed_parts = self._upload_parts_concurrent(
439+
file, upload_info, part_size
440+
)
441+
442+
# Phase 3: Complete upload
443+
return self._complete_upload(
444+
url, upload_info["upload_id"], upload_info["file_id"], completed_parts
445+
)
446+
447+
except Exception as e:
448+
# Cleanup on failure
449+
if "upload_info" in locals():
450+
self._abort_upload(
451+
url, upload_info["upload_id"], upload_info["file_id"]
452+
)
453+
raise e
454+
455+
def _get_file_type(self, file: Path) -> str:
456+
"""Get file type from extension, defaulting to jsonl as discussed in feedback"""
457+
if file.suffix == ".jsonl":
458+
return "jsonl"
459+
elif file.suffix == ".parquet":
460+
return "parquet"
461+
elif file.suffix == ".csv":
462+
return "csv"
463+
else:
464+
return "jsonl"
465+
466+
def _calculate_parts(self, file_size: int) -> tuple[int, int]:
467+
"""Calculate optimal part size and count"""
468+
min_part_size = MIN_PART_SIZE_MB * 1024 * 1024 # 5MB
469+
target_part_size = TARGET_PART_SIZE_MB * 1024 * 1024 # 100MB
470+
471+
if file_size <= target_part_size:
472+
return file_size, 1
473+
474+
num_parts = min(MAX_MULTIPART_PARTS, math.ceil(file_size / target_part_size))
475+
part_size = math.ceil(file_size / num_parts)
476+
477+
# Ensure minimum part size
478+
if part_size < min_part_size:
479+
part_size = min_part_size
480+
num_parts = math.ceil(file_size / part_size)
481+
482+
return part_size, num_parts
483+
484+
def _initiate_upload(
485+
self,
486+
url: str,
487+
file: Path,
488+
file_size: int,
489+
num_parts: int,
490+
purpose: FilePurpose,
491+
file_type: str,
492+
) -> Dict[str, Any]:
493+
"""Initiate multipart upload with backend"""
494+
495+
requestor = api_requestor.APIRequestor(client=self._client)
496+
497+
payload = {
498+
"file_name": file.name,
499+
"file_size": file_size,
500+
"num_parts": num_parts,
501+
"purpose": purpose.value,
502+
"file_type": file_type,
503+
}
504+
505+
response, _, _ = requestor.request(
506+
options=TogetherRequest(
507+
method="POST",
508+
url="files/multipart/initiate",
509+
params=payload,
510+
),
511+
)
512+
513+
return cast(Dict[str, Any], response.data)
514+
515+
def _upload_parts_concurrent(
516+
self, file: Path, upload_info: Dict[str, Any], part_size: int
517+
) -> List[Dict[str, Any]]:
518+
"""Upload file parts concurrently with progress tracking"""
519+
520+
parts = upload_info["parts"]
521+
completed_parts = []
522+
523+
with ThreadPoolExecutor(max_workers=self.max_concurrent_parts) as executor:
524+
with tqdm(total=len(parts), desc="Uploading parts", unit="part") as pbar:
525+
future_to_part = {}
526+
527+
with open(file, "rb") as f:
528+
for part_info in parts:
529+
# Read part data
530+
f.seek((part_info["part_number"] - 1) * part_size)
531+
part_data = f.read(part_size)
532+
533+
# Submit upload task
534+
future = executor.submit(
535+
self._upload_single_part, part_info, part_data
536+
)
537+
future_to_part[future] = part_info["part_number"]
538+
539+
# Collect results
540+
for future in as_completed(future_to_part):
541+
part_number = future_to_part[future]
542+
try:
543+
etag = future.result()
544+
completed_parts.append(
545+
{"part_number": part_number, "etag": etag}
546+
)
547+
pbar.update(1)
548+
except Exception as e:
549+
raise Exception(f"Failed to upload part {part_number}: {e}")
550+
551+
completed_parts.sort(key=lambda x: x["part_number"])
552+
return completed_parts
553+
554+
def _upload_single_part(self, part_info: Dict[str, Any], part_data: bytes) -> str:
555+
"""Upload a single part and return ETag"""
556+
557+
response = requests.put(
558+
part_info["url"],
559+
data=part_data,
560+
headers=part_info.get("headers", {}),
561+
timeout=MULTIPART_UPLOAD_TIMEOUT,
562+
)
563+
response.raise_for_status()
564+
565+
etag = response.headers.get("ETag", "").strip('"')
566+
if not etag:
567+
raise Exception(f"No ETag returned for part {part_info['part_number']}")
568+
569+
return etag
570+
571+
def _complete_upload(
572+
self, url: str, upload_id: str, file_id: str, completed_parts: List[Dict[str, Any]]
573+
) -> FileResponse:
574+
"""Complete the multipart upload"""
575+
576+
requestor = api_requestor.APIRequestor(client=self._client)
577+
578+
payload = {
579+
"upload_id": upload_id,
580+
"file_id": file_id,
581+
"parts": completed_parts,
582+
}
583+
584+
response, _, _ = requestor.request(
585+
options=TogetherRequest(
586+
method="POST",
587+
url="files/multipart/complete",
588+
params=payload,
589+
),
590+
)
591+
592+
return FileResponse(**response.data["file"])
593+
594+
def _abort_upload(self, url: str, upload_id: str, file_id: str) -> None:
595+
"""Abort the multipart upload"""
596+
597+
requestor = api_requestor.APIRequestor(client=self._client)
598+
599+
payload = {
600+
"upload_id": upload_id,
601+
"file_id": file_id,
602+
}
603+
604+
requestor.request(
605+
options=TogetherRequest(
606+
method="POST",
607+
url="files/multipart/abort",
608+
params=payload,
609+
),
610+
)

src/together/resources/files.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
from __future__ import annotations
22

3+
import os
34
from pathlib import Path
45
from pprint import pformat
56

67
from together.abstract import api_requestor
8+
from together.constants import MULTIPART_THRESHOLD_GB, NUM_BYTES_IN_GB
79
from together.error import FileTypeError
8-
from together.filemanager import DownloadManager, UploadManager
10+
from together.filemanager import DownloadManager, UploadManager, MultipartUploadManager
911
from together.together_response import TogetherResponse
1012
from together.types import (
1113
FileDeleteResponse,
@@ -30,7 +32,6 @@ def upload(
3032
purpose: FilePurpose | str = FilePurpose.FineTune,
3133
check: bool = True,
3234
) -> FileResponse:
33-
upload_manager = UploadManager(self._client)
3435

3536
if check and purpose == FilePurpose.FineTune:
3637
report_dict = check_file(file)
@@ -47,7 +48,16 @@ def upload(
4748

4849
assert isinstance(purpose, FilePurpose)
4950

50-
return upload_manager.upload("files", file, purpose=purpose, redirect=True)
51+
# Size-based routing: use multipart for files > 5GB
52+
file_size = os.stat(file.as_posix()).st_size
53+
file_size_gb = file_size / NUM_BYTES_IN_GB
54+
55+
if file_size_gb > MULTIPART_THRESHOLD_GB:
56+
multipart_manager = MultipartUploadManager(self._client)
57+
return multipart_manager.upload("files", file, purpose)
58+
else:
59+
upload_manager = UploadManager(self._client)
60+
return upload_manager.upload("files", file, purpose=purpose, redirect=True)
5161

5262
def list(self) -> FileList:
5363
requestor = api_requestor.APIRequestor(

0 commit comments

Comments
 (0)