Skip to content

Commit 7110081

Browse files
authored
invalid CSV input returns InvalidRequestException (#258)
* invalid CSV input returns invalidrequestexception * adding helper function for error handling * . * fixing logging statements
1 parent 4bb1d70 commit 7110081

File tree

1 file changed

+24
-9
lines changed

1 file changed

+24
-9
lines changed

model-engine/model_engine_server/domain/use_cases/llm_fine_tuning_use_cases.py

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import csv
22
import datetime
33
import re
4+
from typing import Optional
45

56
import smart_open
67
from model_engine_server.common.dtos.llms import (
@@ -48,7 +49,7 @@ def ensure_model_name_is_valid_k8s_label(model_name: str):
4849

4950
def read_csv_headers(file_location: str):
5051
"""
51-
Read the headers of a csv file. Assumes the file exists and is valid.
52+
Read the headers of a csv file.
5253
"""
5354
with smart_open.open(file_location, transport_params=dict(buffer_size=1024)) as file:
5455
csv_reader = csv.DictReader(file)
@@ -63,6 +64,26 @@ def are_dataset_headers_valid(file_location: str):
6364
return all(required_header in current_headers for required_header in REQUIRED_COLUMNS)
6465

6566

67+
def check_file_is_valid(file_name: Optional[str], file_type: str):
68+
"""
69+
Ensure the file is valid with required columns 'prompt' and 'response', isn't malformatted, and exists.
70+
file_type: 'training' or 'validation'
71+
"""
72+
try:
73+
if file_name is not None and not are_dataset_headers_valid(file_name):
74+
raise InvalidRequestException(
75+
f"Required column headers {','.join(REQUIRED_COLUMNS)} not found in {file_type} dataset"
76+
)
77+
except FileNotFoundError:
78+
raise InvalidRequestException(
79+
f"Cannot find the {file_type} file. Verify the path and file name are correct."
80+
)
81+
except csv.Error as exc:
82+
raise InvalidRequestException(
83+
f"Cannot parse the {file_type} dataset as CSV. Details: {exc}"
84+
)
85+
86+
6687
class CreateFineTuneV1UseCase:
6788
def __init__(
6889
self,
@@ -140,14 +161,8 @@ async def execute(self, user: User, request: CreateFineTuneRequest) -> CreateFin
140161
else:
141162
validation_file = request.validation_file
142163

143-
if training_file is not None and not are_dataset_headers_valid(training_file):
144-
raise InvalidRequestException(
145-
f"Required column headers {','.join(REQUIRED_COLUMNS)} not found in training dataset"
146-
)
147-
if validation_file is not None and not are_dataset_headers_valid(validation_file):
148-
raise InvalidRequestException(
149-
f"Required column headers {','.join(REQUIRED_COLUMNS)} not found in validation dataset"
150-
)
164+
check_file_is_valid(training_file, "training")
165+
check_file_is_valid(validation_file, "validation")
151166

152167
await self.llm_fine_tune_events_repository.initialize_events(user.team_id, fine_tuned_model)
153168
fine_tune_id = await self.llm_fine_tuning_service.create_fine_tune(

0 commit comments

Comments
 (0)