1010from src .service .constants import INPUT_SUFFIX , METADATA_SUFFIX , OUTPUT_SUFFIX
1111from src .service .data .modelmesh_parser import ModelMeshPayloadParser
1212from src .service .data .storage import get_storage_interface
13- from src .service .utils .upload import (
14- handle_ground_truths ,
15- process_tensors ,
16- sanitize_id ,
17- save_model_data ,
18- validate_data_tag ,
19- validate_input_shapes ,
20- )
13+ from src .service .utils .upload import process_upload_request
2114
2215router = APIRouter ()
2316logger = logging .getLogger (__name__ )
2417
18+
2519class UploadPayload (BaseModel ):
2620 model_name : str
2721 data_tag : Optional [str ] = None
2822 is_ground_truth : bool = False
2923 request : Dict [str , Any ]
30- response : Dict [str , Any ]
24+ response : Optional [ Dict [str , Any ]] = None
3125
3226
3327@router .post ("/data/upload" )
3428async def upload (payload : UploadPayload ) -> Dict [str , str ]:
3529 """Upload model data - regular or ground truth."""
3630 try :
37- # Get fresh storage interface for each request
38- storage = get_storage_interface ()
39-
40- model_name = ModelMeshPayloadParser .standardize_model_id (payload .model_name )
41- if payload .data_tag and (error := validate_data_tag (payload .data_tag )):
42- raise HTTPException (400 , error )
43- inputs = payload .request .get ("inputs" , [])
44- outputs = payload .response .get ("outputs" , [])
45- if not inputs or not outputs :
46- raise HTTPException (400 , "Missing input or output tensors" )
47- input_arrays , input_names , _ , execution_ids = process_tensors (inputs )
48- output_arrays , output_names , _ , _ = process_tensors (outputs )
49- if error := validate_input_shapes (input_arrays , input_names ):
50- raise HTTPException (400 , f"One or more errors in input tensors: { error } " )
51-
52- if payload .is_ground_truth :
53- if not execution_ids :
54- raise HTTPException (400 , "Ground truth requires execution IDs" )
55- result = await handle_ground_truths (
56- model_name ,
57- input_arrays ,
58- input_names ,
59- output_arrays ,
60- output_names ,
61- [sanitize_id (id ) for id in execution_ids ],
62- )
63- if not result .success :
64- raise HTTPException (400 , result .message )
65- result_data = result .data
66- if result_data is None :
67- raise HTTPException (500 , "Ground truth processing failed" )
68- gt_name = f"{ model_name } _ground_truth"
69- await storage .write_data (gt_name + OUTPUT_SUFFIX , result_data ["outputs" ], result_data ["output_names" ])
70- await storage .write_data (
71- gt_name + METADATA_SUFFIX ,
72- result_data ["metadata" ],
73- result_data ["metadata_names" ],
74- )
75- return {"message" : result .message }
76- else :
77- n_rows = input_arrays [0 ].shape [0 ]
78- exec_ids = execution_ids or [str (uuid .uuid4 ()) for _ in range (n_rows )]
79-
80- def flatten (arrays : List [np .ndarray ], row : int ) -> List [Any ]:
81- return [x for arr in arrays for x in (arr [row ].flatten () if arr .ndim > 1 else [arr [row ]])]
82-
83- input_data = [flatten (input_arrays , i ) for i in range (n_rows )]
84- output_data = [flatten (output_arrays , i ) for i in range (n_rows )]
85- cols = ["id" , "model_id" , "timestamp" , "tag" ]
86- current_timestamp = datetime .now ().isoformat ()
87- metadata_rows = [
88- [
89- str (eid ),
90- str (model_name ),
91- str (current_timestamp ),
92- str (payload .data_tag or "" ),
93- ]
94- for eid in exec_ids
95- ]
96- metadata = np .array (metadata_rows , dtype = "<U100" )
97- await save_model_data (
98- model_name ,
99- np .array (input_data ),
100- input_names ,
101- np .array (output_data ),
102- output_names ,
103- metadata ,
104- cols ,
105- )
106- return {"message" : f"{ n_rows } datapoints added to { model_name } " }
107-
31+ logger .info (f"Received upload request for model: { payload .model_name } " )
32+ result = await process_upload_request (payload )
33+ logger .info (f"Upload completed for model: { payload .model_name } " )
34+ return result
10835 except HTTPException :
109- # Re-raise HTTP exceptions as-is
11036 raise
11137 except Exception as e :
11238 logger .error (f"Unexpected error in upload endpoint for model { payload .model_name } : { str (e )} " , exc_info = True )
113- raise HTTPException (500 , f"Internal server error: { str (e )} " )
39+ raise HTTPException (500 , f"Internal server error: { str (e )} " )
0 commit comments