77from fastapi import APIRouter , HTTPException
88from pydantic import BaseModel
99
10- from src .service .constants import METADATA_SUFFIX , OUTPUT_SUFFIX
10+ from src .service .constants import INPUT_SUFFIX , METADATA_SUFFIX , OUTPUT_SUFFIX
1111from src .service .data .modelmesh_parser import ModelMeshPayloadParser
1212from src .service .data .storage import get_storage_interface
1313from src .service .utils .upload import (
2121
2222router = APIRouter ()
2323logger = logging .getLogger (__name__ )
24- storage = get_storage_interface ()
25-
2624
2725class UploadPayload (BaseModel ):
2826 model_name : str
@@ -35,69 +33,81 @@ class UploadPayload(BaseModel):
3533@router .post ("/data/upload" )
3634async def upload (payload : UploadPayload ) -> Dict [str , str ]:
3735 """Upload model data - regular or ground truth."""
38- model_name = ModelMeshPayloadParser .standardize_model_id (payload .model_name )
39- if payload .data_tag and (error := validate_data_tag (payload .data_tag )):
40- raise HTTPException (400 , error )
41- inputs = payload .request .get ("inputs" , [])
42- outputs = payload .response .get ("outputs" , [])
43- if not inputs or not outputs :
44- raise HTTPException (400 , "Missing input or output tensors" )
45- input_arrays , input_names , _ , execution_ids = process_tensors (inputs )
46- output_arrays , output_names , _ , _ = process_tensors (outputs )
47- if error := validate_input_shapes (input_arrays , input_names ):
48- raise HTTPException (400 , f"One or more errors in input tensors: { error } " )
49- if payload .is_ground_truth :
50- if not execution_ids :
51- raise HTTPException (400 , "Ground truth requires execution IDs" )
52- result = await handle_ground_truths (
53- model_name ,
54- input_arrays ,
55- input_names ,
56- output_arrays ,
57- output_names ,
58- [sanitize_id (id ) for id in execution_ids ],
59- )
60- if not result .success :
61- raise HTTPException (400 , result .message )
62- result_data = result .data
63- if result_data is None :
64- raise HTTPException (500 , "Ground truth processing failed" )
65- gt_name = f"{ model_name } _ground_truth"
66- await storage .write_data (gt_name + OUTPUT_SUFFIX , result_data ["outputs" ], result_data ["output_names" ])
67- await storage .write_data (
68- gt_name + METADATA_SUFFIX ,
69- result_data ["metadata" ],
70- result_data ["metadata_names" ],
71- )
72- return {"message" : result .message }
73- else :
74- n_rows = input_arrays [0 ].shape [0 ]
75- exec_ids = execution_ids or [str (uuid .uuid4 ()) for _ in range (n_rows )]
36+ try :
37+ # Get fresh storage interface for each request
38+ storage = get_storage_interface ()
39+
40+ model_name = ModelMeshPayloadParser .standardize_model_id (payload .model_name )
41+ if payload .data_tag and (error := validate_data_tag (payload .data_tag )):
42+ raise HTTPException (400 , error )
43+ inputs = payload .request .get ("inputs" , [])
44+ outputs = payload .response .get ("outputs" , [])
45+ if not inputs or not outputs :
46+ raise HTTPException (400 , "Missing input or output tensors" )
47+ input_arrays , input_names , _ , execution_ids = process_tensors (inputs )
48+ output_arrays , output_names , _ , _ = process_tensors (outputs )
49+ if error := validate_input_shapes (input_arrays , input_names ):
50+ raise HTTPException (400 , f"One or more errors in input tensors: { error } " )
51+
52+ if payload .is_ground_truth :
53+ if not execution_ids :
54+ raise HTTPException (400 , "Ground truth requires execution IDs" )
55+ result = await handle_ground_truths (
56+ model_name ,
57+ input_arrays ,
58+ input_names ,
59+ output_arrays ,
60+ output_names ,
61+ [sanitize_id (id ) for id in execution_ids ],
62+ )
63+ if not result .success :
64+ raise HTTPException (400 , result .message )
65+ result_data = result .data
66+ if result_data is None :
67+ raise HTTPException (500 , "Ground truth processing failed" )
68+ gt_name = f"{ model_name } _ground_truth"
69+ await storage .write_data (gt_name + OUTPUT_SUFFIX , result_data ["outputs" ], result_data ["output_names" ])
70+ await storage .write_data (
71+ gt_name + METADATA_SUFFIX ,
72+ result_data ["metadata" ],
73+ result_data ["metadata_names" ],
74+ )
75+ return {"message" : result .message }
76+ else :
77+ n_rows = input_arrays [0 ].shape [0 ]
78+ exec_ids = execution_ids or [str (uuid .uuid4 ()) for _ in range (n_rows )]
7679
77- def flatten (arrays : List [np .ndarray ], row : int ) -> List [Any ]:
78- return [x for arr in arrays for x in (arr [row ].flatten () if arr .ndim > 1 else [arr [row ]])]
80+ def flatten (arrays : List [np .ndarray ], row : int ) -> List [Any ]:
81+ return [x for arr in arrays for x in (arr [row ].flatten () if arr .ndim > 1 else [arr [row ]])]
7982
80- input_data = [flatten (input_arrays , i ) for i in range (n_rows )]
81- output_data = [flatten (output_arrays , i ) for i in range (n_rows )]
82- cols = ["id" , "model_id" , "timestamp" , "tag" ]
83- current_timestamp = datetime .now ().isoformat ()
84- metadata_rows = [
85- [
86- str (eid ),
87- str (model_name ),
88- str (current_timestamp ),
89- str (payload .data_tag or "" ),
83+ input_data = [flatten (input_arrays , i ) for i in range (n_rows )]
84+ output_data = [flatten (output_arrays , i ) for i in range (n_rows )]
85+ cols = ["id" , "model_id" , "timestamp" , "tag" ]
86+ current_timestamp = datetime .now ().isoformat ()
87+ metadata_rows = [
88+ [
89+ str (eid ),
90+ str (model_name ),
91+ str (current_timestamp ),
92+ str (payload .data_tag or "" ),
93+ ]
94+ for eid in exec_ids
9095 ]
91- for eid in exec_ids
92- ]
93- metadata = np .array (metadata_rows , dtype = "<U100" )
94- await save_model_data (
95- model_name ,
96- np .array (input_data ),
97- input_names ,
98- np .array (output_data ),
99- output_names ,
100- metadata ,
101- cols ,
102- )
103- return {"message" : f"{ n_rows } datapoints added to { model_name } " }
96+ metadata = np .array (metadata_rows , dtype = "<U100" )
97+ await save_model_data (
98+ model_name ,
99+ np .array (input_data ),
100+ input_names ,
101+ np .array (output_data ),
102+ output_names ,
103+ metadata ,
104+ cols ,
105+ )
106+ return {"message" : f"{ n_rows } datapoints added to { model_name } " }
107+
108+ except HTTPException :
109+ # Re-raise HTTP exceptions as-is
110+ raise
111+ except Exception as e :
112+ logger .error (f"Unexpected error in upload endpoint for model { payload .model_name } : { str (e )} " , exc_info = True )
113+ raise HTTPException (500 , f"Internal server error: { str (e )} " )
0 commit comments