Skip to content

Commit 5a3d20c

Browse files
committed
Add KServe payload parsing and data storage
1 parent cf397ba commit 5a3d20c

File tree

16 files changed

+881
-77
lines changed

16 files changed

+881
-77
lines changed

Dockerfile

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,13 @@ WORKDIR /app
66

77
COPY pyproject.toml poetry.lock* ./
88

9-
RUN pip install poetry==1.6.1
10-
11-
RUN poetry export -f requirements.txt --without dev > requirements.txt && \
9+
USER root
10+
RUN pip install poetry==1.6.1 && \
11+
poetry export -f requirements.txt --without dev > requirements.txt && \
1212
pip install --no-cache-dir -r requirements.txt
13-
1413
COPY . .
14+
USER 1001
15+
EXPOSE 4443
1516

16-
EXPOSE 8080
17-
18-
CMD ["uvicorn", "src.main:app", "--host", "0.0.0.0", "--port", "8080"]
17+
#CMD ["uvicorn", "src.main:app", "--host", "0.0.0.0", "--port", "4443", "--ssl-keyfile", "/etc/tls/internal/tls.key", "--ssl-certfile", "/etc/tls/internal/tls.crt", "--log-level", "trace"]
18+
CMD ["uvicorn", "src.main:app", "--host", "0.0.0.0", "--port", "4443", "--ssl-keyfile", "/etc/tls/internal/tls.key", "--ssl-certfile", "/etc/tls/internal/tls.crt"]

poetry.lock

Lines changed: 202 additions & 61 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,10 @@ pandas = "^2.2.3"
1313
prometheus-client = "^0.21.1"
1414
pydantic = "^2.4.2"
1515
uvicorn = "^0.34.0"
16+
cryptography = "^44.0.2"
1617
protobuf = "^4.24.4"
1718
requests = "^2.31.0"
19+
h5py = "^3.13.0"
1820

1921
[tool.poetry.group.dev.dependencies]
2022
pytest = "^7.4.2"

src/__init__.py

Whitespace-only changes.

src/core/__init__.py

Lines changed: 0 additions & 4 deletions
This file was deleted.

src/endpoints/consumer.py

Lines changed: 168 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,32 @@
11
# endpoints/consumer.py
2-
from fastapi import APIRouter, HTTPException
2+
import asyncio
3+
import time
4+
from datetime import datetime
5+
6+
import numpy as np
7+
from fastapi import APIRouter, HTTPException, Header
38
from pydantic import BaseModel
4-
from typing import Dict, Optional, Literal
9+
from typing import Dict, Optional, Literal, List, Union, Callable, Annotated
510
import logging
611

12+
from src.service.constants import *
13+
from src.service.data.model_data import ModelData
14+
from src.service.data.storage import get_storage_interface
15+
from src.service.utils import list_utils
16+
717
router = APIRouter()
818
logger = logging.getLogger(__name__)
919

1020
PartialKind = Literal["request", "response"]
21+
storage_inferface = get_storage_interface()
22+
unreconciled_inputs = {}
23+
unreconciled_outputs = {}
24+
1125

1226
class PartialPayloadId(BaseModel):
1327
pass
1428

29+
1530
class InferencePartialPayload(BaseModel):
1631
partialPayloadId: Optional[PartialPayloadId] = None
1732
metadata: Optional[Dict[str, str]] = None
@@ -21,6 +36,29 @@ class InferencePartialPayload(BaseModel):
2136
modelid: Optional[str] = None
2237

2338

