diff --git a/pyproject.toml b/pyproject.toml index 271a2c7..8fd7bc5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,11 +31,10 @@ dev = [ "pytest-cov>=4.1.0,<7", "httpx>=0.25.0,<0.29", ] -eval = ["lm-eval[api]==0.4.4", "fastapi-utils>=0.8.0", "typing-inspect==0.9.0"] +eval = ["lm-eval[api]==0.4.4", "typing-inspect==0.9.0"] protobuf = ["numpy>=1.24.0,<3", "grpcio>=1.62.1,<2", "grpcio-tools>=1.62.1,<2"] mariadb = ["mariadb>=1.1.12", "javaobj-py3==0.4.4"] - [tool.hatch.build.targets.sdist] include = ["src"] diff --git a/scripts/test_upload_endpoint.sh b/scripts/test_upload_endpoint.sh new file mode 100644 index 0000000..ee814d1 --- /dev/null +++ b/scripts/test_upload_endpoint.sh @@ -0,0 +1,203 @@ +#!/usr/bin/env bash +# scripts/test_upload_endpoint.sh +# +# KServe-strict endpoint test for /data/upload +# +# +# Usage: +# ENDPOINT="https:///data/upload" \ +# MODEL="gaussian-credit-model" \ +# TAG="TRAINING" \ +# ./scripts/test_upload_endpoint.sh + +set -uo pipefail + +# --- Config via env vars (no secrets hardcoded) --- +: "${ENDPOINT:?ENDPOINT is required, e.g. https://.../data/upload}" +MODEL="${MODEL:-gaussian-credit-model}" +# Separate model for BYTES to avoid mixing with an existing numeric dataset +MODEL_BYTES="${MODEL_BYTES:-${MODEL}-bytes}" +TAG="${TAG:-TRAINING}" +AUTH_HEADER="${AUTH_HEADER:-}" # e.g. 'Authorization: Bearer ' + +CURL_OPTS=( --silent --show-error -H "Content-Type: application/json" ) +[[ -n "$AUTH_HEADER" ]] && CURL_OPTS+=( -H "$AUTH_HEADER" ) + +RED=$'\033[31m'; GREEN=$'\033[32m'; YELLOW=$'\033[33m'; CYAN=$'\033[36m'; RESET=$'\033[0m' +pass_cnt=0; fail_cnt=0; results=() +have_jq=1; command -v jq >/dev/null 2>&1 || have_jq=0 + +line(){ printf '%s\n' "--------------------------------------------------------------------------------"; } +snippet(){ if (( have_jq )); then echo "$1" | jq -r 'tostring' 2>/dev/null | head -c 240; else echo "$1" | head -c 240; fi; } + +# ---------- payload builders ---------- +mk_inputs_2x4_int32() { + cat < uses MODEL + local req="$1" out="$2" + cat < 0 )); then + echo "${YELLOW}Details for failures:${RESET}" + for r in "${results[@]}"; do + IFS='|' read -r status name http body <<<"$r" + if [[ "$status" == "FAIL" ]]; then + printf "%s[FAIL]%s %s (HTTP %s)\n" "$RED" "$RESET" "$name" "$http" + printf " body: %s\n" "$body" + line + fi + done +fi \ No newline at end of file diff --git a/src/endpoints/consumer/__init__.py b/src/endpoints/consumer/__init__.py new file mode 100644 index 0000000..2c4c891 --- /dev/null +++ b/src/endpoints/consumer/__init__.py @@ -0,0 +1,156 @@ +from typing import Optional, Dict, List, Literal, Any +from enum import Enum +from pydantic import BaseModel, model_validator, ConfigDict +import numpy as np + + +PartialKind = Literal["request", "response"] + +class PartialPayloadId(BaseModel): + prediction_id: Optional[str] = None + kind: Optional[PartialKind] = None + + def get_prediction_id(self) -> str: + return self.prediction_id + + def set_prediction_id(self, id: str): + self.prediction_id = id + + def get_kind(self) -> PartialKind: + return self.kind + + def set_kind(self, kind: PartialKind): + self.kind = kind + + +class InferencePartialPayload(BaseModel): + partialPayloadId: Optional[PartialPayloadId] = None + metadata: Optional[Dict[str, str]] = {} + data: Optional[str] = None + modelid: Optional[str] = None + + def get_id(self) -> str: + return self.partialPayloadId.prediction_id if self.partialPayloadId else None + + def set_id(self, id: str): + if not self.partialPayloadId: + self.partialPayloadId = PartialPayloadId() + self.partialPayloadId.prediction_id = id + + def get_kind(self) -> PartialKind: + return self.partialPayloadId.kind if self.partialPayloadId else None + + def set_kind(self, kind: PartialKind): + if not self.partialPayloadId: + self.partialPayloadId = PartialPayloadId() + self.partialPayloadId.kind = kind + + def get_model_id(self) -> str: + return self.modelid + + def set_model_id(self, model_id: str): + self.modelid = model_id + + +class KServeDataType(str, Enum): + BOOL = "BOOL" + INT8 = "INT8" + INT16 = "INT16" + INT32 = "INT32" + INT64 = "INT64" + UINT8 = "UINT8" + UINT16 = "UINT16" + UINT32 = "UINT32" + UINT64 = "UINT64" + FP16 = "FP16" + FP32 = "FP32" + FP64 = "FP64" + BYTES = "BYTES" + +K_SERVE_NUMPY_DTYPES = { + KServeDataType.INT8: np.int8, + KServeDataType.INT16: np.int16, + KServeDataType.INT32: np.int32, + KServeDataType.INT64: np.int64, + KServeDataType.UINT8: np.uint8, + KServeDataType.UINT16: np.uint16, + KServeDataType.UINT32: np.uint32, + KServeDataType.UINT64: np.uint64, + KServeDataType.FP16: np.float16, + KServeDataType.FP32: np.float32, + KServeDataType.FP64: np.float64, +} + +class KServeData(BaseModel): + + model_config = ConfigDict(use_enum_values=True) + + name: str + shape: List[int] + datatype: KServeDataType + parameters: Optional[Dict[str, str]] = None + data: List[Any] + + @model_validator(mode="after") + def _validate_shape(self) -> "KServeData": + raw = np.array(self.data, dtype=object) + actual = tuple(raw.shape) + declared = tuple(self.shape) + if declared != actual: + raise ValueError( + f"Declared shape {declared} does not match data shape {actual}" + ) + return self + + @model_validator(mode="after") + def validate_data_matches_type(self) -> "KServeData": + flat = np.array(self.data, dtype=object).flatten() + + if self.datatype == KServeDataType.BYTES: + for v in flat: + if not isinstance(v, str): + raise ValueError( + f"All values must be JSON strings for datatype {self.datatype}; " + f"found {type(v).__name__}: {v}" + ) + return self + + if self.datatype == KServeDataType.BOOL: + for v in flat: + if not (isinstance(v, (bool, int)) and v in (0, 1, True, False)): + raise ValueError( + f"All values must be bool or 0/1 for datatype {self.datatype}; found {v}" + ) + return self + + np_dtype = K_SERVE_NUMPY_DTYPES.get(self.datatype) + if np_dtype is None: + raise ValueError(f"Unsupported datatype: {self.datatype}") + + if np.dtype(np_dtype).kind == "u": + for v in flat: + if isinstance(v, (int, float)) and v < 0: + raise ValueError( + f"Negative value {v} not allowed for unsigned type {self.datatype}" + ) + + try: + np.array(flat, dtype=np_dtype) + except (ValueError, TypeError) as e: + raise ValueError(f"Data cannot be cast to {self.datatype}: {e}") + + return self + +class KServeInferenceRequest(BaseModel): + id: Optional[str] = None + parameters: Optional[Dict[str, str]] = None + inputs: List[KServeData] + outputs: Optional[List[KServeData]] = None + + +class KServeInferenceResponse(BaseModel): + model_name: str = None + model_version: Optional[str] = None + id: Optional[str] = None + parameters: Optional[Dict[str, str]] = None + outputs: List[KServeData] diff --git a/src/endpoints/consumer/consumer_endpoint.py b/src/endpoints/consumer/consumer_endpoint.py index 1f17647..c0e4a84 100644 --- a/src/endpoints/consumer/consumer_endpoint.py +++ b/src/endpoints/consumer/consumer_endpoint.py @@ -1,17 +1,17 @@ # endpoints/consumer.py import asyncio import time -from datetime import datetime +from datetime import datetime, timezone import numpy as np from fastapi import APIRouter, HTTPException, Header -from pydantic import BaseModel -from typing import Dict, Optional, Literal, List, Union, Callable, Annotated +from typing import Literal, Union, Callable, Annotated import logging +from src.endpoints.consumer import InferencePartialPayload, KServeData, KServeInferenceRequest, KServeInferenceResponse # Import local dependencies from src.service.data.model_data import ModelData -from src.service.data.storage import get_storage_interface +from src.service.data.storage import get_global_storage_interface from src.service.utils import list_utils from src.service.data.modelmesh_parser import ModelMeshPayloadParser, PartialPayload @@ -26,81 +26,10 @@ router = APIRouter() logger = logging.getLogger(__name__) -PartialKind = Literal["request", "response"] -storage_interface = get_storage_interface() unreconciled_inputs = {} unreconciled_outputs = {} -class PartialPayloadId(BaseModel): - prediction_id: Optional[str] = None - kind: Optional[PartialKind] = None - - def get_prediction_id(self) -> str: - return self.prediction_id - - def set_prediction_id(self, id: str): - self.prediction_id = id - - def get_kind(self) -> PartialKind: - return self.kind - - def set_kind(self, kind: PartialKind): - self.kind = kind - - -class InferencePartialPayload(BaseModel): - partialPayloadId: Optional[PartialPayloadId] = None - metadata: Optional[Dict[str, str]] = {} - data: Optional[str] = None - modelid: Optional[str] = None - - def get_id(self) -> str: - return self.partialPayloadId.prediction_id if self.partialPayloadId else None - - def set_id(self, id: str): - if not self.partialPayloadId: - self.partialPayloadId = PartialPayloadId() - self.partialPayloadId.prediction_id = id - - def get_kind(self) -> PartialKind: - return self.partialPayloadId.kind if self.partialPayloadId else None - - def set_kind(self, kind: PartialKind): - if not self.partialPayloadId: - self.partialPayloadId = PartialPayloadId() - self.partialPayloadId.kind = kind - - def get_model_id(self) -> str: - return self.modelid - - def set_model_id(self, model_id: str): - self.modelid = model_id - - -class KServeData(BaseModel): - name: str - shape: List[int] - datatype: str - parameters: Optional[Dict[str, str]] = None - data: List - - -class KServeInferenceRequest(BaseModel): - id: Optional[str] = None - parameters: Optional[Dict[str, str]] = None - inputs: List[KServeData] - outputs: Optional[List[KServeData]] = None - - -class KServeInferenceResponse(BaseModel): - model_name: str - model_version: Optional[str] = None - id: Optional[str] = None - parameters: Optional[Dict[str, str]] = None - outputs: List[KServeData] - - @router.post("/consumer/kserve/v2") async def consume_inference_payload( payload: InferencePartialPayload, @@ -118,6 +47,8 @@ async def consume_inference_payload( Returns: A JSON response indicating success or failure """ + storage_interface = get_global_storage_interface() + try: if not payload.modelid: raise HTTPException( @@ -144,7 +75,6 @@ async def consume_inference_payload( ) partial_payload = PartialPayload(data=payload.data) - if payload_kind == "request": logger.info( f"Received partial input payload from model={model_id}, id={payload_id}" @@ -160,12 +90,12 @@ async def consume_inference_payload( ) from e # Store the input payload - await storage_interface.persist_modelmesh_payload( - partial_payload, payload_id, is_input + await storage_interface.persist_partial_payload( + partial_payload, payload_id=payload_id, is_input=is_input ) - output_payload = await storage_interface.get_modelmesh_payload( - payload_id, False + output_payload = await storage_interface.get_partial_payload( + payload_id, is_input=False, is_modelmesh=True ) if output_payload: @@ -188,12 +118,12 @@ async def consume_inference_payload( ) from e # Store the output payload - await storage_interface.persist_modelmesh_payload( - partial_payload, payload_id, is_input + await storage_interface.persist_partial_payload( + payload=partial_payload, payload_id=payload_id, is_input=is_input ) - input_payload = await storage_interface.get_modelmesh_payload( - payload_id, True + input_payload = await storage_interface.get_partial_payload( + payload_id, is_input=True, is_modelmesh=True ) if input_payload: @@ -221,6 +151,47 @@ async def consume_inference_payload( ) from e +async def write_reconciled_data( + input_array, input_names, + output_array, output_names, + model_id, tags, id_): + storage_interface = get_global_storage_interface() + + iso_time = datetime.now(timezone.utc).isoformat() + unix_timestamp = time.time() + metadata = np.array( + [[None, iso_time, unix_timestamp, tags]] * len(input_array), dtype="O" + ) + metadata[:, 0] = [f"{id_}_{i}" for i in range(len(input_array))] + metadata_names = ["id", "iso_time", "unix_timestamp", "tags"] + + input_dataset = model_id + INPUT_SUFFIX + output_dataset = model_id + OUTPUT_SUFFIX + metadata_dataset = model_id + METADATA_SUFFIX + + await asyncio.gather( + storage_interface.write_data(input_dataset, input_array, input_names), + storage_interface.write_data(output_dataset, output_array, output_names), + storage_interface.write_data(metadata_dataset, metadata, metadata_names), + ) + + shapes = await ModelData(model_id).shapes() + logger.info( + f"Successfully reconciled inference {id_}, " + f"consisting of {len(input_array):,} rows from {model_id}." + ) + logger.debug( + f"Current storage shapes for {model_id}: " + f"Inputs={shapes[0]}, " + f"Outputs={shapes[1]}, " + f"Metadata={shapes[2]}" + ) + + # Clean up + await storage_interface.delete_partial_payload(id_, True) + await storage_interface.delete_partial_payload(id_, False) + + async def reconcile_modelmesh_payloads( input_payload: PartialPayload, output_payload: PartialPayload, @@ -241,46 +212,41 @@ async def reconcile_modelmesh_payloads( # Create metadata array tags = [SYNTHETIC_TAG] if any(df["synthetic"]) else [UNLABELED_TAG] - iso_time = datetime.isoformat(datetime.utcnow()) - unix_timestamp = time.time() - metadata = np.array([[iso_time, unix_timestamp, tags]] * len(df), dtype="O") - # Get dataset names - input_dataset = model_id + INPUT_SUFFIX - output_dataset = model_id + OUTPUT_SUFFIX - metadata_dataset = model_id + METADATA_SUFFIX + await write_reconciled_data( + df[input_cols].values, input_cols, + df[output_cols].values, output_cols, + model_id=model_id, tags=tags, id_=request_id + ) - metadata_cols = ["iso_time", "unix_timestamp", "tags"] - await asyncio.gather( - storage_interface.write_data(input_dataset, df[input_cols].values, input_cols), - storage_interface.write_data( - output_dataset, df[output_cols].values, output_cols - ), - storage_interface.write_data(metadata_dataset, metadata, metadata_cols), - ) - shapes = await ModelData(model_id).shapes() - logger.info( - f"Successfully reconciled ModelMesh inference {request_id}, " - f"consisting of {len(df):,} rows from {model_id}." - ) - logger.debug( - f"Current storage shapes for {model_id}: " - f"Inputs={shapes[0]}, " - f"Outputs={shapes[1]}, " - f"Metadata={shapes[2]}" +async def reconcile_kserve( + input_payload: KServeInferenceRequest, output_payload: KServeInferenceResponse, tag: str): + input_array, input_names = process_payload(input_payload, lambda p: p.inputs) + output_array, output_names = process_payload( + output_payload, lambda p: p.outputs, input_array.shape[0] ) - # Clean up - await storage_interface.delete_modelmesh_payload(request_id, True) - await storage_interface.delete_modelmesh_payload(request_id, False) + if tag is not None: + tags = [tag] + elif (input_payload.parameters is not None and + input_payload.parameters.get(BIAS_IGNORE_PARAM, "false") == "true"): + tags = [SYNTHETIC_TAG] + else: + tags = [UNLABELED_TAG] + + await write_reconciled_data( + input_array, input_names, + output_array, output_names, + model_id=output_payload.model_name, tags=tags, id_=input_payload.id + ) def reconcile_mismatching_shape_error(shape_tuples, payload_type, payload_id): msg = ( - f"Could not reconcile KServe Inference {payload_id}, because {payload_type} shapes were mismatched. " - f"When using multiple {payload_type}s to describe data columns, all shapes must match." + f"Could not reconcile_kserve KServe Inference {payload_id}, because {payload_type} shapes were mismatched. " + f"When using multiple {payload_type}s to describe data columns, all shapes must match. " f"However, the following tensor shapes were found:" ) for i, (name, shape) in enumerate(shape_tuples): @@ -291,7 +257,7 @@ def reconcile_mismatching_shape_error(shape_tuples, payload_type, payload_id): def reconcile_mismatching_row_count_error(payload_id, input_shape, output_shape): msg = ( - f"Could not reconcile KServe Inference {payload_id}, because the number of " + f"Could not reconcile_kserve KServe Inference {payload_id}, because the number of " f"output rows ({output_shape}) did not match the number of input rows " f"({input_shape})." ) @@ -309,9 +275,9 @@ def process_payload(payload, get_data: Callable, enforced_first_shape: int = Non column_names = [] for kserve_data in get_data(payload): data.append(kserve_data.data) - shapes.add(tuple(kserve_data.data.shape)) + shapes.add(tuple(kserve_data.shape)) column_names.append(kserve_data.name) - shape_tuples.append((kserve_data.data.name, kserve_data.data.shape)) + shape_tuples.append((kserve_data.name, kserve_data.shape)) if len(shapes) == 1: row_count = list(shapes)[0][0] if enforced_first_shape is not None and row_count != enforced_first_shape: @@ -350,59 +316,18 @@ def process_payload(payload, get_data: Callable, enforced_first_shape: int = Non return np.array(kserve_data.data), column_names -async def reconcile( - input_payload: KServeInferenceRequest, output_payload: KServeInferenceResponse -): - input_array, input_names = process_payload(input_payload, lambda p: p.inputs) - output_array, output_names = process_payload( - output_payload, lambda p: p.outputs, input_array.shape[0] - ) - - metadata_names = ["iso_time", "unix_timestamp", "tags"] - if ( - input_payload.parameters is not None - and input_payload.parameters.get(BIAS_IGNORE_PARAM, "false") == "true" - ): - tags = [SYNTHETIC_TAG] - else: - tags = [UNLABELED_TAG] - iso_time = datetime.isoformat(datetime.utcnow()) - unix_timestamp = time.time() - metadata = np.array( - [[iso_time, unix_timestamp, tags]] * len(input_array), dtype="O" - ) - - input_dataset = output_payload.model_name + INPUT_SUFFIX - output_dataset = output_payload.model_name + OUTPUT_SUFFIX - metadata_dataset = output_payload.model_name + METADATA_SUFFIX - - await asyncio.gather( - storage_interface.write_data(input_dataset, input_array, input_names), - storage_interface.write_data(output_dataset, output_array, output_names), - storage_interface.write_data(metadata_dataset, metadata, metadata_names), - ) - - shapes = await ModelData(output_payload.model_name).shapes() - logger.info( - f"Successfully reconciled KServe inference {input_payload.id}, " - f"consisting of {input_array.shape[0]:,} rows from {output_payload.model_name}." - ) - logger.debug( - f"Current storage shapes for {output_payload.model_name}: " - f"Inputs={shapes[0]}, " - f"Outputs={shapes[1]}, " - f"Metadata={shapes[2]}" - ) - - @router.post("/") async def consume_cloud_event( payload: Union[KServeInferenceRequest, KServeInferenceResponse], ce_id: Annotated[str | None, Header()] = None, + tag: str = None ): # set payload if from cloud event header payload.id = ce_id + # get global storage interface + storage_interface = get_global_storage_interface() + if isinstance(payload, KServeInferenceRequest): if len(payload.inputs) == 0: msg = f"KServe Inference Input {payload.id} received, but data field was empty. Payload will not be saved." @@ -412,12 +337,12 @@ async def consume_cloud_event( logger.info(f"KServe Inference Input {payload.id} received.") # if a match is found, the payload is auto-deleted from data partial_output = await storage_interface.get_partial_payload( - payload.id, is_input=False + payload.id, is_input=False, is_modelmesh=False ) if partial_output is not None: - await reconcile(payload, partial_output) + await reconcile_kserve(payload, partial_output, tag) else: - await storage_interface.persist_partial_payload(payload, is_input=True) + await storage_interface.persist_partial_payload(payload, payload_id=payload.id, is_input=True) return { "status": "success", "message": f"Input payload {payload.id} processed successfully", @@ -436,12 +361,12 @@ async def consume_cloud_event( f"KServe Inference Output {payload.id} received from model={payload.model_name}." ) partial_input = await storage_interface.get_partial_payload( - payload.id, is_input=True + payload.id, is_input=True, is_modelmesh=False ) if partial_input is not None: - await reconcile(partial_input, payload) + await reconcile_kserve(partial_input, payload, tag) else: - await storage_interface.persist_partial_payload(payload, is_input=False) + await storage_interface.persist_partial_payload(payload, payload_id=payload.id, is_input=False) return { "status": "success", diff --git a/src/endpoints/data/data_download.py b/src/endpoints/data/data_download.py deleted file mode 100644 index eb65111..0000000 --- a/src/endpoints/data/data_download.py +++ /dev/null @@ -1,32 +0,0 @@ -from fastapi import APIRouter, HTTPException -from pydantic import BaseModel -from typing import List, Any, Optional -import logging - -router = APIRouter() -logger = logging.getLogger(__name__) - - -class RowMatcher(BaseModel): - columnName: str - operation: str - values: List[Any] - - -class DataRequestPayload(BaseModel): - modelId: str - matchAny: Optional[List[RowMatcher]] = None - matchAll: Optional[List[RowMatcher]] = None - matchNone: Optional[List[RowMatcher]] = None - - -@router.post("/data/download") -async def download_data(payload: DataRequestPayload): - """Download model data.""" - try: - logger.info(f"Received data download request for model: {payload.modelId}") - # TODO: Implement - return {"status": "success", "data": []} - except Exception as e: - logger.error(f"Error downloading data: {str(e)}") - raise HTTPException(status_code=500, detail=f"Error downloading data: {str(e)}") diff --git a/src/endpoints/data/data_upload.py b/src/endpoints/data/data_upload.py index 49889db..845012a 100644 --- a/src/endpoints/data/data_upload.py +++ b/src/endpoints/data/data_upload.py @@ -1,27 +1,83 @@ +import logging +from typing import Dict, Optional + +import uuid from fastapi import APIRouter, HTTPException from pydantic import BaseModel -from typing import Dict, Any -import logging + +from src.endpoints.consumer.consumer_endpoint import consume_cloud_event +from src.endpoints.consumer import KServeInferenceRequest, KServeInferenceResponse, KServeData +from src.service.constants import TRUSTYAI_TAG_PREFIX +from src.service.data.model_data import ModelData + router = APIRouter() logger = logging.getLogger(__name__) -class ModelInferJointPayload(BaseModel): +class UploadPayload(BaseModel): model_name: str - data_tag: str = None + data_tag: Optional[str] = None is_ground_truth: bool = False - request: Dict[str, Any] - response: Dict[str, Any] + request: KServeInferenceRequest + response: KServeInferenceResponse + + +def validate_data_tag(tag: str) -> Optional[str]: + """Validate data tag format and content.""" + if not tag: + return None + if tag.startswith(TRUSTYAI_TAG_PREFIX): + return ( + f"The tag prefix '{TRUSTYAI_TAG_PREFIX}' is reserved for internal TrustyAI use only. " + f"Provided tag '{tag}' violates this restriction." + ) + return None + @router.post("/data/upload") -async def upload_data(payload: ModelInferJointPayload): - """Upload a batch of model data to TrustyAI.""" +async def upload(payload: UploadPayload) -> Dict[str, str]: + """Upload model data""" + + # validate tag + tag_validation_msg = validate_data_tag(payload.data_tag) + if tag_validation_msg: + raise HTTPException(status_code=400, detail=tag_validation_msg) try: - logger.info(f"Received data upload for model: {payload.model_name}") - # TODO: Implement - return {"status": "success", "message": "Data uploaded successfully"} + logger.info(f"Received upload request for model: {payload.model_name}") + + # overwrite response model name with provided model name + if payload.response.model_name != payload.model_name: + logger.warning(f"Response model name '{payload.response.model_name}' differs from request model name '{payload.model_name}'. Using '{payload.model_name}'.") + payload.response.model_name = payload.model_name + + req_id = str(uuid.uuid4()) + model_data = ModelData(payload.model_name) + datasets_exist = await model_data.datasets_exist() + + if all(datasets_exist): + previous_data_points = (await model_data.row_counts())[0] + else: + previous_data_points = 0 + + await consume_cloud_event(payload.response, req_id) + await consume_cloud_event(payload.request, req_id, tag=payload.data_tag) + + model_data = ModelData(payload.model_name) + new_data_points = (await model_data.row_counts())[0] + + logger.info(f"Upload completed for model: {payload.model_name}") + + return { + "status": "success", + "message": f"{new_data_points-previous_data_points} datapoints successfully added to {payload.model_name} data." + } + + except HTTPException as e: + if "Could not reconcile_kserve KServe Inference" in str(e): + raise HTTPException(status_code=400, detail=f"Could not upload payload for model {payload.model_name}: {str(e)}") from e + raise e except Exception as e: - logger.error(f"Error uploading data: {str(e)}") - raise HTTPException(status_code=500, detail=f"Error uploading data: {str(e)}") + logger.error(f"Unexpected error in upload endpoint for model {payload.model_name}: {str(e)}", exc_info=True) + raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}") diff --git a/src/main.py b/src/main.py index b569b81..4efba0e 100644 --- a/src/main.py +++ b/src/main.py @@ -11,7 +11,6 @@ # Endpoint routers from src.endpoints.consumer.consumer_endpoint import router as consumer_router -from src.endpoints.data.data_download import router as data_download_router from src.endpoints.data.data_upload import router as data_upload_router # from src.endpoints.explainers import router as explainers_router @@ -139,7 +138,6 @@ async def lifespan(app: FastAPI): app.include_router(identity_router, tags=["Identity Endpoint"]) app.include_router(metadata_router, tags=["Service Metadata"]) app.include_router(metrics_info_router, tags=["Metrics Information Endpoint"]) -app.include_router(data_download_router, tags=["Download Endpoint"]) if lm_evaluation_harness_available: app.include_router( @@ -179,6 +177,8 @@ async def liveness_probe(): return JSONResponse(content={"status": "live"}, status_code=200) + + if __name__ == "__main__": # SERVICE_STORAGE_FORMAT=PVC; STORAGE_DATA_FOLDER=/tmp; STORAGE_DATA_FILENAME=trustyai_test.hdf5 uvicorn.run(app=app, host="0.0.0.0", port=8080) diff --git a/src/service/constants.py b/src/service/constants.py index 517e700..5138691 100644 --- a/src/service/constants.py +++ b/src/service/constants.py @@ -2,15 +2,18 @@ INPUT_SUFFIX = "_inputs" OUTPUT_SUFFIX = "_outputs" METADATA_SUFFIX = "_metadata" +GROUND_TRUTH_SUFFIX = "_ground_truth" PROTECTED_DATASET_SUFFIX = "trustyai_internal_" PARTIAL_PAYLOAD_DATASET_NAME = "partial_payloads" GROUND_TRUTH_SUFFIX = "-ground-truths" METADATA_FILENAME = "metadata.json" INTERNAL_DATA_FILENAME = "internal_data.csv" + # Payload parsing TRUSTYAI_TAG_PREFIX = "_trustyai" SYNTHETIC_TAG = f"{TRUSTYAI_TAG_PREFIX}_synthetic" UNLABELED_TAG = f"{TRUSTYAI_TAG_PREFIX}_unlabeled" BIAS_IGNORE_PARAM = "bias-ignore" + # Prometheus constants PROMETHEUS_METRIC_PREFIX = "trustyai_" diff --git a/src/service/data/model_data.py b/src/service/data/model_data.py index 9bbb423..bc813b0 100644 --- a/src/service/data/model_data.py +++ b/src/service/data/model_data.py @@ -1,12 +1,13 @@ +import logging from typing import List, Optional import numpy as np +import pandas as pd -from src.service.data.storage import get_storage_interface from src.service.constants import * +from src.service.data.storage import get_global_storage_interface -storage_interface = get_storage_interface() - +logger = logging.getLogger(__name__) class ModelDataContainer: def __init__(self, model_name: str, input_data: np.ndarray, input_names: List[str], output_data: np.ndarray, @@ -33,10 +34,30 @@ def __init__(self, model_name): self.output_dataset = self.model_name+OUTPUT_SUFFIX self.metadata_dataset = self.model_name+METADATA_SUFFIX + async def datasets_exist(self) -> tuple[bool, bool, bool]: + """ + Checks if the requested model exists + """ + storage_interface = get_global_storage_interface() + input_exists = await storage_interface.dataset_exists(self.input_dataset) + output_exists = await storage_interface.dataset_exists(self.output_dataset) + metadata_exists = await storage_interface.dataset_exists(self.metadata_dataset) + + # warn if we're missing one of the expected datasets + dataset_checks = (input_exists, output_exists, metadata_exists) + if not all(dataset_checks): + expected_datasets = [self.input_dataset, self.output_dataset, self.metadata_dataset] + missing_datasets = [dataset for idx, dataset in enumerate(expected_datasets) if not dataset_checks[idx]] + logger.warning(f"Not all datasets present for model {self.model_name}: missing {missing_datasets}. This could be indicative of storage corruption or" + f"improper saving of previous model data.") + return dataset_checks + + async def row_counts(self) -> tuple[int, int, int]: """ Get the number of input, output, and metadata rows that exist in a model dataset """ + storage_interface = get_global_storage_interface() input_rows = await storage_interface.dataset_rows(self.input_dataset) output_rows = await storage_interface.dataset_rows(self.output_dataset) metadata_rows = await storage_interface.dataset_rows(self.metadata_dataset) @@ -46,28 +67,30 @@ async def shapes(self) -> tuple[List[int], List[int], List[int]]: """ Get the shapes of the input, output, and metadata datasets that exist in a model dataset """ + storage_interface = get_global_storage_interface() input_shape = await storage_interface.dataset_shape(self.input_dataset) output_shape = await storage_interface.dataset_shape(self.output_dataset) metadata_shape = await storage_interface.dataset_shape(self.metadata_dataset) return input_shape, output_shape, metadata_shape async def column_names(self) -> tuple[List[str], List[str], List[str]]: + storage_interface = get_global_storage_interface() input_names = await storage_interface.get_aliased_column_names(self.input_dataset) output_names = await storage_interface.get_aliased_column_names(self.output_dataset) - # these can't be aliased metadata_names = await storage_interface.get_original_column_names(self.metadata_dataset) return input_names, output_names, metadata_names async def original_column_names(self) -> tuple[List[str], List[str], List[str]]: + storage_interface = get_global_storage_interface() input_names = await storage_interface.get_original_column_names(self.input_dataset) output_names = await storage_interface.get_original_column_names(self.output_dataset) metadata_names = await storage_interface.get_original_column_names(self.metadata_dataset) return input_names, output_names, metadata_names - async def data(self, start_row=None, n_rows=None, get_input=True, get_output=True, get_metadata=True) \ + async def data(self, start_row=0, n_rows=None, get_input=True, get_output=True, get_metadata=True) \ -> tuple[Optional[np.ndarray], Optional[np.ndarray], Optional[np.ndarray]]: """ Get data from a saved model @@ -78,6 +101,8 @@ async def data(self, start_row=None, n_rows=None, get_input=True, get_output=Tru * get_output: whether to retrieve output data <- use this to reduce file reads * get_metadata: whether to retrieve metadata <- use this to reduce file reads """ + storage_interface = get_global_storage_interface() + if get_input: input_data = await storage_interface.read_data(self.input_dataset, start_row, n_rows) else: @@ -93,6 +118,13 @@ async def data(self, start_row=None, n_rows=None, get_input=True, get_output=Tru return input_data, output_data, metadata + async def get_metadata_as_df(self): + _, _, metadata = await self.data(get_input=False, get_output=False) + metadata_cols = (await self.column_names())[2] + return pd.DataFrame(metadata, columns=metadata_cols) + + + async def summary_string(self): out = f"=== {self.model_name} Data ===" diff --git a/src/service/data/storage/__init__.py b/src/service/data/storage/__init__.py index d914fad..869cd48 100644 --- a/src/service/data/storage/__init__.py +++ b/src/service/data/storage/__init__.py @@ -1,9 +1,21 @@ +import asyncio import os from src.service.data.storage.maria.legacy_maria_reader import LegacyMariaDBStorageReader from src.service.data.storage.maria.maria import MariaDBStorage from src.service.data.storage.pvc import PVCStorage +global_storage_interface = None +_storage_lock = asyncio.Lock() + +def get_global_storage_interface(force_reload=False): + global global_storage_interface + + if global_storage_interface is None or force_reload: + global_storage_interface = get_storage_interface() + return global_storage_interface + + def get_storage_interface(): storage_format = os.environ.get("SERVICE_STORAGE_FORMAT", "PVC") if storage_format == "PVC": @@ -15,7 +27,7 @@ def get_storage_interface(): host=os.environ.get("DATABASE_HOST"), port=int(os.environ.get("DATABASE_PORT")), database=os.environ.get("DATABASE_DATABASE"), - attempt_migration=True + attempt_migration=bool(int((os.environ.get("DATABASE_ATTEMPT_MIGRATION", "0")))), ) else: raise ValueError(f"Storage format={storage_format} not yet supported by the Python implementation of the service.") diff --git a/src/service/data/storage/maria/legacy_maria_reader.py b/src/service/data/storage/maria/legacy_maria_reader.py index e944fb3..c9bc3df 100644 --- a/src/service/data/storage/maria/legacy_maria_reader.py +++ b/src/service/data/storage/maria/legacy_maria_reader.py @@ -1,4 +1,3 @@ -import asyncio import javaobj import logging import pandas as pd @@ -174,13 +173,15 @@ async def migrate_data(self, new_maria_storage: StorageInterface): output_df.columns.values) if not input_has_migrated: - await new_maria_storage.write_data(input_dataset, input_df.to_numpy(), list(input_df.columns.values)) - new_maria_storage.apply_name_mapping(input_dataset, input_mapping) + await new_maria_storage.write_data(input_dataset, input_df.to_numpy(), input_df.columns.to_list()) + await new_maria_storage.apply_name_mapping(input_dataset, input_mapping) if not output_has_migrated: - await new_maria_storage.write_data(output_dataset, output_df.to_numpy(), list(output_df.columns.values)) - new_maria_storage.apply_name_mapping(output_dataset, output_mapping) + await new_maria_storage.write_data(output_dataset, output_df.to_numpy(), + output_df.columns.to_list()) + await new_maria_storage.apply_name_mapping(output_dataset, output_mapping) if not metadata_has_migrated: - await new_maria_storage.write_data(metadata_dataset, metadata_df.to_numpy(), list(metadata_df.columns.values)) + await new_maria_storage.write_data(metadata_dataset, metadata_df.to_numpy(), + metadata_df.columns.to_list()) migrations.append(True) logger.info(f"Dataset {dataset_name} successfully migrated.") diff --git a/src/service/data/storage/maria/maria.py b/src/service/data/storage/maria/maria.py index 51c665c..0bbdcc6 100644 --- a/src/service/data/storage/maria/maria.py +++ b/src/service/data/storage/maria/maria.py @@ -2,11 +2,13 @@ import io import json import logging + import mariadb import numpy as np import pickle as pkl -from typing import Optional, Dict, List +from typing import Optional, Dict, List, Union +from src.endpoints.consumer import KServeInferenceRequest, KServeInferenceResponse from src.service.data.modelmesh_parser import PartialPayload from src.service.data.storage import LegacyMariaDBStorageReader from src.service.data.storage.maria.utils import MariaConnectionManager, require_existing_dataset, \ @@ -65,14 +67,17 @@ def __init__(self, user: str, password: str, host: str, port: int, database: str cursor.execute(f"CREATE TABLE IF NOT EXISTS `{self.partial_payload_table}` (payload_id varchar(255), is_input BOOLEAN, payload_data LONGBLOB)") if attempt_migration: - self._migrate_from_legacy_db() + # Schedule the migration to run asynchronously + import asyncio + loop = asyncio.get_event_loop() + loop.create_task(self._migrate_from_legacy_db()) # === MIGRATORS ================================================================================ - def _migrate_from_legacy_db(self): + async def _migrate_from_legacy_db(self): legacy_reader = LegacyMariaDBStorageReader(user=self.user, password=self.password, host=self.host, port=self.port, database=self.database) if legacy_reader.legacy_data_exists(): logger.info("Legacy TrustyAI v1 data exists in database, checking if a migration is necessary.") - asyncio.run(legacy_reader.migrate_data(self)) + await legacy_reader.migrate_data(self) # === INTERNAL HELPER FUNCTIONS ================================================================ @@ -80,7 +85,7 @@ def _build_table_name(self, index): return f"{self.schema_prefix}_dataset_{index}" @require_existing_dataset - def _get_clean_table_name(self, dataset_name: str) -> str: + async def _get_clean_table_name(self, dataset_name: str) -> str: """ Get a generated table name corresponding to a particular dataset. This avoids possible SQL injection from within the model names. @@ -91,7 +96,7 @@ def _get_clean_table_name(self, dataset_name: str) -> str: @require_existing_dataset - def _get_dataset_metadata(self, dataset_name: str) -> Optional[Dict]: + async def _get_dataset_metadata(self, dataset_name: str) -> Optional[Dict]: """ Return the metadata field from a particular dataset within the dataset_reference_table. """ @@ -102,7 +107,7 @@ def _get_dataset_metadata(self, dataset_name: str) -> Optional[Dict]: #=== DATASET QUERYING ========================================================================== - def dataset_exists(self, dataset_name: str) -> bool: + async def dataset_exists(self, dataset_name: str) -> bool: """ Check if a dataset exists within the TrustyAI model data. """ @@ -113,18 +118,20 @@ def dataset_exists(self, dataset_name: str) -> bool: except mariadb.ProgrammingError: return False - def list_all_datasets(self): - """ - List all available datasets in the database. - """ + def _list_all_datasets_sync(self): with self.connection_manager as (conn, cursor): cursor.execute(f"SELECT dataset_name FROM `{self.dataset_reference_table}`") - results = [x[0] for x in cursor.fetchall()] - return results + return [x[0] for x in cursor.fetchall()] + + async def list_all_datasets(self): + """ + List all datasets in the database. + """ + return await asyncio.to_thread(self._list_all_datasets_sync) @require_existing_dataset - def dataset_rows(self, dataset_name: str) -> int: + async def dataset_rows(self, dataset_name: str) -> int: """ Get the number of rows in a stored dataset (equivalent to data.shape[0]) """ @@ -135,23 +142,23 @@ def dataset_rows(self, dataset_name: str) -> int: @require_existing_dataset - def dataset_cols(self, dataset_name: str) -> int: + async def dataset_cols(self, dataset_name: str) -> int: """ Get the number of columns in a stored dataset (equivalent to data.shape[1]) """ - table_name = self._get_clean_table_name(dataset_name) + table_name = await self._get_clean_table_name(dataset_name) with self.connection_manager as (conn, cursor): cursor.execute(f"SHOW COLUMNS FROM {table_name}") return len(cursor.fetchall()) - 1 @require_existing_dataset - def dataset_shape(self, dataset_name: str) -> tuple[int]: + async def dataset_shape(self, dataset_name: str) -> tuple[int]: """ Get the whole shape of a stored dataset (equivalent to data.shape) """ - rows = self.dataset_rows(dataset_name) - shape = self._get_dataset_metadata(dataset_name)["shape"] + rows = await self.dataset_rows(dataset_name) + shape = (await self._get_dataset_metadata(dataset_name))["shape"] shape[0] = rows return tuple(shape) @@ -172,9 +179,9 @@ async def write_data(self, dataset_name: str, new_rows: np.ndarray, column_names if len(new_rows) == 0: raise ValueError(f"No data provided! `new_rows`=={new_rows}.") - # if received a single row, reshape into a single-row matrix + # if received a single row, reshape into a single-column matrix if new_rows.ndim < 2: - new_rows = new_rows.reshape(1, -1) + new_rows = new_rows.reshape(-1, 1) # validate that the number of provided column names matches the shape of the provided array if new_rows.shape[1] != len(column_names): @@ -182,7 +189,7 @@ async def write_data(self, dataset_name: str, new_rows: np.ndarray, column_names f"Shape mismatch: Number of provided column names ({len(column_names)}) does not match number of columns in provided array ({new_rows.shape[1]}).") # if this is the first time we've seen this dataset, set up its tables inside the DB - if not self.dataset_exists(dataset_name): + if not await self.dataset_exists(dataset_name): with self.connection_manager as (conn, cursor): # create an entry in `trustyai_v2_table_reference` @@ -212,10 +219,10 @@ async def write_data(self, dataset_name: str, new_rows: np.ndarray, column_names nrows = 0 else: # if dataset already exists, grab its current shape and information - stored_shape = self.dataset_shape(dataset_name) + stored_shape = await self.dataset_shape(dataset_name) ncols = stored_shape[1] - nrows = self.dataset_rows(dataset_name) - table_name = self._get_clean_table_name(dataset_name) + nrows = await self.dataset_rows(dataset_name) + table_name = await self._get_clean_table_name(dataset_name) cleaned_names = get_clean_column_names(column_names) # validate that the number of columns in the saved DB matched the provided column names @@ -255,9 +262,9 @@ async def write_data(self, dataset_name: str, new_rows: np.ndarray, column_names @require_existing_dataset - def read_data(self, dataset_name: str, start_row: int = 0, n_rows: int = None): + async def read_data(self, dataset_name: str, start_row: int = 0, n_rows: int = None): """ - Read saved data from the database, from `start_row` to `start_row + n_rows` (inclusive) + Read saved data from the database, from `start_row` to `start_row + n_rows` (inclusive)wait storage.dataset_exists(dataset_name): `dataset_name`: the name of the dataset to read. This is NOT the table name; see `trustyai_v2_table_reference.dataset_name` or use list_all_datasets() for the available dataset_names. @@ -265,16 +272,16 @@ def read_data(self, dataset_name: str, start_row: int = 0, n_rows: int = None): `n_rows`: The total number of rows to read. If not specified, read all rows. """ - table_name = self._get_clean_table_name(dataset_name) + table_name = await self._get_clean_table_name(dataset_name) if n_rows is None: - n_rows = self.dataset_rows(dataset_name) + n_rows = await self.dataset_rows(dataset_name) with self.connection_manager as (conn, cursor): # grab matching data cursor.execute( - f"SELECT * FROM `{table_name}` WHERE row_idx>? AND row_idx<=?", - (start_row, start_row+n_rows) + f"SELECT * FROM `{table_name}` ORDER BY row_idx ASC LIMIT ? OFFSET ?", + (n_rows, start_row) ) # parse saved data back to Numpy array @@ -297,15 +304,15 @@ def read_data(self, dataset_name: str, start_row: int = 0, n_rows: int = None): # === COLUMN NAMES ============================================================================= @require_existing_dataset - def get_original_column_names(self, dataset_name: str) -> Optional[List[str]]: - return self._get_dataset_metadata(dataset_name).get("column_names") + async def get_original_column_names(self, dataset_name: str) -> Optional[List[str]]: + return (await self._get_dataset_metadata(dataset_name)).get("column_names") @require_existing_dataset - def get_aliased_column_names(self, dataset_name: str) -> List[str]: - return self._get_dataset_metadata(dataset_name).get("aliased_names") + async def get_aliased_column_names(self, dataset_name: str) -> List[str]: + return (await self._get_dataset_metadata(dataset_name)).get("aliased_names") @require_existing_dataset - def apply_name_mapping(self, dataset_name: str, name_mapping: Dict[str, str]): + async def apply_name_mapping(self, dataset_name: str, name_mapping: Dict[str, str]): """Apply a name mapping to a dataset. `dataset_name`: the name of the dataset to read. This is NOT the table name; @@ -314,8 +321,8 @@ def apply_name_mapping(self, dataset_name: str, name_mapping: Dict[str, str]): to original column names and values should correspond to the desired new names. """ - original_names = self.get_original_column_names(dataset_name) - aliased_names = self.get_aliased_column_names(dataset_name) + original_names = await self.get_original_column_names(dataset_name) + aliased_names = await self.get_aliased_column_names(dataset_name) # get the new set of optionaly-aliased column names for col_idx, original_name in enumerate(original_names): @@ -334,67 +341,56 @@ def apply_name_mapping(self, dataset_name: str, name_mapping: Dict[str, str]): # === PARTIAL PAYLOADS ========================================================================= - async def _persist_payload(self, payload, is_input: bool, request_id: Optional[str] = None): + async def persist_partial_payload(self, + payload: Union[PartialPayload, KServeInferenceRequest, KServeInferenceResponse], + payload_id, is_input: bool): """Save a partial payload to the database.""" with self.connection_manager as (conn, cursor): - if request_id is None: - request_id = payload.id - cursor.execute( f"INSERT INTO `{self.partial_payload_table}` (payload_id, is_input, payload_data) VALUES (?, ?, ?)", - (request_id, is_input, pkl.dumps(payload))) + (payload_id, is_input, pkl.dumps(payload.model_dump()))) conn.commit() - async def _get_partial_payload(self, payload_id: str, is_input: bool): + async def get_partial_payload(self, payload_id: str, is_input: bool, is_modelmesh: bool) -> Union[PartialPayload, KServeInferenceRequest, KServeInferenceResponse]: """Retrieve a partial payload from the database.""" with self.connection_manager as (conn, cursor): cursor.execute(f"SELECT payload_data FROM `{self.partial_payload_table}` WHERE payload_id=? AND is_input=?", (payload_id, is_input)) result = cursor.fetchone() if result is None or len(result) == 0: return None - payload_data = result[0] - return pkl.loads(payload_data) - - - async def persist_partial_payload(self, payload, is_input: bool): - await self._persist_payload(payload, is_input) - - - async def persist_modelmesh_payload(self, payload: PartialPayload, request_id: str, is_input: bool): - await self._persist_payload(payload, is_input, request_id=request_id) - - - async def get_partial_payload(self, payload_id: str, is_input: bool): - return await self._get_partial_payload(payload_id, is_input) - - - async def get_modelmesh_payload(self, request_id: str, is_input: bool) -> Optional[PartialPayload]: - return await self._get_partial_payload(request_id, is_input) + payload_dict = pkl.loads(result[0]) + if is_modelmesh: + return PartialPayload(**payload_dict) + elif is_input: # kserve input + return KServeInferenceRequest(**payload_dict) + else: # kserve output + return KServeInferenceResponse(**payload_dict) - async def delete_modelmesh_payload(self, request_id: str, is_input: bool): + async def delete_partial_payload(self, payload_id: str, is_input: bool): with self.connection_manager as (conn, cursor): - cursor.execute(f"DELETE FROM {self.partial_payload_table} WHERE payload_id=? AND is_input=?", (request_id, is_input)) + cursor.execute(f"DELETE FROM {self.partial_payload_table} WHERE payload_id=? AND is_input=?", (payload_id, is_input)) conn.commit() # === DATABASE CLEANUP ========================================================================= @require_existing_dataset - def delete_dataset(self, dataset_name: str): - table_name = self._get_clean_table_name(dataset_name) + async def delete_dataset(self, dataset_name: str): + table_name = await self._get_clean_table_name(dataset_name) + logger.info(f"Deleting table={table_name} to delete dataset={dataset_name}.") with self.connection_manager as (conn, cursor): cursor.execute(f"DELETE FROM `{self.dataset_reference_table}` WHERE dataset_name=?", (dataset_name,)) cursor.execute(f"DROP TABLE IF EXISTS `{table_name}`") conn.commit() - def delete_all_datasets(self): - for dataset_name in self.list_all_datasets(): + async def delete_all_datasets(self): + for dataset_name in await self.list_all_datasets(): logger.warning(f"Deleting dataset {dataset_name}") - self.delete_dataset(dataset_name) + await self.delete_dataset(dataset_name) - def reset_database(self): + async def reset_database(self): logger.warning(f"Fully resetting TrustyAI V2 database.") - self.delete_all_datasets() + await self.delete_all_datasets() with self.connection_manager as (conn, cursor): cursor.execute(f"DROP TABLE IF EXISTS `{self.dataset_reference_table}`") cursor.execute(f"DROP TABLE IF EXISTS `{self.partial_payload_table}`") diff --git a/src/service/data/storage/maria/utils.py b/src/service/data/storage/maria/utils.py index af57978..97bfb47 100644 --- a/src/service/data/storage/maria/utils.py +++ b/src/service/data/storage/maria/utils.py @@ -5,15 +5,16 @@ def require_existing_dataset(func): """Annotation to assert that a given function requires a valid dataset name as the first non-self argument""" - def validate_dataset_exists(*args, **kwargs): + async def validate_dataset_exists(*args, **kwargs): storage, dataset_name = args[0], args[1] - if not storage.dataset_exists(dataset_name): + if not await storage.dataset_exists(dataset_name): raise ValueError(f"Error when calling {func.__name__}: Dataset '{dataset_name}' does not exist.") - return func(*args, **kwargs) + return await func(*args, **kwargs) return validate_dataset_exists + def get_clean_column_names(column_names) -> List[str]: """ Programmatically generate the column names in a model data table. diff --git a/src/service/data/storage/pvc.py b/src/service/data/storage/pvc.py index 813ed9b..23dc24b 100644 --- a/src/service/data/storage/pvc.py +++ b/src/service/data/storage/pvc.py @@ -1,5 +1,5 @@ import asyncio -from typing import List, Dict, Optional +from typing import List, Dict, Optional, Union import numpy as np import os @@ -7,6 +7,7 @@ import logging import pickle as pkl +from src.endpoints.consumer import KServeInferenceRequest, KServeInferenceResponse from src.service.utils import list_utils from .storage_interface import StorageInterface from src.service.constants import PROTECTED_DATASET_SUFFIX, PARTIAL_PAYLOAD_DATASET_NAME @@ -21,13 +22,7 @@ PARTIAL_OUTPUT_NAME = ( PROTECTED_DATASET_SUFFIX + PARTIAL_PAYLOAD_DATASET_NAME + "_outputs" ) -MODELMESH_INPUT_NAME = ( - f"{PROTECTED_DATASET_SUFFIX}modelmesh_partial_payloads_inputs" -) -MODELMESH_OUTPUT_NAME = ( - f"{PROTECTED_DATASET_SUFFIX}modelmesh_partial_payloads_outputs" -) - +MAX_VOID_TYPE_LENGTH=1024 class H5PYContext: """Open the corresponding H5PY file for a dataset and manage its context`""" @@ -111,10 +106,15 @@ async def dataset_exists(self, dataset_name: str) -> bool: return False - def list_all_datasets(self) -> List[str]: - """ List all datasets known by the dataset """ - return [fname.replace(f"_{self.data_file}", "") for fname in os.listdir(self.data_directory) if self.data_file in fname] + def _list_all_datasets_sync(self) -> List[str]: + return [ + fname.replace(f"_{self.data_file}", "") + for fname in os.listdir(self.data_directory) + if self.data_file in fname + ] + async def list_all_datasets(self) -> List[str]: + return await asyncio.to_thread(self._list_all_datasets_sync) async def dataset_rows(self, dataset_name: str) -> int: """Number of data rows in dataset, returns a FileNotFoundError if the dataset does not exist""" @@ -156,6 +156,13 @@ async def _write_raw_data( existing_shape = None dataset_exists = False + # standardize serialized rows to prevent metadata serialization failures + if isinstance(new_rows.dtype, np.dtypes.VoidDType): + if new_rows.dtype.itemsize > MAX_VOID_TYPE_LENGTH: + raise ValueError(f"The datatype of the array to be serialized is {new_rows.dtype}- the largest serializable void type is V{MAX_VOID_TYPE_LENGTH}") + new_rows = new_rows.astype(f"V{MAX_VOID_TYPE_LENGTH}") # this might cause bugs later + + if dataset_exists: # if we've already got saved inferences for this model if existing_shape[1:] == inbound_shape[1:]: # shapes match async with self.get_lock(allocated_dataset_name): @@ -203,11 +210,17 @@ async def _write_raw_data( max_shape = [None] + list(new_rows.shape)[ 1: ] # to-do: tune this value? + + + # if isinstance(new_rows.data, np.dtypes.VoidDType): + # new_rows = new_rows.astype("V400") + dataset = db.create_dataset( allocated_dataset_name, data=new_rows, maxshape=max_shape, chunks=True, + dtype=new_rows.dtype # use the dtype of the data ) dataset.attrs[COLUMN_NAMES_ATTRIBUTE] = column_names dataset.attrs[BYTES_ATTRIBUTE] = is_bytes @@ -224,9 +237,16 @@ async def write_data(self, dataset_name: str, new_rows, column_names: List[str]) or not isinstance(new_rows, np.ndarray) and list_utils.contains_non_numeric(new_rows) ): + serialized = list_utils.serialize_rows(new_rows, MAX_VOID_TYPE_LENGTH) + arr = np.array(serialized) + if arr.ndim == 1: + arr = arr.reshape(-1, 1) await self._write_raw_data( - dataset_name, list_utils.serialize_rows(new_rows), column_names - ) + dataset_name, + arr, + column_names, + is_bytes=True, +) else: await self._write_raw_data(dataset_name, np.array(new_rows), column_names) @@ -247,20 +267,18 @@ async def _read_raw_data( f"Requested a data read from start_row={start_row}, but dataset " f"only has {dataset.shape[0]} rows. An empty array will be returned." ) - return ( - dataset[start_row:end_row], - dataset.attrs[COLUMN_NAMES_ATTRIBUTE], - ) + return dataset[start_row:end_row] + async def read_data( - self, dataset_name: str, start_row: int = None, n_rows: int = None + self, dataset_name: str, start_row: int = 0, n_rows: int = None ) -> (np.ndarray, List[str]): """Read data from a dataset, automatically deserializing any byte data""" - read, column_names = await self._read_raw_data(dataset_name, start_row, n_rows) + read = await self._read_raw_data(dataset_name, start_row, n_rows) if len(read) and read[0].dtype.type in {np.bytes_, np.void}: - return list_utils.deserialize_rows(read), column_names + return list_utils.deserialize_rows(read) else: - return read, column_names + return read async def delete_dataset(self, dataset_name: str): """Delete dataset data, ignoring non-existent datasets""" @@ -307,77 +325,34 @@ async def apply_name_mapping(self, dataset_name: str, name_mapping: Dict[str, st aliased_names = [name_mapping.get(name, name) for name in curr_names] db[allocated_dataset_name].attrs[COLUMN_ALIAS_ATTRIBUTE] = aliased_names - async def persist_partial_payload(self, payload, is_input: bool): - """Save a partial payload to disk. Returns None if no matching id exists""" - - # lock to prevent simultaneous read/writes - partial_dataset_name = PARTIAL_INPUT_NAME if is_input else PARTIAL_OUTPUT_NAME - async with self.get_lock(partial_dataset_name): - with H5PYContext( - self, - partial_dataset_name, - "a", - ) as db: - if partial_dataset_name not in db: - dataset = db.create_dataset( - partial_dataset_name, dtype="f", track_order=True - ) - else: - dataset = db[partial_dataset_name] - dataset.attrs[payload.id] = np.void(pkl.dumps(payload)) - async def persist_modelmesh_payload( - self, payload: PartialPayload, request_id: str, is_input: bool - ): + async def persist_partial_payload(self, payload: Union[PartialPayload, KServeInferenceRequest, KServeInferenceResponse], payload_id: str, is_input: bool): """ - Persist a ModelMesh payload. - - Args: - payload: The payload to persist - request_id: The unique identifier for the inference request - is_input: Whether this is an input payload (True) or output payload (False) + Save a KServe or ModelMesh payload to disk. """ - dataset_name = MODELMESH_INPUT_NAME if is_input else MODELMESH_OUTPUT_NAME - + dataset_name = PARTIAL_INPUT_NAME if is_input else PARTIAL_OUTPUT_NAME serialized_data = pkl.dumps(payload.model_dump()) + is_modelmesh = isinstance(payload, PartialPayload) async with self.get_lock(dataset_name): try: with H5PYContext(self, dataset_name, "a") as db: if dataset_name not in db: - dataset = db.create_dataset(dataset_name, data=np.array([0])) - dataset.attrs["request_ids"] = [] - - dataset = db[dataset_name] - request_ids = list(dataset.attrs["request_ids"]) - - dataset.attrs[request_id] = np.void(serialized_data) - - if request_id not in request_ids: - request_ids.append(request_id) - dataset.attrs["request_ids"] = request_ids + dataset = db.create_dataset(dataset_name, dtype="f", track_order=True) + else: + dataset = db[dataset_name] + dataset.attrs[payload_id] = np.void(serialized_data) logger.debug( - f"Stored ModelMesh {'input' if is_input else 'output'} payload for request ID: {request_id}" + f"Stored {'ModelMesh' if is_modelmesh else 'KServe'} {'input' if is_input else 'output'} payload for request ID: {payload_id}" ) except Exception as e: - logger.error(f"Error storing ModelMesh payload: {str(e)}") + logger.error(f"Error storing {'ModelMesh' if is_modelmesh else 'KServe'} payload: {str(e)}") raise - async def get_modelmesh_payload( - self, request_id: str, is_input: bool - ) -> Optional[PartialPayload]: - """ - Retrieve a stored ModelMesh payload by request ID. - - Args: - request_id: The unique identifier for the inference request - is_input: Whether to retrieve an input payload (True) or output payload (False) - - Returns: - The retrieved payload, or None if not found - """ - dataset_name = MODELMESH_INPUT_NAME if is_input else MODELMESH_OUTPUT_NAME + async def get_partial_payload(self, payload_id: str, is_input: bool, is_modelmesh: bool) -> Optional[ + Union[PartialPayload, KServeInferenceRequest, KServeInferenceResponse]]: + dataset_name = PARTIAL_INPUT_NAME if is_input else PARTIAL_OUTPUT_NAME try: async with self.get_lock(dataset_name): @@ -386,51 +361,38 @@ async def get_modelmesh_payload( return None dataset = db[dataset_name] - if request_id not in dataset.attrs: + if payload_id not in dataset.attrs: return None - serialized_data = dataset.attrs[request_id] + serialized_data = dataset.attrs[payload_id] try: payload_dict = pkl.loads(serialized_data) - return PartialPayload(**payload_dict) + if is_modelmesh: + return PartialPayload(**payload_dict) + elif is_input: # kserve input + return KServeInferenceRequest(**payload_dict) + else: # kserve output + return KServeInferenceResponse(**payload_dict) except Exception as e: logger.error(f"Error unpickling payload: {str(e)}") return None except MissingH5PYDataException: return None except Exception as e: - logger.error(f"Error retrieving ModelMesh payload: {str(e)}") + logger.error(f"Error retrieving {'ModelMesh' if is_modelmesh else 'KServe'} payload: {str(e)}") return None - async def get_partial_payload(self, payload_id: str, is_input: bool): - """Looks up a partial payload by id. Returns None if no matching id exists""" - - # lock to prevent simultaneous read/writes - partial_dataset_name = PARTIAL_INPUT_NAME if is_input else PARTIAL_OUTPUT_NAME - async with self.get_lock(partial_dataset_name): - try: - with H5PYContext(self, partial_dataset_name, "r") as db: - if partial_dataset_name not in db: - return None - recovered_bytes = db[partial_dataset_name].attrs.get(payload_id) - return ( - None - if recovered_bytes is None - else pkl.loads(recovered_bytes) - ) - except MissingH5PYDataException: - return None - async def delete_modelmesh_payload(self, request_id: str, is_input: bool): + async def delete_partial_payload(self, request_id: str, is_input: bool): """ - Delete a stored ModelMesh payload. + Delete a stored partial payload. Args: request_id: The unique identifier for the inference request is_input: Whether to delete an input payload (True) or output payload (False) """ - dataset_name = MODELMESH_INPUT_NAME if is_input else MODELMESH_OUTPUT_NAME + dataset_name = PARTIAL_INPUT_NAME if is_input else PARTIAL_OUTPUT_NAME try: async with self.get_lock(dataset_name): @@ -439,24 +401,20 @@ async def delete_modelmesh_payload(self, request_id: str, is_input: bool): return dataset = db[dataset_name] - request_ids = list(dataset.attrs["request_ids"]) - if request_id not in request_ids: + if request_id not in dataset.attrs: return if request_id in dataset.attrs: del dataset.attrs[request_id] - request_ids.remove(request_id) - dataset.attrs["request_ids"] = request_ids - - if not request_ids: + if not dataset.attrs: del db[dataset_name] logger.debug( - f"Deleted ModelMesh {'input' if is_input else 'output'} payload for request ID: {request_id}" + f"Deleted {'input' if is_input else 'output'} payload for request ID: {request_id}" ) except MissingH5PYDataException: return except Exception as e: - logger.error(f"Error deleting ModelMesh payload: {str(e)}") + logger.error(f"Error deleting payload: {str(e)}") diff --git a/src/service/data/storage/storage_interface.py b/src/service/data/storage/storage_interface.py index 28ad65b..84c005d 100644 --- a/src/service/data/storage/storage_interface.py +++ b/src/service/data/storage/storage_interface.py @@ -1,23 +1,25 @@ from abc import ABC, abstractmethod -from typing import List, Dict, Optional +from typing import List, Dict, Optional, Union + +from src.endpoints.consumer import KServeInferenceResponse, KServeInferenceRequest from src.service.data.modelmesh_parser import PartialPayload class StorageInterface(ABC): @abstractmethod - def dataset_exists(self, dataset_name: str) -> bool: + async def dataset_exists(self, dataset_name: str) -> bool: pass @abstractmethod - def list_all_datasets(self) -> List[str]: + async def list_all_datasets(self) -> List[str]: pass @abstractmethod - def dataset_rows(self, dataset_name: str) -> int: + async def dataset_rows(self, dataset_name: str) -> int: pass @abstractmethod - def dataset_shape(self, dataset_name: str) -> tuple[int]: + async def dataset_shape(self, dataset_name: str) -> tuple[int]: pass @abstractmethod @@ -25,68 +27,41 @@ async def write_data(self, dataset_name: str, new_rows, column_names: List[str]) pass @abstractmethod - def read_data(self, dataset_name: str, start_row: int = None, n_rows: int = None): + async def read_data(self, dataset_name: str, start_row: int = None, n_rows: int = None): pass @abstractmethod - def get_original_column_names(self, dataset_name: str) -> List[str]: + async def get_original_column_names(self, dataset_name: str) -> List[str]: pass @abstractmethod - def get_aliased_column_names(self, dataset_name: str) -> List[str]: + async def get_aliased_column_names(self, dataset_name: str) -> List[str]: pass @abstractmethod - def apply_name_mapping(self, dataset_name: str, name_mapping: Dict[str, str]): + async def apply_name_mapping(self, dataset_name: str, name_mapping: Dict[str, str]): pass @abstractmethod - def delete_dataset(self, dataset_name: str): + async def delete_dataset(self, dataset_name: str): pass @abstractmethod - async def persist_partial_payload(self, payload, is_input: bool): + async def persist_partial_payload(self, + payload: Union[PartialPayload, KServeInferenceRequest, KServeInferenceResponse], + payload_id, is_input: bool): pass @abstractmethod - async def get_partial_payload(self, payload_id: str, is_input: bool): - pass - - - @abstractmethod - async def persist_modelmesh_payload( - self, payload: PartialPayload, request_id: str, is_input: bool - ): - """ - Store a ModelMesh partial payload (either input or output) for later reconciliation. - - Args: - payload: The partial payload to store - request_id: A unique identifier for this inference request - is_input: Whether this is an input payload (True) or output payload (False) - """ + async def get_partial_payload(self, payload_id: str, is_input: bool, is_modelmesh: bool) -> Optional[ + Union[PartialPayload, KServeInferenceRequest, KServeInferenceResponse]]: pass - @abstractmethod - async def get_modelmesh_payload( - self, request_id: str, is_input: bool - ) -> Optional[PartialPayload]: - """ - Retrieve a stored ModelMesh payload by request ID. - - Args: - request_id: The unique identifier for the inference request - is_input: Whether to retrieve an input payload (True) or output payload (False) - - Returns: - The retrieved payload, or None if not found - """ - pass @abstractmethod - async def delete_modelmesh_payload(self, request_id: str, is_input: bool): + async def delete_partial_payload(self, payload_id: str, is_input: bool): """ - Delete a stored ModelMesh payload. + Delete a stored partial payload. Args: request_id: The unique identifier for the inference request diff --git a/src/service/utils/list_utils.py b/src/service/utils/list_utils.py index b278227..8ca398f 100644 --- a/src/service/utils/list_utils.py +++ b/src/service/utils/list_utils.py @@ -15,10 +15,10 @@ def contains_non_numeric(l: list) -> bool: return isinstance(l, (bool, str)) -def serialize_rows(l: list): +def serialize_rows(l: list, max_void_type_length): """Convert a nested list to a 1D numpy array, where the nth element contains a bytes serialization of the nth row""" serialized = [np.void(pickle.dumps(row)) for row in l] - return np.array(serialized) + return np.array(serialized, dtype=f"V{max_void_type_length}") def deserialize_rows(serialized: np.ndarray): diff --git a/tests/endpoints/test_upload_endpoint_maria.py b/tests/endpoints/test_upload_endpoint_maria.py new file mode 100644 index 0000000..255b41b --- /dev/null +++ b/tests/endpoints/test_upload_endpoint_maria.py @@ -0,0 +1,59 @@ +import asyncio +import os +import unittest + +from fastapi.testclient import TestClient + +from tests.endpoints.test_upload_endpoint_pvc import TestUploadEndpointPVC + + +class TestUploadEndpointMaria(TestUploadEndpointPVC): + + def setUp(self): + os.environ["SERVICE_STORAGE_FORMAT"] = 'MARIA' + os.environ["DATABASE_USERNAME"] = "trustyai" + os.environ["DATABASE_PASSWORD"] = "trustyai" + os.environ["DATABASE_HOST"] = "127.0.0.1" + os.environ["DATABASE_PORT"] = "3306" + os.environ["DATABASE_DATABASE"] = "trustyai-database" + + # Force reload of the global storage interface to use the new temp dir + from src.service.data import storage + self.storage_interface = storage.get_global_storage_interface(force_reload=True) + + # Re-create the FastAPI app to ensure it uses the new storage interface + from importlib import reload + import src.main + reload(src.main) + from src.main import app + self.client = TestClient(app) + + self.original_datasets = set(self.storage_interface.list_all_datasets()) + + def tearDown(self): + # delete any datasets we've created + new_datasets = set(self.storage_interface.list_all_datasets()) + for ds in new_datasets.difference(self.original_datasets): + asyncio.run(self.storage_interface.delete_dataset(ds)) + + def test_upload_data(self): + super().test_upload_data() + + def test_upload_multi_input_data(self): + super().test_upload_multi_input_data() + + def test_upload_multi_input_data_no_unique_name(self): + super().test_upload_multi_input_data_no_unique_name() + + def test_upload_multiple_tagging(self): + super().test_upload_multiple_tagging() + + def test_upload_tag_that_uses_protected_name(self): + super().test_upload_tag_that_uses_protected_name() + + def test_upload_gaussian_data(self): + super().test_upload_gaussian_data() + + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/tests/endpoints/test_upload_endpoint_pvc.py b/tests/endpoints/test_upload_endpoint_pvc.py new file mode 100644 index 0000000..8ee223a --- /dev/null +++ b/tests/endpoints/test_upload_endpoint_pvc.py @@ -0,0 +1,364 @@ +import asyncio +import itertools +import os +import shutil +import tempfile +import unittest +import uuid + +from fastapi.testclient import TestClient + +from src.service.data.model_data import ModelData +from src.service.constants import ( + TRUSTYAI_TAG_PREFIX, +) + +MODEL_ID = "example1" + + +def generate_payload(n_rows, n_input_cols, n_output_cols, datatype, tag, input_offset=0, output_offset=0): + """Generate a test payload with specific dimensions and data types.""" + model_name = f"{MODEL_ID}_{uuid.uuid4().hex[:8]}" + input_data = [] + for i in range(n_rows): + if n_input_cols == 1: + input_data.append(i + input_offset) + else: + row = [i + j + input_offset for j in range(n_input_cols)] + input_data.append(row) + output_data = [] + for i in range(n_rows): + if n_output_cols == 1: + output_data.append(i * 2 + output_offset) + else: + row = [i * 2 + j + output_offset for j in range(n_output_cols)] + output_data.append(row) + payload = { + "model_name": model_name, + "data_tag": tag, + "is_ground_truth": False, + "request": { + "inputs": [ + { + "name": "input", + "shape": [n_rows, n_input_cols] if n_input_cols > 1 else [n_rows], + "datatype": datatype, + "data": input_data, + } + ] + }, + "response": { + "outputs": [ + { + "name": "output", + "shape": [n_rows, n_output_cols] if n_output_cols > 1 else [n_rows], + "datatype": datatype, + "data": output_data, + } + ] + }, + } + return payload + + +def generate_multi_input_payload(n_rows, n_input_cols, n_output_cols, datatype, tag): + """Generate a test payload with multi-dimensional tensors like real data.""" + model_name = f"{MODEL_ID}_{uuid.uuid4().hex[:8]}" + input_data = [] + for row_idx in range(n_rows): + row = [row_idx + col_idx * 10 for col_idx in range(n_input_cols)] + input_data.append(row) + output_data = [] + for row_idx in range(n_rows): + row = [row_idx * 2 + col_idx for col_idx in range(n_output_cols)] + output_data.append(row) + payload = { + "model_name": model_name, + "data_tag": tag, + "is_ground_truth": False, + "request": { + "inputs": [ + { + "name": "multi_input", + "shape": [n_rows, n_input_cols], + "datatype": datatype, + "data": input_data, + } + ] + }, + "response": { + "outputs": [ + { + "name": "multi_output", + "shape": [n_rows, n_output_cols], + "datatype": datatype, + "data": output_data, + } + ] + }, + } + return payload + + +def generate_mismatched_shape_no_unique_name_multi_input_payload(n_rows, n_input_cols, n_output_cols, datatype, tag): + """Generate a payload with mismatched shapes and non-unique names.""" + model_name = f"{MODEL_ID}_{uuid.uuid4().hex[:8]}" + input_data_1 = [[row_idx + col_idx * 10 for col_idx in range(n_input_cols)] for row_idx in range(n_rows)] + mismatched_rows = n_rows - 1 if n_rows > 1 else 1 + input_data_2 = [[row_idx + col_idx * 20 for col_idx in range(n_input_cols)] for row_idx in range(mismatched_rows)] + output_data = [[row_idx * 2 + col_idx for col_idx in range(n_output_cols)] for row_idx in range(n_rows)] + payload = { + "model_name": model_name, + "data_tag": tag, + "is_ground_truth": False, + "request": { + "inputs": [ + { + "name": "same_name", + "shape": [n_rows, n_input_cols], + "datatype": datatype, + "data": input_data_1, + }, + { + "name": "same_name", + "shape": [mismatched_rows, n_input_cols], + "datatype": datatype, + "data": input_data_2, + }, + ] + }, + "response": { + "outputs": [ + { + "name": "multi_output", + "shape": [n_rows, n_output_cols], + "datatype": datatype, + "data": output_data, + } + ] + }, + } + return payload + + +def count_rows_with_tag(model_name, tag): + """Count rows with a specific tag in metadata.""" + metadata_df = asyncio.run(ModelData(model_name).get_metadata_as_df()) + return metadata_df['tags'].apply(lambda tags: tag in tags).sum() + + +def get_metadata_ids(model_name): + """Count rows with a specific tag in metadata.""" + metadata_df = asyncio.run(ModelData(model_name).get_metadata_as_df()) + return metadata_df['id'].tolist() + + +class TestUploadEndpointPVC(unittest.TestCase): + + def setUp(self): + self.TEMP_DIR = tempfile.mkdtemp() + os.environ["STORAGE_DATA_FOLDER"] = self.TEMP_DIR + + # Force reload of the global storage interface to use the new temp dir + from src.service.data import storage + storage.get_global_storage_interface(force_reload=True) + + # Re-create the FastAPI app to ensure it uses the new storage interface + from importlib import reload + import src.main + reload(src.main) + from src.main import app + self.client = TestClient(app) + + def tearDown(self): + if os.path.exists(self.TEMP_DIR): + shutil.rmtree(self.TEMP_DIR) + + + def post_test(self, payload, expected_status_code, check_msgs): + """Post a payload and check the response.""" + response = self.client.post("/data/upload", json=payload) + if response.status_code != expected_status_code: + print(f"\n=== DEBUG INFO ===") + print(f"Expected status: {expected_status_code}") + print(f"Actual status: {response.status_code}") + print(f"Response text: {response.text}") + print(f"Response headers: {dict(response.headers)}") + if hasattr(response, "json"): + try: + print(f"Response JSON: {response.json()}") + except: + pass + print("==================") + + self.assertEqual(response.status_code, expected_status_code) + return response + + + # data upload tests + def test_upload_data(self): + n_input_rows_options = [1, 5, 250] + n_input_cols_options = [1, 4] + n_output_cols_options = [1, 2] + datatype_options = ["INT64", "INT32", "FP32", "FP64", "BOOL"] + + for idx, (n_input_rows, n_input_cols, n_output_cols, datatype) in enumerate(itertools.product( + n_input_rows_options, n_input_cols_options, n_output_cols_options, datatype_options + )): + with self.subTest( + f"subtest-{idx}", + n_input_rows=n_input_rows, + n_input_cols=n_input_cols, + n_output_cols=n_output_cols, + datatype=datatype, + ): + """Test uploading data with various dimensions and datatypes.""" + data_tag = "TRAINING" + payload = generate_payload(n_input_rows, n_input_cols, n_output_cols, datatype, data_tag) + response = self.post_test(payload, 200, [f"{n_input_rows} datapoints"]) + + inputs, outputs, metadata = asyncio.run(ModelData(payload["model_name"]).data()) + + self.assertEqual(response.status_code, 200) + self.assertIn(str(n_input_rows), response.text) + self.assertIsNotNone(inputs, "Input data not found in storage") + self.assertIsNotNone(outputs, "Output data not found in storage") + + self.assertEqual(len(inputs), n_input_rows, "Incorrect number of input rows") + self.assertEqual(len(outputs), n_input_rows, "Incorrect number of output rows") + + tag_count = count_rows_with_tag(payload["model_name"], data_tag) + self.assertEqual(tag_count, n_input_rows, "Not all rows have the correct tag") + + + def test_upload_multi_input_data(self): + """Test uploading data with multiple input tensors.""" + n_rows_options = [1, 3, 5, 250] + n_input_cols_options = [2, 6] + n_output_cols_options = [4] + datatype_options = ["INT64", "INT32", "FP32", "FP64", "BOOL"] + + for n_rows, n_input_cols, n_output_cols, datatype in itertools.product( + n_rows_options, n_input_cols_options, n_output_cols_options, datatype_options + ): + with self.subTest( + n_rows=n_rows, + n_input_cols=n_input_cols, + n_output_cols=n_output_cols, + datatype=datatype, + ): + # Arrange + data_tag = "TRAINING" + payload = generate_multi_input_payload(n_rows, n_input_cols, n_output_cols, datatype, data_tag) + + # Act + self.post_test(payload, 200, [f"{n_rows} datapoints"]) + + model_data = ModelData(payload["model_name"]) + inputs, outputs, metadata = asyncio.run(model_data.data()) + input_column_names, output_column_names, metadata_column_names = asyncio.run( + model_data.original_column_names()) + + # Assert + self.assertIsNotNone(inputs, "Input data not found in storage") + self.assertIsNotNone(outputs, "Output data not found in storage") + self.assertEqual(len(inputs), n_rows, "Incorrect number of input rows") + self.assertEqual(len(outputs), n_rows, "Incorrect number of output rows") + self.assertEqual(len(input_column_names), n_input_cols, "Incorrect number of input columns") + self.assertEqual(len(output_column_names), n_output_cols, "Incorrect number of output columns") + self.assertGreaterEqual(len(input_column_names), 2, "Should have at least 2 input column names") + tag_count = count_rows_with_tag(payload["model_name"], data_tag) + self.assertEqual(tag_count, n_rows, "Not all rows have the correct tag") + + + def test_upload_multi_input_data_no_unique_name(self): + """Test error case for non-unique tensor names.""" + payload = generate_mismatched_shape_no_unique_name_multi_input_payload(250, 4, 3, "FP64", "TRAINING") + response = self.client.post("/data/upload", json=payload) + self.assertEqual(response.status_code,400) + print(response.text) + self.assertIn("input shapes were mismatched", response.text) + self.assertIn("[250, 4]", response.text) + + + def test_upload_multiple_tagging(self): + """Test uploading data with multiple tags.""" + n_payload1 = 50 + n_payload2 = 51 + tag1 = "TRAINING" + tag2 = "NOT TRAINING " + model_name = f"{MODEL_ID}_{uuid.uuid4().hex[:8]}" + + payload1 = generate_payload(n_payload1, 10, 1, "INT64", tag1) + payload1["model_name"] = model_name + self.post_test(payload1, 200, [f"{n_payload1} datapoints"]) + + payload2 = generate_payload(n_payload2, 10, 1, "INT64", tag2) + payload2["model_name"] = model_name + self.post_test(payload2, 200, [f"{n_payload2} datapoints"]) + + tag1_count = count_rows_with_tag(model_name, tag1) + tag2_count = count_rows_with_tag(model_name, tag2) + + self.assertEqual(tag1_count, n_payload1, f"Expected {n_payload1} rows with tag {tag1}") + self.assertEqual(tag2_count, n_payload2, f"Expected {n_payload2} rows with tag {tag2}") + + input_rows, _, _ = asyncio.run(ModelData(payload1["model_name"]).row_counts()) + self.assertEqual(input_rows, n_payload1 + n_payload2, "Incorrect total number of rows") + + + def test_upload_tag_that_uses_protected_name(self): + """Test error when using a protected tag name.""" + invalid_tag = f"{TRUSTYAI_TAG_PREFIX}_something" + payload = generate_payload(5, 10, 1, "INT64", invalid_tag) + response = self.post_test(payload, 400, ["reserved for internal TrustyAI use only"]) + expected_msg = f"The tag prefix '{TRUSTYAI_TAG_PREFIX}' is reserved for internal TrustyAI use only. Provided tag '{invalid_tag}' violates this restriction." + self.assertIn(expected_msg, response.text) + + + def test_upload_gaussian_data(self): + """Test uploading realistic Gaussian data.""" + payload = { + "model_name": "gaussian-credit-model", + "data_tag": "TRAINING", + "request": { + "inputs": [ + { + "name": "credit_inputs", + "shape": [2, 4], + "datatype": "FP64", + "data": [ + [ + 47.45380690750797, + 478.6846214843319, + 13.462184703540503, + 20.764525303373535, + ], + [ + 47.468246185717554, + 575.6911203538863, + 10.844143722475575, + 14.81343667761101, + ], + ], + } + ] + }, + "response": { + "model_name": "gaussian-credit-model__isvc-d79a7d395d", + "model_version": "1", + "outputs": [ + { + "name": "predict", + "datatype": "FP32", + "shape": [2, 1], + "data": [0.19013395683309373, 0.2754730253205645], + } + ], + }, + } + self.post_test(payload, 200, ["2 datapoints"]) + + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/tests/service/data/test_async_contract.py b/tests/service/data/test_async_contract.py new file mode 100644 index 0000000..99688df --- /dev/null +++ b/tests/service/data/test_async_contract.py @@ -0,0 +1,36 @@ +import inspect +import importlib +import pytest + + +CLASSES = [ + ("src.service.data.storage.maria.maria", "MariaDBStorage"), + ("src.service.data.storage.pvc", "PVCStorage"), +] + +SYNC_ALLOWED = { + "PVCStorage": {"allocate_valid_dataset_name", "get_lock"}, # no resource usage, non blocking +} + +def public_methods(cls): + for name, fn in inspect.getmembers(cls, predicate=inspect.isfunction): + if name.startswith("_"): + continue + if name in {"__init__", "__new__", "__class_getitem__"}: + continue + yield name, fn + +@pytest.mark.parametrize("module_path,class_name", CLASSES) +def test_storage_methods_are_async(module_path, class_name): + mod = importlib.import_module(module_path) + cls = getattr(mod, class_name) + allowed = SYNC_ALLOWED.get(class_name, set()) + + offenders = [] + for name, fn in public_methods(cls): + if name in allowed: + continue + if not inspect.iscoroutinefunction(fn): + offenders.append(name) + + assert not offenders, f"{class_name} has non-async methods: {offenders}" \ No newline at end of file diff --git a/tests/service/data/test_mariadb_migration.py b/tests/service/data/test_mariadb_migration.py index 1eb4326..f85ffdc 100644 --- a/tests/service/data/test_mariadb_migration.py +++ b/tests/service/data/test_mariadb_migration.py @@ -5,8 +5,6 @@ import asyncio import datetime import unittest -import os -import logging import numpy as np from src.service.data.storage.maria.maria import MariaDBStorage @@ -30,17 +28,20 @@ def setUp(self): def tearDown(self): - self.storage.reset_database() + asyncio.run(self.storage.reset_database()) async def _test_retrieve_data(self): # total data checks available_datasets = self.storage.list_all_datasets() - self.assertEqual(len(available_datasets), 12) + for i in [1,2,3,4]: + for split in ["inputs", "outputs", "metadata"]: + self.assertIn(f"model{i}_{split}", available_datasets) + self.assertGreaterEqual(len(available_datasets), 12) # model 1 checks - model_1_inputs = self.storage.read_data("model1_inputs", 0, 100) - model_1_metadata = self.storage.read_data("model1_metadata", 0, 1) + model_1_inputs = await self.storage.read_data("model1_inputs", 0, 100) + model_1_metadata = await self.storage.read_data("model1_metadata", 0, 1) self.assertTrue(np.array_equal(np.array([[0,1,2,3,4]]*100), model_1_inputs)) self.assertTrue(np.array_equal(np.array([0, 1, 2, 3, 4]), model_1_inputs[0])) self.assertEqual(model_1_metadata[0][0], datetime.datetime.fromisoformat("2025-06-09 12:19:06.074828")) @@ -48,10 +49,10 @@ async def _test_retrieve_data(self): self.assertEqual(model_1_metadata[0][2], "_trustyai_unlabeled") # model 3 checks - self.assertEqual(self.storage.get_aliased_column_names("model3_inputs"), ["year mapped", "make mapped", "color mapped"]) + self.assertEqual(await self.storage.get_aliased_column_names("model3_inputs"), ["year mapped", "make mapped", "color mapped"]) # model 4 checks - model_4_inputs_row0 = self.storage.read_data("model4_inputs", 0, 5) + model_4_inputs_row0 = await self.storage.read_data("model4_inputs", 0, 5) self.assertEqual(model_4_inputs_row0[0].tolist(), [0.0, "i'm text-0", True, 0]) self.assertEqual(model_4_inputs_row0[4].tolist(), [4.0, "i'm text-4", True, 8]) diff --git a/tests/service/data/test_mariadb_storage.py b/tests/service/data/test_mariadb_storage.py index 32be2fd..b762300 100644 --- a/tests/service/data/test_mariadb_storage.py +++ b/tests/service/data/test_mariadb_storage.py @@ -4,7 +4,6 @@ import asyncio import unittest -import os import numpy as np from src.service.data.storage.maria.maria import MariaDBStorage @@ -25,10 +24,12 @@ def setUp(self): 3306, "trustyai-database", attempt_migration=False) + self.original_datasets = set(self.storage.list_all_datasets()) def tearDown(self): - self.storage.reset_database() + asyncio.run(self.storage.reset_database()) + async def _store_dataset(self, seed, n_rows=None, n_cols=None): n_rows = seed * 3 if n_rows is None else n_rows @@ -45,81 +46,80 @@ async def _test_retrieve_data(self): start_idx = dataset_idx n_rows = dataset_idx * 2 - retrieved_full_dataset = self.storage.read_data(dataset_name) - retrieved_partial_dataset = self.storage.read_data(dataset_name, start_idx, n_rows) + retrieved_full_dataset = await self.storage.read_data(dataset_name) + retrieved_partial_dataset = await self.storage.read_data(dataset_name, start_idx, n_rows) self.assertTrue(np.array_equal(retrieved_full_dataset, original_dataset)) - self.assertEqual(original_dataset.shape, self.storage.dataset_shape(dataset_name)) - self.assertEqual(original_dataset.shape[0], self.storage.dataset_rows(dataset_name)) - self.assertEqual(original_dataset.shape[1], self.storage.dataset_cols(dataset_name)) + self.assertEqual(original_dataset.shape, await self.storage.dataset_shape(dataset_name)) + self.assertEqual(original_dataset.shape[0], await self.storage.dataset_rows(dataset_name)) + self.assertEqual(original_dataset.shape[1], await self.storage.dataset_cols(dataset_name)) self.assertTrue(np.array_equal(retrieved_partial_dataset, original_dataset[start_idx:start_idx+n_rows])) async def _test_big_insert(self): original_dataset, _, dataset_name = await self._store_dataset(0, 5000, 10) - retrieved_full_dataset = self.storage.read_data(dataset_name) + retrieved_full_dataset = await self.storage.read_data(dataset_name) self.assertTrue(np.array_equal(retrieved_full_dataset, original_dataset)) - self.assertEqual(original_dataset.shape, self.storage.dataset_shape(dataset_name)) - self.assertEqual(original_dataset.shape[0], self.storage.dataset_rows(dataset_name)) - self.assertEqual(original_dataset.shape[1], self.storage.dataset_cols(dataset_name)) + self.assertEqual(original_dataset.shape, await self.storage.dataset_shape(dataset_name)) + self.assertEqual(original_dataset.shape[0], await self.storage.dataset_rows(dataset_name)) + self.assertEqual(original_dataset.shape[1], await self.storage.dataset_cols(dataset_name)) async def _test_single_row_insert(self): original_dataset, _, dataset_name = await self._store_dataset(0, 1, 10) - retrieved_full_dataset = self.storage.read_data(dataset_name, 0, 1) + retrieved_full_dataset = await self.storage.read_data(dataset_name, 0, 1) self.assertTrue(np.array_equal(retrieved_full_dataset, original_dataset)) - self.assertEqual(original_dataset.shape, self.storage.dataset_shape(dataset_name)) - self.assertEqual(original_dataset.shape[0], self.storage.dataset_rows(dataset_name)) - self.assertEqual(original_dataset.shape[1], self.storage.dataset_cols(dataset_name)) + self.assertEqual(original_dataset.shape, await self.storage.dataset_shape(dataset_name)) + self.assertEqual(original_dataset.shape[0], await self.storage.dataset_rows(dataset_name)) + self.assertEqual(original_dataset.shape[1], await self.storage.dataset_cols(dataset_name)) - async def _test_single_row_retrieval(self): - original_dataset = np.arange(0, 10).reshape(1, 10) - column_names = [alphabet[i] for i in range(original_dataset.shape[0])] + async def _test_vector_retrieval(self): + original_dataset = np.arange(0, 10) + column_names = ["single_column"] dataset_name = "dataset_single_row" await self.storage.write_data(dataset_name, original_dataset, column_names) - retrieved_full_dataset = self.storage.read_data(dataset_name)[0] + retrieved_full_dataset = await self.storage.read_data(dataset_name) + transposed_dataset = retrieved_full_dataset.reshape(-1) - self.assertTrue(np.array_equal(retrieved_full_dataset, original_dataset)) - self.assertEqual(1, self.storage.dataset_rows(dataset_name)) - self.assertEqual(10, self.storage.dataset_cols(dataset_name)) + self.assertTrue(np.array_equal(transposed_dataset, original_dataset)) + self.assertEqual(10, await self.storage.dataset_rows(dataset_name)) + self.assertEqual(1, await self.storage.dataset_cols(dataset_name)) async def _test_name_mapping(self): for dataset_idx in range(1, 10): original_dataset, column_names, dataset_name = await self._store_dataset(dataset_idx) name_mapping = {name: "aliased_" + name for i, name in enumerate(column_names) if i % 2 == 0} expected_mapping = [name_mapping.get(name, name) for name in column_names] - self.storage.apply_name_mapping(dataset_name, name_mapping) + await self.storage.apply_name_mapping(dataset_name, name_mapping) - self.assertEqual(column_names, self.storage.get_original_column_names(dataset_name)) - self.assertEqual(expected_mapping, self.storage.get_aliased_column_names(dataset_name)) + retrieved_original_names = await self.storage.get_original_column_names(dataset_name) + retrieved_aliased_names = await self.storage.get_aliased_column_names(dataset_name) -def run_async_test(coro): - """Helper function to run async tests.""" - loop = asyncio.new_event_loop() - return loop.run_until_complete(coro) + self.assertEqual(column_names, retrieved_original_names) + self.assertEqual(expected_mapping, retrieved_aliased_names) + def test_retrieve_data(self): + run_async_test(self._test_retrieve_data()) -TestMariaDBStorage.test_retrieve_data = lambda self: run_async_test( - self._test_retrieve_data() -) -TestMariaDBStorage.test_name_mapping = lambda self: run_async_test( - self._test_name_mapping() -) + def test_name_mapping(self): + run_async_test(self._test_name_mapping()) -TestMariaDBStorage.test_big_insert = lambda self: run_async_test( - self._test_big_insert() -) + def test_big_insert(self): + run_async_test(self._test_big_insert()) -TestMariaDBStorage.test_single_row_insert = lambda self: run_async_test( - self._test_single_row_insert() -) + def test_single_row_insert(self): + run_async_test(self._test_single_row_insert()) -TestMariaDBStorage._test_single_row_retrieval = lambda self: run_async_test( - self._test_single_row_retrieval() -) + def test_single_row_retrieval(self): + run_async_test(self._test_vector_retrieval()) +def run_async_test(coro): + """Helper function to run async tests.""" + loop = asyncio.new_event_loop() + return loop.run_until_complete(coro) + if __name__ == "__main__": diff --git a/tests/service/data/test_modelmesh_parser.py b/tests/service/data/test_modelmesh_parser.py index 17eabfe..a77fb6f 100644 --- a/tests/service/data/test_modelmesh_parser.py +++ b/tests/service/data/test_modelmesh_parser.py @@ -3,10 +3,7 @@ """ import unittest -from typing import Dict, Optional - import pandas as pd -from pydantic import BaseModel from src.service.data.modelmesh_parser import ModelMeshPayloadParser, PartialPayload from tests.service.data.test_utils import ModelMeshTestData diff --git a/tests/service/data/test_payload_reconciliation_maria.py b/tests/service/data/test_payload_reconciliation_maria.py index 025a7c4..715ec28 100644 --- a/tests/service/data/test_payload_reconciliation_maria.py +++ b/tests/service/data/test_payload_reconciliation_maria.py @@ -1,27 +1,18 @@ """ Tests for ModelMesh payload reconciliation.MariaDBStorage("root", "root", "127.0.0.1", 3306, "trustyai_database_v2") """ - import asyncio import unittest -import tempfile -import os -import base64 -import time -from datetime import datetime -from unittest import mock import uuid -import pandas as pd -import numpy as np -from src.service.data.modelmesh_parser import ModelMeshPayloadParser, PartialPayload +from src.service.data.modelmesh_parser import PartialPayload from src.service.data.storage.maria.maria import MariaDBStorage -from src.service.data.storage.pvc import PVCStorage +from tests.service.data.test_payload_reconciliation_pvc import TestPayloadReconciliation from tests.service.data.test_utils import ModelMeshTestData -class TestMariaPayloadReconciliation(unittest.TestCase): +class TestMariaPayloadReconciliation(TestPayloadReconciliation): """ Test class for ModelMesh payload reconciliation. """ @@ -56,185 +47,7 @@ def setUp(self): def tearDown(self): """Clean up after tests.""" - self.storage.reset_database() - - async def _test_persist_input_payload(self): - """Test persisting an input payload.""" - await self.storage.persist_modelmesh_payload( - self.input_payload, self.request_id, is_input=True - ) - - retrieved_payload = await self.storage.get_modelmesh_payload( - self.request_id, is_input=True - ) - - self.assertIsNotNone(retrieved_payload) - self.assertEqual(retrieved_payload.data, self.input_payload.data) - - output_payload = await self.storage.get_modelmesh_payload( - self.request_id, is_input=False - ) - self.assertIsNone(output_payload) - - async def _test_persist_output_payload(self): - """Test persisting an output payload.""" - await self.storage.persist_modelmesh_payload( - self.output_payload, self.request_id, is_input=False - ) - - retrieved_payload = await self.storage.get_modelmesh_payload( - self.request_id, is_input=False - ) - - self.assertIsNotNone(retrieved_payload) - self.assertEqual(retrieved_payload.data, self.output_payload.data) - - input_payload = await self.storage.get_modelmesh_payload( - self.request_id, is_input=True - ) - self.assertIsNone(input_payload) - - async def _test_full_reconciliation(self): - """Test the full payload reconciliation process.""" - await self.storage.persist_modelmesh_payload( - self.input_payload, self.request_id, is_input=True - ) - await self.storage.persist_modelmesh_payload( - self.output_payload, self.request_id, is_input=False - ) - - input_payload = await self.storage.get_modelmesh_payload( - self.request_id, is_input=True - ) - output_payload = await self.storage.get_modelmesh_payload( - self.request_id, is_input=False - ) - - self.assertIsNotNone(input_payload) - self.assertIsNotNone(output_payload) - - df = ModelMeshPayloadParser.payloads_to_dataframe( - input_payload, output_payload, self.request_id, self.model_name - ) - - self.assertIsInstance(df, pd.DataFrame) - self.assertIn("input", df.columns) - self.assertIn("output_output", df.columns) - self.assertEqual(len(df), 5) # Based on our test data with 5 rows - - self.assertIn("id", df.columns) - self.assertEqual(df["id"].iloc[0], self.request_id) - - self.assertIn("model_id", df.columns) - self.assertEqual(df["model_id"].iloc[0], self.model_name) - - # Clean up - await self.storage.delete_modelmesh_payload(self.request_id, is_input=True) - await self.storage.delete_modelmesh_payload(self.request_id, is_input=False) - - input_payload = await self.storage.get_modelmesh_payload( - self.request_id, is_input=True - ) - output_payload = await self.storage.get_modelmesh_payload( - self.request_id, is_input=False - ) - - self.assertIsNone(input_payload) - self.assertIsNone(output_payload) - - async def _test_reconciliation_with_real_data(self): - """Test reconciliation with sample b64 encoded data from files.""" - current_dir = os.path.dirname(os.path.abspath(__file__)) - test_data_dir = os.path.join( - os.path.dirname(os.path.dirname(current_dir)), "data" - ) - - with open(os.path.join(test_data_dir, "input-sample.b64"), "r") as f: - sample_input_data = f.read().strip() - - with open(os.path.join(test_data_dir, "output-sample.b64"), "r") as f: - sample_output_data = f.read().strip() - - input_payload = PartialPayload(data=sample_input_data) - output_payload = PartialPayload(data=sample_output_data) - - request_id = str(uuid.uuid4()) - model_id = "sample-model" - - await self.storage.persist_modelmesh_payload( - input_payload, request_id, is_input=True - ) - await self.storage.persist_modelmesh_payload( - output_payload, request_id, is_input=False - ) - - stored_input = await self.storage.get_modelmesh_payload( - request_id, is_input=True - ) - stored_output = await self.storage.get_modelmesh_payload( - request_id, is_input=False - ) - - self.assertIsNotNone(stored_input) - self.assertIsNotNone(stored_output) - self.assertEqual(stored_input.data, sample_input_data) - self.assertEqual(stored_output.data, sample_output_data) - - with mock.patch.object( - ModelMeshPayloadParser, "payloads_to_dataframe" - ) as mock_to_df: - sample_df = pd.DataFrame( - { - "input_feature": [1, 2, 3], - "output_output_feature": [4, 5, 6], - "id": [request_id] * 3, - "model_id": [model_id] * 3, - "synthetic": [False] * 3, - } - ) - mock_to_df.return_value = sample_df - - df = ModelMeshPayloadParser.payloads_to_dataframe( - stored_input, stored_output, request_id, model_id - ) - - self.assertIsInstance(df, pd.DataFrame) - self.assertEqual(len(df), 3) - - mock_to_df.assert_called_once_with( - stored_input, stored_output, request_id, model_id - ) - - # Clean up - await self.storage.delete_modelmesh_payload(request_id, is_input=True) - await self.storage.delete_modelmesh_payload(request_id, is_input=False) - - self.assertIsNone( - await self.storage.get_modelmesh_payload(request_id, is_input=True) - ) - self.assertIsNone( - await self.storage.get_modelmesh_payload(request_id, is_input=False) - ) - - -def run_async_test(coro): - """Helper function to run async tests.""" - loop = asyncio.new_event_loop() - return loop.run_until_complete(coro) - - -TestMariaPayloadReconciliation.test_persist_input_payload = lambda self: run_async_test( - self._test_persist_input_payload() -) -TestMariaPayloadReconciliation.test_persist_output_payload = lambda self: run_async_test( - self._test_persist_output_payload() -) -TestMariaPayloadReconciliation.test_full_reconciliation = lambda self: run_async_test( - self._test_full_reconciliation() -) -TestMariaPayloadReconciliation.test_reconciliation_with_real_data = ( - lambda self: run_async_test(self._test_reconciliation_with_real_data()) -) + asyncio.run(self.storage.reset_database()) if __name__ == "__main__": diff --git a/tests/service/data/test_payload_reconciliation_pvc.py b/tests/service/data/test_payload_reconciliation_pvc.py index 23c22bc..b7553af 100644 --- a/tests/service/data/test_payload_reconciliation_pvc.py +++ b/tests/service/data/test_payload_reconciliation_pvc.py @@ -6,14 +6,10 @@ import unittest import tempfile import os -import base64 -import time -from datetime import datetime from unittest import mock import uuid import pandas as pd -import numpy as np from src.service.data.modelmesh_parser import ModelMeshPayloadParser, PartialPayload from src.service.data.storage.pvc import PVCStorage @@ -53,54 +49,54 @@ def tearDown(self): async def _test_persist_input_payload(self): """Test persisting an input payload.""" - await self.storage.persist_modelmesh_payload( - self.input_payload, self.request_id, is_input=True + await self.storage.persist_partial_payload( + self.input_payload, payload_id=self.request_id, is_input=True ) - retrieved_payload = await self.storage.get_modelmesh_payload( - self.request_id, is_input=True + retrieved_payload = await self.storage.get_partial_payload( + self.request_id, is_input=True, is_modelmesh=True ) self.assertIsNotNone(retrieved_payload) self.assertEqual(retrieved_payload.data, self.input_payload.data) - output_payload = await self.storage.get_modelmesh_payload( - self.request_id, is_input=False + output_payload = await self.storage.get_partial_payload( + self.request_id, is_input=False, is_modelmesh=True ) self.assertIsNone(output_payload) async def _test_persist_output_payload(self): """Test persisting an output payload.""" - await self.storage.persist_modelmesh_payload( - self.output_payload, self.request_id, is_input=False + await self.storage.persist_partial_payload( + self.output_payload, payload_id=self.request_id, is_input=False ) - retrieved_payload = await self.storage.get_modelmesh_payload( - self.request_id, is_input=False + retrieved_payload = await self.storage.get_partial_payload( + self.request_id, is_input=False, is_modelmesh=True ) self.assertIsNotNone(retrieved_payload) self.assertEqual(retrieved_payload.data, self.output_payload.data) - input_payload = await self.storage.get_modelmesh_payload( - self.request_id, is_input=True + input_payload = await self.storage.get_partial_payload( + self.request_id, is_input=True, is_modelmesh=True ) self.assertIsNone(input_payload) async def _test_full_reconciliation(self): """Test the full payload reconciliation process.""" - await self.storage.persist_modelmesh_payload( + await self.storage.persist_partial_payload( self.input_payload, self.request_id, is_input=True ) - await self.storage.persist_modelmesh_payload( + await self.storage.persist_partial_payload( self.output_payload, self.request_id, is_input=False ) - input_payload = await self.storage.get_modelmesh_payload( - self.request_id, is_input=True + input_payload = await self.storage.get_partial_payload( + self.request_id, is_input=True, is_modelmesh=True ) - output_payload = await self.storage.get_modelmesh_payload( - self.request_id, is_input=False + output_payload = await self.storage.get_partial_payload( + self.request_id, is_input=False, is_modelmesh=True ) self.assertIsNotNone(input_payload) @@ -122,14 +118,14 @@ async def _test_full_reconciliation(self): self.assertEqual(df["model_id"].iloc[0], self.model_name) # Clean up - await self.storage.delete_modelmesh_payload(self.request_id, is_input=True) - await self.storage.delete_modelmesh_payload(self.request_id, is_input=False) + await self.storage.delete_partial_payload(self.request_id, is_input=True) + await self.storage.delete_partial_payload(self.request_id, is_input=False) - input_payload = await self.storage.get_modelmesh_payload( - self.request_id, is_input=True + input_payload = await self.storage.get_partial_payload( + self.request_id, is_input=True, is_modelmesh=True ) - output_payload = await self.storage.get_modelmesh_payload( - self.request_id, is_input=False + output_payload = await self.storage.get_partial_payload( + self.request_id, is_input=False, is_modelmesh=True ) self.assertIsNone(input_payload) @@ -154,18 +150,18 @@ async def _test_reconciliation_with_real_data(self): request_id = str(uuid.uuid4()) model_id = "sample-model" - await self.storage.persist_modelmesh_payload( + await self.storage.persist_partial_payload( input_payload, request_id, is_input=True ) - await self.storage.persist_modelmesh_payload( + await self.storage.persist_partial_payload( output_payload, request_id, is_input=False ) - stored_input = await self.storage.get_modelmesh_payload( - request_id, is_input=True + stored_input = await self.storage.get_partial_payload( + request_id, is_input=True, is_modelmesh=True ) - stored_output = await self.storage.get_modelmesh_payload( - request_id, is_input=False + stored_output = await self.storage.get_partial_payload( + request_id, is_input=False, is_modelmesh=True ) self.assertIsNotNone(stored_input) @@ -199,14 +195,14 @@ async def _test_reconciliation_with_real_data(self): ) # Clean up - await self.storage.delete_modelmesh_payload(request_id, is_input=True) - await self.storage.delete_modelmesh_payload(request_id, is_input=False) + await self.storage.delete_partial_payload(request_id, is_input=True) + await self.storage.delete_partial_payload(request_id, is_input=False) self.assertIsNone( - await self.storage.get_modelmesh_payload(request_id, is_input=True) + await self.storage.get_partial_payload(request_id, is_input=True, is_modelmesh=True) ) self.assertIsNone( - await self.storage.get_modelmesh_payload(request_id, is_input=False) + await self.storage.get_partial_payload(request_id, is_input=False, is_modelmesh=True) ) diff --git a/tests/service/test_consumer_endpoint_reconciliation.py b/tests/service/test_consumer_endpoint_reconciliation.py index 5d79797..4b6ee9b 100644 --- a/tests/service/test_consumer_endpoint_reconciliation.py +++ b/tests/service/test_consumer_endpoint_reconciliation.py @@ -26,9 +26,11 @@ def setUp(self): self.temp_dir = tempfile.TemporaryDirectory() self.storage_patch = mock.patch( - "src.endpoints.consumer.consumer_endpoint.storage_interface" + "src.endpoints.consumer.consumer_endpoint.get_global_storage_interface" ) - self.mock_storage = self.storage_patch.start() + self.mock_get_storage = self.storage_patch.start() + self.mock_storage = mock.AsyncMock() + self.mock_get_storage.return_value = self.mock_storage self.parser_patch = mock.patch.object( ModelMeshPayloadParser, "parse_input_payload" @@ -92,8 +94,8 @@ def tearDown(self): async def _test_consume_input_payload(self): """Test consuming an input payload.""" - self.mock_storage.persist_modelmesh_payload = mock.AsyncMock() - self.mock_storage.get_modelmesh_payload = mock.AsyncMock(return_value=None) + self.mock_storage.persist_partial_payload = mock.AsyncMock() + self.mock_storage.get_partial_payload = mock.AsyncMock(return_value=None) self.mock_parse_input.return_value = True self.mock_parse_output.side_effect = ValueError("Not an output payload") @@ -104,6 +106,7 @@ async def _test_consume_input_payload(self): } response = self.client.post("/consumer/kserve/v2", json=inference_payload) + print(response.text) self.assertEqual(response.status_code, 200) self.assertEqual( @@ -114,15 +117,15 @@ async def _test_consume_input_payload(self): }, ) - self.mock_storage.persist_modelmesh_payload.assert_called_once() - call_args = self.mock_storage.persist_modelmesh_payload.call_args[0] - self.assertEqual(call_args[1], self.request_id) - self.assertTrue(call_args[2]) # is_input=True + self.mock_storage.persist_partial_payload.assert_called_once() + call_kwargs = self.mock_storage.persist_partial_payload.call_args[1] + self.assertEqual(call_kwargs["payload_id"], self.request_id) + self.assertTrue(call_kwargs["is_input"]) # is_input=True async def _test_consume_output_payload(self): """Test consuming an output payload.""" - self.mock_storage.persist_modelmesh_payload = mock.AsyncMock() - self.mock_storage.get_modelmesh_payload = mock.AsyncMock(return_value=None) + self.mock_storage.persist_partial_payload = mock.AsyncMock() + self.mock_storage.get_partial_payload = mock.AsyncMock(return_value=None) self.mock_parse_input.side_effect = ValueError("Not an input payload") self.mock_parse_output.return_value = True @@ -143,23 +146,23 @@ async def _test_consume_output_payload(self): }, ) - self.mock_storage.persist_modelmesh_payload.assert_called_once() - call_args = self.mock_storage.persist_modelmesh_payload.call_args[0] - self.assertEqual(call_args[1], self.request_id) - self.assertFalse(call_args[2]) # is_input=False + self.mock_storage.persist_partial_payload.assert_called_once() + call_kwargs = self.mock_storage.persist_partial_payload.call_args[1] + self.assertEqual(call_kwargs["payload_id"], self.request_id) + self.assertFalse(call_kwargs["is_input"]) # is_input=True async def _test_reconcile_payloads(self): """Test reconciling both input and output payloads.""" # Setup mocks for correct interactions - self.mock_storage.get_modelmesh_payload = mock.AsyncMock() - self.mock_storage.get_modelmesh_payload.side_effect = [ + self.mock_storage.get_partial_payload = mock.AsyncMock() + self.mock_storage.get_partial_payload.side_effect = [ None, self.input_payload, ] - self.mock_storage.persist_modelmesh_payload = mock.AsyncMock() + self.mock_storage.persist_partial_payload = mock.AsyncMock() self.mock_storage.write_data = mock.AsyncMock() - self.mock_storage.delete_modelmesh_payload = mock.AsyncMock() + self.mock_storage.delete_partial_payload = mock.AsyncMock() with ( mock.patch( @@ -213,13 +216,13 @@ async def mock_gather_impl(*args, **kwargs): }, ) - self.mock_storage.persist_modelmesh_payload.assert_called_once() + self.mock_storage.persist_partial_payload.assert_called_once() + call_kwargs = self.mock_storage.persist_partial_payload.call_args[1] - call_args = self.mock_storage.persist_modelmesh_payload.call_args[0] - self.assertEqual(call_args[1], self.request_id) - self.assertTrue(call_args[2]) # is_input=True + self.assertEqual(call_kwargs['payload_id'], self.request_id) + self.assertTrue(call_kwargs['is_input']) # is_input=True - self.mock_storage.persist_modelmesh_payload.reset_mock() + self.mock_storage.persist_partial_payload.reset_mock() mock_parse_input.side_effect = ValueError("Not an input") mock_parse_output.side_effect = lambda x: True @@ -246,9 +249,9 @@ async def mock_gather_impl(*args, **kwargs): }, ) - call_args = self.mock_storage.persist_modelmesh_payload.call_args[0] - self.assertEqual(call_args[1], self.request_id) - self.assertFalse(call_args[2]) # is_input=False + call_kwargs = self.mock_storage.persist_partial_payload.call_args[1] + self.assertEqual(call_kwargs["payload_id"], self.request_id) + self.assertFalse(call_kwargs["is_input"]) # is_input=False mock_reconcile.assert_called_once() reconcile_args = mock_reconcile.call_args[0]