99from concurrent .futures import ThreadPoolExecutor , as_completed
1010from functools import partial
1111from pathlib import Path
12- from typing import Any , Dict , List , Tuple , cast
12+ from typing import Any , Dict , List , Tuple
1313
1414import requests
1515from filelock import FileLock
3434 AuthenticationError ,
3535 DownloadError ,
3636 FileTypeError ,
37+ ResponseError ,
3738)
3839from together .together_response import TogetherResponse
3940from together .types import (
@@ -352,7 +353,7 @@ def upload(
352353 )
353354 redirect_url , file_id = self .get_upload_url (url , file , purpose , filetype )
354355
355- file_size = os .stat (file . as_posix () ).st_size
356+ file_size = os .stat (file ).st_size
356357
357358 with tqdm (
358359 total = file_size ,
@@ -415,9 +416,8 @@ def upload(
415416 ) -> FileResponse :
416417 """Upload large file using multipart upload"""
417418
418- file_size = os .stat (file . as_posix () ).st_size
419+ file_size = os .stat (file ).st_size
419420
420- # Validate file size limits
421421 file_size_gb = file_size / NUM_BYTES_IN_GB
422422 if file_size_gb > MAX_FILE_SIZE_GB :
423423 raise FileTypeError (
@@ -427,41 +427,42 @@ def upload(
427427 part_size , num_parts = self ._calculate_parts (file_size )
428428
429429 file_type = self ._get_file_type (file )
430+ upload_info = None
430431
431432 try :
432- # Phase 1: Initiate multipart upload
433433 upload_info = self ._initiate_upload (
434434 url , file , file_size , num_parts , purpose , file_type
435435 )
436436
437- # Phase 2: Upload parts concurrently
438437 completed_parts = self ._upload_parts_concurrent (
439438 file , upload_info , part_size
440439 )
441440
442- # Phase 3: Complete upload
443441 return self ._complete_upload (
444442 url , upload_info ["upload_id" ], upload_info ["file_id" ], completed_parts
445443 )
446444
447445 except Exception as e :
448446 # Cleanup on failure
449- if " upload_info" in locals () :
447+ if upload_info is not None :
450448 self ._abort_upload (
451449 url , upload_info ["upload_id" ], upload_info ["file_id" ]
452450 )
453451 raise e
454452
455453 def _get_file_type (self , file : Path ) -> str :
456- """Get file type from extension, defaulting to jsonl as discussed in feedback """
454+ """Get file type from extension, raising ValueError for unsupported extensions """
457455 if file .suffix == ".jsonl" :
458456 return "jsonl"
459457 elif file .suffix == ".parquet" :
460458 return "parquet"
461459 elif file .suffix == ".csv" :
462460 return "csv"
463461 else :
464- return "jsonl"
462+ raise ValueError (
463+ f"Unsupported file extension: '{ file .suffix } '. "
464+ f"Supported extensions: .jsonl, .parquet, .csv"
465+ )
465466
466467 def _calculate_parts (self , file_size : int ) -> tuple [int , int ]:
467468 """Calculate optimal part size and count"""
@@ -474,7 +475,6 @@ def _calculate_parts(self, file_size: int) -> tuple[int, int]:
474475 num_parts = min (MAX_MULTIPART_PARTS , math .ceil (file_size / target_part_size ))
475476 part_size = math .ceil (file_size / num_parts )
476477
477- # Ensure minimum part size
478478 if part_size < min_part_size :
479479 part_size = min_part_size
480480 num_parts = math .ceil (file_size / part_size )
@@ -489,7 +489,7 @@ def _initiate_upload(
489489 num_parts : int ,
490490 purpose : FilePurpose ,
491491 file_type : str ,
492- ) -> Dict [ str , Any ] :
492+ ) -> Any :
493493 """Initiate multipart upload with backend"""
494494
495495 requestor = api_requestor .APIRequestor (client = self ._client )
@@ -510,7 +510,7 @@ def _initiate_upload(
510510 ),
511511 )
512512
513- return cast ( Dict [ str , Any ], response .data )
513+ return response .data
514514
515515 def _upload_parts_concurrent (
516516 self , file : Path , upload_info : Dict [str , Any ], part_size : int
@@ -526,11 +526,9 @@ def _upload_parts_concurrent(
526526
527527 with open (file , "rb" ) as f :
528528 for part_info in parts :
529- # Read part data
530529 f .seek ((part_info ["part_number" ] - 1 ) * part_size )
531530 part_data = f .read (part_size )
532531
533- # Submit upload task
534532 future = executor .submit (
535533 self ._upload_single_part , part_info , part_data
536534 )
@@ -564,7 +562,7 @@ def _upload_single_part(self, part_info: Dict[str, Any], part_data: bytes) -> st
564562
565563 etag = response .headers .get ("ETag" , "" ).strip ('"' )
566564 if not etag :
567- raise Exception (f"No ETag returned for part { part_info ['part_number' ]} " )
565+ raise ResponseError (f"No ETag returned for part { part_info ['part_number' ]} " )
568566
569567 return etag
570568
0 commit comments