11# endpoints/consumer.py
2- from fastapi import APIRouter , HTTPException
2+ import asyncio
3+ import time
4+ from datetime import datetime
5+
6+ import numpy as np
7+ from fastapi import APIRouter , HTTPException , Header
38from pydantic import BaseModel
4- from typing import Dict , Optional , Literal
9+ from typing import Dict , Optional , Literal , List , Union , Callable , Annotated
510import logging
611
12+ from src .service .constants import *
13+ from src .service .data .model_data import ModelData
14+ from src .service .data .storage import get_storage_interface
15+ from src .service .utils import list_utils
16+
717router = APIRouter ()
818logger = logging .getLogger (__name__ )
919
1020PartialKind = Literal ["request" , "response" ]
21+ storage_inferface = get_storage_interface ()
22+ unreconciled_inputs = {}
23+ unreconciled_outputs = {}
24+
1125
1226class PartialPayloadId (BaseModel ):
1327 pass
1428
29+
1530class InferencePartialPayload (BaseModel ):
1631 partialPayloadId : Optional [PartialPayloadId ] = None
1732 metadata : Optional [Dict [str , str ]] = None
@@ -21,6 +36,29 @@ class InferencePartialPayload(BaseModel):
2136 modelid : Optional [str ] = None
2237
2338
39+ class KServeData (BaseModel ):
40+ name : str
41+ shape : List [int ]
42+ datatype : str
43+ parameters : Optional [Dict [str , str ]] = None
44+ data : List
45+
46+
47+ class KServeInferenceRequest (BaseModel ):
48+ id : Optional [str ] = None
49+ parameters : Optional [Dict [str , str ]] = None
50+ inputs : List [KServeData ]
51+ outputs : Optional [List [KServeData ]] = None
52+
53+
54+ class KServeInferenceResponse (BaseModel ):
55+ model_name : str
56+ model_version : Optional [str ] = None
57+ id : Optional [str ] = None
58+ parameters : Optional [Dict [str , str ]] = None
59+ outputs : List [KServeData ]
60+
61+
2462@router .post ("/consumer/kserve/v2" )
2563async def consume_inference_payload (payload : InferencePartialPayload ):
2664 """Send a single input or output payload to TrustyAI."""
@@ -32,4 +70,131 @@ async def consume_inference_payload(payload: InferencePartialPayload):
3270 logger .error (f"Error processing inference payload: { str (e )} " )
3371 raise HTTPException (
3472 status_code = 500 , detail = f"Error processing payload: { str (e )} "
35- )
73+ )
74+
75+
76+ def reconcile_mismatching_shape_error (shape_tuples , payload_type , payload_id ):
77+ msg = (f"Could not reconcile KServe Inference { payload_id } , because { payload_type } shapes were mismatched. "
78+ f"When using multiple { payload_type } s to describe data columns, all shapes must match."
79+ f"However, the following tensor shapes were found:" )
80+ for i , (name , shape ) in enumerate (shape_tuples ):
81+ msg += f"\n { i } :\t { name } :\t { shape } "
82+ logger .error (msg )
83+ raise HTTPException (status_code = 400 , detail = msg )
84+
85+
86+ def reconcile_mismatching_row_count_error (payload_id , input_shape , output_shape ):
87+ msg = (f"Could not reconcile KServe Inference { payload_id } , because the number of "
88+ f"output rows ({ output_shape } ) did not match the number of input rows "
89+ f"({ input_shape } )." )
90+ logger .error (msg )
91+ raise HTTPException (status_code = 400 , detail = msg )
92+
93+
94+ def process_payload (payload , get_data : Callable , enforced_first_shape : int = None ):
95+ if len (get_data (payload )) > 1 : # multi tensor case: we have ncols of data of shape [nrows]
96+ data = []
97+ shapes = set ()
98+ shape_tuples = []
99+ column_names = []
100+ for kserve_data in get_data (payload ):
101+ data .append (kserve_data .data )
102+ shapes .add (tuple (kserve_data .data .shape ))
103+ column_names .append (kserve_data .name )
104+ shape_tuples .append ((kserve_data .data .name , kserve_data .data .shape ))
105+ if len (shapes ) == 1 :
106+ row_count = list (shapes )[0 ][0 ]
107+ if enforced_first_shape is not None and row_count != enforced_first_shape :
108+ reconcile_mismatching_row_count_error (payload .id , enforced_first_shape , row_count )
109+ if list_utils .contains_non_numeric (data ):
110+ return np .array (data , dtype = "O" ).T , column_names
111+ else :
112+ return np .array (data ).T , column_names
113+ else :
114+ reconcile_mismatching_shape_error (
115+ shape_tuples ,
116+ "input" if enforced_first_shape is None else "output" ,
117+ payload .id
118+ )
119+ else : # single tensor case: we have one tensor of shape [nrows, d1, d2, ...., dN]
120+ kserve_data : KServeData = get_data (payload )[0 ]
121+ if enforced_first_shape is not None and kserve_data .shape [0 ] != enforced_first_shape :
122+ reconcile_mismatching_row_count_error (payload .id , enforced_first_shape , kserve_data .shape [0 ])
123+
124+ if len (kserve_data .shape ) > 1 :
125+ column_names = ["{}-{}" .format (kserve_data .name , i ) for i in range (kserve_data .shape [1 ])]
126+ else :
127+ column_names = [kserve_data .name ]
128+ if list_utils .contains_non_numeric (kserve_data .data ):
129+ return np .array (kserve_data .data , dtype = "O" ), column_names
130+ else :
131+ return np .array (kserve_data .data ), column_names
132+
133+
134+ async def reconcile (input_payload : KServeInferenceRequest , output_payload : KServeInferenceResponse ):
135+ input_array , input_names = process_payload (input_payload , lambda p : p .inputs )
136+ output_array , output_names = process_payload (output_payload , lambda p : p .outputs , input_array .shape [0 ])
137+
138+ metadata_names = ["iso_time" , "unix_timestamp" , "tags" ]
139+ if input_payload .parameters is not None and input_payload .parameters .get (BIAS_IGNORE_PARAM , "false" ) == "true" :
140+ tags = [SYNTHETIC_TAG ]
141+ else :
142+ tags = [UNLABELED_TAG ]
143+ iso_time = datetime .isoformat (datetime .utcnow ())
144+ unix_timestamp = time .time ()
145+ metadata = np .array ([[iso_time , unix_timestamp , tags ]] * len (input_array ), dtype = "O" )
146+
147+ input_dataset = output_payload .model_name + INPUT_SUFFIX
148+ output_dataset = output_payload .model_name + OUTPUT_SUFFIX
149+ metadata_dataset = output_payload .model_name + METADATA_SUFFIX
150+
151+ async with asyncio .TaskGroup () as tg :
152+ tg .create_task (storage_inferface .write_data (input_dataset , input_array , input_names ))
153+ tg .create_task (storage_inferface .write_data (output_dataset , output_array , output_names ))
154+ tg .create_task (storage_inferface .write_data (metadata_dataset , metadata , metadata_names ))
155+
156+ shapes = await (ModelData (output_payload .model_name ).shapes ())
157+ logger .info (f"Successfully reconciled KServe inference { input_payload .id } , "
158+ f"consisting of { input_array .shape [0 ]:,} rows from { output_payload .model_name } ." )
159+ logger .debug (f"Current storage shapes for { output_payload .model_name } : "
160+ f"Inputs={ shapes [0 ]} , "
161+ f"Outputs={ shapes [1 ]} , "
162+ f"Metadata={ shapes [2 ]} " )
163+
164+
165+ @router .post ("/" )
166+ async def consume_cloud_event (payload : Union [KServeInferenceRequest , KServeInferenceResponse ],
167+ ce_id : Annotated [str | None , Header ()] = None ):
168+ # set payload if from cloud event header
169+ payload .id = ce_id
170+
171+ if isinstance (payload , KServeInferenceRequest ):
172+ if len (payload .inputs ) == 0 :
173+ msg = f"KServe Inference Input { payload .id } received, but data field was empty. Payload will not be saved."
174+ logger .error (msg )
175+ raise HTTPException (status_code = 400 , detail = msg )
176+ else :
177+ logger .info (f"KServe Inference Input { payload .id } received." )
178+ # if a match is found, the payload is auto-deleted from data
179+ partial_output = await storage_inferface .get_partial_payload (payload .id , is_input = False )
180+ if partial_output is not None :
181+ await reconcile (payload , partial_output )
182+ else :
183+ await storage_inferface .persist_partial_payload (payload , is_input = True )
184+ return {"status" : "success" , "message" : f"Input payload { payload .id } processed successfully" }
185+
186+ elif isinstance (payload , KServeInferenceResponse ):
187+ if len (payload .outputs ) == 0 :
188+ msg = (f"KServe Inference Output { payload .id } received from model={ payload .model_name } , "
189+ f"but data field was empty. Payload will not be saved." )
190+ logger .error (msg )
191+ raise HTTPException (status_code = 400 , detail = msg )
192+ else :
193+ logger .info (f"KServe Inference Output { payload .id } received from model={ payload .model_name } ." )
194+ partial_input = await storage_inferface .get_partial_payload (payload .id , is_input = True )
195+ if partial_input is not None :
196+ await reconcile (partial_input , payload )
197+ else :
198+ await storage_inferface .persist_partial_payload (payload , is_input = False )
199+
200+ return {"status" : "success" , "message" : f"Output payload { payload .id } processed successfully" }
0 commit comments