39+
class KServeData(BaseModel):
40+
name: str
41+
shape: List[int]
42+
datatype: str
43+
parameters: Optional[Dict[str, str]] = None
44+
data: List
45+
46+
47+
class KServeInferenceRequest(BaseModel):
48+
id: Optional[str] = None
49+
parameters: Optional[Dict[str, str]] = None
50+
inputs: List[KServeData]
51+
outputs: Optional[List[KServeData]] = None
52+
53+
54+
class KServeInferenceResponse(BaseModel):
55+
model_name: str
56+
model_version: Optional[str] = None
57+
id: Optional[str] = None
58+
parameters: Optional[Dict[str, str]] = None
59+
outputs: List[KServeData]
60+
61+
2462
@router.post("/consumer/kserve/v2")
2563
async def consume_inference_payload(payload: InferencePartialPayload):
2664
"""Send a single input or output payload to TrustyAI."""
@@ -32,4 +70,131 @@ async def consume_inference_payload(payload: InferencePartialPayload):
3270
logger.error(f"Error processing inference payload: {str(e)}")
3371
raise HTTPException(
3472
status_code=500, detail=f"Error processing payload: {str(e)}"
35-
)
73+
)
74+
75+
76+
def reconcile_mismatching_shape_error(shape_tuples, payload_type, payload_id):
77+
msg = (f"Could not reconcile KServe Inference {payload_id}, because {payload_type} shapes were mismatched. "
78+
f"When using multiple {payload_type}s to describe data columns, all shapes must match."
79+
f"However, the following tensor shapes were found:")
80+
for i, (name, shape) in enumerate(shape_tuples):
81+
msg += f"\n{i}:\t{name}:\t{shape}"
82+
logger.error(msg)
83+
raise HTTPException(status_code=400, detail=msg)
84+
85+
86+
def reconcile_mismatching_row_count_error(payload_id, input_shape, output_shape):
87+
msg = (f"Could not reconcile KServe Inference {payload_id}, because the number of "
88+
f"output rows ({output_shape}) did not match the number of input rows "
89+
f"({input_shape}).")
90+
logger.error(msg)
91+
raise HTTPException(status_code=400, detail=msg)
92+
93+
94+
def process_payload(payload, get_data: Callable, enforced_first_shape: int = None):
95+
if len(get_data(payload)) > 1: # multi tensor case: we have ncols of data of shape [nrows]
96+
data = []
97+
shapes = set()
98+
shape_tuples = []
99+
column_names = []
100+
for kserve_data in get_data(payload):
101+
data.append(kserve_data.data)
102+
shapes.add(tuple(kserve_data.data.shape))
103+
column_names.append(kserve_data.name)
104+
shape_tuples.append((kserve_data.data.name, kserve_data.data.shape))
105+
if len(shapes) == 1:
106+
row_count = list(shapes)[0][0]
107+
if enforced_first_shape is not None and row_count != enforced_first_shape:
108+
reconcile_mismatching_row_count_error(payload.id, enforced_first_shape, row_count)
109+
if list_utils.contains_non_numeric(data):
110+
return np.array(data, dtype="O").T, column_names
111+
else:
112+
return np.array(data).T, column_names
113+
else:
114+
reconcile_mismatching_shape_error(
115+
shape_tuples,
116+
"input" if enforced_first_shape is None else "output",
117+
payload.id
118+
)
119+
else: # single tensor case: we have one tensor of shape [nrows, d1, d2, ...., dN]
120+
kserve_data: KServeData = get_data(payload)[0]
121+
if enforced_first_shape is not None and kserve_data.shape[0] != enforced_first_shape:
122+
reconcile_mismatching_row_count_error(payload.id, enforced_first_shape, kserve_data.shape[0])
123+
124+
if len(kserve_data.shape) > 1:
125+
column_names = ["{}-{}".format(kserve_data.name, i) for i in range(kserve_data.shape[1])]
126+
else:
127+
column_names = [kserve_data.name]
128+
if list_utils.contains_non_numeric(kserve_data.data):
129+
return np.array(kserve_data.data, dtype="O"), column_names
130+
else:
131+
return np.array(kserve_data.data), column_names
132+
133+
134+
async def reconcile(input_payload: KServeInferenceRequest, output_payload: KServeInferenceResponse):
135+
input_array, input_names = process_payload(input_payload, lambda p: p.inputs)
136+
output_array, output_names = process_payload(output_payload, lambda p: p.outputs, input_array.shape[0])
137+
138+
metadata_names = ["iso_time", "unix_timestamp", "tags"]
139+
if input_payload.parameters is not None and input_payload.parameters.get(BIAS_IGNORE_PARAM, "false") == "true":
140+
tags = [SYNTHETIC_TAG]
141+
else:
142+
tags = [UNLABELED_TAG]
143+
iso_time = datetime.isoformat(datetime.utcnow())
144+
unix_timestamp = time.time()
145+
metadata = np.array([[iso_time, unix_timestamp, tags]] * len(input_array), dtype="O")
146+
147+
input_dataset = output_payload.model_name + INPUT_SUFFIX
148+
output_dataset = output_payload.model_name + OUTPUT_SUFFIX
149+
metadata_dataset = output_payload.model_name + METADATA_SUFFIX
150+
151+
async with asyncio.TaskGroup() as tg:
152+
tg.create_task(storage_inferface.write_data(input_dataset, input_array, input_names))
153+
tg.create_task(storage_inferface.write_data(output_dataset, output_array, output_names))
154+
tg.create_task(storage_inferface.write_data(metadata_dataset, metadata, metadata_names))
155+
156+
shapes = await (ModelData(output_payload.model_name).shapes())
157+
logger.info(f"Successfully reconciled KServe inference {input_payload.id}, "
158+
f"consisting of {input_array.shape[0]:,} rows from {output_payload.model_name}.")
159+
logger.debug(f"Current storage shapes for {output_payload.model_name}: "
160+
f"Inputs={shapes[0]}, "
161+
f"Outputs={shapes[1]}, "
162+
f"Metadata={shapes[2]}")
163+
164+
165+
@router.post("/")
166+
async def consume_cloud_event(payload: Union[KServeInferenceRequest, KServeInferenceResponse],
167+
ce_id: Annotated[str | None, Header()] = None):
168+
# set payload if from cloud event header
169+
payload.id = ce_id
170+
171+
if isinstance(payload, KServeInferenceRequest):
172+
if len(payload.inputs) == 0:
173+
msg = f"KServe Inference Input {payload.id} received, but data field was empty. Payload will not be saved."
174+
logger.error(msg)
175+
raise HTTPException(status_code=400, detail=msg)
176+
else:
177+
logger.info(f"KServe Inference Input {payload.id} received.")
178+
# if a match is found, the payload is auto-deleted from data
179+
partial_output = await storage_inferface.get_partial_payload(payload.id, is_input=False)
180+
if partial_output is not None:
181+
await reconcile(payload, partial_output)
182+
else:
183+
await storage_inferface.persist_partial_payload(payload, is_input=True)
184+
return {"status": "success", "message": f"Input payload {payload.id} processed successfully"}
185+
186+
elif isinstance(payload, KServeInferenceResponse):
187+
if len(payload.outputs) == 0:
188+
msg = (f"KServe Inference Output {payload.id} received from model={payload.model_name}, "
189+
f"but data field was empty. Payload will not be saved.")
190+
logger.error(msg)
191+
raise HTTPException(status_code=400, detail=msg)
192+
else:
193+
logger.info(f"KServe Inference Output {payload.id} received from model={payload.model_name}.")
194+
partial_input = await storage_inferface.get_partial_payload(payload.id, is_input=True)
195+
if partial_input is not None:
196+
await reconcile(partial_input, payload)
197+
else:
198+
await storage_inferface.persist_partial_payload(payload, is_input=False)
199+
200+
return {"status": "success", "message": f"Output payload {payload.id} processed successfully"}

