Skip to content

Commit f58e480

Browse files
committed
refactor
1 parent 8a84df2 commit f58e480

File tree

4 files changed

+26
-24
lines changed

4 files changed

+26
-24
lines changed

src/together/filemanager.py

Lines changed: 14 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from concurrent.futures import ThreadPoolExecutor, as_completed
1010
from functools import partial
1111
from pathlib import Path
12-
from typing import Any, Dict, List, Tuple, cast
12+
from typing import Any, Dict, List, Tuple
1313

1414
import requests
1515
from filelock import FileLock
@@ -34,6 +34,7 @@
3434
AuthenticationError,
3535
DownloadError,
3636
FileTypeError,
37+
ResponseError,
3738
)
3839
from together.together_response import TogetherResponse
3940
from together.types import (
@@ -352,7 +353,7 @@ def upload(
352353
)
353354
redirect_url, file_id = self.get_upload_url(url, file, purpose, filetype)
354355

355-
file_size = os.stat(file.as_posix()).st_size
356+
file_size = os.stat(file).st_size
356357

357358
with tqdm(
358359
total=file_size,
@@ -415,9 +416,8 @@ def upload(
415416
) -> FileResponse:
416417
"""Upload large file using multipart upload"""
417418

418-
file_size = os.stat(file.as_posix()).st_size
419+
file_size = os.stat(file).st_size
419420

420-
# Validate file size limits
421421
file_size_gb = file_size / NUM_BYTES_IN_GB
422422
if file_size_gb > MAX_FILE_SIZE_GB:
423423
raise FileTypeError(
@@ -427,41 +427,42 @@ def upload(
427427
part_size, num_parts = self._calculate_parts(file_size)
428428

429429
file_type = self._get_file_type(file)
430+
upload_info = None
430431

431432
try:
432-
# Phase 1: Initiate multipart upload
433433
upload_info = self._initiate_upload(
434434
url, file, file_size, num_parts, purpose, file_type
435435
)
436436

437-
# Phase 2: Upload parts concurrently
438437
completed_parts = self._upload_parts_concurrent(
439438
file, upload_info, part_size
440439
)
441440

442-
# Phase 3: Complete upload
443441
return self._complete_upload(
444442
url, upload_info["upload_id"], upload_info["file_id"], completed_parts
445443
)
446444

447445
except Exception as e:
448446
# Cleanup on failure
449-
if "upload_info" in locals():
447+
if upload_info is not None:
450448
self._abort_upload(
451449
url, upload_info["upload_id"], upload_info["file_id"]
452450
)
453451
raise e
454452

455453
def _get_file_type(self, file: Path) -> str:
456-
"""Get file type from extension, defaulting to jsonl as discussed in feedback"""
454+
"""Get file type from extension, raising ValueError for unsupported extensions"""
457455
if file.suffix == ".jsonl":
458456
return "jsonl"
459457
elif file.suffix == ".parquet":
460458
return "parquet"
461459
elif file.suffix == ".csv":
462460
return "csv"
463461
else:
464-
return "jsonl"
462+
raise ValueError(
463+
f"Unsupported file extension: '{file.suffix}'. "
464+
f"Supported extensions: .jsonl, .parquet, .csv"
465+
)
465466

466467
def _calculate_parts(self, file_size: int) -> tuple[int, int]:
467468
"""Calculate optimal part size and count"""
@@ -474,7 +475,6 @@ def _calculate_parts(self, file_size: int) -> tuple[int, int]:
474475
num_parts = min(MAX_MULTIPART_PARTS, math.ceil(file_size / target_part_size))
475476
part_size = math.ceil(file_size / num_parts)
476477

477-
# Ensure minimum part size
478478
if part_size < min_part_size:
479479
part_size = min_part_size
480480
num_parts = math.ceil(file_size / part_size)
@@ -489,7 +489,7 @@ def _initiate_upload(
489489
num_parts: int,
490490
purpose: FilePurpose,
491491
file_type: str,
492-
) -> Dict[str, Any]:
492+
) -> Any:
493493
"""Initiate multipart upload with backend"""
494494

495495
requestor = api_requestor.APIRequestor(client=self._client)
@@ -510,7 +510,7 @@ def _initiate_upload(
510510
),
511511
)
512512

513-
return cast(Dict[str, Any], response.data)
513+
return response.data
514514

515515
def _upload_parts_concurrent(
516516
self, file: Path, upload_info: Dict[str, Any], part_size: int
@@ -526,11 +526,9 @@ def _upload_parts_concurrent(
526526

527527
with open(file, "rb") as f:
528528
for part_info in parts:
529-
# Read part data
530529
f.seek((part_info["part_number"] - 1) * part_size)
531530
part_data = f.read(part_size)
532531

533-
# Submit upload task
534532
future = executor.submit(
535533
self._upload_single_part, part_info, part_data
536534
)
@@ -564,7 +562,7 @@ def _upload_single_part(self, part_info: Dict[str, Any], part_data: bytes) -> st
564562

565563
etag = response.headers.get("ETag", "").strip('"')
566564
if not etag:
567-
raise Exception(f"No ETag returned for part {part_info['part_number']}")
565+
raise ResponseError(f"No ETag returned for part {part_info['part_number']}")
568566

569567
return etag
570568

src/together/resources/files.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,7 @@ def upload(
4848

4949
assert isinstance(purpose, FilePurpose)
5050

51-
# Size-based routing: use multipart for files > 5GB
52-
file_size = os.stat(file.as_posix()).st_size
51+
file_size = os.stat(file).st_size
5352
file_size_gb = file_size / NUM_BYTES_IN_GB
5453

5554
if file_size_gb > MULTIPART_THRESHOLD_GB:

src/together/utils/files.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def check_file(
6565
else:
6666
report_dict["found"] = True
6767

68-
file_size = os.stat(file.as_posix()).st_size
68+
file_size = os.stat(file).st_size
6969

7070
if file_size > MAX_FILE_SIZE_GB * NUM_BYTES_IN_GB:
7171
report_dict["message"] = (

tests/unit/test_multipart_upload_manager.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from together.types import FilePurpose, FileResponse
1515
from together.types.common import ObjectType
1616
from together.together_response import TogetherResponse
17-
from together.error import FileTypeError
17+
from together.error import FileTypeError, ResponseError
1818

1919

2020
class TestMultipartUploadManager:
@@ -51,10 +51,15 @@ def test_get_file_type_csv(self, manager):
5151
file = Path("test.csv")
5252
assert manager._get_file_type(file) == "csv"
5353

54-
def test_get_file_type_unknown_defaults_to_jsonl(self, manager):
55-
"""Test that unknown file types default to jsonl"""
54+
def test_get_file_type_unknown_raises_error(self, manager):
55+
"""Test that unknown file types raise ValueError"""
5656
file = Path("test.txt")
57-
assert manager._get_file_type(file) == "jsonl"
57+
with pytest.raises(ValueError) as exc_info:
58+
manager._get_file_type(file)
59+
60+
error_message = str(exc_info.value)
61+
assert "Unsupported file extension: '.txt'" in error_message
62+
assert "Supported extensions: .jsonl, .parquet, .csv" in error_message
5863

5964
def test_calculate_parts_small_file(self, manager):
6065
"""Test part calculation for files smaller than target part size"""
@@ -186,7 +191,7 @@ def test_upload_single_part_no_etag_error(self, mock_put, manager):
186191
part_data = b"test data"
187192

188193
# Test
189-
with pytest.raises(Exception, match="No ETag returned for part 1"):
194+
with pytest.raises(ResponseError, match="No ETag returned for part 1"):
190195
manager._upload_single_part(part_info, part_data)
191196

192197
@patch("together.filemanager.api_requestor.APIRequestor")

0 commit comments

Comments
 (0)