11import csv
22import datetime
33import re
4+ from typing import Optional
45
56import smart_open
67from model_engine_server .common .dtos .llms import (
@@ -48,7 +49,7 @@ def ensure_model_name_is_valid_k8s_label(model_name: str):
4849
4950def read_csv_headers (file_location : str ):
5051 """
51- Read the headers of a csv file. Assumes the file exists and is valid.
52+ Read the headers of a csv file.
5253 """
5354 with smart_open .open (file_location , transport_params = dict (buffer_size = 1024 )) as file :
5455 csv_reader = csv .DictReader (file )
@@ -63,6 +64,26 @@ def are_dataset_headers_valid(file_location: str):
6364 return all (required_header in current_headers for required_header in REQUIRED_COLUMNS )
6465
6566
67+ def check_file_is_valid (file_name : Optional [str ], file_type : str ):
68+ """
69+ Ensure the file is valid with required columns 'prompt' and 'response', isn't malformatted, and exists.
70+ file_type: 'training' or 'validation'
71+ """
72+ try :
73+ if file_name is not None and not are_dataset_headers_valid (file_name ):
74+ raise InvalidRequestException (
75+ f"Required column headers { ',' .join (REQUIRED_COLUMNS )} not found in { file_type } dataset"
76+ )
77+ except FileNotFoundError :
78+ raise InvalidRequestException (
79+ f"Cannot find the { file_type } file. Verify the path and file name are correct."
80+ )
81+ except csv .Error as exc :
82+ raise InvalidRequestException (
83+ f"Cannot parse the { file_type } dataset as CSV. Details: { exc } "
84+ )
85+
86+
6687class CreateFineTuneV1UseCase :
6788 def __init__ (
6889 self ,
@@ -140,14 +161,8 @@ async def execute(self, user: User, request: CreateFineTuneRequest) -> CreateFin
140161 else :
141162 validation_file = request .validation_file
142163
143- if training_file is not None and not are_dataset_headers_valid (training_file ):
144- raise InvalidRequestException (
145- f"Required column headers { ',' .join (REQUIRED_COLUMNS )} not found in training dataset"
146- )
147- if validation_file is not None and not are_dataset_headers_valid (validation_file ):
148- raise InvalidRequestException (
149- f"Required column headers { ',' .join (REQUIRED_COLUMNS )} not found in validation dataset"
150- )
164+ check_file_is_valid (training_file , "training" )
165+ check_file_is_valid (validation_file , "validation" )
151166
152167 await self .llm_fine_tune_events_repository .initialize_events (user .team_id , fine_tuned_model )
153168 fine_tune_id = await self .llm_fine_tuning_service .create_fine_tune (
0 commit comments