src/main.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
import os
2+
3+
import uvicorn
14
from fastapi import FastAPI, Request, Response
25
from fastapi.responses import JSONResponse
36
from prometheus_client import CONTENT_TYPE_LATEST, generate_latest
@@ -17,7 +20,7 @@
1720
from src.endpoints.data_download import router as data_download_router
1821

1922
logging.basicConfig(
20-
level=logging.INFO,
23+
level=logging.DEBUG,
2124
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
2225
)
2326
logger = logging.getLogger(__name__)
@@ -76,7 +79,7 @@ async def root():
7679
return {"message": "Welcome to TrustyAI Explainability Service"}
7780

7881

79-
@app.get("/p/metrics")
82+
@app.get("/q/metrics")
8083
async def metrics(request: Request):
8184
return Response(content=generate_latest(), media_type=CONTENT_TYPE_LATEST)
8285

@@ -91,3 +94,8 @@ async def readiness_probe():
9194
@app.get("/q/health/live")
9295
async def liveness_probe():
9396
return JSONResponse(content={"status": "live"}, status_code=200)
97+
98+
99+
if __name__ == "__main__":
100+
# SERVICE_STORAGE_FORMAT=PVC; STORAGE_DATA_FOLDER=/tmp; STORAGE_DATA_FILENAME=trustyai_test.hdf5
101+
uvicorn.run(app=app, host="0.0.0.0", port=8080)

src/service/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
# Place all external packages and service algorithms here
2+
# The main REST server should only call
3+
# - modules directly related to the service HTTP, data, prometheus, etc.
4+
# - This service module

src/service/constants.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
# data constants
2+
INPUT_SUFFIX = "_inputs"
3+
OUTPUT_SUFFIX = "_outputs"
4+
METADATA_SUFFIX = "_metadata"
5+
PROTECTED_DATASET_SUFFIX = "trustyai_internal_"
6+
PARTIAL_PAYLOAD_DATASET_NAME = "partial_payloads"
7+
8+
# Payload parsing
9+
TRUSTYAI_TAG_PREFIX = "_trustyai"
10+
SYNTHETIC_TAG = TRUSTYAI_TAG_PREFIX + "_synthetic"
11+
UNLABELED_TAG = TRUSTYAI_TAG_PREFIX + "_unlabeled"
12+
BIAS_IGNORE_PARAM = "bias-ignore"

src/service/data/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)