From 34e127b872b7c8855c765bee6084e77a9e630585 Mon Sep 17 00:00:00 2001 From: m-misiura Date: Wed, 18 Jun 2025 13:13:11 +0100 Subject: [PATCH 01/10] :construction: working on the /data/download and /data/upload endpoints --- src/endpoints/data/data_download.py | 42 +- src/endpoints/data/data_upload.py | 36 +- src/service/utils/download.py | 310 ++++++++++++ src/service/utils/upload.py | 545 ++++++++++++++++++++++ tests/endpoints/test_download_endpoint.py | 404 ++++++++++++++++ tests/endpoints/test_upload_endpoint.py | 484 +++++++++++++++++++ 6 files changed, 1788 insertions(+), 33 deletions(-) create mode 100644 src/service/utils/download.py create mode 100644 src/service/utils/upload.py create mode 100644 tests/endpoints/test_download_endpoint.py create mode 100644 tests/endpoints/test_upload_endpoint.py diff --git a/src/endpoints/data/data_download.py b/src/endpoints/data/data_download.py index eb65111..4e39858 100644 --- a/src/endpoints/data/data_download.py +++ b/src/endpoints/data/data_download.py @@ -1,32 +1,32 @@ -from fastapi import APIRouter, HTTPException -from pydantic import BaseModel -from typing import List, Any, Optional import logging -router = APIRouter() -logger = logging.getLogger(__name__) - - -class RowMatcher(BaseModel): - columnName: str - operation: str - values: List[Any] +import pandas as pd +from fastapi import APIRouter, HTTPException +from src.service.utils.download import ( + DataRequestPayload, + DataResponsePayload, + apply_filters, # ← New utility function + load_model_dataframe, +) -class DataRequestPayload(BaseModel): - modelId: str - matchAny: Optional[List[RowMatcher]] = None - matchAll: Optional[List[RowMatcher]] = None - matchNone: Optional[List[RowMatcher]] = None +router = APIRouter() +logger = logging.getLogger(__name__) @router.post("/data/download") -async def download_data(payload: DataRequestPayload): - """Download model data.""" +async def download_data(payload: DataRequestPayload) -> DataResponsePayload: + """Download model data with filtering.""" try: logger.info(f"Received data download request for model: {payload.modelId}") - # TODO: Implement - return {"status": "success", "data": []} + df = await load_model_dataframe(payload.modelId) + if df.empty: + return DataResponsePayload(dataCSV="") + df = apply_filters(df, payload) + csv_data = df.to_csv(index=False) + return DataResponsePayload(dataCSV=csv_data) + except HTTPException: + raise except Exception as e: logger.error(f"Error downloading data: {str(e)}") - raise HTTPException(status_code=500, detail=f"Error downloading data: {str(e)}") + raise HTTPException(status_code=500, detail=f"Error downloading data: {str(e)}") \ No newline at end of file diff --git a/src/endpoints/data/data_upload.py b/src/endpoints/data/data_upload.py index 49889db..6f20bb2 100644 --- a/src/endpoints/data/data_upload.py +++ b/src/endpoints/data/data_upload.py @@ -1,27 +1,39 @@ +import logging +import uuid +from datetime import datetime +from typing import Any, Dict, List, Optional + +import numpy as np from fastapi import APIRouter, HTTPException from pydantic import BaseModel -from typing import Dict, Any -import logging + +from src.service.constants import INPUT_SUFFIX, METADATA_SUFFIX, OUTPUT_SUFFIX +from src.service.data.modelmesh_parser import ModelMeshPayloadParser +from src.service.data.storage import get_storage_interface +from src.service.utils.upload import process_upload_request router = APIRouter() logger = logging.getLogger(__name__) -class ModelInferJointPayload(BaseModel): +class UploadPayload(BaseModel): model_name: str - data_tag: str = None + data_tag: Optional[str] = None is_ground_truth: bool = False request: Dict[str, Any] - response: Dict[str, Any] + response: Optional[Dict[str, Any]] = None @router.post("/data/upload") -async def upload_data(payload: ModelInferJointPayload): - """Upload a batch of model data to TrustyAI.""" +async def upload(payload: UploadPayload) -> Dict[str, str]: + """Upload model data - regular or ground truth.""" try: - logger.info(f"Received data upload for model: {payload.model_name}") - # TODO: Implement - return {"status": "success", "message": "Data uploaded successfully"} + logger.info(f"Received upload request for model: {payload.model_name}") + result = await process_upload_request(payload) + logger.info(f"Upload completed for model: {payload.model_name}") + return result + except HTTPException: + raise except Exception as e: - logger.error(f"Error uploading data: {str(e)}") - raise HTTPException(status_code=500, detail=f"Error uploading data: {str(e)}") + logger.error(f"Unexpected error in upload endpoint for model {payload.model_name}: {str(e)}", exc_info=True) + raise HTTPException(500, f"Internal server error: {str(e)}") \ No newline at end of file diff --git a/src/service/utils/download.py b/src/service/utils/download.py new file mode 100644 index 0000000..a6deea3 --- /dev/null +++ b/src/service/utils/download.py @@ -0,0 +1,310 @@ +import logging +import numbers +import pickle +from typing import Any, List + +import pandas as pd +from fastapi import HTTPException +from pydantic import BaseModel, Field + +from src.service.data.storage import get_storage_interface + +logger = logging.getLogger(__name__) + + +class RowMatcher(BaseModel): + """Represents a row matching condition for data filtering.""" + + columnName: str + operation: str # "EQUALS" or "BETWEEN" + values: List[Any] + + +class DataRequestPayload(BaseModel): + """Request payload for data download operations.""" + + modelId: str + matchAny: List[RowMatcher] = Field(default_factory=list) + matchAll: List[RowMatcher] = Field(default_factory=list) + matchNone: List[RowMatcher] = Field(default_factory=list) + + +class DataResponsePayload(BaseModel): + """Response payload containing filtered data as CSV.""" + + dataCSV: str + + +def get_storage() -> Any: + """Get storage interface instance.""" + return get_storage_interface() + + +async def load_model_dataframe(model_id: str) -> pd.DataFrame: + try: + storage = get_storage_interface() + print(f"DEBUG: storage type = {type(storage)}") + input_data, input_cols = await storage.read_data(f"{model_id}_inputs") + output_data, output_cols = await storage.read_data(f"{model_id}_outputs") + metadata_data, metadata_cols = await storage.read_data(f"{model_id}_metadata") + if input_data is None or output_data is None or metadata_data is None: + raise HTTPException(404, f"Model {model_id} not found") + df = pd.DataFrame() + if len(input_data) > 0: + input_df = pd.DataFrame(input_data, columns=input_cols) + df = pd.concat([df, input_df], axis=1) + if len(output_data) > 0: + output_df = pd.DataFrame(output_data, columns=output_cols) + df = pd.concat([df, output_df], axis=1) + if len(metadata_data) > 0: + logger.debug(f"Metadata data type: {type(metadata_data)}") + logger.debug(f"First row type: {type(metadata_data[0]) if len(metadata_data) > 0 else 'empty'}") + logger.debug( + f"First row dtype: {metadata_data[0].dtype if hasattr(metadata_data[0], 'dtype') else 'no dtype'}" + ) + metadata_df = pd.DataFrame(metadata_data, columns=metadata_cols) + trusty_mapping = { + "ID": "trustyai.ID", + "MODEL_ID": "trustyai.MODEL_ID", + "TIMESTAMP": "trustyai.TIMESTAMP", + "TAG": "trustyai.TAG", + "INDEX": "trustyai.INDEX", + } + for orig_col in metadata_cols: + trusty_col = trusty_mapping.get(orig_col, orig_col) + df[trusty_col] = metadata_df[orig_col] + return df + except HTTPException: + raise + except Exception as e: + logger.error(f"Error loading model dataframe: {e}") + raise HTTPException(500, f"Error loading model data: {str(e)}") + + +def apply_filters(df: pd.DataFrame, payload: DataRequestPayload) -> pd.DataFrame: + """ + Apply all filters to DataFrame with performance optimization. + """ + if not any([payload.matchAll, payload.matchAny, payload.matchNone]): + return df + has_timestamp_filter = _has_timestamp_filters(payload) + if has_timestamp_filter: + logger.debug("Using boolean mask approach for timestamp filters") + return _apply_filters_with_boolean_masks(df, payload) + else: + logger.debug("Using query approach for non-timestamp filters") + return _apply_filters_with_query(df, payload) + + +def _has_timestamp_filters(payload: DataRequestPayload) -> bool: + """Check if payload contains any timestamp filters.""" + for matcher_list in [payload.matchAll or [], payload.matchAny or [], payload.matchNone or []]: + for matcher in matcher_list: + if matcher.columnName == "trustyai.TIMESTAMP": + return True + return False + + +def _apply_filters_with_query(df: pd.DataFrame, payload: DataRequestPayload) -> pd.DataFrame: + """Apply filters using pandas query (optimized for non-timestamp filters).""" + query_expr = _build_query_expression(df, payload) + if query_expr: + logger.debug(f"Executing query: {query_expr}") + try: + df = df.query(query_expr) + except Exception as e: + logger.error(f"Query execution failed: {query_expr}") + raise HTTPException(status_code=400, detail=f"Filter execution failed: {str(e)}") + return df + + +def _apply_filters_with_boolean_masks(df: pd.DataFrame, payload: DataRequestPayload) -> pd.DataFrame: + """Apply filters using boolean masks (optimized for timestamp filters).""" + final_mask = pd.Series(True, index=df.index) + if payload.matchAll: + for matcher in payload.matchAll: + matcher_mask = _get_matcher_mask(df, matcher, negate=False) + final_mask &= matcher_mask + if payload.matchNone: + for matcher in payload.matchNone: + matcher_mask = _get_matcher_mask(df, matcher, negate=True) + final_mask &= matcher_mask + if payload.matchAny: + any_mask = pd.Series(False, index=df.index) + for matcher in payload.matchAny: + matcher_mask = _get_matcher_mask(df, matcher, negate=False) + any_mask |= matcher_mask + final_mask &= any_mask + return df[final_mask] + + +def _get_matcher_mask(df: pd.DataFrame, matcher: RowMatcher, negate: bool = False) -> pd.Series: + """ + Get boolean mask for a single matcher with comprehensive validation. + """ + column_name = matcher.columnName + values = matcher.values + if matcher.operation not in ["EQUALS", "BETWEEN"]: + raise HTTPException(status_code=400, detail="RowMatch operation must be one of [BETWEEN, EQUALS]") + if column_name not in df.columns: + raise HTTPException(status_code=400, detail=f"No feature or output found with name={column_name}") + if matcher.operation == "EQUALS": + mask = df[column_name].isin(values) + elif matcher.operation == "BETWEEN": + mask = _create_between_mask(df, column_name, values) + if negate: + mask = ~mask + + return mask + + +def _create_between_mask(df: pd.DataFrame, column_name: str, values: List[Any]) -> pd.Series: + """Create boolean mask for BETWEEN operation with type-specific handling.""" + errors = [] + if len(values) != 2: + errors.append( + f"BETWEEN operation must contain exactly two values, describing the lower and upper bounds of the desired range. Received {len(values)} values" + ) + if column_name == "trustyai.TIMESTAMP": + if errors: + raise HTTPException(status_code=400, detail=", ".join(errors)) + try: + start_time = pd.to_datetime(str(values[0])) + end_time = pd.to_datetime(str(values[1])) + df_times = pd.to_datetime(df[column_name]) + return (df_times >= start_time) & (df_times < end_time) + except Exception as e: + raise HTTPException( + status_code=400, detail=f"Timestamp value is unparseable as an ISO_LOCAL_DATE_TIME: {str(e)}" + ) + elif column_name == "trustyai.INDEX": + if errors: + raise HTTPException(status_code=400, detail=", ".join(errors)) + min_val, max_val = sorted([int(v) for v in values]) + return (df[column_name] >= min_val) & (df[column_name] < max_val) + else: + if not all(isinstance(v, numbers.Number) for v in values): + errors.append( + "BETWEEN operation must only contain numbers, describing the lower and upper bounds of the desired range. Received non-numeric values" + ) + if errors: + raise HTTPException(status_code=400, detail=", ".join(errors)) + min_val, max_val = sorted(values) + try: + if df[column_name].dtype in ["int64", "float64", "int32", "float32"]: + return (df[column_name] >= min_val) & (df[column_name] < max_val) + else: + numeric_column = pd.to_numeric(df[column_name], errors="raise") + return (numeric_column >= min_val) & (numeric_column < max_val) + except (ValueError, TypeError): + raise HTTPException( + status_code=400, + detail=f"Column '{column_name}' contains non-numeric values that cannot be compared with BETWEEN operation.", + ) + + +def _build_query_expression(df: pd.DataFrame, payload: DataRequestPayload) -> str: + """Build optimized pandas query expression for all filters.""" + conditions = [] + if payload.matchAll: + for matcher in payload.matchAll: + condition = _build_condition(df, matcher, negate=False) + if condition: + conditions.append(condition) + if payload.matchNone: + for matcher in payload.matchNone: + condition = _build_condition(df, matcher, negate=True) + if condition: + conditions.append(condition) + if payload.matchAny: + any_conditions = [] + for matcher in payload.matchAny: + condition = _build_condition(df, matcher, negate=False) + if condition: + any_conditions.append(condition) + if any_conditions: + any_expr = " | ".join(f"({cond})" for cond in any_conditions) + conditions.append(f"({any_expr})") + return " & ".join(f"({cond})" for cond in conditions) if conditions else "" + + +def _build_condition(df: pd.DataFrame, matcher: RowMatcher, negate: bool = False) -> str: + """Build a single condition for pandas query.""" + column_name = matcher.columnName + values = matcher.values + if matcher.operation not in ["EQUALS", "BETWEEN"]: + raise HTTPException(status_code=400, detail="RowMatch operation must be one of [BETWEEN, EQUALS]") + if column_name not in df.columns: + raise HTTPException(status_code=400, detail=f"No feature or output found with name={column_name}") + safe_column = _sanitize_column_name(column_name) + if matcher.operation == "EQUALS": + condition = _build_equals_condition(safe_column, values, df[column_name].dtype) + elif matcher.operation == "BETWEEN": + condition = _build_between_condition(safe_column, values, column_name, df[column_name].dtype) + if negate: + condition = f"~({condition})" + return condition + + +def _sanitize_column_name(column_name: str) -> str: + """Sanitize column name for pandas query syntax.""" + if "." in column_name or column_name.startswith("trustyai"): + return f"`{column_name}`" + return column_name + + +def _build_equals_condition(safe_column: str, values: List[Any], dtype) -> str: + """Build EQUALS condition for query with optimization.""" + if len(values) == 1: + val = _format_value_for_query(values[0], dtype) + return f"{safe_column} == {val}" + else: + formatted_values = [_format_value_for_query(v, dtype) for v in values] + values_str = "[" + ", ".join(formatted_values) + "]" + return f"{safe_column}.isin({values_str})" + + +def _build_between_condition(safe_column: str, values: List[Any], original_column: str, dtype) -> str: + """Build BETWEEN condition for query with comprehensive validation.""" + errors = [] + if len(values) != 2: + errors.append( + f"BETWEEN operation must contain exactly two values, describing the lower and upper bounds of the desired range. Received {len(values)} values" + ) + if original_column == "trustyai.TIMESTAMP": + if errors: + raise HTTPException(status_code=400, detail=", ".join(errors)) + try: + start_time = pd.to_datetime(str(values[0])) + end_time = pd.to_datetime(str(values[1])) + return f"'{start_time}' <= {safe_column} < '{end_time}'" + except Exception as e: + raise HTTPException( + status_code=400, detail=f"Timestamp value is unparseable as an ISO_LOCAL_DATE_TIME: {str(e)}" + ) + elif original_column == "trustyai.INDEX": + if errors: + raise HTTPException(status_code=400, detail=", ".join(errors)) + min_val, max_val = sorted([int(v) for v in values]) + return f"{min_val} <= {safe_column} < {max_val}" + else: + if not all(isinstance(v, numbers.Number) for v in values): + errors.append( + "BETWEEN operation must only contain numbers, describing the lower and upper bounds of the desired range. Received non-numeric values" + ) + if errors: + raise HTTPException(status_code=400, detail=", ".join(errors)) + min_val, max_val = sorted(values) + return f"{min_val} <= {safe_column} < {max_val}" + + +def _format_value_for_query(value: Any, dtype) -> str: + """Format value appropriately for pandas query syntax.""" + if isinstance(value, str): + escaped = value.replace("'", "\\'") + return f"'{escaped}'" + elif isinstance(value, (int, float)): + return str(value) + else: + escaped = str(value).replace("'", "\\'") + return f"'{escaped}'" \ No newline at end of file diff --git a/src/service/utils/upload.py b/src/service/utils/upload.py new file mode 100644 index 0000000..58f3942 --- /dev/null +++ b/src/service/utils/upload.py @@ -0,0 +1,545 @@ +import logging +import uuid +from datetime import datetime +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional, Tuple + +import numpy as np +from fastapi import HTTPException + +from src.service.constants import ( + INPUT_SUFFIX, + METADATA_SUFFIX, + OUTPUT_SUFFIX, + TRUSTYAI_TAG_PREFIX, +) +from src.service.data.modelmesh_parser import ModelMeshPayloadParser +from src.service.data.storage import get_storage_interface +from src.service.utils import list_utils +from src.endpoints.consumer.consumer_endpoint import process_payload + + +logger = logging.getLogger(__name__) + + +METADATA_STRING_MAX_LENGTH = 100 + + +class KServeDataAdapter: + """ + Convert upload tensors to consumer endpoint format. + """ + + def __init__(self, tensor_dict: Dict[str, Any], numpy_array: np.ndarray): + """Initialize adapter with validated data.""" + self._name = tensor_dict.get("name", "unknown") + self._shape = tensor_dict.get("shape", []) + self._datatype = tensor_dict.get("datatype", "FP64") + self._data = numpy_array # Keep numpy array intact + + @property + def name(self) -> str: + return self._name + + @property + def shape(self) -> List[int]: + return self._shape + + @property + def datatype(self) -> str: + return self._datatype + + @property + def data(self) -> np.ndarray: + """Returns numpy array with .shape attribute as expected by consumer endpoint.""" + return self._data + + +class ConsumerEndpointAdapter: + """ + Consumer endpoint's expected structure. + """ + + def __init__(self, adapted_tensors: List[KServeDataAdapter]): + self.tensors = adapted_tensors + self.id = f"upload_request_{uuid.uuid4().hex[:8]}" + + +async def process_upload_request(payload: Any) -> Dict[str, str]: + """ + Process complete upload request with validation and data handling. + """ + try: + model_name = ModelMeshPayloadParser.standardize_model_id(payload.model_name) + if payload.data_tag: + error = validate_data_tag(payload.data_tag) + if error: + raise HTTPException(400, error) + inputs = payload.request.get("inputs", []) + outputs = payload.response.get("outputs", []) if payload.response else [] + if not inputs: + raise HTTPException(400, "Missing input tensors") + if payload.is_ground_truth and not outputs: + raise HTTPException(400, "Ground truth uploads require output tensors") + + input_arrays, input_names, _, execution_ids = process_tensors_using_kserve_logic(inputs) + if outputs: + output_arrays, output_names, _, _ = process_tensors_using_kserve_logic(outputs) + else: + output_arrays, output_names = [], [] + error = validate_input_shapes(input_arrays, input_names) + if error: + raise HTTPException(400, f"One or more errors in input tensors: {error}") + if payload.is_ground_truth: + return await _process_ground_truth_data( + model_name, input_arrays, input_names, output_arrays, output_names, execution_ids + ) + else: + return await _process_regular_data( + model_name, input_arrays, input_names, output_arrays, output_names, execution_ids, payload.data_tag + ) + except ProcessingError as e: + raise HTTPException(400, str(e)) + except ValidationError as e: + raise HTTPException(400, str(e)) + + +async def _process_ground_truth_data( + model_name: str, + input_arrays: List[np.ndarray], + input_names: List[str], + output_arrays: List[np.ndarray], + output_names: List[str], + execution_ids: Optional[List[str]], +) -> Dict[str, str]: + """Process ground truth data upload.""" + if not execution_ids: + raise HTTPException(400, "Ground truth requires execution IDs") + result = await handle_ground_truths( + model_name, + input_arrays, + input_names, + output_arrays, + output_names, + [sanitize_id(id) for id in execution_ids], + ) + if not result.success: + raise HTTPException(400, result.message) + result_data = result.data + if result_data is None: + raise HTTPException(500, "Ground truth processing failed") + gt_name = f"{model_name}_ground_truth" + storage_interface = get_storage_interface() + await storage_interface.write_data(gt_name + OUTPUT_SUFFIX, result_data["outputs"], result_data["output_names"]) + await storage_interface.write_data( + gt_name + METADATA_SUFFIX, + result_data["metadata"], + result_data["metadata_names"], + ) + logger.info(f"Ground truth data saved for model: {model_name}") + return {"message": result.message} + + +async def _process_regular_data( + model_name: str, + input_arrays: List[np.ndarray], + input_names: List[str], + output_arrays: List[np.ndarray], + output_names: List[str], + execution_ids: Optional[List[str]], + data_tag: Optional[str], +) -> Dict[str, str]: + """Process regular model data upload.""" + n_rows = input_arrays[0].shape[0] + exec_ids = execution_ids or [str(uuid.uuid4()) for _ in range(n_rows)] + input_data = _flatten_tensor_data(input_arrays, n_rows) + output_data = _flatten_tensor_data(output_arrays, n_rows) + metadata, metadata_cols = _create_metadata(exec_ids, model_name, data_tag) + await save_model_data( + model_name, + np.array(input_data), + input_names, + np.array(output_data), + output_names, + metadata, + metadata_cols, + ) + logger.info(f"Regular data saved for model: {model_name}, rows: {n_rows}") + return {"message": f"{n_rows} datapoints added to {model_name}"} + + +def _flatten_tensor_data(arrays: List[np.ndarray], n_rows: int) -> List[List[Any]]: + """ + Flatten tensor arrays into row-based format for storage. + """ + + def flatten_row(arrays: List[np.ndarray], row: int) -> List[Any]: + """Flatten arrays for a single row.""" + return [x for arr in arrays for x in (arr[row].flatten() if arr.ndim > 1 else [arr[row]])] + + return [flatten_row(arrays, i) for i in range(n_rows)] + + +def _create_metadata( + execution_ids: List[str], model_name: str, data_tag: Optional[str] +) -> Tuple[np.ndarray, List[str]]: + """ + Create metadata array for model data storage. + """ + current_timestamp = datetime.now().isoformat() + metadata_cols = ["ID", "MODEL_ID", "TIMESTAMP", "TAG"] + metadata_rows = [ + [ + str(eid), + str(model_name), + str(current_timestamp), + str(data_tag or ""), + ] + for eid in execution_ids + ] + _validate_metadata_lengths(metadata_rows, metadata_cols) + metadata = np.array(metadata_rows, dtype=f" None: + """ + Validate that all metadata values fit within the defined string length limit. + """ + for row_idx, row in enumerate(metadata_rows): + for col_idx, value in enumerate(row): + value_str = str(value) + if len(value_str) > METADATA_STRING_MAX_LENGTH: + col_name = column_names[col_idx] if col_idx < len(column_names) else f"column_{col_idx}" + raise ValidationError( + f"Metadata field '{col_name}' in row {row_idx} exceeds maximum length " + f"of {METADATA_STRING_MAX_LENGTH} characters (got {len(value_str)} chars): " + f"'{value_str[:50]}{'...' if len(value_str) > 50 else ''}'" + ) + + +class ValidationError(Exception): + """Validation errors.""" + + pass + + +class ProcessingError(Exception): + """Processing errors.""" + + pass + + +@dataclass +class GroundTruthValidationResult: + """Result of ground truth validation.""" + + success: bool + message: str + data: Optional[Dict[str, Any]] = None + errors: List[str] = field(default_factory=list) + + +TYPE_MAP = { + np.int64: "Long", + np.int32: "Integer", + np.float32: "Float", + np.float64: "Double", + np.bool_: "Boolean", + int: "Long", + float: "Double", + bool: "Boolean", + str: "String", +} + + +def get_type_name(val: Any) -> str: + """Get Java-style type name for a value (used in ground truth validation).""" + if hasattr(val, "dtype"): + return TYPE_MAP.get(val.dtype.type, "String") + return TYPE_MAP.get(type(val), "String") + + +def sanitize_id(execution_id: str) -> str: + """Sanitize execution ID.""" + return str(execution_id).strip() + + +def extract_row_data(arrays: List[np.ndarray], row_index: int) -> List[Any]: + """Extract data from arrays for a specific row.""" + row_data = [] + for arr in arrays: + if arr.ndim > 1: + row_data.extend(arr[row_index].flatten()) + else: + row_data.append(arr[row_index]) + return row_data + + +def process_tensors_using_kserve_logic( + tensors: List[Dict[str, Any]], +) -> Tuple[List[np.ndarray], List[str], List[str], Optional[List[str]]]: + """ + Process tensor data using consumer endpoint logic via clean adapter pattern. + """ + if not tensors: + return [], [], [], None + validation_errors = _validate_tensor_inputs(tensors) + if validation_errors: + error_message = "One or more errors occurred: " + ". ".join(validation_errors) + raise HTTPException(400, error_message) + adapted_tensors = [] + execution_ids = None + datatypes = [] + for tensor in tensors: + if execution_ids is None: + execution_ids = tensor.get("execution_ids") + numpy_array = _convert_tensor_to_numpy(tensor) + adapter = KServeDataAdapter(tensor, numpy_array) + adapted_tensors.append(adapter) + datatypes.append(adapter.datatype) + try: + adapter_payload = ConsumerEndpointAdapter(adapted_tensors) + tensor_array, column_names = process_payload(adapter_payload, lambda payload: payload.tensors) + arrays, all_names = _convert_consumer_results_to_upload_format(tensor_array, column_names, adapted_tensors) + return arrays, all_names, datatypes, execution_ids + except Exception as e: + logger.error(f"Consumer endpoint processing failed: {e}") + raise HTTPException(400, f"Tensor processing error: {str(e)}") + + +def _validate_tensor_inputs(tensors: List[Dict[str, Any]]) -> List[str]: + """Validate tensor inputs and return list of error messages.""" + errors = [] + tensor_names = [tensor.get("name", f"tensor_{i}") for i, tensor in enumerate(tensors)] + if len(tensor_names) != len(set(tensor_names)): + errors.append("Input tensors must have unique names") + shapes = [tensor.get("shape", []) for tensor in tensors] + if len(shapes) > 1: + first_dims = [shape[0] if shape else 0 for shape in shapes] + if len(set(first_dims)) > 1: + errors.append(f"Input tensors must have consistent first dimension. Found: {first_dims}") + return errors + + +def _convert_tensor_to_numpy(tensor: Dict[str, Any]) -> np.ndarray: + """Convert tensor dictionary to numpy array with proper dtype.""" + raw_data = tensor.get("data", []) + + if list_utils.contains_non_numeric(raw_data): + return np.array(raw_data, dtype="O") + dtype_map = {"INT64": np.int64, "INT32": np.int32, "FP32": np.float32, "FP64": np.float64, "BOOL": np.bool_} + datatype = tensor.get("datatype", "FP64") + np_dtype = dtype_map.get(datatype, np.float64) + return np.array(raw_data, dtype=np_dtype) + + +def _convert_consumer_results_to_upload_format( + tensor_array: np.ndarray, column_names: List[str], adapted_tensors: List[KServeDataAdapter] +) -> Tuple[List[np.ndarray], List[str]]: + """Convert consumer endpoint results back to upload format.""" + if len(adapted_tensors) == 1: + # Single tensor case + return [tensor_array], column_names + arrays = [] + all_names = [] + col_start = 0 + for adapter in adapted_tensors: + if len(adapter.shape) > 1: + n_cols = adapter.shape[1] + tensor_names = [f"{adapter.name}-{i}" for i in range(n_cols)] + else: + n_cols = 1 + tensor_names = [adapter.name] + if tensor_array.ndim == 2: + tensor_data = tensor_array[:, col_start : col_start + n_cols] + else: + tensor_data = tensor_array[col_start : col_start + n_cols] + arrays.append(tensor_data) + all_names.extend(tensor_names) + col_start += n_cols + return arrays, all_names + + +def validate_input_shapes(input_arrays: List[np.ndarray], input_names: List[str]) -> Optional[str]: + """Validate input array shapes and names - collect ALL errors.""" + if not input_arrays: + return None + errors = [] + if len(set(input_names)) != len(input_names): + errors.append("Input tensors must have unique names") + first_dim = input_arrays[0].shape[0] + for i, arr in enumerate(input_arrays[1:], 1): + if arr.shape[0] != first_dim: + errors.append( + f"Input tensor '{input_names[i]}' has first dimension {arr.shape[0]}, " + f"which doesn't match the first dimension {first_dim} of '{input_names[0]}'" + ) + if errors: + return ". ".join(errors) + "." + return None + + +def validate_data_tag(tag: str) -> Optional[str]: + """Validate data tag format and content.""" + if not tag: + return None + if tag.startswith(TRUSTYAI_TAG_PREFIX): + return ( + f"The tag prefix '{TRUSTYAI_TAG_PREFIX}' is reserved for internal TrustyAI use only. " + f"Provided tag '{tag}' violates this restriction." + ) + return None + + +class GroundTruthValidator: + """Ground truth validator.""" + + def __init__(self, model_name: str): + self.model_name = model_name + self.id_to_row: Dict[str, int] = {} + self.inputs: Optional[np.ndarray] = None + self.outputs: Optional[np.ndarray] = None + self.metadata: Optional[np.ndarray] = None + + async def initialize(self) -> None: + """Load existing data.""" + storage_interface = get_storage_interface() + self.inputs, _ = await storage_interface.read_data(self.model_name + INPUT_SUFFIX) + self.outputs, _ = await storage_interface.read_data(self.model_name + OUTPUT_SUFFIX) + self.metadata, _ = await storage_interface.read_data(self.model_name + METADATA_SUFFIX) + metadata_cols = await storage_interface.get_original_column_names(self.model_name + METADATA_SUFFIX) + id_col = next((i for i, name in enumerate(metadata_cols) if name.upper() == "ID"), 0) + if self.metadata is not None: + for j, row in enumerate(self.metadata): + id_val = row[id_col] + self.id_to_row[str(id_val)] = j + + def find_row(self, exec_id: str) -> Optional[int]: + """Find row index for execution ID.""" + return self.id_to_row.get(str(exec_id)) + + async def validate_data( + self, + exec_id: str, + uploaded_inputs: List[Any], + uploaded_outputs: List[Any], + row_idx: int, + input_names: Optional[List[str]] = None, + output_names: Optional[List[str]] = None, + ) -> Optional[str]: + """Validate inputs and outputs.""" + if self.inputs is None or self.outputs is None: + return f"ID={exec_id} no existing data found" + existing_inputs = self.inputs[row_idx] + existing_outputs = self.outputs[row_idx] + for i, (existing, uploaded) in enumerate(zip(existing_inputs[:3], uploaded_inputs[:3])): + if hasattr(existing, "dtype"): + print( + f" Input {i}: existing.dtype={existing.dtype}, uploaded.dtype={getattr(uploaded, 'dtype', 'no dtype')}" + ) + print(f" Input {i}: existing={existing}, uploaded={uploaded}") + for i, (existing, uploaded) in enumerate(zip(existing_outputs[:2], uploaded_outputs[:2])): + if hasattr(existing, "dtype"): + print( + f" Output {i}: existing.dtype={existing.dtype}, uploaded.dtype={getattr(uploaded, 'dtype', 'no dtype')}" + ) + print(f" Output {i}: existing={existing}, uploaded={uploaded}") + if len(existing_inputs) != len(uploaded_inputs): + return f"ID={exec_id} input shapes do not match. Observed inputs have length={len(existing_inputs)} while uploaded inputs have length={len(uploaded_inputs)}" + for i, (existing, uploaded) in enumerate(zip(existing_inputs, uploaded_inputs)): + existing_type = get_type_name(existing) + uploaded_type = get_type_name(uploaded) + print(f" Input {i}: existing_type='{existing_type}', uploaded_type='{uploaded_type}'") + if existing_type != uploaded_type: + return f"ID={exec_id} input type mismatch at position {i + 1}: Class={existing_type} != Class={uploaded_type}" + if existing != uploaded: + return f"ID={exec_id} inputs are not identical: value mismatch at position {i + 1}" + if len(existing_outputs) != len(uploaded_outputs): + return f"ID={exec_id} output shapes do not match. Observed outputs have length={len(existing_outputs)} while uploaded ground-truths have length={len(uploaded_outputs)}" + for i, (existing, uploaded) in enumerate(zip(existing_outputs, uploaded_outputs)): + existing_type = get_type_name(existing) + uploaded_type = get_type_name(uploaded) + print(f" Output {i}: existing_type='{existing_type}', uploaded_type='{uploaded_type}'") + if existing_type != uploaded_type: + return f"ID={exec_id} output type mismatch at position {i + 1}: Class={existing_type} != Class={uploaded_type}" + return None + + +async def handle_ground_truths( + model_name: str, + input_arrays: List[np.ndarray], + input_names: List[str], + output_arrays: List[np.ndarray], + output_names: List[str], + execution_ids: List[str], + config: Optional[Any] = None, +) -> GroundTruthValidationResult: + """Handle ground truth validation.""" + if not execution_ids: + return GroundTruthValidationResult(success=False, message="No execution IDs provided.") + storage_interface = get_storage_interface() + if not await storage_interface.dataset_exists(model_name + INPUT_SUFFIX): + return GroundTruthValidationResult(success=False, message=f"Model {model_name} not found.") + validator = GroundTruthValidator(model_name) + await validator.initialize() + errors = [] + valid_outputs = [] + valid_metadata = [] + n_rows = input_arrays[0].shape[0] if input_arrays else 0 + for i, exec_id in enumerate(execution_ids): + if i >= n_rows: + errors.append(f"ID={exec_id} index out of bounds") + continue + row_idx = validator.find_row(exec_id) + if row_idx is None: + errors.append(f"ID={exec_id} not found") + continue + uploaded_inputs = extract_row_data(input_arrays, i) + uploaded_outputs = extract_row_data(output_arrays, i) + error = await validator.validate_data(exec_id, uploaded_inputs, uploaded_outputs, row_idx) + if error: + errors.append(error) + continue + valid_outputs.append(uploaded_outputs) + valid_metadata.append([exec_id]) + if errors: + return GroundTruthValidationResult( + success=False, + message="Found fatal mismatches between uploaded data and recorded inference data:\n" + + "\n".join(errors[:5]), + errors=errors, + ) + if not valid_outputs: + return GroundTruthValidationResult(success=False, message="No valid ground truths found.") + return GroundTruthValidationResult( + success=True, + message=f"{len(valid_outputs)} ground truths added.", + data={ + "outputs": np.array(valid_outputs), + "output_names": output_names, + "metadata": np.array(valid_metadata), + "metadata_names": ["ID"], + }, + ) + + +async def save_model_data( + model_name: str, + input_data: np.ndarray, + input_names: List[str], + output_data: np.ndarray, + output_names: List[str], + metadata_data: np.ndarray, + metadata_names: List[str], +) -> Dict[str, Any]: + """Save model data to storage.""" + storage_interface = get_storage_interface() + await storage_interface.write_data(model_name + INPUT_SUFFIX, input_data, input_names) + await storage_interface.write_data(model_name + OUTPUT_SUFFIX, output_data, output_names) + await storage_interface.write_data(model_name + METADATA_SUFFIX, metadata_data, metadata_names) + logger.info(f"Saved model data for {model_name}: {len(input_data)} rows") + return { + "model_name": model_name, + "rows": len(input_data), + } \ No newline at end of file diff --git a/tests/endpoints/test_download_endpoint.py b/tests/endpoints/test_download_endpoint.py new file mode 100644 index 0000000..647f694 --- /dev/null +++ b/tests/endpoints/test_download_endpoint.py @@ -0,0 +1,404 @@ +import uuid +from datetime import datetime, timedelta +from io import StringIO +from unittest.mock import patch + +import numpy as np +import pandas as pd +import pytest +from fastapi.testclient import TestClient + +from src.main import app + +client = TestClient(app) + + +class DataframeGenerators: + """Python equivalent of Java DataframeGenerators""" + + @staticmethod + def generate_random_dataframe(observations: int, feature_diversity: int = 100) -> pd.DataFrame: + random = np.random.RandomState(0) + data = { + "age": [], + "gender": [], + "race": [], + "income": [], + "trustyai.ID": [], + "trustyai.MODEL_ID": [], + "trustyai.TIMESTAMP": [], + "trustyai.TAG": [], + "trustyai.INDEX": [], + } + for i in range(observations): + data["age"].append(i % feature_diversity) + data["gender"].append(1 if random.choice([True, False]) else 0) + data["race"].append(1 if random.choice([True, False]) else 0) + data["income"].append(1 if random.choice([True, False]) else 0) + data["trustyai.ID"].append(str(uuid.uuid4())) + data["trustyai.MODEL_ID"].append("example1") + data["trustyai.TIMESTAMP"].append((datetime.now() - timedelta(seconds=i)).isoformat()) + data["trustyai.TAG"].append("") + data["trustyai.INDEX"].append(i) + return pd.DataFrame(data) + + @staticmethod + def generate_random_text_dataframe(observations: int, seed: int = 0) -> pd.DataFrame: + if seed < 0: + random = np.random.RandomState(0) + else: + random = np.random.RandomState(seed) + makes = ["Ford", "Chevy", "Dodge", "GMC", "Buick"] + colors = ["Red", "Blue", "White", "Black", "Purple", "Green", "Yellow"] + data = { + "year": [], + "make": [], + "color": [], + "value": [], + "trustyai.ID": [], + "trustyai.MODEL_ID": [], + "trustyai.TIMESTAMP": [], + "trustyai.TAG": [], + "trustyai.INDEX": [], + } + for i in range(observations): + data["year"].append(1970 + i % 50) + data["make"].append(makes[i % len(makes)]) + data["color"].append(colors[i % len(colors)]) + data["value"].append(random.random() * 50) + data["trustyai.ID"].append(str(uuid.uuid4())) + data["trustyai.MODEL_ID"].append("example1") + data["trustyai.TIMESTAMP"].append((datetime.now() - timedelta(seconds=i)).isoformat()) + data["trustyai.TAG"].append("") + data["trustyai.INDEX"].append(i) + return pd.DataFrame(data) + + +# Mock storage for testing +class MockStorage: + def __init__(self): + self.data = {} + + async def read_data(self, dataset_name: str): + if dataset_name.endswith("_outputs"): + model_id = dataset_name.replace("_outputs", "") + if model_id not in self.data: + raise Exception(f"Model {model_id} not found") + output_data = self.data[model_id].get("output") + output_cols = self.data[model_id].get("output_cols", []) + return output_data, output_cols + elif dataset_name.endswith("_metadata"): + model_id = dataset_name.replace("_metadata", "") + if model_id not in self.data: + raise Exception(f"Model {model_id} not found") + metadata_data = self.data[model_id].get("metadata") + metadata_cols = ["ID", "MODEL_ID", "TIMESTAMP", "TAG", "INDEX"] + return metadata_data, metadata_cols + elif dataset_name.endswith("_inputs"): + model_id = dataset_name.replace("_inputs", "") + if model_id not in self.data: + raise Exception(f"Model {model_id} not found") + input_data = self.data[model_id].get("input") + input_cols = self.data[model_id].get("input_cols", []) + return input_data, input_cols + else: + raise Exception(f"Unknown dataset: {dataset_name}") + + def save_dataframe(self, df: pd.DataFrame, model_id: str): + input_cols = [col for col in df.columns if not col.startswith("trustyai.") and col not in ["income", "value"]] + output_cols = [col for col in df.columns if col in ["income", "value"]] + metadata_cols = [col for col in df.columns if col.startswith("trustyai.")] + input_data = df[input_cols].values if input_cols else np.array([]) + output_data = df[output_cols].values if output_cols else np.array([]) + metadata_data_cols = ["ID", "MODEL_ID", "TIMESTAMP", "TAG", "INDEX"] + metadata_values = [] + for _, row in df.iterrows(): + row_data = [] + for col in metadata_data_cols: + trusty_col = f"trustyai.{col}" + if trusty_col in df.columns: + value = row[trusty_col] + if col == "INDEX": + row_data.append(int(value)) + else: + row_data.append(str(value)) + else: + row_data.append("" if col != "INDEX" else 0) + metadata_values.append(row_data) + metadata_data = np.array(metadata_values, dtype=object) + self.data[model_id] = { + "dataframe": df, + "input": input_data, + "input_cols": input_cols, + "output": output_data, + "output_cols": output_cols, + "metadata": metadata_data, + } + + def reset(self): + self.data.clear() + + +mock_storage = MockStorage() + + +@pytest.fixture(autouse=True) +def setup_storage(): + """Setup mock storage for all tests""" + with patch("src.service.utils.download.get_storage_interface", return_value=mock_storage): + yield + + +@pytest.fixture(autouse=True) +def reset_storage(): + """Reset storage before each test""" + mock_storage.reset() + yield + + +# Test constants +MODEL_ID = "example1" + + +def test_download_data(): + """equivalent of Java downloadData() test""" + dataframe = DataframeGenerators.generate_random_dataframe(1000) + mock_storage.save_dataframe(dataframe, MODEL_ID) + + payload = { + "modelId": MODEL_ID, + "matchAll": [ + {"columnName": "gender", "operation": "EQUALS", "values": [0]}, + {"columnName": "race", "operation": "EQUALS", "values": [0]}, + {"columnName": "income", "operation": "EQUALS", "values": [0]}, + ], + "matchAny": [ + {"columnName": "age", "operation": "BETWEEN", "values": [5, 10]}, + {"columnName": "age", "operation": "BETWEEN", "values": [50, 70]}, + ], + "matchNone": [{"columnName": "age", "operation": "BETWEEN", "values": [55, 65]}], + } + response = client.post("/data/download", json=payload) + assert response.status_code == 200 + result = response.json() + result_df = pd.read_csv(StringIO(result["dataCSV"])) + assert len(result_df[(result_df["age"] > 55) & (result_df["age"] < 65)]) == 0 + assert len(result_df[result_df["gender"] == 1]) == 0 + assert len(result_df[result_df["race"] == 1]) == 0 + assert len(result_df[result_df["income"] == 1]) == 0 + assert len(result_df[(result_df["age"] >= 10) & (result_df["age"] < 50)]) == 0 + assert len(result_df[result_df["age"] > 70]) == 0 + + +def test_download_text_data(): + """equivalent of Java downloadTextData() test""" + dataframe = DataframeGenerators.generate_random_text_dataframe(1000) + mock_storage.save_dataframe(dataframe, MODEL_ID) + + payload = { + "modelId": MODEL_ID, + "matchAll": [ + { + "columnName": "make", + "operation": "EQUALS", + "values": ["Chevy", "Ford", "Dodge"], + }, + { + "columnName": "year", + "operation": "BETWEEN", + "values": [1990, 2050], + }, + ], + } + + response = client.post("/data/download", json=payload) + assert response.status_code == 200 + result = response.json() + result_df = pd.read_csv(StringIO(result["dataCSV"])) + assert len(result_df[result_df["year"] < 1990]) == 0 + assert len(result_df[result_df["make"] == "GMC"]) == 0 + assert len(result_df[result_df["make"] == "Buick"]) == 0 + + +def test_download_text_data_between_error(): + """equivalent of Java downloadTextDataBetweenError() test""" + dataframe = DataframeGenerators.generate_random_text_dataframe(1000) + mock_storage.save_dataframe(dataframe, MODEL_ID) + payload = { + "modelId": MODEL_ID, + "matchAll": [ + { + "columnName": "make", + "operation": "BETWEEN", + "values": ["Chevy", "Ford", "Dodge"], + } + ], + } + response = client.post("/data/download", json=payload) + assert response.status_code == 400 + assert ( + "BETWEEN operation must contain exactly two values, describing the lower and upper bounds of the desired range. Received 3 values" + in response.text + ) + assert ( + "BETWEEN operation must only contain numbers, describing the lower and upper bounds of the desired range. Received non-numeric values" + in response.text + ) + + +def test_download_text_data_invalid_column_error(): + """equivalent of Java downloadTextDataInvalidColumnError() test""" + dataframe = DataframeGenerators.generate_random_text_dataframe(1000) + mock_storage.save_dataframe(dataframe, MODEL_ID) + payload = { + "modelId": MODEL_ID, + "matchAll": [ + { + "columnName": "mak123e", + "operation": "EQUALS", + "values": ["Chevy", "Ford"], + } + ], + } + + response = client.post("/data/download", json=payload) + assert response.status_code == 400 + assert "No feature or output found with name=" in response.text + + +def test_download_text_data_invalid_operation_error(): + """equivalent of Java downloadTextDataInvalidOperationError() test""" + dataframe = DataframeGenerators.generate_random_text_dataframe(1000) + mock_storage.save_dataframe(dataframe, MODEL_ID) + payload = { + "modelId": MODEL_ID, + "matchAll": [ + { + "columnName": "mak123e", + "operation": "DOESNOTEXIST", + "values": ["Chevy", "Ford"], + } + ], + } + response = client.post("/data/download", json=payload) + assert response.status_code == 400 + assert "RowMatch operation must be one of [BETWEEN, EQUALS]" in response.text + + +def test_download_text_data_internal_column(): + """equivalent of Java downloadTextDataInternalColumn() test""" + dataframe = DataframeGenerators.generate_random_text_dataframe(1000) + dataframe.loc[0:499, "trustyai.TAG"] = "TRAINING" + mock_storage.save_dataframe(dataframe, MODEL_ID) + payload = { + "modelId": MODEL_ID, + "matchAll": [ + { + "columnName": "trustyai.TAG", + "operation": "EQUALS", + "values": ["TRAINING"], + } + ], + } + response = client.post("/data/download", json=payload) + assert response.status_code == 200 + result = response.json() + result_df = pd.read_csv(StringIO(result["dataCSV"])) + assert len(result_df) == 500 + + +def test_download_text_data_internal_column_index(): + """equivalent of Java downloadTextDataInternalColumnIndex() test""" + dataframe = DataframeGenerators.generate_random_text_dataframe(1000) + mock_storage.save_dataframe(dataframe, MODEL_ID) + expected_rows = dataframe.iloc[0:10].copy() + payload = { + "modelId": MODEL_ID, + "matchAll": [ + { + "columnName": "trustyai.INDEX", + "operation": "BETWEEN", + "values": [0, 10], + } + ], + } + response = client.post("/data/download", json=payload) + print(f"Response status: {response.status_code}") + print(f"Response text: {response.text}") + assert response.status_code == 200 + result = response.json() + result_df = pd.read_csv(StringIO(result["dataCSV"])) + assert len(result_df) == 10 + input_cols = ["year", "make", "color"] + for i in range(10): + for col in input_cols: + assert result_df.iloc[i][col] == expected_rows.iloc[i][col], f"Row {i}, column {col} mismatch" + + +def test_download_text_data_internal_column_timestamp(): + """equivalent of Java downloadTextDataInternalColumnTimestamp() test""" + dataframe = DataframeGenerators.generate_random_text_dataframe(1, -1) + base_time = datetime.now() + for i in range(100): + new_row = DataframeGenerators.generate_random_text_dataframe(1, i) + # Use milliseconds to simulate Thread.sleep(1) and ensure ascending order + timestamp = (base_time + timedelta(milliseconds=i + 1)).isoformat() + # Fix this line - change to UPPERCASE + new_row["trustyai.TIMESTAMP"] = [timestamp] + dataframe = pd.concat([dataframe, new_row], ignore_index=True) + mock_storage.save_dataframe(dataframe, MODEL_ID) + extract_idx = 50 + n_to_get = 10 + expected_rows = dataframe.iloc[extract_idx : extract_idx + n_to_get].copy() + start_time = dataframe.iloc[extract_idx]["trustyai.TIMESTAMP"] + end_time = dataframe.iloc[extract_idx + n_to_get]["trustyai.TIMESTAMP"] + payload = { + "modelId": MODEL_ID, + "matchAny": [ + { + "columnName": "trustyai.TIMESTAMP", + "operation": "BETWEEN", + "values": [start_time, end_time], + } + ], + } + response = client.post("/data/download", json=payload) + assert response.status_code == 200 + result = response.json() + result_df = pd.read_csv(StringIO(result["dataCSV"])) + assert len(result_df) == 10 + input_cols = ["year", "make", "color"] + for i in range(10): + for col in input_cols: + assert result_df.iloc[i][col] == expected_rows.iloc[i][col], f"Row {i}, column {col} mismatch" + + +def test_download_text_data_internal_column_timestamp_unparseable(): + """equivalent of Java downloadTextDataInternalColumnTimestampUnparseable() test""" + dataframe = DataframeGenerators.generate_random_text_dataframe(1000) + mock_storage.save_dataframe(dataframe, MODEL_ID) + payload = { + "modelId": MODEL_ID, + "matchAny": [ + { + "columnName": "trustyai.TIMESTAMP", + "operation": "BETWEEN", + "values": ["not a timestamp", "also not a timestamp"], + } + ], + } + response = client.post("/data/download", json=payload) + assert response.status_code == 400 + assert "unparseable as an ISO_LOCAL_DATE_TIME" in response.text + + +def test_download_text_data_null_request(): + """equivalent of Java downloadTextDataNullRequest() test""" + dataframe = DataframeGenerators.generate_random_text_dataframe(1000) + mock_storage.save_dataframe(dataframe, MODEL_ID) + payload = {"modelId": MODEL_ID} + response = client.post("/data/download", json=payload) + assert response.status_code == 200 + result = response.json() + result_df = pd.read_csv(StringIO(result["dataCSV"])) + assert len(result_df) == 1000 \ No newline at end of file diff --git a/tests/endpoints/test_upload_endpoint.py b/tests/endpoints/test_upload_endpoint.py new file mode 100644 index 0000000..12ccf09 --- /dev/null +++ b/tests/endpoints/test_upload_endpoint.py @@ -0,0 +1,484 @@ +import copy +import json +import os +import pickle +import shutil +import sys +import tempfile +import uuid + +import h5py +import numpy as np +import pytest + +TEMP_DIR = tempfile.mkdtemp() +os.environ["STORAGE_DATA_FOLDER"] = TEMP_DIR +from fastapi.testclient import TestClient + +from src.main import app +from src.service.constants import ( + INPUT_SUFFIX, + METADATA_SUFFIX, + OUTPUT_SUFFIX, + TRUSTYAI_TAG_PREFIX, +) +from src.service.data.storage import get_storage_interface + + +def pytest_sessionfinish(session, exitstatus): + """Clean up the temporary directory after all tests are done.""" + if os.path.exists(TEMP_DIR): + shutil.rmtree(TEMP_DIR) + + +pytest.hookimpl(pytest_sessionfinish) +client = TestClient(app) +MODEL_ID = "example1" + + +def generate_payload(n_rows, n_input_cols, n_output_cols, datatype, tag, input_offset=0, output_offset=0): + """Generate a test payload with specific dimensions and data types.""" + model_name = f"{MODEL_ID}_{uuid.uuid4().hex[:8]}" + input_data = [] + for i in range(n_rows): + if n_input_cols == 1: + input_data.append(i + input_offset) + else: + row = [i + j + input_offset for j in range(n_input_cols)] + input_data.append(row) + output_data = [] + for i in range(n_rows): + if n_output_cols == 1: + output_data.append(i * 2 + output_offset) + else: + row = [i * 2 + j + output_offset for j in range(n_output_cols)] + output_data.append(row) + payload = { + "model_name": model_name, + "data_tag": tag, + "is_ground_truth": False, + "request": { + "inputs": [ + { + "name": "input", + "shape": [n_rows, n_input_cols] if n_input_cols > 1 else [n_rows], + "datatype": datatype, + "data": input_data, + } + ] + }, + "response": { + "outputs": [ + { + "name": "output", + "shape": [n_rows, n_output_cols] if n_output_cols > 1 else [n_rows], + "datatype": datatype, + "data": output_data, + } + ] + }, + } + return payload + + +def generate_multi_input_payload(n_rows, n_input_cols, n_output_cols, datatype, tag): + """Generate a test payload with multi-dimensional tensors like real data.""" + model_name = f"{MODEL_ID}_{uuid.uuid4().hex[:8]}" + input_data = [] + for row_idx in range(n_rows): + row = [row_idx + col_idx * 10 for col_idx in range(n_input_cols)] + input_data.append(row) + output_data = [] + for row_idx in range(n_rows): + row = [row_idx * 2 + col_idx for col_idx in range(n_output_cols)] + output_data.append(row) + payload = { + "model_name": model_name, + "data_tag": tag, + "is_ground_truth": False, + "request": { + "inputs": [ + { + "name": "multi_input", + "shape": [n_rows, n_input_cols], + "datatype": datatype, + "data": input_data, + } + ] + }, + "response": { + "outputs": [ + { + "name": "multi_output", + "shape": [n_rows, n_output_cols], + "datatype": datatype, + "data": output_data, + } + ] + }, + } + return payload + + +def generate_mismatched_shape_no_unique_name_multi_input_payload(n_rows, n_input_cols, n_output_cols, datatype, tag): + """Generate a payload with mismatched shapes and non-unique names.""" + model_name = f"{MODEL_ID}_{uuid.uuid4().hex[:8]}" + input_data_1 = [[row_idx + col_idx * 10 for col_idx in range(n_input_cols)] for row_idx in range(n_rows)] + mismatched_rows = n_rows - 1 if n_rows > 1 else 1 + input_data_2 = [[row_idx + col_idx * 20 for col_idx in range(n_input_cols)] for row_idx in range(mismatched_rows)] + output_data = [[row_idx * 2 + col_idx for col_idx in range(n_output_cols)] for row_idx in range(n_rows)] + payload = { + "model_name": model_name, + "data_tag": tag, + "is_ground_truth": False, + "request": { + "inputs": [ + { + "name": "same_name", + "shape": [n_rows, n_input_cols], + "datatype": datatype, + "data": input_data_1, + }, + { + "name": "same_name", + "shape": [mismatched_rows, n_input_cols], + "datatype": datatype, + "data": input_data_2, + }, + ] + }, + "response": { + "outputs": [ + { + "name": "multi_output", + "shape": [n_rows, n_output_cols], + "datatype": datatype, + "data": output_data, + } + ] + }, + } + return payload + + +def get_data_from_storage(model_name, suffix): + """Get data from storage file.""" + storage = get_storage_interface() + filename = storage._get_filename(model_name + suffix) + if not os.path.exists(filename): + return None + with h5py.File(filename, "r") as f: + if model_name + suffix in f: + data = f[model_name + suffix][:] + column_names = f[model_name + suffix].attrs.get("column_names", []) + return {"data": data, "column_names": column_names} + + +def get_metadata_ids(model_name): + """Extract actual IDs from metadata storage.""" + storage = get_storage_interface() + filename = storage._get_filename(model_name + METADATA_SUFFIX) + if not os.path.exists(filename): + return [] + ids = [] + with h5py.File(filename, "r") as f: + if model_name + METADATA_SUFFIX in f: + metadata = f[model_name + METADATA_SUFFIX][:] + column_names = f[model_name + METADATA_SUFFIX].attrs.get("column_names", []) + id_idx = next((i for i, name in enumerate(column_names) if name.lower() == "id"), None) + if id_idx is not None: + for row in metadata: + try: + if hasattr(row, "__getitem__") and len(row) > id_idx: + id_val = row[id_idx] + else: + row_data = pickle.loads(row.tobytes()) + id_val = row_data[id_idx] + if isinstance(id_val, np.ndarray): + ids.append(str(id_val)) + else: + ids.append(str(id_val)) + except Exception as e: + print(f"Error processing ID from row {len(ids)}: {e}") + continue + print(f"Successfully extracted {len(ids)} IDs: {ids}") + return ids + + +def get_metadata_from_storage(model_name): + """Get metadata directly from storage file.""" + storage = get_storage_interface() + filename = storage._get_filename(model_name + METADATA_SUFFIX) + if not os.path.exists(filename): + return {"data": [], "column_names": []} + with h5py.File(filename, "r") as f: + if model_name + METADATA_SUFFIX in f: + metadata = f[model_name + METADATA_SUFFIX][:] + column_names = f[model_name + METADATA_SUFFIX].attrs.get("column_names", []) + parsed_rows = [] + for row in metadata: + try: + row_data = pickle.loads(row.tobytes()) + parsed_rows.append(row_data) + except Exception as e: + print(f"Error unpickling metadata row: {e}") + + return {"data": parsed_rows, "column_names": column_names} + return {"data": [], "column_names": []} + + +def count_rows_with_tag(model_name, tag): + """Count rows with a specific tag in metadata.""" + storage = get_storage_interface() + filename = storage._get_filename(model_name + METADATA_SUFFIX) + if not os.path.exists(filename): + return 0 + count = 0 + with h5py.File(filename, "r") as f: + if model_name + METADATA_SUFFIX in f: + metadata = f[model_name + METADATA_SUFFIX][:] + column_names = f[model_name + METADATA_SUFFIX].attrs.get("column_names", []) + tag_idx = next( + (i for i, name in enumerate(column_names) if name.lower() == "tag"), + None, + ) + if tag_idx is not None: + for row in metadata: + try: + row_data = pickle.loads(row.tobytes()) + if tag_idx < len(row_data) and row_data[tag_idx] == tag: + count += 1 + except Exception as e: + print(f"Error unpickling tag: {e}") + return count + + +def post_test(payload, expected_status_code, check_msgs): + """Post a payload and check the response.""" + response = client.post("/data/upload", json=payload) + if response.status_code != expected_status_code: + print(f"\n=== DEBUG INFO ===") + print(f"Expected status: {expected_status_code}") + print(f"Actual status: {response.status_code}") + print(f"Response text: {response.text}") + print(f"Response headers: {dict(response.headers)}") + if hasattr(response, "json"): + try: + print(f"Response JSON: {response.json()}") + except: + pass + print(f"==================") + + assert response.status_code == expected_status_code + return response + + +# data upload tests +@pytest.mark.parametrize("n_input_rows", [1, 5, 250]) +@pytest.mark.parametrize("n_input_cols", [1, 4]) +@pytest.mark.parametrize("n_output_cols", [1, 2]) +@pytest.mark.parametrize("datatype", ["INT64", "INT32", "FP32", "FP64", "BOOL"]) +def test_upload_data(n_input_rows, n_input_cols, n_output_cols, datatype): + """Test uploading data with various dimensions and datatypes.""" + data_tag = "TRAINING" + payload = generate_payload(n_input_rows, n_input_cols, n_output_cols, datatype, data_tag) + response = post_test(payload, 200, [f"{n_input_rows} datapoints"]) + inputs = get_data_from_storage(payload["model_name"], INPUT_SUFFIX) + outputs = get_data_from_storage(payload["model_name"], OUTPUT_SUFFIX) + assert inputs is not None, "Input data not found in storage" + assert outputs is not None, "Output data not found in storage" + assert len(inputs["data"]) == n_input_rows, "Incorrect number of input rows" + assert len(outputs["data"]) == n_input_rows, "Incorrect number of output rows" + tag_count = count_rows_with_tag(payload["model_name"], data_tag) + assert tag_count == n_input_rows, "Not all rows have the correct tag" + + +@pytest.mark.parametrize("n_rows", [1, 3, 5, 250]) +@pytest.mark.parametrize("n_input_cols", [2, 6]) +@pytest.mark.parametrize("n_output_cols", [4]) +@pytest.mark.parametrize("datatype", ["INT64", "INT32", "FP32", "FP64", "BOOL"]) +def test_upload_multi_input_data(n_rows, n_input_cols, n_output_cols, datatype): + """Test uploading data with multiple input tensors.""" + data_tag = "TRAINING" + payload = generate_multi_input_payload(n_rows, n_input_cols, n_output_cols, datatype, data_tag) + response = post_test(payload, 200, [f"{n_rows} datapoints"]) + inputs = get_data_from_storage(payload["model_name"], INPUT_SUFFIX) + outputs = get_data_from_storage(payload["model_name"], OUTPUT_SUFFIX) + assert inputs is not None, "Input data not found in storage" + assert outputs is not None, "Output data not found in storage" + assert len(inputs["data"]) == n_rows, "Incorrect number of input rows" + assert len(outputs["data"]) == n_rows, "Incorrect number of output rows" + assert len(inputs["column_names"]) == n_input_cols, "Incorrect number of input columns" + assert len(outputs["column_names"]) == n_output_cols, "Incorrect number of output columns" + assert len(inputs["column_names"]) >= 2, "Should have at least 2 input column names" + tag_count = count_rows_with_tag(payload["model_name"], data_tag) + assert tag_count == n_rows, "Not all rows have the correct tag" + + +def test_upload_multi_input_data_no_unique_name(): + """Test error case for non-unique tensor names.""" + payload = generate_mismatched_shape_no_unique_name_multi_input_payload(250, 4, 3, "FP64", "TRAINING") + response = client.post("/data/upload", json=payload) + assert response.status_code == 400 + assert "One or more errors" in response.text + assert "unique names" in response.text + assert "first dimension" in response.text + + +def test_upload_multiple_tagging(): + """Test uploading data with multiple tags.""" + n_payload1 = 50 + n_payload2 = 51 + tag1 = "TRAINING" + tag2 = "NOT TRAINING" + model_name = f"{MODEL_ID}_{uuid.uuid4().hex[:8]}" + payload1 = generate_payload(n_payload1, 10, 1, "INT64", tag1) + payload1["model_name"] = model_name + post_test(payload1, 200, [f"{n_payload1} datapoints"]) + payload2 = generate_payload(n_payload2, 10, 1, "INT64", tag2) + payload2["model_name"] = model_name + post_test(payload2, 200, [f"{n_payload2} datapoints"]) + tag1_count = count_rows_with_tag(model_name, tag1) + tag2_count = count_rows_with_tag(model_name, tag2) + assert tag1_count == n_payload1, f"Expected {n_payload1} rows with tag {tag1}" + assert tag2_count == n_payload2, f"Expected {n_payload2} rows with tag {tag2}" + inputs = get_data_from_storage(model_name, INPUT_SUFFIX) + assert len(inputs["data"]) == n_payload1 + n_payload2, "Incorrect total number of rows" + + +def test_upload_tag_that_uses_protected_name(): + """Test error when using a protected tag name.""" + invalid_tag = f"{TRUSTYAI_TAG_PREFIX}_something" + payload = generate_payload(5, 10, 1, "INT64", invalid_tag) + response = post_test(payload, 400, ["reserved for internal TrustyAI use only"]) + expected_msg = f"The tag prefix '{TRUSTYAI_TAG_PREFIX}' is reserved for internal TrustyAI use only. Provided tag '{invalid_tag}' violates this restriction." + assert expected_msg in response.text + + +@pytest.mark.parametrize("n_input_rows", [1, 5, 250]) +@pytest.mark.parametrize("n_input_cols", [1, 4]) +@pytest.mark.parametrize("n_output_cols", [1, 2]) +@pytest.mark.parametrize("datatype", ["INT64", "INT32", "FP32", "FP64", "BOOL"]) +def test_upload_data_and_ground_truth(n_input_rows, n_input_cols, n_output_cols, datatype): + """Test uploading model data and corresponding ground truth data.""" + model_name = f"{MODEL_ID}_{uuid.uuid4().hex[:8]}" + payload = generate_payload(n_input_rows, n_input_cols, n_output_cols, datatype, "TRAINING") + payload["model_name"] = model_name + payload["is_ground_truth"] = False + post_test(payload, 200, [f"{n_input_rows} datapoints"]) + ids = get_metadata_ids(model_name) + payload_gt = generate_payload(n_input_rows, n_input_cols, n_output_cols, datatype, "TRAINING", 0, 1) + payload_gt["model_name"] = model_name + payload_gt["is_ground_truth"] = True + payload_gt["request"] = payload["request"] + payload_gt["request"]["inputs"][0]["execution_ids"] = ids + post_test(payload_gt, 200, [f"{n_input_rows} ground truths"]) + original_data = get_data_from_storage(model_name, OUTPUT_SUFFIX) + gt_data = get_data_from_storage(f"{model_name}_ground_truth", OUTPUT_SUFFIX) + assert len(original_data["data"]) == len(gt_data["data"]), "Row dimensions don't match" + assert len(original_data["column_names"]) == len(gt_data["column_names"]), "Column dimensions don't match" + original_ids = get_metadata_ids(model_name) + gt_ids = get_metadata_ids(f"{model_name}_ground_truth") + assert original_ids == gt_ids, "Ground truth IDs don't match original IDs" + + +def test_upload_mismatch_input_values(): + """Test error when ground truth inputs don't match original data.""" + n_input_rows = 5 + model_name = f"{MODEL_ID}_{uuid.uuid4().hex[:8]}" + payload0 = generate_payload(n_input_rows, 10, 1, "INT64", "TRAINING") + payload0["model_name"] = model_name + post_test(payload0, 200, [f"{n_input_rows} datapoints"]) + ids = get_metadata_ids(model_name) + payload1 = generate_payload(n_input_rows, 10, 1, "INT64", "TRAINING", 1, 0) + payload1["model_name"] = model_name + payload1["is_ground_truth"] = True + payload1["request"]["inputs"][0]["execution_ids"] = ids + response = client.post("/data/upload", json=payload1) + assert response.status_code == 400 + assert "Found fatal mismatches" in response.text or "inputs are not identical" in response.text + + +def test_upload_mismatch_input_lengths(): + """Test error when ground truth has different input lengths.""" + n_input_rows = 5 + model_name = f"{MODEL_ID}_{uuid.uuid4().hex[:8]}" + payload0 = generate_payload(n_input_rows, 10, 1, "INT64", "TRAINING") + payload0["model_name"] = model_name + post_test(payload0, 200, [f"{n_input_rows} datapoints"]) + ids = get_metadata_ids(model_name) + payload1 = generate_payload(n_input_rows, 11, 1, "INT64", "TRAINING") + payload1["model_name"] = model_name + payload1["is_ground_truth"] = True + payload1["request"]["inputs"][0]["execution_ids"] = ids + response = client.post("/data/upload", json=payload1) + assert response.status_code == 400 + assert "Found fatal mismatches" in response.text + assert ( + "input shapes do not match. Observed inputs have length=10 while uploaded inputs have length=11" + in response.text + ) + + +def test_upload_mismatch_input_and_output_types(): + """Test error when ground truth has different data types.""" + n_input_rows = 5 + model_name = f"{MODEL_ID}_{uuid.uuid4().hex[:8]}" + payload0 = generate_payload(n_input_rows, 10, 2, "INT64", "TRAINING") + payload0["model_name"] = model_name + post_test(payload0, 200, [f"{n_input_rows} datapoints"]) + ids = get_metadata_ids(model_name) + payload1 = generate_payload(n_input_rows, 10, 2, "FP32", "TRAINING", 0, 1) + payload1["model_name"] = model_name + payload1["is_ground_truth"] = True + payload1["request"]["inputs"][0]["execution_ids"] = ids + response = client.post("/data/upload", json=payload1) + print(f"Response status: {response.status_code}") + print(f"Response text: {response.text}") + assert response.status_code == 400 + assert "Found fatal mismatches" in response.text + assert "Class=Long != Class=Float" in response.text or "inputs are not identical" in response.text + + +def test_upload_gaussian_data(): + """Test uploading realistic Gaussian data.""" + payload = { + "model_name": "gaussian-credit-model", + "data_tag": "TRAINING", + "request": { + "inputs": [ + { + "name": "credit_inputs", + "shape": [2, 4], + "datatype": "FP64", + "data": [ + [ + 47.45380690750797, + 478.6846214843319, + 13.462184703540503, + 20.764525303373535, + ], + [ + 47.468246185717554, + 575.6911203538863, + 10.844143722475575, + 14.81343667761101, + ], + ], + } + ] + }, + "response": { + "model_name": "gaussian-credit-model__isvc-d79a7d395d", + "model_version": "1", + "outputs": [ + { + "name": "predict", + "datatype": "FP32", + "shape": [2, 1], + "data": [0.19013395683309373, 0.2754730253205645], + } + ], + }, + } + post_test(payload, 200, ["2 datapoints"]) \ No newline at end of file From cb24c43dee869d11074762a8c0887b7c04b39b22 Mon Sep 17 00:00:00 2001 From: Rob Geada Date: Mon, 23 Jun 2025 09:50:43 +0100 Subject: [PATCH 02/10] Refactor endpoints based on storage interface, add maria testing --- src/endpoints/consumer/__init__.py | 74 +++ src/endpoints/consumer/consumer_endpoint.py | 261 +++------ src/endpoints/data/data_download.py | 32 - src/endpoints/data/data_upload.py | 72 ++- src/main.py | 5 +- src/service/constants.py | 3 +- src/service/data/model_data.py | 23 +- src/service/data/storage/__init__.py | 14 +- .../data/storage/maria/legacy_maria_reader.py | 12 +- src/service/data/storage/maria/maria.py | 114 ++-- src/service/data/storage/maria/utils.py | 7 +- src/service/data/storage/pvc.py | 157 ++--- src/service/data/storage/storage_interface.py | 61 +- src/service/utils/download.py | 310 ---------- src/service/utils/list_utils.py | 4 +- src/service/utils/upload.py | 545 ------------------ tests/endpoints/test_download_endpoint.py | 404 ------------- tests/endpoints/test_upload_endpoint.py | 484 ---------------- tests/endpoints/test_upload_endpoint_maria.py | 59 ++ tests/endpoints/test_upload_endpoint_pvc.py | 364 ++++++++++++ tests/service/data/test_mariadb_migration.py | 10 +- tests/service/data/test_mariadb_storage.py | 88 +-- .../data/test_payload_reconciliation_maria.py | 181 +----- .../data/test_payload_reconciliation_pvc.py | 72 ++- 24 files changed, 906 insertions(+), 2450 deletions(-) create mode 100644 src/endpoints/consumer/__init__.py delete mode 100644 src/endpoints/data/data_download.py delete mode 100644 src/service/utils/download.py delete mode 100644 src/service/utils/upload.py delete mode 100644 tests/endpoints/test_download_endpoint.py delete mode 100644 tests/endpoints/test_upload_endpoint.py create mode 100644 tests/endpoints/test_upload_endpoint_maria.py create mode 100644 tests/endpoints/test_upload_endpoint_pvc.py diff --git a/src/endpoints/consumer/__init__.py b/src/endpoints/consumer/__init__.py new file mode 100644 index 0000000..15ad2a9 --- /dev/null +++ b/src/endpoints/consumer/__init__.py @@ -0,0 +1,74 @@ +from typing import Optional, Dict, List, Literal + +from pydantic import BaseModel + + +PartialKind = Literal["request", "response"] + +class PartialPayloadId(BaseModel): + prediction_id: Optional[str] = None + kind: Optional[PartialKind] = None + + def get_prediction_id(self) -> str: + return self.prediction_id + + def set_prediction_id(self, id: str): + self.prediction_id = id + + def get_kind(self) -> PartialKind: + return self.kind + + def set_kind(self, kind: PartialKind): + self.kind = kind + + +class InferencePartialPayload(BaseModel): + partialPayloadId: Optional[PartialPayloadId] = None + metadata: Optional[Dict[str, str]] = {} + data: Optional[str] = None + modelid: Optional[str] = None + + def get_id(self) -> str: + return self.partialPayloadId.prediction_id if self.partialPayloadId else None + + def set_id(self, id: str): + if not self.partialPayloadId: + self.partialPayloadId = PartialPayloadId() + self.partialPayloadId.prediction_id = id + + def get_kind(self) -> PartialKind: + return self.partialPayloadId.kind if self.partialPayloadId else None + + def set_kind(self, kind: PartialKind): + if not self.partialPayloadId: + self.partialPayloadId = PartialPayloadId() + self.partialPayloadId.kind = kind + + def get_model_id(self) -> str: + return self.modelid + + def set_model_id(self, model_id: str): + self.modelid = model_id + + +class KServeData(BaseModel): + name: str + shape: List[int] + datatype: str + parameters: Optional[Dict[str, str]] = None + data: List + + +class KServeInferenceRequest(BaseModel): + id: Optional[str] = None + parameters: Optional[Dict[str, str]] = None + inputs: List[KServeData] + outputs: Optional[List[KServeData]] = None + + +class KServeInferenceResponse(BaseModel): + model_name: str = None + model_version: Optional[str] = None + id: Optional[str] = None + parameters: Optional[Dict[str, str]] = None + outputs: List[KServeData] diff --git a/src/endpoints/consumer/consumer_endpoint.py b/src/endpoints/consumer/consumer_endpoint.py index 1f17647..d1c45d1 100644 --- a/src/endpoints/consumer/consumer_endpoint.py +++ b/src/endpoints/consumer/consumer_endpoint.py @@ -5,13 +5,13 @@ import numpy as np from fastapi import APIRouter, HTTPException, Header -from pydantic import BaseModel -from typing import Dict, Optional, Literal, List, Union, Callable, Annotated +from typing import Literal, Union, Callable, Annotated import logging +from src.endpoints.consumer import InferencePartialPayload, KServeData, KServeInferenceRequest, KServeInferenceResponse # Import local dependencies from src.service.data.model_data import ModelData -from src.service.data.storage import get_storage_interface +from src.service.data.storage import get_global_storage_interface from src.service.utils import list_utils from src.service.data.modelmesh_parser import ModelMeshPayloadParser, PartialPayload @@ -26,81 +26,10 @@ router = APIRouter() logger = logging.getLogger(__name__) -PartialKind = Literal["request", "response"] -storage_interface = get_storage_interface() unreconciled_inputs = {} unreconciled_outputs = {} -class PartialPayloadId(BaseModel): - prediction_id: Optional[str] = None - kind: Optional[PartialKind] = None - - def get_prediction_id(self) -> str: - return self.prediction_id - - def set_prediction_id(self, id: str): - self.prediction_id = id - - def get_kind(self) -> PartialKind: - return self.kind - - def set_kind(self, kind: PartialKind): - self.kind = kind - - -class InferencePartialPayload(BaseModel): - partialPayloadId: Optional[PartialPayloadId] = None - metadata: Optional[Dict[str, str]] = {} - data: Optional[str] = None - modelid: Optional[str] = None - - def get_id(self) -> str: - return self.partialPayloadId.prediction_id if self.partialPayloadId else None - - def set_id(self, id: str): - if not self.partialPayloadId: - self.partialPayloadId = PartialPayloadId() - self.partialPayloadId.prediction_id = id - - def get_kind(self) -> PartialKind: - return self.partialPayloadId.kind if self.partialPayloadId else None - - def set_kind(self, kind: PartialKind): - if not self.partialPayloadId: - self.partialPayloadId = PartialPayloadId() - self.partialPayloadId.kind = kind - - def get_model_id(self) -> str: - return self.modelid - - def set_model_id(self, model_id: str): - self.modelid = model_id - - -class KServeData(BaseModel): - name: str - shape: List[int] - datatype: str - parameters: Optional[Dict[str, str]] = None - data: List - - -class KServeInferenceRequest(BaseModel): - id: Optional[str] = None - parameters: Optional[Dict[str, str]] = None - inputs: List[KServeData] - outputs: Optional[List[KServeData]] = None - - -class KServeInferenceResponse(BaseModel): - model_name: str - model_version: Optional[str] = None - id: Optional[str] = None - parameters: Optional[Dict[str, str]] = None - outputs: List[KServeData] - - @router.post("/consumer/kserve/v2") async def consume_inference_payload( payload: InferencePartialPayload, @@ -118,6 +47,8 @@ async def consume_inference_payload( Returns: A JSON response indicating success or failure """ + storage_interface = get_global_storage_interface() + try: if not payload.modelid: raise HTTPException( @@ -144,7 +75,6 @@ async def consume_inference_payload( ) partial_payload = PartialPayload(data=payload.data) - if payload_kind == "request": logger.info( f"Received partial input payload from model={model_id}, id={payload_id}" @@ -160,12 +90,12 @@ async def consume_inference_payload( ) from e # Store the input payload - await storage_interface.persist_modelmesh_payload( - partial_payload, payload_id, is_input + await storage_interface.persist_partial_payload( + partial_payload, payload_id=payload_id, is_input=is_input ) - output_payload = await storage_interface.get_modelmesh_payload( - payload_id, False + output_payload = await storage_interface.get_partial_payload( + payload_id, is_input=False, is_modelmesh=True ) if output_payload: @@ -188,12 +118,12 @@ async def consume_inference_payload( ) from e # Store the output payload - await storage_interface.persist_modelmesh_payload( - partial_payload, payload_id, is_input + await storage_interface.persist_partial_payload( + payload=partial_payload, payload_id=payload_id, is_input=is_input ) - input_payload = await storage_interface.get_modelmesh_payload( - payload_id, True + input_payload = await storage_interface.get_partial_payload( + payload_id, is_input=True, is_modelmesh=True ) if input_payload: @@ -221,12 +151,55 @@ async def consume_inference_payload( ) from e +async def write_reconciled_data( + input_array, input_names, + output_array, output_names, + model_id, tags, id_): + storage_interface = get_global_storage_interface() + + iso_time = datetime.isoformat(datetime.utcnow()) + unix_timestamp = time.time() + metadata = np.array( + [[None, iso_time, unix_timestamp, tags]] * len(input_array), dtype="O" + ) + metadata[:, 0] = [f"{id_}_{i}" for i in range(len(input_array))] + metadata_names = ["id", "iso_time", "unix_timestamp", "tags"] + + input_dataset = model_id + INPUT_SUFFIX + output_dataset = model_id + OUTPUT_SUFFIX + metadata_dataset = model_id + METADATA_SUFFIX + + await asyncio.gather( + storage_interface.write_data(input_dataset, input_array, input_names), + storage_interface.write_data(output_dataset, output_array, output_names), + storage_interface.write_data(metadata_dataset, metadata, metadata_names), + ) + + shapes = await ModelData(model_id).shapes() + logger.info( + f"Successfully reconciled inference {id_}, " + f"consisting of {len(input_array):,} rows from {model_id}." + ) + logger.debug( + f"Current storage shapes for {model_id}: " + f"Inputs={shapes[0]}, " + f"Outputs={shapes[1]}, " + f"Metadata={shapes[2]}" + ) + + # Clean up + await storage_interface.delete_partial_payload(id_, True) + await storage_interface.delete_partial_payload(id_, False) + + async def reconcile_modelmesh_payloads( input_payload: PartialPayload, output_payload: PartialPayload, request_id: str, model_id: str, ): + storage_interface = get_global_storage_interface() + """Reconcile the input and output ModelMesh payloads into dataset entries.""" df = ModelMeshPayloadParser.payloads_to_dataframe( input_payload, output_payload, request_id, model_id @@ -241,46 +214,41 @@ async def reconcile_modelmesh_payloads( # Create metadata array tags = [SYNTHETIC_TAG] if any(df["synthetic"]) else [UNLABELED_TAG] - iso_time = datetime.isoformat(datetime.utcnow()) - unix_timestamp = time.time() - metadata = np.array([[iso_time, unix_timestamp, tags]] * len(df), dtype="O") - # Get dataset names - input_dataset = model_id + INPUT_SUFFIX - output_dataset = model_id + OUTPUT_SUFFIX - metadata_dataset = model_id + METADATA_SUFFIX + await write_reconciled_data( + df[input_cols].values, input_cols, + df[output_cols].values, output_cols, + model_id=model_id, tags=tags, id_=request_id + ) - metadata_cols = ["iso_time", "unix_timestamp", "tags"] - await asyncio.gather( - storage_interface.write_data(input_dataset, df[input_cols].values, input_cols), - storage_interface.write_data( - output_dataset, df[output_cols].values, output_cols - ), - storage_interface.write_data(metadata_dataset, metadata, metadata_cols), - ) - shapes = await ModelData(model_id).shapes() - logger.info( - f"Successfully reconciled ModelMesh inference {request_id}, " - f"consisting of {len(df):,} rows from {model_id}." - ) - logger.debug( - f"Current storage shapes for {model_id}: " - f"Inputs={shapes[0]}, " - f"Outputs={shapes[1]}, " - f"Metadata={shapes[2]}" +async def reconcile_kserve( + input_payload: KServeInferenceRequest, output_payload: KServeInferenceResponse, tag: str): + input_array, input_names = process_payload(input_payload, lambda p: p.inputs) + output_array, output_names = process_payload( + output_payload, lambda p: p.outputs, input_array.shape[0] ) - # Clean up - await storage_interface.delete_modelmesh_payload(request_id, True) - await storage_interface.delete_modelmesh_payload(request_id, False) + if tag is not None: + tags = [tag] + elif (input_payload.parameters is not None and + input_payload.parameters.get(BIAS_IGNORE_PARAM, "false") == "true"): + tags = [SYNTHETIC_TAG] + else: + tags = [UNLABELED_TAG] + + await write_reconciled_data( + input_array, input_names, + output_array, output_names, + model_id=output_payload.model_name, tags=tags, id_=input_payload.id + ) def reconcile_mismatching_shape_error(shape_tuples, payload_type, payload_id): msg = ( - f"Could not reconcile KServe Inference {payload_id}, because {payload_type} shapes were mismatched. " - f"When using multiple {payload_type}s to describe data columns, all shapes must match." + f"Could not reconcile_kserve KServe Inference {payload_id}, because {payload_type} shapes were mismatched. " + f"When using multiple {payload_type}s to describe data columns, all shapes must match. " f"However, the following tensor shapes were found:" ) for i, (name, shape) in enumerate(shape_tuples): @@ -291,7 +259,7 @@ def reconcile_mismatching_shape_error(shape_tuples, payload_type, payload_id): def reconcile_mismatching_row_count_error(payload_id, input_shape, output_shape): msg = ( - f"Could not reconcile KServe Inference {payload_id}, because the number of " + f"Could not reconcile_kserve KServe Inference {payload_id}, because the number of " f"output rows ({output_shape}) did not match the number of input rows " f"({input_shape})." ) @@ -309,9 +277,9 @@ def process_payload(payload, get_data: Callable, enforced_first_shape: int = Non column_names = [] for kserve_data in get_data(payload): data.append(kserve_data.data) - shapes.add(tuple(kserve_data.data.shape)) + shapes.add(tuple(kserve_data.shape)) column_names.append(kserve_data.name) - shape_tuples.append((kserve_data.data.name, kserve_data.data.shape)) + shape_tuples.append((kserve_data.name, kserve_data.shape)) if len(shapes) == 1: row_count = list(shapes)[0][0] if enforced_first_shape is not None and row_count != enforced_first_shape: @@ -350,59 +318,18 @@ def process_payload(payload, get_data: Callable, enforced_first_shape: int = Non return np.array(kserve_data.data), column_names -async def reconcile( - input_payload: KServeInferenceRequest, output_payload: KServeInferenceResponse -): - input_array, input_names = process_payload(input_payload, lambda p: p.inputs) - output_array, output_names = process_payload( - output_payload, lambda p: p.outputs, input_array.shape[0] - ) - - metadata_names = ["iso_time", "unix_timestamp", "tags"] - if ( - input_payload.parameters is not None - and input_payload.parameters.get(BIAS_IGNORE_PARAM, "false") == "true" - ): - tags = [SYNTHETIC_TAG] - else: - tags = [UNLABELED_TAG] - iso_time = datetime.isoformat(datetime.utcnow()) - unix_timestamp = time.time() - metadata = np.array( - [[iso_time, unix_timestamp, tags]] * len(input_array), dtype="O" - ) - - input_dataset = output_payload.model_name + INPUT_SUFFIX - output_dataset = output_payload.model_name + OUTPUT_SUFFIX - metadata_dataset = output_payload.model_name + METADATA_SUFFIX - - await asyncio.gather( - storage_interface.write_data(input_dataset, input_array, input_names), - storage_interface.write_data(output_dataset, output_array, output_names), - storage_interface.write_data(metadata_dataset, metadata, metadata_names), - ) - - shapes = await ModelData(output_payload.model_name).shapes() - logger.info( - f"Successfully reconciled KServe inference {input_payload.id}, " - f"consisting of {input_array.shape[0]:,} rows from {output_payload.model_name}." - ) - logger.debug( - f"Current storage shapes for {output_payload.model_name}: " - f"Inputs={shapes[0]}, " - f"Outputs={shapes[1]}, " - f"Metadata={shapes[2]}" - ) - - @router.post("/") async def consume_cloud_event( payload: Union[KServeInferenceRequest, KServeInferenceResponse], ce_id: Annotated[str | None, Header()] = None, + tag: str = None ): # set payload if from cloud event header payload.id = ce_id + # get global storage interface + storage_interface = get_global_storage_interface() + if isinstance(payload, KServeInferenceRequest): if len(payload.inputs) == 0: msg = f"KServe Inference Input {payload.id} received, but data field was empty. Payload will not be saved." @@ -412,12 +339,12 @@ async def consume_cloud_event( logger.info(f"KServe Inference Input {payload.id} received.") # if a match is found, the payload is auto-deleted from data partial_output = await storage_interface.get_partial_payload( - payload.id, is_input=False + payload.id, is_input=False, is_modelmesh=False ) if partial_output is not None: - await reconcile(payload, partial_output) + await reconcile_kserve(payload, partial_output, tag) else: - await storage_interface.persist_partial_payload(payload, is_input=True) + await storage_interface.persist_partial_payload(payload, payload_id=payload.id, is_input=True) return { "status": "success", "message": f"Input payload {payload.id} processed successfully", @@ -436,12 +363,12 @@ async def consume_cloud_event( f"KServe Inference Output {payload.id} received from model={payload.model_name}." ) partial_input = await storage_interface.get_partial_payload( - payload.id, is_input=True + payload.id, is_input=True, is_modelmesh=False ) if partial_input is not None: - await reconcile(partial_input, payload) + await reconcile_kserve(partial_input, payload, tag) else: - await storage_interface.persist_partial_payload(payload, is_input=False) + await storage_interface.persist_partial_payload(payload, payload_id=payload.id, is_input=False) return { "status": "success", diff --git a/src/endpoints/data/data_download.py b/src/endpoints/data/data_download.py deleted file mode 100644 index 4e39858..0000000 --- a/src/endpoints/data/data_download.py +++ /dev/null @@ -1,32 +0,0 @@ -import logging - -import pandas as pd -from fastapi import APIRouter, HTTPException - -from src.service.utils.download import ( - DataRequestPayload, - DataResponsePayload, - apply_filters, # ← New utility function - load_model_dataframe, -) - -router = APIRouter() -logger = logging.getLogger(__name__) - - -@router.post("/data/download") -async def download_data(payload: DataRequestPayload) -> DataResponsePayload: - """Download model data with filtering.""" - try: - logger.info(f"Received data download request for model: {payload.modelId}") - df = await load_model_dataframe(payload.modelId) - if df.empty: - return DataResponsePayload(dataCSV="") - df = apply_filters(df, payload) - csv_data = df.to_csv(index=False) - return DataResponsePayload(dataCSV=csv_data) - except HTTPException: - raise - except Exception as e: - logger.error(f"Error downloading data: {str(e)}") - raise HTTPException(status_code=500, detail=f"Error downloading data: {str(e)}") \ No newline at end of file diff --git a/src/endpoints/data/data_upload.py b/src/endpoints/data/data_upload.py index 6f20bb2..5e110a7 100644 --- a/src/endpoints/data/data_upload.py +++ b/src/endpoints/data/data_upload.py @@ -1,16 +1,18 @@ +import asyncio import logging -import uuid -from datetime import datetime +import traceback from typing import Any, Dict, List, Optional import numpy as np +import uuid from fastapi import APIRouter, HTTPException from pydantic import BaseModel -from src.service.constants import INPUT_SUFFIX, METADATA_SUFFIX, OUTPUT_SUFFIX -from src.service.data.modelmesh_parser import ModelMeshPayloadParser -from src.service.data.storage import get_storage_interface -from src.service.utils.upload import process_upload_request +from src.endpoints.consumer.consumer_endpoint import (reconcile_kserve, consume_cloud_event) +from src.endpoints.consumer import KServeInferenceRequest, KServeInferenceResponse +from src.service.constants import TRUSTYAI_TAG_PREFIX +from src.service.data.model_data import ModelData + router = APIRouter() logger = logging.getLogger(__name__) @@ -20,20 +22,64 @@ class UploadPayload(BaseModel): model_name: str data_tag: Optional[str] = None is_ground_truth: bool = False - request: Dict[str, Any] - response: Optional[Dict[str, Any]] = None + request: KServeInferenceRequest + response: KServeInferenceResponse +def validate_data_tag(tag: str) -> Optional[str]: + """Validate data tag format and content.""" + if not tag: + return None + if tag.startswith(TRUSTYAI_TAG_PREFIX): + return ( + f"The tag prefix '{TRUSTYAI_TAG_PREFIX}' is reserved for internal TrustyAI use only. " + f"Provided tag '{tag}' violates this restriction." + ) + return None + @router.post("/data/upload") async def upload(payload: UploadPayload) -> Dict[str, str]: """Upload model data - regular or ground truth.""" + error_msgs = [] + + # validate tag + tag_validation_msg = validate_data_tag(payload.data_tag) + if tag_validation_msg: + raise HTTPException(status_code=400, detail=tag_validation_msg) try: logger.info(f"Received upload request for model: {payload.model_name}") - result = await process_upload_request(payload) + + # overwrite response model name with provided model name + payload.response.model_name = payload.model_name + + req_id = str(uuid.uuid4()) + try: + model_data = ModelData(payload.model_name) + previous_data_points = (await model_data.row_counts())[0] + except: + previous_data_points = 0 + + await consume_cloud_event(payload.response, req_id), + await consume_cloud_event(payload.request, req_id, tag=payload.data_tag) + + model_data = ModelData(payload.model_name) + new_data_points = (await model_data.row_counts())[0] + logger.info(f"Upload completed for model: {payload.model_name}") - return result - except HTTPException: - raise + + return { + "status": "success", + "message": f"{new_data_points-previous_data_points} datapoints successfully added to {payload.model_name} data." + } + + except HTTPException as e: + if "Could not reconcile_kserve KServe Inference" in str(e): + new_msg = e.detail.replace("Could not reconcile_kserve KServe Inference", + "Could not upload payload") + raise HTTPException(status_code=400, detail=new_msg) + else: + raise except Exception as e: + traceback.print_exc() logger.error(f"Unexpected error in upload endpoint for model {payload.model_name}: {str(e)}", exc_info=True) - raise HTTPException(500, f"Internal server error: {str(e)}") \ No newline at end of file + raise HTTPException(500, f"Internal server error: {str(e)}") diff --git a/src/main.py b/src/main.py index ad7bbd2..82bc3b7 100644 --- a/src/main.py +++ b/src/main.py @@ -27,7 +27,7 @@ from src.endpoints.metrics.identity.identity_endpoint import router as identity_router from src.endpoints.metadata import router as metadata_router from src.endpoints.metrics.metrics_info import router as metrics_info_router -from src.endpoints.data.data_download import router as data_download_router +from src.service.data.storage import get_global_storage_interface try: from src.endpoints.evaluation.lm_evaluation_harness import ( @@ -109,7 +109,6 @@ app.include_router(identity_router, tags=["Identity Endpoint"]) app.include_router(metadata_router, tags=["Service Metadata"]) app.include_router(metrics_info_router, tags=["Metrics Information Endpoint"]) -app.include_router(data_download_router, tags=["Download Endpoint"]) if lm_evaluation_harness_available: app.include_router( @@ -149,6 +148,8 @@ async def liveness_probe(): return JSONResponse(content={"status": "live"}, status_code=200) + + if __name__ == "__main__": # SERVICE_STORAGE_FORMAT=PVC; STORAGE_DATA_FOLDER=/tmp; STORAGE_DATA_FILENAME=trustyai_test.hdf5 uvicorn.run(app=app, host="0.0.0.0", port=8080) diff --git a/src/service/constants.py b/src/service/constants.py index e27d662..318e1d4 100644 --- a/src/service/constants.py +++ b/src/service/constants.py @@ -2,6 +2,7 @@ INPUT_SUFFIX = "_inputs" OUTPUT_SUFFIX = "_outputs" METADATA_SUFFIX = "_metadata" +GROUND_TRUTH_SUFFIX = "_ground_truth" PROTECTED_DATASET_SUFFIX = "trustyai_internal_" PARTIAL_PAYLOAD_DATASET_NAME = "partial_payloads" @@ -9,4 +10,4 @@ TRUSTYAI_TAG_PREFIX = "_trustyai" SYNTHETIC_TAG = f"{TRUSTYAI_TAG_PREFIX}_synthetic" UNLABELED_TAG = f"{TRUSTYAI_TAG_PREFIX}_unlabeled" -BIAS_IGNORE_PARAM = "bias-ignore" \ No newline at end of file +BIAS_IGNORE_PARAM = "bias-ignore" diff --git a/src/service/data/model_data.py b/src/service/data/model_data.py index 9bbb423..d83c9f3 100644 --- a/src/service/data/model_data.py +++ b/src/service/data/model_data.py @@ -1,12 +1,11 @@ +import asyncio from typing import List, Optional import numpy as np +import pandas as pd -from src.service.data.storage import get_storage_interface from src.service.constants import * - -storage_interface = get_storage_interface() - +from src.service.data.storage import get_global_storage_interface class ModelDataContainer: def __init__(self, model_name: str, input_data: np.ndarray, input_names: List[str], output_data: np.ndarray, @@ -37,6 +36,7 @@ async def row_counts(self) -> tuple[int, int, int]: """ Get the number of input, output, and metadata rows that exist in a model dataset """ + storage_interface = get_global_storage_interface() input_rows = await storage_interface.dataset_rows(self.input_dataset) output_rows = await storage_interface.dataset_rows(self.output_dataset) metadata_rows = await storage_interface.dataset_rows(self.metadata_dataset) @@ -46,28 +46,30 @@ async def shapes(self) -> tuple[List[int], List[int], List[int]]: """ Get the shapes of the input, output, and metadata datasets that exist in a model dataset """ + storage_interface = get_global_storage_interface() input_shape = await storage_interface.dataset_shape(self.input_dataset) output_shape = await storage_interface.dataset_shape(self.output_dataset) metadata_shape = await storage_interface.dataset_shape(self.metadata_dataset) return input_shape, output_shape, metadata_shape async def column_names(self) -> tuple[List[str], List[str], List[str]]: + storage_interface = get_global_storage_interface() input_names = await storage_interface.get_aliased_column_names(self.input_dataset) output_names = await storage_interface.get_aliased_column_names(self.output_dataset) - # these can't be aliased metadata_names = await storage_interface.get_original_column_names(self.metadata_dataset) return input_names, output_names, metadata_names async def original_column_names(self) -> tuple[List[str], List[str], List[str]]: + storage_interface = get_global_storage_interface() input_names = await storage_interface.get_original_column_names(self.input_dataset) output_names = await storage_interface.get_original_column_names(self.output_dataset) metadata_names = await storage_interface.get_original_column_names(self.metadata_dataset) return input_names, output_names, metadata_names - async def data(self, start_row=None, n_rows=None, get_input=True, get_output=True, get_metadata=True) \ + async def data(self, start_row=0, n_rows=None, get_input=True, get_output=True, get_metadata=True) \ -> tuple[Optional[np.ndarray], Optional[np.ndarray], Optional[np.ndarray]]: """ Get data from a saved model @@ -78,6 +80,8 @@ async def data(self, start_row=None, n_rows=None, get_input=True, get_output=Tru * get_output: whether to retrieve output data <- use this to reduce file reads * get_metadata: whether to retrieve metadata <- use this to reduce file reads """ + storage_interface = get_global_storage_interface() + if get_input: input_data = await storage_interface.read_data(self.input_dataset, start_row, n_rows) else: @@ -93,6 +97,13 @@ async def data(self, start_row=None, n_rows=None, get_input=True, get_output=Tru return input_data, output_data, metadata + async def get_metadata_as_df(self): + _, _, metadata = await self.data(get_input=False, get_output=False) + metadata_cols = (await self.column_names())[2] + return pd.DataFrame(metadata, columns=metadata_cols) + + + async def summary_string(self): out = f"=== {self.model_name} Data ===" diff --git a/src/service/data/storage/__init__.py b/src/service/data/storage/__init__.py index d914fad..869cd48 100644 --- a/src/service/data/storage/__init__.py +++ b/src/service/data/storage/__init__.py @@ -1,9 +1,21 @@ +import asyncio import os from src.service.data.storage.maria.legacy_maria_reader import LegacyMariaDBStorageReader from src.service.data.storage.maria.maria import MariaDBStorage from src.service.data.storage.pvc import PVCStorage +global_storage_interface = None +_storage_lock = asyncio.Lock() + +def get_global_storage_interface(force_reload=False): + global global_storage_interface + + if global_storage_interface is None or force_reload: + global_storage_interface = get_storage_interface() + return global_storage_interface + + def get_storage_interface(): storage_format = os.environ.get("SERVICE_STORAGE_FORMAT", "PVC") if storage_format == "PVC": @@ -15,7 +27,7 @@ def get_storage_interface(): host=os.environ.get("DATABASE_HOST"), port=int(os.environ.get("DATABASE_PORT")), database=os.environ.get("DATABASE_DATABASE"), - attempt_migration=True + attempt_migration=bool(int((os.environ.get("DATABASE_ATTEMPT_MIGRATION", "0")))), ) else: raise ValueError(f"Storage format={storage_format} not yet supported by the Python implementation of the service.") diff --git a/src/service/data/storage/maria/legacy_maria_reader.py b/src/service/data/storage/maria/legacy_maria_reader.py index e944fb3..0627d98 100644 --- a/src/service/data/storage/maria/legacy_maria_reader.py +++ b/src/service/data/storage/maria/legacy_maria_reader.py @@ -174,13 +174,15 @@ async def migrate_data(self, new_maria_storage: StorageInterface): output_df.columns.values) if not input_has_migrated: - await new_maria_storage.write_data(input_dataset, input_df.to_numpy(), list(input_df.columns.values)) - new_maria_storage.apply_name_mapping(input_dataset, input_mapping) + await new_maria_storage.write_data(input_dataset, input_df.to_numpy(), input_df.columns.to_list()) + await new_maria_storage.apply_name_mapping(input_dataset, input_mapping) if not output_has_migrated: - await new_maria_storage.write_data(output_dataset, output_df.to_numpy(), list(output_df.columns.values)) - new_maria_storage.apply_name_mapping(output_dataset, output_mapping) + await new_maria_storage.write_data(output_dataset, output_df.to_numpy(), + output_df.columns.to_list()) + await new_maria_storage.apply_name_mapping(output_dataset, output_mapping) if not metadata_has_migrated: - await new_maria_storage.write_data(metadata_dataset, metadata_df.to_numpy(), list(metadata_df.columns.values)) + await new_maria_storage.write_data(metadata_dataset, metadata_df.to_numpy(), + metadata_df.columns.to_list()) migrations.append(True) logger.info(f"Dataset {dataset_name} successfully migrated.") diff --git a/src/service/data/storage/maria/maria.py b/src/service/data/storage/maria/maria.py index 51c665c..7f57025 100644 --- a/src/service/data/storage/maria/maria.py +++ b/src/service/data/storage/maria/maria.py @@ -2,11 +2,14 @@ import io import json import logging +import time + import mariadb import numpy as np import pickle as pkl -from typing import Optional, Dict, List +from typing import Optional, Dict, List, Union +from src.endpoints.consumer import KServeInferenceRequest, KServeInferenceResponse from src.service.data.modelmesh_parser import PartialPayload from src.service.data.storage import LegacyMariaDBStorageReader from src.service.data.storage.maria.utils import MariaConnectionManager, require_existing_dataset, \ @@ -80,7 +83,7 @@ def _build_table_name(self, index): return f"{self.schema_prefix}_dataset_{index}" @require_existing_dataset - def _get_clean_table_name(self, dataset_name: str) -> str: + async def _get_clean_table_name(self, dataset_name: str) -> str: """ Get a generated table name corresponding to a particular dataset. This avoids possible SQL injection from within the model names. @@ -91,7 +94,7 @@ def _get_clean_table_name(self, dataset_name: str) -> str: @require_existing_dataset - def _get_dataset_metadata(self, dataset_name: str) -> Optional[Dict]: + async def _get_dataset_metadata(self, dataset_name: str) -> Optional[Dict]: """ Return the metadata field from a particular dataset within the dataset_reference_table. """ @@ -102,7 +105,7 @@ def _get_dataset_metadata(self, dataset_name: str) -> Optional[Dict]: #=== DATASET QUERYING ========================================================================== - def dataset_exists(self, dataset_name: str) -> bool: + async def dataset_exists(self, dataset_name: str) -> bool: """ Check if a dataset exists within the TrustyAI model data. """ @@ -124,7 +127,7 @@ def list_all_datasets(self): @require_existing_dataset - def dataset_rows(self, dataset_name: str) -> int: + async def dataset_rows(self, dataset_name: str) -> int: """ Get the number of rows in a stored dataset (equivalent to data.shape[0]) """ @@ -135,23 +138,23 @@ def dataset_rows(self, dataset_name: str) -> int: @require_existing_dataset - def dataset_cols(self, dataset_name: str) -> int: + async def dataset_cols(self, dataset_name: str) -> int: """ Get the number of columns in a stored dataset (equivalent to data.shape[1]) """ - table_name = self._get_clean_table_name(dataset_name) + table_name = await self._get_clean_table_name(dataset_name) with self.connection_manager as (conn, cursor): cursor.execute(f"SHOW COLUMNS FROM {table_name}") return len(cursor.fetchall()) - 1 @require_existing_dataset - def dataset_shape(self, dataset_name: str) -> tuple[int]: + async def dataset_shape(self, dataset_name: str) -> tuple[int]: """ Get the whole shape of a stored dataset (equivalent to data.shape) """ - rows = self.dataset_rows(dataset_name) - shape = self._get_dataset_metadata(dataset_name)["shape"] + rows = await self.dataset_rows(dataset_name) + shape = (await self._get_dataset_metadata(dataset_name))["shape"] shape[0] = rows return tuple(shape) @@ -172,9 +175,9 @@ async def write_data(self, dataset_name: str, new_rows: np.ndarray, column_names if len(new_rows) == 0: raise ValueError(f"No data provided! `new_rows`=={new_rows}.") - # if received a single row, reshape into a single-row matrix + # if received a single row, reshape into a single-column matrix if new_rows.ndim < 2: - new_rows = new_rows.reshape(1, -1) + new_rows = new_rows.reshape(-1, 1) # validate that the number of provided column names matches the shape of the provided array if new_rows.shape[1] != len(column_names): @@ -182,7 +185,7 @@ async def write_data(self, dataset_name: str, new_rows: np.ndarray, column_names f"Shape mismatch: Number of provided column names ({len(column_names)}) does not match number of columns in provided array ({new_rows.shape[1]}).") # if this is the first time we've seen this dataset, set up its tables inside the DB - if not self.dataset_exists(dataset_name): + if not await self.dataset_exists(dataset_name): with self.connection_manager as (conn, cursor): # create an entry in `trustyai_v2_table_reference` @@ -212,10 +215,10 @@ async def write_data(self, dataset_name: str, new_rows: np.ndarray, column_names nrows = 0 else: # if dataset already exists, grab its current shape and information - stored_shape = self.dataset_shape(dataset_name) + stored_shape = await self.dataset_shape(dataset_name) ncols = stored_shape[1] - nrows = self.dataset_rows(dataset_name) - table_name = self._get_clean_table_name(dataset_name) + nrows = await self.dataset_rows(dataset_name) + table_name = await self._get_clean_table_name(dataset_name) cleaned_names = get_clean_column_names(column_names) # validate that the number of columns in the saved DB matched the provided column names @@ -255,9 +258,9 @@ async def write_data(self, dataset_name: str, new_rows: np.ndarray, column_names @require_existing_dataset - def read_data(self, dataset_name: str, start_row: int = 0, n_rows: int = None): + async def read_data(self, dataset_name: str, start_row: int = 0, n_rows: int = None): """ - Read saved data from the database, from `start_row` to `start_row + n_rows` (inclusive) + Read saved data from the database, from `start_row` to `start_row + n_rows` (inclusive)wait storage.dataset_exists(dataset_name): `dataset_name`: the name of the dataset to read. This is NOT the table name; see `trustyai_v2_table_reference.dataset_name` or use list_all_datasets() for the available dataset_names. @@ -265,16 +268,16 @@ def read_data(self, dataset_name: str, start_row: int = 0, n_rows: int = None): `n_rows`: The total number of rows to read. If not specified, read all rows. """ - table_name = self._get_clean_table_name(dataset_name) + table_name = await self._get_clean_table_name(dataset_name) if n_rows is None: - n_rows = self.dataset_rows(dataset_name) + n_rows = await self.dataset_rows(dataset_name) with self.connection_manager as (conn, cursor): # grab matching data cursor.execute( - f"SELECT * FROM `{table_name}` WHERE row_idx>? AND row_idx<=?", - (start_row, start_row+n_rows) + f"SELECT * FROM `{table_name}` ORDER BY row_idx ASC LIMIT ? OFFSET ?", + (n_rows, start_row) ) # parse saved data back to Numpy array @@ -297,15 +300,15 @@ def read_data(self, dataset_name: str, start_row: int = 0, n_rows: int = None): # === COLUMN NAMES ============================================================================= @require_existing_dataset - def get_original_column_names(self, dataset_name: str) -> Optional[List[str]]: - return self._get_dataset_metadata(dataset_name).get("column_names") + async def get_original_column_names(self, dataset_name: str) -> Optional[List[str]]: + return (await self._get_dataset_metadata(dataset_name)).get("column_names") @require_existing_dataset - def get_aliased_column_names(self, dataset_name: str) -> List[str]: - return self._get_dataset_metadata(dataset_name).get("aliased_names") + async def get_aliased_column_names(self, dataset_name: str) -> List[str]: + return (await self._get_dataset_metadata(dataset_name)).get("aliased_names") @require_existing_dataset - def apply_name_mapping(self, dataset_name: str, name_mapping: Dict[str, str]): + async def apply_name_mapping(self, dataset_name: str, name_mapping: Dict[str, str]): """Apply a name mapping to a dataset. `dataset_name`: the name of the dataset to read. This is NOT the table name; @@ -314,8 +317,8 @@ def apply_name_mapping(self, dataset_name: str, name_mapping: Dict[str, str]): to original column names and values should correspond to the desired new names. """ - original_names = self.get_original_column_names(dataset_name) - aliased_names = self.get_aliased_column_names(dataset_name) + original_names = await self.get_original_column_names(dataset_name) + aliased_names = await self.get_aliased_column_names(dataset_name) # get the new set of optionaly-aliased column names for col_idx, original_name in enumerate(original_names): @@ -334,67 +337,56 @@ def apply_name_mapping(self, dataset_name: str, name_mapping: Dict[str, str]): # === PARTIAL PAYLOADS ========================================================================= - async def _persist_payload(self, payload, is_input: bool, request_id: Optional[str] = None): + async def persist_partial_payload(self, + payload: Union[PartialPayload, KServeInferenceRequest, KServeInferenceResponse], + payload_id, is_input: bool): """Save a partial payload to the database.""" with self.connection_manager as (conn, cursor): - if request_id is None: - request_id = payload.id - cursor.execute( f"INSERT INTO `{self.partial_payload_table}` (payload_id, is_input, payload_data) VALUES (?, ?, ?)", - (request_id, is_input, pkl.dumps(payload))) + (payload_id, is_input, pkl.dumps(payload.model_dump()))) conn.commit() - async def _get_partial_payload(self, payload_id: str, is_input: bool): + async def get_partial_payload(self, payload_id: str, is_input: bool, is_modelmesh: bool) -> Union[PartialPayload, KServeInferenceRequest, KServeInferenceResponse]: """Retrieve a partial payload from the database.""" with self.connection_manager as (conn, cursor): cursor.execute(f"SELECT payload_data FROM `{self.partial_payload_table}` WHERE payload_id=? AND is_input=?", (payload_id, is_input)) result = cursor.fetchone() if result is None or len(result) == 0: return None - payload_data = result[0] - return pkl.loads(payload_data) - - - async def persist_partial_payload(self, payload, is_input: bool): - await self._persist_payload(payload, is_input) - - - async def persist_modelmesh_payload(self, payload: PartialPayload, request_id: str, is_input: bool): - await self._persist_payload(payload, is_input, request_id=request_id) - - - async def get_partial_payload(self, payload_id: str, is_input: bool): - return await self._get_partial_payload(payload_id, is_input) - - - async def get_modelmesh_payload(self, request_id: str, is_input: bool) -> Optional[PartialPayload]: - return await self._get_partial_payload(request_id, is_input) + payload_dict = pkl.loads(result[0]) + if is_modelmesh: + return PartialPayload(**payload_dict) + elif is_input: # kserve input + return KServeInferenceRequest(**payload_dict) + else: # kserve output + return KServeInferenceResponse(**payload_dict) - async def delete_modelmesh_payload(self, request_id: str, is_input: bool): + async def delete_partial_payload(self, payload_id: str, is_input: bool): with self.connection_manager as (conn, cursor): - cursor.execute(f"DELETE FROM {self.partial_payload_table} WHERE payload_id=? AND is_input=?", (request_id, is_input)) + cursor.execute(f"DELETE FROM {self.partial_payload_table} WHERE payload_id=? AND is_input=?", (payload_id, is_input)) conn.commit() # === DATABASE CLEANUP ========================================================================= @require_existing_dataset - def delete_dataset(self, dataset_name: str): - table_name = self._get_clean_table_name(dataset_name) + async def delete_dataset(self, dataset_name: str): + table_name = await self._get_clean_table_name(dataset_name) + logger.info(f"Deleting table={table_name} to delete dataset={dataset_name}.") with self.connection_manager as (conn, cursor): cursor.execute(f"DELETE FROM `{self.dataset_reference_table}` WHERE dataset_name=?", (dataset_name,)) cursor.execute(f"DROP TABLE IF EXISTS `{table_name}`") conn.commit() - def delete_all_datasets(self): + async def delete_all_datasets(self): for dataset_name in self.list_all_datasets(): logger.warning(f"Deleting dataset {dataset_name}") - self.delete_dataset(dataset_name) + await self.delete_dataset(dataset_name) - def reset_database(self): + async def reset_database(self): logger.warning(f"Fully resetting TrustyAI V2 database.") - self.delete_all_datasets() + await self.delete_all_datasets() with self.connection_manager as (conn, cursor): cursor.execute(f"DROP TABLE IF EXISTS `{self.dataset_reference_table}`") cursor.execute(f"DROP TABLE IF EXISTS `{self.partial_payload_table}`") diff --git a/src/service/data/storage/maria/utils.py b/src/service/data/storage/maria/utils.py index af57978..97bfb47 100644 --- a/src/service/data/storage/maria/utils.py +++ b/src/service/data/storage/maria/utils.py @@ -5,15 +5,16 @@ def require_existing_dataset(func): """Annotation to assert that a given function requires a valid dataset name as the first non-self argument""" - def validate_dataset_exists(*args, **kwargs): + async def validate_dataset_exists(*args, **kwargs): storage, dataset_name = args[0], args[1] - if not storage.dataset_exists(dataset_name): + if not await storage.dataset_exists(dataset_name): raise ValueError(f"Error when calling {func.__name__}: Dataset '{dataset_name}' does not exist.") - return func(*args, **kwargs) + return await func(*args, **kwargs) return validate_dataset_exists + def get_clean_column_names(column_names) -> List[str]: """ Programmatically generate the column names in a model data table. diff --git a/src/service/data/storage/pvc.py b/src/service/data/storage/pvc.py index 813ed9b..d81b405 100644 --- a/src/service/data/storage/pvc.py +++ b/src/service/data/storage/pvc.py @@ -1,5 +1,5 @@ import asyncio -from typing import List, Dict, Optional +from typing import List, Dict, Optional, Union import numpy as np import os @@ -7,6 +7,7 @@ import logging import pickle as pkl +from src.endpoints.consumer import KServeInferenceRequest, KServeInferenceResponse from src.service.utils import list_utils from .storage_interface import StorageInterface from src.service.constants import PROTECTED_DATASET_SUFFIX, PARTIAL_PAYLOAD_DATASET_NAME @@ -21,13 +22,7 @@ PARTIAL_OUTPUT_NAME = ( PROTECTED_DATASET_SUFFIX + PARTIAL_PAYLOAD_DATASET_NAME + "_outputs" ) -MODELMESH_INPUT_NAME = ( - f"{PROTECTED_DATASET_SUFFIX}modelmesh_partial_payloads_inputs" -) -MODELMESH_OUTPUT_NAME = ( - f"{PROTECTED_DATASET_SUFFIX}modelmesh_partial_payloads_outputs" -) - +MAX_VOID_TYPE_LENGTH=1024 class H5PYContext: """Open the corresponding H5PY file for a dataset and manage its context`""" @@ -156,6 +151,13 @@ async def _write_raw_data( existing_shape = None dataset_exists = False + # standardize serialized rows to prevent metadata serialization failures + if isinstance(new_rows.dtype, np.dtypes.VoidDType): + if new_rows.dtype.itemsize > MAX_VOID_TYPE_LENGTH: + raise ValueError(f"The datatype of the array to be serialized is {new_rows.dtype}- the largest serializable void type is V{MAX_VOID_TYPE_LENGTH}") + new_rows = new_rows.astype(f"V{MAX_VOID_TYPE_LENGTH}") # this might cause bugs later + + if dataset_exists: # if we've already got saved inferences for this model if existing_shape[1:] == inbound_shape[1:]: # shapes match async with self.get_lock(allocated_dataset_name): @@ -203,6 +205,11 @@ async def _write_raw_data( max_shape = [None] + list(new_rows.shape)[ 1: ] # to-do: tune this value? + + + if isinstance(new_rows.data, np.dtypes.VoidDType): + new_rows = new_rows.astype("V400") + dataset = db.create_dataset( allocated_dataset_name, data=new_rows, @@ -225,7 +232,7 @@ async def write_data(self, dataset_name: str, new_rows, column_names: List[str]) and list_utils.contains_non_numeric(new_rows) ): await self._write_raw_data( - dataset_name, list_utils.serialize_rows(new_rows), column_names + dataset_name, list_utils.serialize_rows(new_rows, MAX_VOID_TYPE_LENGTH), column_names ) else: await self._write_raw_data(dataset_name, np.array(new_rows), column_names) @@ -247,20 +254,18 @@ async def _read_raw_data( f"Requested a data read from start_row={start_row}, but dataset " f"only has {dataset.shape[0]} rows. An empty array will be returned." ) - return ( - dataset[start_row:end_row], - dataset.attrs[COLUMN_NAMES_ATTRIBUTE], - ) + return dataset[start_row:end_row] + async def read_data( - self, dataset_name: str, start_row: int = None, n_rows: int = None + self, dataset_name: str, start_row: int = 0, n_rows: int = None ) -> (np.ndarray, List[str]): """Read data from a dataset, automatically deserializing any byte data""" - read, column_names = await self._read_raw_data(dataset_name, start_row, n_rows) + read = await self._read_raw_data(dataset_name, start_row, n_rows) if len(read) and read[0].dtype.type in {np.bytes_, np.void}: - return list_utils.deserialize_rows(read), column_names + return list_utils.deserialize_rows(read) else: - return read, column_names + return read async def delete_dataset(self, dataset_name: str): """Delete dataset data, ignoring non-existent datasets""" @@ -307,77 +312,34 @@ async def apply_name_mapping(self, dataset_name: str, name_mapping: Dict[str, st aliased_names = [name_mapping.get(name, name) for name in curr_names] db[allocated_dataset_name].attrs[COLUMN_ALIAS_ATTRIBUTE] = aliased_names - async def persist_partial_payload(self, payload, is_input: bool): - """Save a partial payload to disk. Returns None if no matching id exists""" - - # lock to prevent simultaneous read/writes - partial_dataset_name = PARTIAL_INPUT_NAME if is_input else PARTIAL_OUTPUT_NAME - async with self.get_lock(partial_dataset_name): - with H5PYContext( - self, - partial_dataset_name, - "a", - ) as db: - if partial_dataset_name not in db: - dataset = db.create_dataset( - partial_dataset_name, dtype="f", track_order=True - ) - else: - dataset = db[partial_dataset_name] - dataset.attrs[payload.id] = np.void(pkl.dumps(payload)) - async def persist_modelmesh_payload( - self, payload: PartialPayload, request_id: str, is_input: bool - ): + async def persist_partial_payload(self, payload: Union[PartialPayload, KServeInferenceRequest, KServeInferenceResponse], payload_id: str, is_input: bool): """ - Persist a ModelMesh payload. - - Args: - payload: The payload to persist - request_id: The unique identifier for the inference request - is_input: Whether this is an input payload (True) or output payload (False) + Save a KServe or ModelMesh payload to disk. """ - dataset_name = MODELMESH_INPUT_NAME if is_input else MODELMESH_OUTPUT_NAME - + dataset_name = PARTIAL_INPUT_NAME if is_input else PARTIAL_OUTPUT_NAME serialized_data = pkl.dumps(payload.model_dump()) + is_modelmesh = isinstance(payload, PartialPayload) async with self.get_lock(dataset_name): try: with H5PYContext(self, dataset_name, "a") as db: if dataset_name not in db: - dataset = db.create_dataset(dataset_name, data=np.array([0])) - dataset.attrs["request_ids"] = [] - - dataset = db[dataset_name] - request_ids = list(dataset.attrs["request_ids"]) - - dataset.attrs[request_id] = np.void(serialized_data) - - if request_id not in request_ids: - request_ids.append(request_id) - dataset.attrs["request_ids"] = request_ids + dataset = db.create_dataset(dataset_name, dtype="f", track_order=True) + else: + dataset = db[dataset_name] + dataset.attrs[payload_id] = np.void(serialized_data) logger.debug( - f"Stored ModelMesh {'input' if is_input else 'output'} payload for request ID: {request_id}" + f"Stored {'ModelMesh' if is_modelmesh else 'KServe'} {'input' if is_input else 'output'} payload for request ID: {payload_id}" ) except Exception as e: - logger.error(f"Error storing ModelMesh payload: {str(e)}") + logger.error(f"Error storing {'ModelMesh' if is_modelmesh else 'KServe'} payload: {str(e)}") raise - async def get_modelmesh_payload( - self, request_id: str, is_input: bool - ) -> Optional[PartialPayload]: - """ - Retrieve a stored ModelMesh payload by request ID. - - Args: - request_id: The unique identifier for the inference request - is_input: Whether to retrieve an input payload (True) or output payload (False) - - Returns: - The retrieved payload, or None if not found - """ - dataset_name = MODELMESH_INPUT_NAME if is_input else MODELMESH_OUTPUT_NAME + async def get_partial_payload(self, payload_id: str, is_input: bool, is_modelmesh: bool) -> Optional[ + Union[PartialPayload, KServeInferenceRequest, KServeInferenceResponse]]: + dataset_name = PARTIAL_INPUT_NAME if is_input else PARTIAL_OUTPUT_NAME try: async with self.get_lock(dataset_name): @@ -386,51 +348,38 @@ async def get_modelmesh_payload( return None dataset = db[dataset_name] - if request_id not in dataset.attrs: + if payload_id not in dataset.attrs: return None - serialized_data = dataset.attrs[request_id] + serialized_data = dataset.attrs[payload_id] try: payload_dict = pkl.loads(serialized_data) - return PartialPayload(**payload_dict) + if is_modelmesh: + return PartialPayload(**payload_dict) + elif is_input: # kserve input + return KServeInferenceRequest(**payload_dict) + else: # kserve output + return KServeInferenceResponse(**payload_dict) except Exception as e: logger.error(f"Error unpickling payload: {str(e)}") return None except MissingH5PYDataException: return None except Exception as e: - logger.error(f"Error retrieving ModelMesh payload: {str(e)}") + logger.error(f"Error retrieving {'ModelMesh' if is_modelmesh else 'KServe'} payload: {str(e)}") return None - async def get_partial_payload(self, payload_id: str, is_input: bool): - """Looks up a partial payload by id. Returns None if no matching id exists""" - # lock to prevent simultaneous read/writes - partial_dataset_name = PARTIAL_INPUT_NAME if is_input else PARTIAL_OUTPUT_NAME - async with self.get_lock(partial_dataset_name): - try: - with H5PYContext(self, partial_dataset_name, "r") as db: - if partial_dataset_name not in db: - return None - recovered_bytes = db[partial_dataset_name].attrs.get(payload_id) - return ( - None - if recovered_bytes is None - else pkl.loads(recovered_bytes) - ) - except MissingH5PYDataException: - return None - - async def delete_modelmesh_payload(self, request_id: str, is_input: bool): + async def delete_partial_payload(self, request_id: str, is_input: bool): """ - Delete a stored ModelMesh payload. + Delete a stored partial payload. Args: request_id: The unique identifier for the inference request is_input: Whether to delete an input payload (True) or output payload (False) """ - dataset_name = MODELMESH_INPUT_NAME if is_input else MODELMESH_OUTPUT_NAME + dataset_name = PARTIAL_INPUT_NAME if is_input else PARTIAL_OUTPUT_NAME try: async with self.get_lock(dataset_name): @@ -439,24 +388,20 @@ async def delete_modelmesh_payload(self, request_id: str, is_input: bool): return dataset = db[dataset_name] - request_ids = list(dataset.attrs["request_ids"]) - if request_id not in request_ids: + if request_id not in dataset.attrs: return if request_id in dataset.attrs: del dataset.attrs[request_id] - request_ids.remove(request_id) - dataset.attrs["request_ids"] = request_ids - - if not request_ids: + if not dataset.attrs: del db[dataset_name] logger.debug( - f"Deleted ModelMesh {'input' if is_input else 'output'} payload for request ID: {request_id}" + f"Deleted {'input' if is_input else 'output'} payload for request ID: {request_id}" ) except MissingH5PYDataException: return except Exception as e: - logger.error(f"Error deleting ModelMesh payload: {str(e)}") + logger.error(f"Error deleting payload: {str(e)}") diff --git a/src/service/data/storage/storage_interface.py b/src/service/data/storage/storage_interface.py index 28ad65b..2b12182 100644 --- a/src/service/data/storage/storage_interface.py +++ b/src/service/data/storage/storage_interface.py @@ -1,11 +1,13 @@ from abc import ABC, abstractmethod -from typing import List, Dict, Optional +from typing import List, Dict, Optional, Union + +from src.endpoints.consumer import KServeInferenceResponse, KServeInferenceRequest from src.service.data.modelmesh_parser import PartialPayload class StorageInterface(ABC): @abstractmethod - def dataset_exists(self, dataset_name: str) -> bool: + async def dataset_exists(self, dataset_name: str) -> bool: pass @abstractmethod @@ -13,11 +15,11 @@ def list_all_datasets(self) -> List[str]: pass @abstractmethod - def dataset_rows(self, dataset_name: str) -> int: + async def dataset_rows(self, dataset_name: str) -> int: pass @abstractmethod - def dataset_shape(self, dataset_name: str) -> tuple[int]: + async def dataset_shape(self, dataset_name: str) -> tuple[int]: pass @abstractmethod @@ -25,68 +27,41 @@ async def write_data(self, dataset_name: str, new_rows, column_names: List[str]) pass @abstractmethod - def read_data(self, dataset_name: str, start_row: int = None, n_rows: int = None): + async def read_data(self, dataset_name: str, start_row: int = None, n_rows: int = None): pass @abstractmethod - def get_original_column_names(self, dataset_name: str) -> List[str]: + async def get_original_column_names(self, dataset_name: str) -> List[str]: pass @abstractmethod - def get_aliased_column_names(self, dataset_name: str) -> List[str]: + async def get_aliased_column_names(self, dataset_name: str) -> List[str]: pass @abstractmethod - def apply_name_mapping(self, dataset_name: str, name_mapping: Dict[str, str]): + async def apply_name_mapping(self, dataset_name: str, name_mapping: Dict[str, str]): pass @abstractmethod - def delete_dataset(self, dataset_name: str): + async def delete_dataset(self, dataset_name: str): pass @abstractmethod - async def persist_partial_payload(self, payload, is_input: bool): + async def persist_partial_payload(self, + payload: Union[PartialPayload, KServeInferenceRequest, KServeInferenceResponse], + payload_id, is_input: bool): pass @abstractmethod - async def get_partial_payload(self, payload_id: str, is_input: bool): - pass - - - @abstractmethod - async def persist_modelmesh_payload( - self, payload: PartialPayload, request_id: str, is_input: bool - ): - """ - Store a ModelMesh partial payload (either input or output) for later reconciliation. - - Args: - payload: The partial payload to store - request_id: A unique identifier for this inference request - is_input: Whether this is an input payload (True) or output payload (False) - """ + async def get_partial_payload(self, payload_id: str, is_input: bool, is_modelmesh: bool) -> Optional[ + Union[PartialPayload, KServeInferenceRequest, KServeInferenceResponse]]: pass - @abstractmethod - async def get_modelmesh_payload( - self, request_id: str, is_input: bool - ) -> Optional[PartialPayload]: - """ - Retrieve a stored ModelMesh payload by request ID. - - Args: - request_id: The unique identifier for the inference request - is_input: Whether to retrieve an input payload (True) or output payload (False) - - Returns: - The retrieved payload, or None if not found - """ - pass @abstractmethod - async def delete_modelmesh_payload(self, request_id: str, is_input: bool): + async def delete_partial_payload(self, payload_id: str, is_input: bool): """ - Delete a stored ModelMesh payload. + Delete a stored partial payload. Args: request_id: The unique identifier for the inference request diff --git a/src/service/utils/download.py b/src/service/utils/download.py deleted file mode 100644 index a6deea3..0000000 --- a/src/service/utils/download.py +++ /dev/null @@ -1,310 +0,0 @@ -import logging -import numbers -import pickle -from typing import Any, List - -import pandas as pd -from fastapi import HTTPException -from pydantic import BaseModel, Field - -from src.service.data.storage import get_storage_interface - -logger = logging.getLogger(__name__) - - -class RowMatcher(BaseModel): - """Represents a row matching condition for data filtering.""" - - columnName: str - operation: str # "EQUALS" or "BETWEEN" - values: List[Any] - - -class DataRequestPayload(BaseModel): - """Request payload for data download operations.""" - - modelId: str - matchAny: List[RowMatcher] = Field(default_factory=list) - matchAll: List[RowMatcher] = Field(default_factory=list) - matchNone: List[RowMatcher] = Field(default_factory=list) - - -class DataResponsePayload(BaseModel): - """Response payload containing filtered data as CSV.""" - - dataCSV: str - - -def get_storage() -> Any: - """Get storage interface instance.""" - return get_storage_interface() - - -async def load_model_dataframe(model_id: str) -> pd.DataFrame: - try: - storage = get_storage_interface() - print(f"DEBUG: storage type = {type(storage)}") - input_data, input_cols = await storage.read_data(f"{model_id}_inputs") - output_data, output_cols = await storage.read_data(f"{model_id}_outputs") - metadata_data, metadata_cols = await storage.read_data(f"{model_id}_metadata") - if input_data is None or output_data is None or metadata_data is None: - raise HTTPException(404, f"Model {model_id} not found") - df = pd.DataFrame() - if len(input_data) > 0: - input_df = pd.DataFrame(input_data, columns=input_cols) - df = pd.concat([df, input_df], axis=1) - if len(output_data) > 0: - output_df = pd.DataFrame(output_data, columns=output_cols) - df = pd.concat([df, output_df], axis=1) - if len(metadata_data) > 0: - logger.debug(f"Metadata data type: {type(metadata_data)}") - logger.debug(f"First row type: {type(metadata_data[0]) if len(metadata_data) > 0 else 'empty'}") - logger.debug( - f"First row dtype: {metadata_data[0].dtype if hasattr(metadata_data[0], 'dtype') else 'no dtype'}" - ) - metadata_df = pd.DataFrame(metadata_data, columns=metadata_cols) - trusty_mapping = { - "ID": "trustyai.ID", - "MODEL_ID": "trustyai.MODEL_ID", - "TIMESTAMP": "trustyai.TIMESTAMP", - "TAG": "trustyai.TAG", - "INDEX": "trustyai.INDEX", - } - for orig_col in metadata_cols: - trusty_col = trusty_mapping.get(orig_col, orig_col) - df[trusty_col] = metadata_df[orig_col] - return df - except HTTPException: - raise - except Exception as e: - logger.error(f"Error loading model dataframe: {e}") - raise HTTPException(500, f"Error loading model data: {str(e)}") - - -def apply_filters(df: pd.DataFrame, payload: DataRequestPayload) -> pd.DataFrame: - """ - Apply all filters to DataFrame with performance optimization. - """ - if not any([payload.matchAll, payload.matchAny, payload.matchNone]): - return df - has_timestamp_filter = _has_timestamp_filters(payload) - if has_timestamp_filter: - logger.debug("Using boolean mask approach for timestamp filters") - return _apply_filters_with_boolean_masks(df, payload) - else: - logger.debug("Using query approach for non-timestamp filters") - return _apply_filters_with_query(df, payload) - - -def _has_timestamp_filters(payload: DataRequestPayload) -> bool: - """Check if payload contains any timestamp filters.""" - for matcher_list in [payload.matchAll or [], payload.matchAny or [], payload.matchNone or []]: - for matcher in matcher_list: - if matcher.columnName == "trustyai.TIMESTAMP": - return True - return False - - -def _apply_filters_with_query(df: pd.DataFrame, payload: DataRequestPayload) -> pd.DataFrame: - """Apply filters using pandas query (optimized for non-timestamp filters).""" - query_expr = _build_query_expression(df, payload) - if query_expr: - logger.debug(f"Executing query: {query_expr}") - try: - df = df.query(query_expr) - except Exception as e: - logger.error(f"Query execution failed: {query_expr}") - raise HTTPException(status_code=400, detail=f"Filter execution failed: {str(e)}") - return df - - -def _apply_filters_with_boolean_masks(df: pd.DataFrame, payload: DataRequestPayload) -> pd.DataFrame: - """Apply filters using boolean masks (optimized for timestamp filters).""" - final_mask = pd.Series(True, index=df.index) - if payload.matchAll: - for matcher in payload.matchAll: - matcher_mask = _get_matcher_mask(df, matcher, negate=False) - final_mask &= matcher_mask - if payload.matchNone: - for matcher in payload.matchNone: - matcher_mask = _get_matcher_mask(df, matcher, negate=True) - final_mask &= matcher_mask - if payload.matchAny: - any_mask = pd.Series(False, index=df.index) - for matcher in payload.matchAny: - matcher_mask = _get_matcher_mask(df, matcher, negate=False) - any_mask |= matcher_mask - final_mask &= any_mask - return df[final_mask] - - -def _get_matcher_mask(df: pd.DataFrame, matcher: RowMatcher, negate: bool = False) -> pd.Series: - """ - Get boolean mask for a single matcher with comprehensive validation. - """ - column_name = matcher.columnName - values = matcher.values - if matcher.operation not in ["EQUALS", "BETWEEN"]: - raise HTTPException(status_code=400, detail="RowMatch operation must be one of [BETWEEN, EQUALS]") - if column_name not in df.columns: - raise HTTPException(status_code=400, detail=f"No feature or output found with name={column_name}") - if matcher.operation == "EQUALS": - mask = df[column_name].isin(values) - elif matcher.operation == "BETWEEN": - mask = _create_between_mask(df, column_name, values) - if negate: - mask = ~mask - - return mask - - -def _create_between_mask(df: pd.DataFrame, column_name: str, values: List[Any]) -> pd.Series: - """Create boolean mask for BETWEEN operation with type-specific handling.""" - errors = [] - if len(values) != 2: - errors.append( - f"BETWEEN operation must contain exactly two values, describing the lower and upper bounds of the desired range. Received {len(values)} values" - ) - if column_name == "trustyai.TIMESTAMP": - if errors: - raise HTTPException(status_code=400, detail=", ".join(errors)) - try: - start_time = pd.to_datetime(str(values[0])) - end_time = pd.to_datetime(str(values[1])) - df_times = pd.to_datetime(df[column_name]) - return (df_times >= start_time) & (df_times < end_time) - except Exception as e: - raise HTTPException( - status_code=400, detail=f"Timestamp value is unparseable as an ISO_LOCAL_DATE_TIME: {str(e)}" - ) - elif column_name == "trustyai.INDEX": - if errors: - raise HTTPException(status_code=400, detail=", ".join(errors)) - min_val, max_val = sorted([int(v) for v in values]) - return (df[column_name] >= min_val) & (df[column_name] < max_val) - else: - if not all(isinstance(v, numbers.Number) for v in values): - errors.append( - "BETWEEN operation must only contain numbers, describing the lower and upper bounds of the desired range. Received non-numeric values" - ) - if errors: - raise HTTPException(status_code=400, detail=", ".join(errors)) - min_val, max_val = sorted(values) - try: - if df[column_name].dtype in ["int64", "float64", "int32", "float32"]: - return (df[column_name] >= min_val) & (df[column_name] < max_val) - else: - numeric_column = pd.to_numeric(df[column_name], errors="raise") - return (numeric_column >= min_val) & (numeric_column < max_val) - except (ValueError, TypeError): - raise HTTPException( - status_code=400, - detail=f"Column '{column_name}' contains non-numeric values that cannot be compared with BETWEEN operation.", - ) - - -def _build_query_expression(df: pd.DataFrame, payload: DataRequestPayload) -> str: - """Build optimized pandas query expression for all filters.""" - conditions = [] - if payload.matchAll: - for matcher in payload.matchAll: - condition = _build_condition(df, matcher, negate=False) - if condition: - conditions.append(condition) - if payload.matchNone: - for matcher in payload.matchNone: - condition = _build_condition(df, matcher, negate=True) - if condition: - conditions.append(condition) - if payload.matchAny: - any_conditions = [] - for matcher in payload.matchAny: - condition = _build_condition(df, matcher, negate=False) - if condition: - any_conditions.append(condition) - if any_conditions: - any_expr = " | ".join(f"({cond})" for cond in any_conditions) - conditions.append(f"({any_expr})") - return " & ".join(f"({cond})" for cond in conditions) if conditions else "" - - -def _build_condition(df: pd.DataFrame, matcher: RowMatcher, negate: bool = False) -> str: - """Build a single condition for pandas query.""" - column_name = matcher.columnName - values = matcher.values - if matcher.operation not in ["EQUALS", "BETWEEN"]: - raise HTTPException(status_code=400, detail="RowMatch operation must be one of [BETWEEN, EQUALS]") - if column_name not in df.columns: - raise HTTPException(status_code=400, detail=f"No feature or output found with name={column_name}") - safe_column = _sanitize_column_name(column_name) - if matcher.operation == "EQUALS": - condition = _build_equals_condition(safe_column, values, df[column_name].dtype) - elif matcher.operation == "BETWEEN": - condition = _build_between_condition(safe_column, values, column_name, df[column_name].dtype) - if negate: - condition = f"~({condition})" - return condition - - -def _sanitize_column_name(column_name: str) -> str: - """Sanitize column name for pandas query syntax.""" - if "." in column_name or column_name.startswith("trustyai"): - return f"`{column_name}`" - return column_name - - -def _build_equals_condition(safe_column: str, values: List[Any], dtype) -> str: - """Build EQUALS condition for query with optimization.""" - if len(values) == 1: - val = _format_value_for_query(values[0], dtype) - return f"{safe_column} == {val}" - else: - formatted_values = [_format_value_for_query(v, dtype) for v in values] - values_str = "[" + ", ".join(formatted_values) + "]" - return f"{safe_column}.isin({values_str})" - - -def _build_between_condition(safe_column: str, values: List[Any], original_column: str, dtype) -> str: - """Build BETWEEN condition for query with comprehensive validation.""" - errors = [] - if len(values) != 2: - errors.append( - f"BETWEEN operation must contain exactly two values, describing the lower and upper bounds of the desired range. Received {len(values)} values" - ) - if original_column == "trustyai.TIMESTAMP": - if errors: - raise HTTPException(status_code=400, detail=", ".join(errors)) - try: - start_time = pd.to_datetime(str(values[0])) - end_time = pd.to_datetime(str(values[1])) - return f"'{start_time}' <= {safe_column} < '{end_time}'" - except Exception as e: - raise HTTPException( - status_code=400, detail=f"Timestamp value is unparseable as an ISO_LOCAL_DATE_TIME: {str(e)}" - ) - elif original_column == "trustyai.INDEX": - if errors: - raise HTTPException(status_code=400, detail=", ".join(errors)) - min_val, max_val = sorted([int(v) for v in values]) - return f"{min_val} <= {safe_column} < {max_val}" - else: - if not all(isinstance(v, numbers.Number) for v in values): - errors.append( - "BETWEEN operation must only contain numbers, describing the lower and upper bounds of the desired range. Received non-numeric values" - ) - if errors: - raise HTTPException(status_code=400, detail=", ".join(errors)) - min_val, max_val = sorted(values) - return f"{min_val} <= {safe_column} < {max_val}" - - -def _format_value_for_query(value: Any, dtype) -> str: - """Format value appropriately for pandas query syntax.""" - if isinstance(value, str): - escaped = value.replace("'", "\\'") - return f"'{escaped}'" - elif isinstance(value, (int, float)): - return str(value) - else: - escaped = str(value).replace("'", "\\'") - return f"'{escaped}'" \ No newline at end of file diff --git a/src/service/utils/list_utils.py b/src/service/utils/list_utils.py index b278227..8ca398f 100644 --- a/src/service/utils/list_utils.py +++ b/src/service/utils/list_utils.py @@ -15,10 +15,10 @@ def contains_non_numeric(l: list) -> bool: return isinstance(l, (bool, str)) -def serialize_rows(l: list): +def serialize_rows(l: list, max_void_type_length): """Convert a nested list to a 1D numpy array, where the nth element contains a bytes serialization of the nth row""" serialized = [np.void(pickle.dumps(row)) for row in l] - return np.array(serialized) + return np.array(serialized, dtype=f"V{max_void_type_length}") def deserialize_rows(serialized: np.ndarray): diff --git a/src/service/utils/upload.py b/src/service/utils/upload.py deleted file mode 100644 index 58f3942..0000000 --- a/src/service/utils/upload.py +++ /dev/null @@ -1,545 +0,0 @@ -import logging -import uuid -from datetime import datetime -from dataclasses import dataclass, field -from typing import Any, Dict, List, Optional, Tuple - -import numpy as np -from fastapi import HTTPException - -from src.service.constants import ( - INPUT_SUFFIX, - METADATA_SUFFIX, - OUTPUT_SUFFIX, - TRUSTYAI_TAG_PREFIX, -) -from src.service.data.modelmesh_parser import ModelMeshPayloadParser -from src.service.data.storage import get_storage_interface -from src.service.utils import list_utils -from src.endpoints.consumer.consumer_endpoint import process_payload - - -logger = logging.getLogger(__name__) - - -METADATA_STRING_MAX_LENGTH = 100 - - -class KServeDataAdapter: - """ - Convert upload tensors to consumer endpoint format. - """ - - def __init__(self, tensor_dict: Dict[str, Any], numpy_array: np.ndarray): - """Initialize adapter with validated data.""" - self._name = tensor_dict.get("name", "unknown") - self._shape = tensor_dict.get("shape", []) - self._datatype = tensor_dict.get("datatype", "FP64") - self._data = numpy_array # Keep numpy array intact - - @property - def name(self) -> str: - return self._name - - @property - def shape(self) -> List[int]: - return self._shape - - @property - def datatype(self) -> str: - return self._datatype - - @property - def data(self) -> np.ndarray: - """Returns numpy array with .shape attribute as expected by consumer endpoint.""" - return self._data - - -class ConsumerEndpointAdapter: - """ - Consumer endpoint's expected structure. - """ - - def __init__(self, adapted_tensors: List[KServeDataAdapter]): - self.tensors = adapted_tensors - self.id = f"upload_request_{uuid.uuid4().hex[:8]}" - - -async def process_upload_request(payload: Any) -> Dict[str, str]: - """ - Process complete upload request with validation and data handling. - """ - try: - model_name = ModelMeshPayloadParser.standardize_model_id(payload.model_name) - if payload.data_tag: - error = validate_data_tag(payload.data_tag) - if error: - raise HTTPException(400, error) - inputs = payload.request.get("inputs", []) - outputs = payload.response.get("outputs", []) if payload.response else [] - if not inputs: - raise HTTPException(400, "Missing input tensors") - if payload.is_ground_truth and not outputs: - raise HTTPException(400, "Ground truth uploads require output tensors") - - input_arrays, input_names, _, execution_ids = process_tensors_using_kserve_logic(inputs) - if outputs: - output_arrays, output_names, _, _ = process_tensors_using_kserve_logic(outputs) - else: - output_arrays, output_names = [], [] - error = validate_input_shapes(input_arrays, input_names) - if error: - raise HTTPException(400, f"One or more errors in input tensors: {error}") - if payload.is_ground_truth: - return await _process_ground_truth_data( - model_name, input_arrays, input_names, output_arrays, output_names, execution_ids - ) - else: - return await _process_regular_data( - model_name, input_arrays, input_names, output_arrays, output_names, execution_ids, payload.data_tag - ) - except ProcessingError as e: - raise HTTPException(400, str(e)) - except ValidationError as e: - raise HTTPException(400, str(e)) - - -async def _process_ground_truth_data( - model_name: str, - input_arrays: List[np.ndarray], - input_names: List[str], - output_arrays: List[np.ndarray], - output_names: List[str], - execution_ids: Optional[List[str]], -) -> Dict[str, str]: - """Process ground truth data upload.""" - if not execution_ids: - raise HTTPException(400, "Ground truth requires execution IDs") - result = await handle_ground_truths( - model_name, - input_arrays, - input_names, - output_arrays, - output_names, - [sanitize_id(id) for id in execution_ids], - ) - if not result.success: - raise HTTPException(400, result.message) - result_data = result.data - if result_data is None: - raise HTTPException(500, "Ground truth processing failed") - gt_name = f"{model_name}_ground_truth" - storage_interface = get_storage_interface() - await storage_interface.write_data(gt_name + OUTPUT_SUFFIX, result_data["outputs"], result_data["output_names"]) - await storage_interface.write_data( - gt_name + METADATA_SUFFIX, - result_data["metadata"], - result_data["metadata_names"], - ) - logger.info(f"Ground truth data saved for model: {model_name}") - return {"message": result.message} - - -async def _process_regular_data( - model_name: str, - input_arrays: List[np.ndarray], - input_names: List[str], - output_arrays: List[np.ndarray], - output_names: List[str], - execution_ids: Optional[List[str]], - data_tag: Optional[str], -) -> Dict[str, str]: - """Process regular model data upload.""" - n_rows = input_arrays[0].shape[0] - exec_ids = execution_ids or [str(uuid.uuid4()) for _ in range(n_rows)] - input_data = _flatten_tensor_data(input_arrays, n_rows) - output_data = _flatten_tensor_data(output_arrays, n_rows) - metadata, metadata_cols = _create_metadata(exec_ids, model_name, data_tag) - await save_model_data( - model_name, - np.array(input_data), - input_names, - np.array(output_data), - output_names, - metadata, - metadata_cols, - ) - logger.info(f"Regular data saved for model: {model_name}, rows: {n_rows}") - return {"message": f"{n_rows} datapoints added to {model_name}"} - - -def _flatten_tensor_data(arrays: List[np.ndarray], n_rows: int) -> List[List[Any]]: - """ - Flatten tensor arrays into row-based format for storage. - """ - - def flatten_row(arrays: List[np.ndarray], row: int) -> List[Any]: - """Flatten arrays for a single row.""" - return [x for arr in arrays for x in (arr[row].flatten() if arr.ndim > 1 else [arr[row]])] - - return [flatten_row(arrays, i) for i in range(n_rows)] - - -def _create_metadata( - execution_ids: List[str], model_name: str, data_tag: Optional[str] -) -> Tuple[np.ndarray, List[str]]: - """ - Create metadata array for model data storage. - """ - current_timestamp = datetime.now().isoformat() - metadata_cols = ["ID", "MODEL_ID", "TIMESTAMP", "TAG"] - metadata_rows = [ - [ - str(eid), - str(model_name), - str(current_timestamp), - str(data_tag or ""), - ] - for eid in execution_ids - ] - _validate_metadata_lengths(metadata_rows, metadata_cols) - metadata = np.array(metadata_rows, dtype=f" None: - """ - Validate that all metadata values fit within the defined string length limit. - """ - for row_idx, row in enumerate(metadata_rows): - for col_idx, value in enumerate(row): - value_str = str(value) - if len(value_str) > METADATA_STRING_MAX_LENGTH: - col_name = column_names[col_idx] if col_idx < len(column_names) else f"column_{col_idx}" - raise ValidationError( - f"Metadata field '{col_name}' in row {row_idx} exceeds maximum length " - f"of {METADATA_STRING_MAX_LENGTH} characters (got {len(value_str)} chars): " - f"'{value_str[:50]}{'...' if len(value_str) > 50 else ''}'" - ) - - -class ValidationError(Exception): - """Validation errors.""" - - pass - - -class ProcessingError(Exception): - """Processing errors.""" - - pass - - -@dataclass -class GroundTruthValidationResult: - """Result of ground truth validation.""" - - success: bool - message: str - data: Optional[Dict[str, Any]] = None - errors: List[str] = field(default_factory=list) - - -TYPE_MAP = { - np.int64: "Long", - np.int32: "Integer", - np.float32: "Float", - np.float64: "Double", - np.bool_: "Boolean", - int: "Long", - float: "Double", - bool: "Boolean", - str: "String", -} - - -def get_type_name(val: Any) -> str: - """Get Java-style type name for a value (used in ground truth validation).""" - if hasattr(val, "dtype"): - return TYPE_MAP.get(val.dtype.type, "String") - return TYPE_MAP.get(type(val), "String") - - -def sanitize_id(execution_id: str) -> str: - """Sanitize execution ID.""" - return str(execution_id).strip() - - -def extract_row_data(arrays: List[np.ndarray], row_index: int) -> List[Any]: - """Extract data from arrays for a specific row.""" - row_data = [] - for arr in arrays: - if arr.ndim > 1: - row_data.extend(arr[row_index].flatten()) - else: - row_data.append(arr[row_index]) - return row_data - - -def process_tensors_using_kserve_logic( - tensors: List[Dict[str, Any]], -) -> Tuple[List[np.ndarray], List[str], List[str], Optional[List[str]]]: - """ - Process tensor data using consumer endpoint logic via clean adapter pattern. - """ - if not tensors: - return [], [], [], None - validation_errors = _validate_tensor_inputs(tensors) - if validation_errors: - error_message = "One or more errors occurred: " + ". ".join(validation_errors) - raise HTTPException(400, error_message) - adapted_tensors = [] - execution_ids = None - datatypes = [] - for tensor in tensors: - if execution_ids is None: - execution_ids = tensor.get("execution_ids") - numpy_array = _convert_tensor_to_numpy(tensor) - adapter = KServeDataAdapter(tensor, numpy_array) - adapted_tensors.append(adapter) - datatypes.append(adapter.datatype) - try: - adapter_payload = ConsumerEndpointAdapter(adapted_tensors) - tensor_array, column_names = process_payload(adapter_payload, lambda payload: payload.tensors) - arrays, all_names = _convert_consumer_results_to_upload_format(tensor_array, column_names, adapted_tensors) - return arrays, all_names, datatypes, execution_ids - except Exception as e: - logger.error(f"Consumer endpoint processing failed: {e}") - raise HTTPException(400, f"Tensor processing error: {str(e)}") - - -def _validate_tensor_inputs(tensors: List[Dict[str, Any]]) -> List[str]: - """Validate tensor inputs and return list of error messages.""" - errors = [] - tensor_names = [tensor.get("name", f"tensor_{i}") for i, tensor in enumerate(tensors)] - if len(tensor_names) != len(set(tensor_names)): - errors.append("Input tensors must have unique names") - shapes = [tensor.get("shape", []) for tensor in tensors] - if len(shapes) > 1: - first_dims = [shape[0] if shape else 0 for shape in shapes] - if len(set(first_dims)) > 1: - errors.append(f"Input tensors must have consistent first dimension. Found: {first_dims}") - return errors - - -def _convert_tensor_to_numpy(tensor: Dict[str, Any]) -> np.ndarray: - """Convert tensor dictionary to numpy array with proper dtype.""" - raw_data = tensor.get("data", []) - - if list_utils.contains_non_numeric(raw_data): - return np.array(raw_data, dtype="O") - dtype_map = {"INT64": np.int64, "INT32": np.int32, "FP32": np.float32, "FP64": np.float64, "BOOL": np.bool_} - datatype = tensor.get("datatype", "FP64") - np_dtype = dtype_map.get(datatype, np.float64) - return np.array(raw_data, dtype=np_dtype) - - -def _convert_consumer_results_to_upload_format( - tensor_array: np.ndarray, column_names: List[str], adapted_tensors: List[KServeDataAdapter] -) -> Tuple[List[np.ndarray], List[str]]: - """Convert consumer endpoint results back to upload format.""" - if len(adapted_tensors) == 1: - # Single tensor case - return [tensor_array], column_names - arrays = [] - all_names = [] - col_start = 0 - for adapter in adapted_tensors: - if len(adapter.shape) > 1: - n_cols = adapter.shape[1] - tensor_names = [f"{adapter.name}-{i}" for i in range(n_cols)] - else: - n_cols = 1 - tensor_names = [adapter.name] - if tensor_array.ndim == 2: - tensor_data = tensor_array[:, col_start : col_start + n_cols] - else: - tensor_data = tensor_array[col_start : col_start + n_cols] - arrays.append(tensor_data) - all_names.extend(tensor_names) - col_start += n_cols - return arrays, all_names - - -def validate_input_shapes(input_arrays: List[np.ndarray], input_names: List[str]) -> Optional[str]: - """Validate input array shapes and names - collect ALL errors.""" - if not input_arrays: - return None - errors = [] - if len(set(input_names)) != len(input_names): - errors.append("Input tensors must have unique names") - first_dim = input_arrays[0].shape[0] - for i, arr in enumerate(input_arrays[1:], 1): - if arr.shape[0] != first_dim: - errors.append( - f"Input tensor '{input_names[i]}' has first dimension {arr.shape[0]}, " - f"which doesn't match the first dimension {first_dim} of '{input_names[0]}'" - ) - if errors: - return ". ".join(errors) + "." - return None - - -def validate_data_tag(tag: str) -> Optional[str]: - """Validate data tag format and content.""" - if not tag: - return None - if tag.startswith(TRUSTYAI_TAG_PREFIX): - return ( - f"The tag prefix '{TRUSTYAI_TAG_PREFIX}' is reserved for internal TrustyAI use only. " - f"Provided tag '{tag}' violates this restriction." - ) - return None - - -class GroundTruthValidator: - """Ground truth validator.""" - - def __init__(self, model_name: str): - self.model_name = model_name - self.id_to_row: Dict[str, int] = {} - self.inputs: Optional[np.ndarray] = None - self.outputs: Optional[np.ndarray] = None - self.metadata: Optional[np.ndarray] = None - - async def initialize(self) -> None: - """Load existing data.""" - storage_interface = get_storage_interface() - self.inputs, _ = await storage_interface.read_data(self.model_name + INPUT_SUFFIX) - self.outputs, _ = await storage_interface.read_data(self.model_name + OUTPUT_SUFFIX) - self.metadata, _ = await storage_interface.read_data(self.model_name + METADATA_SUFFIX) - metadata_cols = await storage_interface.get_original_column_names(self.model_name + METADATA_SUFFIX) - id_col = next((i for i, name in enumerate(metadata_cols) if name.upper() == "ID"), 0) - if self.metadata is not None: - for j, row in enumerate(self.metadata): - id_val = row[id_col] - self.id_to_row[str(id_val)] = j - - def find_row(self, exec_id: str) -> Optional[int]: - """Find row index for execution ID.""" - return self.id_to_row.get(str(exec_id)) - - async def validate_data( - self, - exec_id: str, - uploaded_inputs: List[Any], - uploaded_outputs: List[Any], - row_idx: int, - input_names: Optional[List[str]] = None, - output_names: Optional[List[str]] = None, - ) -> Optional[str]: - """Validate inputs and outputs.""" - if self.inputs is None or self.outputs is None: - return f"ID={exec_id} no existing data found" - existing_inputs = self.inputs[row_idx] - existing_outputs = self.outputs[row_idx] - for i, (existing, uploaded) in enumerate(zip(existing_inputs[:3], uploaded_inputs[:3])): - if hasattr(existing, "dtype"): - print( - f" Input {i}: existing.dtype={existing.dtype}, uploaded.dtype={getattr(uploaded, 'dtype', 'no dtype')}" - ) - print(f" Input {i}: existing={existing}, uploaded={uploaded}") - for i, (existing, uploaded) in enumerate(zip(existing_outputs[:2], uploaded_outputs[:2])): - if hasattr(existing, "dtype"): - print( - f" Output {i}: existing.dtype={existing.dtype}, uploaded.dtype={getattr(uploaded, 'dtype', 'no dtype')}" - ) - print(f" Output {i}: existing={existing}, uploaded={uploaded}") - if len(existing_inputs) != len(uploaded_inputs): - return f"ID={exec_id} input shapes do not match. Observed inputs have length={len(existing_inputs)} while uploaded inputs have length={len(uploaded_inputs)}" - for i, (existing, uploaded) in enumerate(zip(existing_inputs, uploaded_inputs)): - existing_type = get_type_name(existing) - uploaded_type = get_type_name(uploaded) - print(f" Input {i}: existing_type='{existing_type}', uploaded_type='{uploaded_type}'") - if existing_type != uploaded_type: - return f"ID={exec_id} input type mismatch at position {i + 1}: Class={existing_type} != Class={uploaded_type}" - if existing != uploaded: - return f"ID={exec_id} inputs are not identical: value mismatch at position {i + 1}" - if len(existing_outputs) != len(uploaded_outputs): - return f"ID={exec_id} output shapes do not match. Observed outputs have length={len(existing_outputs)} while uploaded ground-truths have length={len(uploaded_outputs)}" - for i, (existing, uploaded) in enumerate(zip(existing_outputs, uploaded_outputs)): - existing_type = get_type_name(existing) - uploaded_type = get_type_name(uploaded) - print(f" Output {i}: existing_type='{existing_type}', uploaded_type='{uploaded_type}'") - if existing_type != uploaded_type: - return f"ID={exec_id} output type mismatch at position {i + 1}: Class={existing_type} != Class={uploaded_type}" - return None - - -async def handle_ground_truths( - model_name: str, - input_arrays: List[np.ndarray], - input_names: List[str], - output_arrays: List[np.ndarray], - output_names: List[str], - execution_ids: List[str], - config: Optional[Any] = None, -) -> GroundTruthValidationResult: - """Handle ground truth validation.""" - if not execution_ids: - return GroundTruthValidationResult(success=False, message="No execution IDs provided.") - storage_interface = get_storage_interface() - if not await storage_interface.dataset_exists(model_name + INPUT_SUFFIX): - return GroundTruthValidationResult(success=False, message=f"Model {model_name} not found.") - validator = GroundTruthValidator(model_name) - await validator.initialize() - errors = [] - valid_outputs = [] - valid_metadata = [] - n_rows = input_arrays[0].shape[0] if input_arrays else 0 - for i, exec_id in enumerate(execution_ids): - if i >= n_rows: - errors.append(f"ID={exec_id} index out of bounds") - continue - row_idx = validator.find_row(exec_id) - if row_idx is None: - errors.append(f"ID={exec_id} not found") - continue - uploaded_inputs = extract_row_data(input_arrays, i) - uploaded_outputs = extract_row_data(output_arrays, i) - error = await validator.validate_data(exec_id, uploaded_inputs, uploaded_outputs, row_idx) - if error: - errors.append(error) - continue - valid_outputs.append(uploaded_outputs) - valid_metadata.append([exec_id]) - if errors: - return GroundTruthValidationResult( - success=False, - message="Found fatal mismatches between uploaded data and recorded inference data:\n" - + "\n".join(errors[:5]), - errors=errors, - ) - if not valid_outputs: - return GroundTruthValidationResult(success=False, message="No valid ground truths found.") - return GroundTruthValidationResult( - success=True, - message=f"{len(valid_outputs)} ground truths added.", - data={ - "outputs": np.array(valid_outputs), - "output_names": output_names, - "metadata": np.array(valid_metadata), - "metadata_names": ["ID"], - }, - ) - - -async def save_model_data( - model_name: str, - input_data: np.ndarray, - input_names: List[str], - output_data: np.ndarray, - output_names: List[str], - metadata_data: np.ndarray, - metadata_names: List[str], -) -> Dict[str, Any]: - """Save model data to storage.""" - storage_interface = get_storage_interface() - await storage_interface.write_data(model_name + INPUT_SUFFIX, input_data, input_names) - await storage_interface.write_data(model_name + OUTPUT_SUFFIX, output_data, output_names) - await storage_interface.write_data(model_name + METADATA_SUFFIX, metadata_data, metadata_names) - logger.info(f"Saved model data for {model_name}: {len(input_data)} rows") - return { - "model_name": model_name, - "rows": len(input_data), - } \ No newline at end of file diff --git a/tests/endpoints/test_download_endpoint.py b/tests/endpoints/test_download_endpoint.py deleted file mode 100644 index 647f694..0000000 --- a/tests/endpoints/test_download_endpoint.py +++ /dev/null @@ -1,404 +0,0 @@ -import uuid -from datetime import datetime, timedelta -from io import StringIO -from unittest.mock import patch - -import numpy as np -import pandas as pd -import pytest -from fastapi.testclient import TestClient - -from src.main import app - -client = TestClient(app) - - -class DataframeGenerators: - """Python equivalent of Java DataframeGenerators""" - - @staticmethod - def generate_random_dataframe(observations: int, feature_diversity: int = 100) -> pd.DataFrame: - random = np.random.RandomState(0) - data = { - "age": [], - "gender": [], - "race": [], - "income": [], - "trustyai.ID": [], - "trustyai.MODEL_ID": [], - "trustyai.TIMESTAMP": [], - "trustyai.TAG": [], - "trustyai.INDEX": [], - } - for i in range(observations): - data["age"].append(i % feature_diversity) - data["gender"].append(1 if random.choice([True, False]) else 0) - data["race"].append(1 if random.choice([True, False]) else 0) - data["income"].append(1 if random.choice([True, False]) else 0) - data["trustyai.ID"].append(str(uuid.uuid4())) - data["trustyai.MODEL_ID"].append("example1") - data["trustyai.TIMESTAMP"].append((datetime.now() - timedelta(seconds=i)).isoformat()) - data["trustyai.TAG"].append("") - data["trustyai.INDEX"].append(i) - return pd.DataFrame(data) - - @staticmethod - def generate_random_text_dataframe(observations: int, seed: int = 0) -> pd.DataFrame: - if seed < 0: - random = np.random.RandomState(0) - else: - random = np.random.RandomState(seed) - makes = ["Ford", "Chevy", "Dodge", "GMC", "Buick"] - colors = ["Red", "Blue", "White", "Black", "Purple", "Green", "Yellow"] - data = { - "year": [], - "make": [], - "color": [], - "value": [], - "trustyai.ID": [], - "trustyai.MODEL_ID": [], - "trustyai.TIMESTAMP": [], - "trustyai.TAG": [], - "trustyai.INDEX": [], - } - for i in range(observations): - data["year"].append(1970 + i % 50) - data["make"].append(makes[i % len(makes)]) - data["color"].append(colors[i % len(colors)]) - data["value"].append(random.random() * 50) - data["trustyai.ID"].append(str(uuid.uuid4())) - data["trustyai.MODEL_ID"].append("example1") - data["trustyai.TIMESTAMP"].append((datetime.now() - timedelta(seconds=i)).isoformat()) - data["trustyai.TAG"].append("") - data["trustyai.INDEX"].append(i) - return pd.DataFrame(data) - - -# Mock storage for testing -class MockStorage: - def __init__(self): - self.data = {} - - async def read_data(self, dataset_name: str): - if dataset_name.endswith("_outputs"): - model_id = dataset_name.replace("_outputs", "") - if model_id not in self.data: - raise Exception(f"Model {model_id} not found") - output_data = self.data[model_id].get("output") - output_cols = self.data[model_id].get("output_cols", []) - return output_data, output_cols - elif dataset_name.endswith("_metadata"): - model_id = dataset_name.replace("_metadata", "") - if model_id not in self.data: - raise Exception(f"Model {model_id} not found") - metadata_data = self.data[model_id].get("metadata") - metadata_cols = ["ID", "MODEL_ID", "TIMESTAMP", "TAG", "INDEX"] - return metadata_data, metadata_cols - elif dataset_name.endswith("_inputs"): - model_id = dataset_name.replace("_inputs", "") - if model_id not in self.data: - raise Exception(f"Model {model_id} not found") - input_data = self.data[model_id].get("input") - input_cols = self.data[model_id].get("input_cols", []) - return input_data, input_cols - else: - raise Exception(f"Unknown dataset: {dataset_name}") - - def save_dataframe(self, df: pd.DataFrame, model_id: str): - input_cols = [col for col in df.columns if not col.startswith("trustyai.") and col not in ["income", "value"]] - output_cols = [col for col in df.columns if col in ["income", "value"]] - metadata_cols = [col for col in df.columns if col.startswith("trustyai.")] - input_data = df[input_cols].values if input_cols else np.array([]) - output_data = df[output_cols].values if output_cols else np.array([]) - metadata_data_cols = ["ID", "MODEL_ID", "TIMESTAMP", "TAG", "INDEX"] - metadata_values = [] - for _, row in df.iterrows(): - row_data = [] - for col in metadata_data_cols: - trusty_col = f"trustyai.{col}" - if trusty_col in df.columns: - value = row[trusty_col] - if col == "INDEX": - row_data.append(int(value)) - else: - row_data.append(str(value)) - else: - row_data.append("" if col != "INDEX" else 0) - metadata_values.append(row_data) - metadata_data = np.array(metadata_values, dtype=object) - self.data[model_id] = { - "dataframe": df, - "input": input_data, - "input_cols": input_cols, - "output": output_data, - "output_cols": output_cols, - "metadata": metadata_data, - } - - def reset(self): - self.data.clear() - - -mock_storage = MockStorage() - - -@pytest.fixture(autouse=True) -def setup_storage(): - """Setup mock storage for all tests""" - with patch("src.service.utils.download.get_storage_interface", return_value=mock_storage): - yield - - -@pytest.fixture(autouse=True) -def reset_storage(): - """Reset storage before each test""" - mock_storage.reset() - yield - - -# Test constants -MODEL_ID = "example1" - - -def test_download_data(): - """equivalent of Java downloadData() test""" - dataframe = DataframeGenerators.generate_random_dataframe(1000) - mock_storage.save_dataframe(dataframe, MODEL_ID) - - payload = { - "modelId": MODEL_ID, - "matchAll": [ - {"columnName": "gender", "operation": "EQUALS", "values": [0]}, - {"columnName": "race", "operation": "EQUALS", "values": [0]}, - {"columnName": "income", "operation": "EQUALS", "values": [0]}, - ], - "matchAny": [ - {"columnName": "age", "operation": "BETWEEN", "values": [5, 10]}, - {"columnName": "age", "operation": "BETWEEN", "values": [50, 70]}, - ], - "matchNone": [{"columnName": "age", "operation": "BETWEEN", "values": [55, 65]}], - } - response = client.post("/data/download", json=payload) - assert response.status_code == 200 - result = response.json() - result_df = pd.read_csv(StringIO(result["dataCSV"])) - assert len(result_df[(result_df["age"] > 55) & (result_df["age"] < 65)]) == 0 - assert len(result_df[result_df["gender"] == 1]) == 0 - assert len(result_df[result_df["race"] == 1]) == 0 - assert len(result_df[result_df["income"] == 1]) == 0 - assert len(result_df[(result_df["age"] >= 10) & (result_df["age"] < 50)]) == 0 - assert len(result_df[result_df["age"] > 70]) == 0 - - -def test_download_text_data(): - """equivalent of Java downloadTextData() test""" - dataframe = DataframeGenerators.generate_random_text_dataframe(1000) - mock_storage.save_dataframe(dataframe, MODEL_ID) - - payload = { - "modelId": MODEL_ID, - "matchAll": [ - { - "columnName": "make", - "operation": "EQUALS", - "values": ["Chevy", "Ford", "Dodge"], - }, - { - "columnName": "year", - "operation": "BETWEEN", - "values": [1990, 2050], - }, - ], - } - - response = client.post("/data/download", json=payload) - assert response.status_code == 200 - result = response.json() - result_df = pd.read_csv(StringIO(result["dataCSV"])) - assert len(result_df[result_df["year"] < 1990]) == 0 - assert len(result_df[result_df["make"] == "GMC"]) == 0 - assert len(result_df[result_df["make"] == "Buick"]) == 0 - - -def test_download_text_data_between_error(): - """equivalent of Java downloadTextDataBetweenError() test""" - dataframe = DataframeGenerators.generate_random_text_dataframe(1000) - mock_storage.save_dataframe(dataframe, MODEL_ID) - payload = { - "modelId": MODEL_ID, - "matchAll": [ - { - "columnName": "make", - "operation": "BETWEEN", - "values": ["Chevy", "Ford", "Dodge"], - } - ], - } - response = client.post("/data/download", json=payload) - assert response.status_code == 400 - assert ( - "BETWEEN operation must contain exactly two values, describing the lower and upper bounds of the desired range. Received 3 values" - in response.text - ) - assert ( - "BETWEEN operation must only contain numbers, describing the lower and upper bounds of the desired range. Received non-numeric values" - in response.text - ) - - -def test_download_text_data_invalid_column_error(): - """equivalent of Java downloadTextDataInvalidColumnError() test""" - dataframe = DataframeGenerators.generate_random_text_dataframe(1000) - mock_storage.save_dataframe(dataframe, MODEL_ID) - payload = { - "modelId": MODEL_ID, - "matchAll": [ - { - "columnName": "mak123e", - "operation": "EQUALS", - "values": ["Chevy", "Ford"], - } - ], - } - - response = client.post("/data/download", json=payload) - assert response.status_code == 400 - assert "No feature or output found with name=" in response.text - - -def test_download_text_data_invalid_operation_error(): - """equivalent of Java downloadTextDataInvalidOperationError() test""" - dataframe = DataframeGenerators.generate_random_text_dataframe(1000) - mock_storage.save_dataframe(dataframe, MODEL_ID) - payload = { - "modelId": MODEL_ID, - "matchAll": [ - { - "columnName": "mak123e", - "operation": "DOESNOTEXIST", - "values": ["Chevy", "Ford"], - } - ], - } - response = client.post("/data/download", json=payload) - assert response.status_code == 400 - assert "RowMatch operation must be one of [BETWEEN, EQUALS]" in response.text - - -def test_download_text_data_internal_column(): - """equivalent of Java downloadTextDataInternalColumn() test""" - dataframe = DataframeGenerators.generate_random_text_dataframe(1000) - dataframe.loc[0:499, "trustyai.TAG"] = "TRAINING" - mock_storage.save_dataframe(dataframe, MODEL_ID) - payload = { - "modelId": MODEL_ID, - "matchAll": [ - { - "columnName": "trustyai.TAG", - "operation": "EQUALS", - "values": ["TRAINING"], - } - ], - } - response = client.post("/data/download", json=payload) - assert response.status_code == 200 - result = response.json() - result_df = pd.read_csv(StringIO(result["dataCSV"])) - assert len(result_df) == 500 - - -def test_download_text_data_internal_column_index(): - """equivalent of Java downloadTextDataInternalColumnIndex() test""" - dataframe = DataframeGenerators.generate_random_text_dataframe(1000) - mock_storage.save_dataframe(dataframe, MODEL_ID) - expected_rows = dataframe.iloc[0:10].copy() - payload = { - "modelId": MODEL_ID, - "matchAll": [ - { - "columnName": "trustyai.INDEX", - "operation": "BETWEEN", - "values": [0, 10], - } - ], - } - response = client.post("/data/download", json=payload) - print(f"Response status: {response.status_code}") - print(f"Response text: {response.text}") - assert response.status_code == 200 - result = response.json() - result_df = pd.read_csv(StringIO(result["dataCSV"])) - assert len(result_df) == 10 - input_cols = ["year", "make", "color"] - for i in range(10): - for col in input_cols: - assert result_df.iloc[i][col] == expected_rows.iloc[i][col], f"Row {i}, column {col} mismatch" - - -def test_download_text_data_internal_column_timestamp(): - """equivalent of Java downloadTextDataInternalColumnTimestamp() test""" - dataframe = DataframeGenerators.generate_random_text_dataframe(1, -1) - base_time = datetime.now() - for i in range(100): - new_row = DataframeGenerators.generate_random_text_dataframe(1, i) - # Use milliseconds to simulate Thread.sleep(1) and ensure ascending order - timestamp = (base_time + timedelta(milliseconds=i + 1)).isoformat() - # Fix this line - change to UPPERCASE - new_row["trustyai.TIMESTAMP"] = [timestamp] - dataframe = pd.concat([dataframe, new_row], ignore_index=True) - mock_storage.save_dataframe(dataframe, MODEL_ID) - extract_idx = 50 - n_to_get = 10 - expected_rows = dataframe.iloc[extract_idx : extract_idx + n_to_get].copy() - start_time = dataframe.iloc[extract_idx]["trustyai.TIMESTAMP"] - end_time = dataframe.iloc[extract_idx + n_to_get]["trustyai.TIMESTAMP"] - payload = { - "modelId": MODEL_ID, - "matchAny": [ - { - "columnName": "trustyai.TIMESTAMP", - "operation": "BETWEEN", - "values": [start_time, end_time], - } - ], - } - response = client.post("/data/download", json=payload) - assert response.status_code == 200 - result = response.json() - result_df = pd.read_csv(StringIO(result["dataCSV"])) - assert len(result_df) == 10 - input_cols = ["year", "make", "color"] - for i in range(10): - for col in input_cols: - assert result_df.iloc[i][col] == expected_rows.iloc[i][col], f"Row {i}, column {col} mismatch" - - -def test_download_text_data_internal_column_timestamp_unparseable(): - """equivalent of Java downloadTextDataInternalColumnTimestampUnparseable() test""" - dataframe = DataframeGenerators.generate_random_text_dataframe(1000) - mock_storage.save_dataframe(dataframe, MODEL_ID) - payload = { - "modelId": MODEL_ID, - "matchAny": [ - { - "columnName": "trustyai.TIMESTAMP", - "operation": "BETWEEN", - "values": ["not a timestamp", "also not a timestamp"], - } - ], - } - response = client.post("/data/download", json=payload) - assert response.status_code == 400 - assert "unparseable as an ISO_LOCAL_DATE_TIME" in response.text - - -def test_download_text_data_null_request(): - """equivalent of Java downloadTextDataNullRequest() test""" - dataframe = DataframeGenerators.generate_random_text_dataframe(1000) - mock_storage.save_dataframe(dataframe, MODEL_ID) - payload = {"modelId": MODEL_ID} - response = client.post("/data/download", json=payload) - assert response.status_code == 200 - result = response.json() - result_df = pd.read_csv(StringIO(result["dataCSV"])) - assert len(result_df) == 1000 \ No newline at end of file diff --git a/tests/endpoints/test_upload_endpoint.py b/tests/endpoints/test_upload_endpoint.py deleted file mode 100644 index 12ccf09..0000000 --- a/tests/endpoints/test_upload_endpoint.py +++ /dev/null @@ -1,484 +0,0 @@ -import copy -import json -import os -import pickle -import shutil -import sys -import tempfile -import uuid - -import h5py -import numpy as np -import pytest - -TEMP_DIR = tempfile.mkdtemp() -os.environ["STORAGE_DATA_FOLDER"] = TEMP_DIR -from fastapi.testclient import TestClient - -from src.main import app -from src.service.constants import ( - INPUT_SUFFIX, - METADATA_SUFFIX, - OUTPUT_SUFFIX, - TRUSTYAI_TAG_PREFIX, -) -from src.service.data.storage import get_storage_interface - - -def pytest_sessionfinish(session, exitstatus): - """Clean up the temporary directory after all tests are done.""" - if os.path.exists(TEMP_DIR): - shutil.rmtree(TEMP_DIR) - - -pytest.hookimpl(pytest_sessionfinish) -client = TestClient(app) -MODEL_ID = "example1" - - -def generate_payload(n_rows, n_input_cols, n_output_cols, datatype, tag, input_offset=0, output_offset=0): - """Generate a test payload with specific dimensions and data types.""" - model_name = f"{MODEL_ID}_{uuid.uuid4().hex[:8]}" - input_data = [] - for i in range(n_rows): - if n_input_cols == 1: - input_data.append(i + input_offset) - else: - row = [i + j + input_offset for j in range(n_input_cols)] - input_data.append(row) - output_data = [] - for i in range(n_rows): - if n_output_cols == 1: - output_data.append(i * 2 + output_offset) - else: - row = [i * 2 + j + output_offset for j in range(n_output_cols)] - output_data.append(row) - payload = { - "model_name": model_name, - "data_tag": tag, - "is_ground_truth": False, - "request": { - "inputs": [ - { - "name": "input", - "shape": [n_rows, n_input_cols] if n_input_cols > 1 else [n_rows], - "datatype": datatype, - "data": input_data, - } - ] - }, - "response": { - "outputs": [ - { - "name": "output", - "shape": [n_rows, n_output_cols] if n_output_cols > 1 else [n_rows], - "datatype": datatype, - "data": output_data, - } - ] - }, - } - return payload - - -def generate_multi_input_payload(n_rows, n_input_cols, n_output_cols, datatype, tag): - """Generate a test payload with multi-dimensional tensors like real data.""" - model_name = f"{MODEL_ID}_{uuid.uuid4().hex[:8]}" - input_data = [] - for row_idx in range(n_rows): - row = [row_idx + col_idx * 10 for col_idx in range(n_input_cols)] - input_data.append(row) - output_data = [] - for row_idx in range(n_rows): - row = [row_idx * 2 + col_idx for col_idx in range(n_output_cols)] - output_data.append(row) - payload = { - "model_name": model_name, - "data_tag": tag, - "is_ground_truth": False, - "request": { - "inputs": [ - { - "name": "multi_input", - "shape": [n_rows, n_input_cols], - "datatype": datatype, - "data": input_data, - } - ] - }, - "response": { - "outputs": [ - { - "name": "multi_output", - "shape": [n_rows, n_output_cols], - "datatype": datatype, - "data": output_data, - } - ] - }, - } - return payload - - -def generate_mismatched_shape_no_unique_name_multi_input_payload(n_rows, n_input_cols, n_output_cols, datatype, tag): - """Generate a payload with mismatched shapes and non-unique names.""" - model_name = f"{MODEL_ID}_{uuid.uuid4().hex[:8]}" - input_data_1 = [[row_idx + col_idx * 10 for col_idx in range(n_input_cols)] for row_idx in range(n_rows)] - mismatched_rows = n_rows - 1 if n_rows > 1 else 1 - input_data_2 = [[row_idx + col_idx * 20 for col_idx in range(n_input_cols)] for row_idx in range(mismatched_rows)] - output_data = [[row_idx * 2 + col_idx for col_idx in range(n_output_cols)] for row_idx in range(n_rows)] - payload = { - "model_name": model_name, - "data_tag": tag, - "is_ground_truth": False, - "request": { - "inputs": [ - { - "name": "same_name", - "shape": [n_rows, n_input_cols], - "datatype": datatype, - "data": input_data_1, - }, - { - "name": "same_name", - "shape": [mismatched_rows, n_input_cols], - "datatype": datatype, - "data": input_data_2, - }, - ] - }, - "response": { - "outputs": [ - { - "name": "multi_output", - "shape": [n_rows, n_output_cols], - "datatype": datatype, - "data": output_data, - } - ] - }, - } - return payload - - -def get_data_from_storage(model_name, suffix): - """Get data from storage file.""" - storage = get_storage_interface() - filename = storage._get_filename(model_name + suffix) - if not os.path.exists(filename): - return None - with h5py.File(filename, "r") as f: - if model_name + suffix in f: - data = f[model_name + suffix][:] - column_names = f[model_name + suffix].attrs.get("column_names", []) - return {"data": data, "column_names": column_names} - - -def get_metadata_ids(model_name): - """Extract actual IDs from metadata storage.""" - storage = get_storage_interface() - filename = storage._get_filename(model_name + METADATA_SUFFIX) - if not os.path.exists(filename): - return [] - ids = [] - with h5py.File(filename, "r") as f: - if model_name + METADATA_SUFFIX in f: - metadata = f[model_name + METADATA_SUFFIX][:] - column_names = f[model_name + METADATA_SUFFIX].attrs.get("column_names", []) - id_idx = next((i for i, name in enumerate(column_names) if name.lower() == "id"), None) - if id_idx is not None: - for row in metadata: - try: - if hasattr(row, "__getitem__") and len(row) > id_idx: - id_val = row[id_idx] - else: - row_data = pickle.loads(row.tobytes()) - id_val = row_data[id_idx] - if isinstance(id_val, np.ndarray): - ids.append(str(id_val)) - else: - ids.append(str(id_val)) - except Exception as e: - print(f"Error processing ID from row {len(ids)}: {e}") - continue - print(f"Successfully extracted {len(ids)} IDs: {ids}") - return ids - - -def get_metadata_from_storage(model_name): - """Get metadata directly from storage file.""" - storage = get_storage_interface() - filename = storage._get_filename(model_name + METADATA_SUFFIX) - if not os.path.exists(filename): - return {"data": [], "column_names": []} - with h5py.File(filename, "r") as f: - if model_name + METADATA_SUFFIX in f: - metadata = f[model_name + METADATA_SUFFIX][:] - column_names = f[model_name + METADATA_SUFFIX].attrs.get("column_names", []) - parsed_rows = [] - for row in metadata: - try: - row_data = pickle.loads(row.tobytes()) - parsed_rows.append(row_data) - except Exception as e: - print(f"Error unpickling metadata row: {e}") - - return {"data": parsed_rows, "column_names": column_names} - return {"data": [], "column_names": []} - - -def count_rows_with_tag(model_name, tag): - """Count rows with a specific tag in metadata.""" - storage = get_storage_interface() - filename = storage._get_filename(model_name + METADATA_SUFFIX) - if not os.path.exists(filename): - return 0 - count = 0 - with h5py.File(filename, "r") as f: - if model_name + METADATA_SUFFIX in f: - metadata = f[model_name + METADATA_SUFFIX][:] - column_names = f[model_name + METADATA_SUFFIX].attrs.get("column_names", []) - tag_idx = next( - (i for i, name in enumerate(column_names) if name.lower() == "tag"), - None, - ) - if tag_idx is not None: - for row in metadata: - try: - row_data = pickle.loads(row.tobytes()) - if tag_idx < len(row_data) and row_data[tag_idx] == tag: - count += 1 - except Exception as e: - print(f"Error unpickling tag: {e}") - return count - - -def post_test(payload, expected_status_code, check_msgs): - """Post a payload and check the response.""" - response = client.post("/data/upload", json=payload) - if response.status_code != expected_status_code: - print(f"\n=== DEBUG INFO ===") - print(f"Expected status: {expected_status_code}") - print(f"Actual status: {response.status_code}") - print(f"Response text: {response.text}") - print(f"Response headers: {dict(response.headers)}") - if hasattr(response, "json"): - try: - print(f"Response JSON: {response.json()}") - except: - pass - print(f"==================") - - assert response.status_code == expected_status_code - return response - - -# data upload tests -@pytest.mark.parametrize("n_input_rows", [1, 5, 250]) -@pytest.mark.parametrize("n_input_cols", [1, 4]) -@pytest.mark.parametrize("n_output_cols", [1, 2]) -@pytest.mark.parametrize("datatype", ["INT64", "INT32", "FP32", "FP64", "BOOL"]) -def test_upload_data(n_input_rows, n_input_cols, n_output_cols, datatype): - """Test uploading data with various dimensions and datatypes.""" - data_tag = "TRAINING" - payload = generate_payload(n_input_rows, n_input_cols, n_output_cols, datatype, data_tag) - response = post_test(payload, 200, [f"{n_input_rows} datapoints"]) - inputs = get_data_from_storage(payload["model_name"], INPUT_SUFFIX) - outputs = get_data_from_storage(payload["model_name"], OUTPUT_SUFFIX) - assert inputs is not None, "Input data not found in storage" - assert outputs is not None, "Output data not found in storage" - assert len(inputs["data"]) == n_input_rows, "Incorrect number of input rows" - assert len(outputs["data"]) == n_input_rows, "Incorrect number of output rows" - tag_count = count_rows_with_tag(payload["model_name"], data_tag) - assert tag_count == n_input_rows, "Not all rows have the correct tag" - - -@pytest.mark.parametrize("n_rows", [1, 3, 5, 250]) -@pytest.mark.parametrize("n_input_cols", [2, 6]) -@pytest.mark.parametrize("n_output_cols", [4]) -@pytest.mark.parametrize("datatype", ["INT64", "INT32", "FP32", "FP64", "BOOL"]) -def test_upload_multi_input_data(n_rows, n_input_cols, n_output_cols, datatype): - """Test uploading data with multiple input tensors.""" - data_tag = "TRAINING" - payload = generate_multi_input_payload(n_rows, n_input_cols, n_output_cols, datatype, data_tag) - response = post_test(payload, 200, [f"{n_rows} datapoints"]) - inputs = get_data_from_storage(payload["model_name"], INPUT_SUFFIX) - outputs = get_data_from_storage(payload["model_name"], OUTPUT_SUFFIX) - assert inputs is not None, "Input data not found in storage" - assert outputs is not None, "Output data not found in storage" - assert len(inputs["data"]) == n_rows, "Incorrect number of input rows" - assert len(outputs["data"]) == n_rows, "Incorrect number of output rows" - assert len(inputs["column_names"]) == n_input_cols, "Incorrect number of input columns" - assert len(outputs["column_names"]) == n_output_cols, "Incorrect number of output columns" - assert len(inputs["column_names"]) >= 2, "Should have at least 2 input column names" - tag_count = count_rows_with_tag(payload["model_name"], data_tag) - assert tag_count == n_rows, "Not all rows have the correct tag" - - -def test_upload_multi_input_data_no_unique_name(): - """Test error case for non-unique tensor names.""" - payload = generate_mismatched_shape_no_unique_name_multi_input_payload(250, 4, 3, "FP64", "TRAINING") - response = client.post("/data/upload", json=payload) - assert response.status_code == 400 - assert "One or more errors" in response.text - assert "unique names" in response.text - assert "first dimension" in response.text - - -def test_upload_multiple_tagging(): - """Test uploading data with multiple tags.""" - n_payload1 = 50 - n_payload2 = 51 - tag1 = "TRAINING" - tag2 = "NOT TRAINING" - model_name = f"{MODEL_ID}_{uuid.uuid4().hex[:8]}" - payload1 = generate_payload(n_payload1, 10, 1, "INT64", tag1) - payload1["model_name"] = model_name - post_test(payload1, 200, [f"{n_payload1} datapoints"]) - payload2 = generate_payload(n_payload2, 10, 1, "INT64", tag2) - payload2["model_name"] = model_name - post_test(payload2, 200, [f"{n_payload2} datapoints"]) - tag1_count = count_rows_with_tag(model_name, tag1) - tag2_count = count_rows_with_tag(model_name, tag2) - assert tag1_count == n_payload1, f"Expected {n_payload1} rows with tag {tag1}" - assert tag2_count == n_payload2, f"Expected {n_payload2} rows with tag {tag2}" - inputs = get_data_from_storage(model_name, INPUT_SUFFIX) - assert len(inputs["data"]) == n_payload1 + n_payload2, "Incorrect total number of rows" - - -def test_upload_tag_that_uses_protected_name(): - """Test error when using a protected tag name.""" - invalid_tag = f"{TRUSTYAI_TAG_PREFIX}_something" - payload = generate_payload(5, 10, 1, "INT64", invalid_tag) - response = post_test(payload, 400, ["reserved for internal TrustyAI use only"]) - expected_msg = f"The tag prefix '{TRUSTYAI_TAG_PREFIX}' is reserved for internal TrustyAI use only. Provided tag '{invalid_tag}' violates this restriction." - assert expected_msg in response.text - - -@pytest.mark.parametrize("n_input_rows", [1, 5, 250]) -@pytest.mark.parametrize("n_input_cols", [1, 4]) -@pytest.mark.parametrize("n_output_cols", [1, 2]) -@pytest.mark.parametrize("datatype", ["INT64", "INT32", "FP32", "FP64", "BOOL"]) -def test_upload_data_and_ground_truth(n_input_rows, n_input_cols, n_output_cols, datatype): - """Test uploading model data and corresponding ground truth data.""" - model_name = f"{MODEL_ID}_{uuid.uuid4().hex[:8]}" - payload = generate_payload(n_input_rows, n_input_cols, n_output_cols, datatype, "TRAINING") - payload["model_name"] = model_name - payload["is_ground_truth"] = False - post_test(payload, 200, [f"{n_input_rows} datapoints"]) - ids = get_metadata_ids(model_name) - payload_gt = generate_payload(n_input_rows, n_input_cols, n_output_cols, datatype, "TRAINING", 0, 1) - payload_gt["model_name"] = model_name - payload_gt["is_ground_truth"] = True - payload_gt["request"] = payload["request"] - payload_gt["request"]["inputs"][0]["execution_ids"] = ids - post_test(payload_gt, 200, [f"{n_input_rows} ground truths"]) - original_data = get_data_from_storage(model_name, OUTPUT_SUFFIX) - gt_data = get_data_from_storage(f"{model_name}_ground_truth", OUTPUT_SUFFIX) - assert len(original_data["data"]) == len(gt_data["data"]), "Row dimensions don't match" - assert len(original_data["column_names"]) == len(gt_data["column_names"]), "Column dimensions don't match" - original_ids = get_metadata_ids(model_name) - gt_ids = get_metadata_ids(f"{model_name}_ground_truth") - assert original_ids == gt_ids, "Ground truth IDs don't match original IDs" - - -def test_upload_mismatch_input_values(): - """Test error when ground truth inputs don't match original data.""" - n_input_rows = 5 - model_name = f"{MODEL_ID}_{uuid.uuid4().hex[:8]}" - payload0 = generate_payload(n_input_rows, 10, 1, "INT64", "TRAINING") - payload0["model_name"] = model_name - post_test(payload0, 200, [f"{n_input_rows} datapoints"]) - ids = get_metadata_ids(model_name) - payload1 = generate_payload(n_input_rows, 10, 1, "INT64", "TRAINING", 1, 0) - payload1["model_name"] = model_name - payload1["is_ground_truth"] = True - payload1["request"]["inputs"][0]["execution_ids"] = ids - response = client.post("/data/upload", json=payload1) - assert response.status_code == 400 - assert "Found fatal mismatches" in response.text or "inputs are not identical" in response.text - - -def test_upload_mismatch_input_lengths(): - """Test error when ground truth has different input lengths.""" - n_input_rows = 5 - model_name = f"{MODEL_ID}_{uuid.uuid4().hex[:8]}" - payload0 = generate_payload(n_input_rows, 10, 1, "INT64", "TRAINING") - payload0["model_name"] = model_name - post_test(payload0, 200, [f"{n_input_rows} datapoints"]) - ids = get_metadata_ids(model_name) - payload1 = generate_payload(n_input_rows, 11, 1, "INT64", "TRAINING") - payload1["model_name"] = model_name - payload1["is_ground_truth"] = True - payload1["request"]["inputs"][0]["execution_ids"] = ids - response = client.post("/data/upload", json=payload1) - assert response.status_code == 400 - assert "Found fatal mismatches" in response.text - assert ( - "input shapes do not match. Observed inputs have length=10 while uploaded inputs have length=11" - in response.text - ) - - -def test_upload_mismatch_input_and_output_types(): - """Test error when ground truth has different data types.""" - n_input_rows = 5 - model_name = f"{MODEL_ID}_{uuid.uuid4().hex[:8]}" - payload0 = generate_payload(n_input_rows, 10, 2, "INT64", "TRAINING") - payload0["model_name"] = model_name - post_test(payload0, 200, [f"{n_input_rows} datapoints"]) - ids = get_metadata_ids(model_name) - payload1 = generate_payload(n_input_rows, 10, 2, "FP32", "TRAINING", 0, 1) - payload1["model_name"] = model_name - payload1["is_ground_truth"] = True - payload1["request"]["inputs"][0]["execution_ids"] = ids - response = client.post("/data/upload", json=payload1) - print(f"Response status: {response.status_code}") - print(f"Response text: {response.text}") - assert response.status_code == 400 - assert "Found fatal mismatches" in response.text - assert "Class=Long != Class=Float" in response.text or "inputs are not identical" in response.text - - -def test_upload_gaussian_data(): - """Test uploading realistic Gaussian data.""" - payload = { - "model_name": "gaussian-credit-model", - "data_tag": "TRAINING", - "request": { - "inputs": [ - { - "name": "credit_inputs", - "shape": [2, 4], - "datatype": "FP64", - "data": [ - [ - 47.45380690750797, - 478.6846214843319, - 13.462184703540503, - 20.764525303373535, - ], - [ - 47.468246185717554, - 575.6911203538863, - 10.844143722475575, - 14.81343667761101, - ], - ], - } - ] - }, - "response": { - "model_name": "gaussian-credit-model__isvc-d79a7d395d", - "model_version": "1", - "outputs": [ - { - "name": "predict", - "datatype": "FP32", - "shape": [2, 1], - "data": [0.19013395683309373, 0.2754730253205645], - } - ], - }, - } - post_test(payload, 200, ["2 datapoints"]) \ No newline at end of file diff --git a/tests/endpoints/test_upload_endpoint_maria.py b/tests/endpoints/test_upload_endpoint_maria.py new file mode 100644 index 0000000..255b41b --- /dev/null +++ b/tests/endpoints/test_upload_endpoint_maria.py @@ -0,0 +1,59 @@ +import asyncio +import os +import unittest + +from fastapi.testclient import TestClient + +from tests.endpoints.test_upload_endpoint_pvc import TestUploadEndpointPVC + + +class TestUploadEndpointMaria(TestUploadEndpointPVC): + + def setUp(self): + os.environ["SERVICE_STORAGE_FORMAT"] = 'MARIA' + os.environ["DATABASE_USERNAME"] = "trustyai" + os.environ["DATABASE_PASSWORD"] = "trustyai" + os.environ["DATABASE_HOST"] = "127.0.0.1" + os.environ["DATABASE_PORT"] = "3306" + os.environ["DATABASE_DATABASE"] = "trustyai-database" + + # Force reload of the global storage interface to use the new temp dir + from src.service.data import storage + self.storage_interface = storage.get_global_storage_interface(force_reload=True) + + # Re-create the FastAPI app to ensure it uses the new storage interface + from importlib import reload + import src.main + reload(src.main) + from src.main import app + self.client = TestClient(app) + + self.original_datasets = set(self.storage_interface.list_all_datasets()) + + def tearDown(self): + # delete any datasets we've created + new_datasets = set(self.storage_interface.list_all_datasets()) + for ds in new_datasets.difference(self.original_datasets): + asyncio.run(self.storage_interface.delete_dataset(ds)) + + def test_upload_data(self): + super().test_upload_data() + + def test_upload_multi_input_data(self): + super().test_upload_multi_input_data() + + def test_upload_multi_input_data_no_unique_name(self): + super().test_upload_multi_input_data_no_unique_name() + + def test_upload_multiple_tagging(self): + super().test_upload_multiple_tagging() + + def test_upload_tag_that_uses_protected_name(self): + super().test_upload_tag_that_uses_protected_name() + + def test_upload_gaussian_data(self): + super().test_upload_gaussian_data() + + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/tests/endpoints/test_upload_endpoint_pvc.py b/tests/endpoints/test_upload_endpoint_pvc.py new file mode 100644 index 0000000..5024bb8 --- /dev/null +++ b/tests/endpoints/test_upload_endpoint_pvc.py @@ -0,0 +1,364 @@ +import asyncio +import itertools +import os +import shutil +import tempfile +import unittest +import uuid +import time + +from fastapi.testclient import TestClient + +from src.service.data.model_data import ModelData +from src.service.constants import ( + TRUSTYAI_TAG_PREFIX, +) + +MODEL_ID = "example1" + + +def generate_payload(n_rows, n_input_cols, n_output_cols, datatype, tag, input_offset=0, output_offset=0): + """Generate a test payload with specific dimensions and data types.""" + model_name = f"{MODEL_ID}_{uuid.uuid4().hex[:8]}" + input_data = [] + for i in range(n_rows): + if n_input_cols == 1: + input_data.append(i + input_offset) + else: + row = [i + j + input_offset for j in range(n_input_cols)] + input_data.append(row) + output_data = [] + for i in range(n_rows): + if n_output_cols == 1: + output_data.append(i * 2 + output_offset) + else: + row = [i * 2 + j + output_offset for j in range(n_output_cols)] + output_data.append(row) + payload = { + "model_name": model_name, + "data_tag": tag, + "is_ground_truth": False, + "request": { + "inputs": [ + { + "name": "input", + "shape": [n_rows, n_input_cols] if n_input_cols > 1 else [n_rows], + "datatype": datatype, + "data": input_data, + } + ] + }, + "response": { + "outputs": [ + { + "name": "output", + "shape": [n_rows, n_output_cols] if n_output_cols > 1 else [n_rows], + "datatype": datatype, + "data": output_data, + } + ] + }, + } + return payload + + +def generate_multi_input_payload(n_rows, n_input_cols, n_output_cols, datatype, tag): + """Generate a test payload with multi-dimensional tensors like real data.""" + model_name = f"{MODEL_ID}_{uuid.uuid4().hex[:8]}" + input_data = [] + for row_idx in range(n_rows): + row = [row_idx + col_idx * 10 for col_idx in range(n_input_cols)] + input_data.append(row) + output_data = [] + for row_idx in range(n_rows): + row = [row_idx * 2 + col_idx for col_idx in range(n_output_cols)] + output_data.append(row) + payload = { + "model_name": model_name, + "data_tag": tag, + "is_ground_truth": False, + "request": { + "inputs": [ + { + "name": "multi_input", + "shape": [n_rows, n_input_cols], + "datatype": datatype, + "data": input_data, + } + ] + }, + "response": { + "outputs": [ + { + "name": "multi_output", + "shape": [n_rows, n_output_cols], + "datatype": datatype, + "data": output_data, + } + ] + }, + } + return payload + + +def generate_mismatched_shape_no_unique_name_multi_input_payload(n_rows, n_input_cols, n_output_cols, datatype, tag): + """Generate a payload with mismatched shapes and non-unique names.""" + model_name = f"{MODEL_ID}_{uuid.uuid4().hex[:8]}" + input_data_1 = [[row_idx + col_idx * 10 for col_idx in range(n_input_cols)] for row_idx in range(n_rows)] + mismatched_rows = n_rows - 1 if n_rows > 1 else 1 + input_data_2 = [[row_idx + col_idx * 20 for col_idx in range(n_input_cols)] for row_idx in range(mismatched_rows)] + output_data = [[row_idx * 2 + col_idx for col_idx in range(n_output_cols)] for row_idx in range(n_rows)] + payload = { + "model_name": model_name, + "data_tag": tag, + "is_ground_truth": False, + "request": { + "inputs": [ + { + "name": "same_name", + "shape": [n_rows, n_input_cols], + "datatype": datatype, + "data": input_data_1, + }, + { + "name": "same_name", + "shape": [mismatched_rows, n_input_cols], + "datatype": datatype, + "data": input_data_2, + }, + ] + }, + "response": { + "outputs": [ + { + "name": "multi_output", + "shape": [n_rows, n_output_cols], + "datatype": datatype, + "data": output_data, + } + ] + }, + } + return payload + + +def count_rows_with_tag(model_name, tag): + """Count rows with a specific tag in metadata.""" + metadata_df = asyncio.run(ModelData(model_name).get_metadata_as_df()) + return metadata_df['tags'].apply(lambda tags: tag in tags).sum() + + +def get_metadata_ids(model_name): + """Count rows with a specific tag in metadata.""" + metadata_df = asyncio.run(ModelData(model_name).get_metadata_as_df()) + return metadata_df['id'].tolist() + + +class TestUploadEndpointPVC(unittest.TestCase): + + def setUp(self): + self.TEMP_DIR = tempfile.mkdtemp() + os.environ["STORAGE_DATA_FOLDER"] = self.TEMP_DIR + + # Force reload of the global storage interface to use the new temp dir + from src.service.data import storage + storage.get_global_storage_interface(force_reload=True) + + # Re-create the FastAPI app to ensure it uses the new storage interface + from importlib import reload + import src.main + reload(src.main) + from src.main import app + self.client = TestClient(app) + + def tearDown(self): + if os.path.exists(self.TEMP_DIR): + shutil.rmtree(self.TEMP_DIR) + + + def post_test(self, payload, expected_status_code, check_msgs): + """Post a payload and check the response.""" + response = self.client.post("/data/upload", json=payload) + if response.status_code != expected_status_code: + print(f"\n=== DEBUG INFO ===") + print(f"Expected status: {expected_status_code}") + print(f"Actual status: {response.status_code}") + print(f"Response text: {response.text}") + print(f"Response headers: {dict(response.headers)}") + if hasattr(response, "json"): + try: + print(f"Response JSON: {response.json()}") + except: + pass + print(f"==================") + + self.assertEqual(response.status_code, expected_status_code) + return response + + + # data upload tests + def test_upload_data(self): + n_input_rows_options = [1, 5, 250] + n_input_cols_options = [1, 4] + n_output_cols_options = [1, 2] + datatype_options = ["INT64", "INT32", "FP32", "FP64", "BOOL"] + + for idx, (n_input_rows, n_input_cols, n_output_cols, datatype) in enumerate(itertools.product( + n_input_rows_options, n_input_cols_options, n_output_cols_options, datatype_options + )): + with self.subTest( + f"subtest-{idx}", + n_input_rows=n_input_rows, + n_input_cols=n_input_cols, + n_output_cols=n_output_cols, + datatype=datatype, + ): + """Test uploading data with various dimensions and datatypes.""" + data_tag = "TRAINING" + payload = generate_payload(n_input_rows, n_input_cols, n_output_cols, datatype, data_tag) + response = self.post_test(payload, 200, [f"{n_input_rows} datapoints"]) + + inputs, outputs, metadata = asyncio.run(ModelData(payload["model_name"]).data()) + + self.assertEqual(response.status_code, 200) + self.assertIn(str(n_input_rows), response.text) + self.assertIsNotNone(inputs, "Input data not found in storage") + self.assertIsNotNone(outputs, "Output data not found in storage") + + self.assertEqual(len(inputs), n_input_rows, "Incorrect number of input rows") + self.assertEqual(len(outputs), n_input_rows, "Incorrect number of output rows") + + tag_count = count_rows_with_tag(payload["model_name"], data_tag) + self.assertEqual(tag_count, n_input_rows, "Not all rows have the correct tag") + + + def test_upload_multi_input_data(self): + """Test uploading data with multiple input tensors.""" + n_rows_options = [1, 3, 5, 250] + n_input_cols_options = [2, 6] + n_output_cols_options = [4] + datatype_options = ["INT64", "INT32", "FP32", "FP64", "BOOL"] + + for n_rows, n_input_cols, n_output_cols, datatype in itertools.product( + n_rows_options, n_input_cols_options, n_output_cols_options, datatype_options + ): + with self.subTest( + n_rows=n_rows, + n_input_cols=n_input_cols, + n_output_cols=n_output_cols, + datatype=datatype, + ): + # Arrange + data_tag = "TRAINING" + payload = generate_multi_input_payload(n_rows, n_input_cols, n_output_cols, datatype, data_tag) + + # Act + self.post_test(payload, 200, [f"{n_rows} datapoints"]) + + model_data = ModelData(payload["model_name"]) + inputs, outputs, metadata = asyncio.run(model_data.data()) + input_column_names, output_column_names, metadata_column_names = asyncio.run( + model_data.original_column_names()) + + # Assert + self.assertIsNotNone(inputs, "Input data not found in storage") + self.assertIsNotNone(outputs, "Output data not found in storage") + self.assertEqual(len(inputs), n_rows, "Incorrect number of input rows") + self.assertEqual(len(outputs), n_rows, "Incorrect number of output rows") + self.assertEqual(len(input_column_names), n_input_cols, "Incorrect number of input columns") + self.assertEqual(len(output_column_names), n_output_cols, "Incorrect number of output columns") + self.assertGreaterEqual(len(input_column_names), 2, "Should have at least 2 input column names") + tag_count = count_rows_with_tag(payload["model_name"], data_tag) + self.assertEqual(tag_count, n_rows, "Not all rows have the correct tag") + + + def test_upload_multi_input_data_no_unique_name(self): + """Test error case for non-unique tensor names.""" + payload = generate_mismatched_shape_no_unique_name_multi_input_payload(250, 4, 3, "FP64", "TRAINING") + response = self.client.post("/data/upload", json=payload) + self.assertEqual(response.status_code,400) + self.assertIn("input shapes were mismatched", response.text) + self.assertIn("[250, 4]", response.text) + + + def test_upload_multiple_tagging(self): + """Test uploading data with multiple tags.""" + n_payload1 = 50 + n_payload2 = 51 + tag1 = "TRAINING" + tag2 = "NOT TRAINING " + model_name = f"{MODEL_ID}_{uuid.uuid4().hex[:8]}" + + payload1 = generate_payload(n_payload1, 10, 1, "INT64", tag1) + payload1["model_name"] = model_name + self.post_test(payload1, 200, [f"{n_payload1} datapoints"]) + + payload2 = generate_payload(n_payload2, 10, 1, "INT64", tag2) + payload2["model_name"] = model_name + self.post_test(payload2, 200, [f"{n_payload2} datapoints"]) + + tag1_count = count_rows_with_tag(model_name, tag1) + tag2_count = count_rows_with_tag(model_name, tag2) + + self.assertEqual(tag1_count, n_payload1, f"Expected {n_payload1} rows with tag {tag1}") + self.assertEqual(tag2_count, n_payload2, f"Expected {n_payload2} rows with tag {tag2}") + + input_rows, _, _ = asyncio.run(ModelData(payload1["model_name"]).row_counts()) + self.assertEqual(input_rows, n_payload1 + n_payload2, "Incorrect total number of rows") + + + def test_upload_tag_that_uses_protected_name(self): + """Test error when using a protected tag name.""" + invalid_tag = f"{TRUSTYAI_TAG_PREFIX}_something" + payload = generate_payload(5, 10, 1, "INT64", invalid_tag) + response = self.post_test(payload, 400, ["reserved for internal TrustyAI use only"]) + expected_msg = f"The tag prefix '{TRUSTYAI_TAG_PREFIX}' is reserved for internal TrustyAI use only. Provided tag '{invalid_tag}' violates this restriction." + self.assertIn(expected_msg, response.text) + + + def test_upload_gaussian_data(self): + """Test uploading realistic Gaussian data.""" + payload = { + "model_name": "gaussian-credit-model", + "data_tag": "TRAINING", + "request": { + "inputs": [ + { + "name": "credit_inputs", + "shape": [2, 4], + "datatype": "FP64", + "data": [ + [ + 47.45380690750797, + 478.6846214843319, + 13.462184703540503, + 20.764525303373535, + ], + [ + 47.468246185717554, + 575.6911203538863, + 10.844143722475575, + 14.81343667761101, + ], + ], + } + ] + }, + "response": { + "model_name": "gaussian-credit-model__isvc-d79a7d395d", + "model_version": "1", + "outputs": [ + { + "name": "predict", + "datatype": "FP32", + "shape": [2, 1], + "data": [0.19013395683309373, 0.2754730253205645], + } + ], + }, + } + self.post_test(payload, 200, ["2 datapoints"]) + + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/tests/service/data/test_mariadb_migration.py b/tests/service/data/test_mariadb_migration.py index 1eb4326..c6204aa 100644 --- a/tests/service/data/test_mariadb_migration.py +++ b/tests/service/data/test_mariadb_migration.py @@ -30,7 +30,7 @@ def setUp(self): def tearDown(self): - self.storage.reset_database() + asyncio.run(self.storage.reset_database()) async def _test_retrieve_data(self): @@ -39,8 +39,8 @@ async def _test_retrieve_data(self): self.assertEqual(len(available_datasets), 12) # model 1 checks - model_1_inputs = self.storage.read_data("model1_inputs", 0, 100) - model_1_metadata = self.storage.read_data("model1_metadata", 0, 1) + model_1_inputs = await self.storage.read_data("model1_inputs", 0, 100) + model_1_metadata = await self.storage.read_data("model1_metadata", 0, 1) self.assertTrue(np.array_equal(np.array([[0,1,2,3,4]]*100), model_1_inputs)) self.assertTrue(np.array_equal(np.array([0, 1, 2, 3, 4]), model_1_inputs[0])) self.assertEqual(model_1_metadata[0][0], datetime.datetime.fromisoformat("2025-06-09 12:19:06.074828")) @@ -48,10 +48,10 @@ async def _test_retrieve_data(self): self.assertEqual(model_1_metadata[0][2], "_trustyai_unlabeled") # model 3 checks - self.assertEqual(self.storage.get_aliased_column_names("model3_inputs"), ["year mapped", "make mapped", "color mapped"]) + self.assertEqual(await self.storage.get_aliased_column_names("model3_inputs"), ["year mapped", "make mapped", "color mapped"]) # model 4 checks - model_4_inputs_row0 = self.storage.read_data("model4_inputs", 0, 5) + model_4_inputs_row0 = await self.storage.read_data("model4_inputs", 0, 5) self.assertEqual(model_4_inputs_row0[0].tolist(), [0.0, "i'm text-0", True, 0]) self.assertEqual(model_4_inputs_row0[4].tolist(), [4.0, "i'm text-4", True, 8]) diff --git a/tests/service/data/test_mariadb_storage.py b/tests/service/data/test_mariadb_storage.py index 32be2fd..4d2b609 100644 --- a/tests/service/data/test_mariadb_storage.py +++ b/tests/service/data/test_mariadb_storage.py @@ -6,6 +6,7 @@ import unittest import os import numpy as np +from sympy import print_tree from src.service.data.storage.maria.maria import MariaDBStorage @@ -25,10 +26,12 @@ def setUp(self): 3306, "trustyai-database", attempt_migration=False) + self.original_datasets = set(self.storage.list_all_datasets()) def tearDown(self): - self.storage.reset_database() + asyncio.run(self.storage.reset_database()) + async def _store_dataset(self, seed, n_rows=None, n_cols=None): n_rows = seed * 3 if n_rows is None else n_rows @@ -45,81 +48,80 @@ async def _test_retrieve_data(self): start_idx = dataset_idx n_rows = dataset_idx * 2 - retrieved_full_dataset = self.storage.read_data(dataset_name) - retrieved_partial_dataset = self.storage.read_data(dataset_name, start_idx, n_rows) + retrieved_full_dataset = await self.storage.read_data(dataset_name) + retrieved_partial_dataset = await self.storage.read_data(dataset_name, start_idx, n_rows) self.assertTrue(np.array_equal(retrieved_full_dataset, original_dataset)) - self.assertEqual(original_dataset.shape, self.storage.dataset_shape(dataset_name)) - self.assertEqual(original_dataset.shape[0], self.storage.dataset_rows(dataset_name)) - self.assertEqual(original_dataset.shape[1], self.storage.dataset_cols(dataset_name)) + self.assertEqual(original_dataset.shape, await self.storage.dataset_shape(dataset_name)) + self.assertEqual(original_dataset.shape[0], await self.storage.dataset_rows(dataset_name)) + self.assertEqual(original_dataset.shape[1], await self.storage.dataset_cols(dataset_name)) self.assertTrue(np.array_equal(retrieved_partial_dataset, original_dataset[start_idx:start_idx+n_rows])) async def _test_big_insert(self): original_dataset, _, dataset_name = await self._store_dataset(0, 5000, 10) - retrieved_full_dataset = self.storage.read_data(dataset_name) + retrieved_full_dataset = await self.storage.read_data(dataset_name) self.assertTrue(np.array_equal(retrieved_full_dataset, original_dataset)) - self.assertEqual(original_dataset.shape, self.storage.dataset_shape(dataset_name)) - self.assertEqual(original_dataset.shape[0], self.storage.dataset_rows(dataset_name)) - self.assertEqual(original_dataset.shape[1], self.storage.dataset_cols(dataset_name)) + self.assertEqual(original_dataset.shape, await self.storage.dataset_shape(dataset_name)) + self.assertEqual(original_dataset.shape[0], await self.storage.dataset_rows(dataset_name)) + self.assertEqual(original_dataset.shape[1], await self.storage.dataset_cols(dataset_name)) async def _test_single_row_insert(self): original_dataset, _, dataset_name = await self._store_dataset(0, 1, 10) - retrieved_full_dataset = self.storage.read_data(dataset_name, 0, 1) + retrieved_full_dataset = await self.storage.read_data(dataset_name, 0, 1) self.assertTrue(np.array_equal(retrieved_full_dataset, original_dataset)) - self.assertEqual(original_dataset.shape, self.storage.dataset_shape(dataset_name)) - self.assertEqual(original_dataset.shape[0], self.storage.dataset_rows(dataset_name)) - self.assertEqual(original_dataset.shape[1], self.storage.dataset_cols(dataset_name)) + self.assertEqual(original_dataset.shape, await self.storage.dataset_shape(dataset_name)) + self.assertEqual(original_dataset.shape[0], await self.storage.dataset_rows(dataset_name)) + self.assertEqual(original_dataset.shape[1], await self.storage.dataset_cols(dataset_name)) - async def _test_single_row_retrieval(self): - original_dataset = np.arange(0, 10).reshape(1, 10) - column_names = [alphabet[i] for i in range(original_dataset.shape[0])] + async def _test_vector_retrieval(self): + original_dataset = np.arange(0, 10) + column_names = ["single_column"] dataset_name = "dataset_single_row" await self.storage.write_data(dataset_name, original_dataset, column_names) - retrieved_full_dataset = self.storage.read_data(dataset_name)[0] + retrieved_full_dataset = await self.storage.read_data(dataset_name) + transposed_dataset = retrieved_full_dataset.reshape(-1) - self.assertTrue(np.array_equal(retrieved_full_dataset, original_dataset)) - self.assertEqual(1, self.storage.dataset_rows(dataset_name)) - self.assertEqual(10, self.storage.dataset_cols(dataset_name)) + self.assertTrue(np.array_equal(transposed_dataset, original_dataset)) + self.assertEqual(10, await self.storage.dataset_rows(dataset_name)) + self.assertEqual(1, await self.storage.dataset_cols(dataset_name)) async def _test_name_mapping(self): for dataset_idx in range(1, 10): original_dataset, column_names, dataset_name = await self._store_dataset(dataset_idx) name_mapping = {name: "aliased_" + name for i, name in enumerate(column_names) if i % 2 == 0} expected_mapping = [name_mapping.get(name, name) for name in column_names] - self.storage.apply_name_mapping(dataset_name, name_mapping) + await self.storage.apply_name_mapping(dataset_name, name_mapping) - self.assertEqual(column_names, self.storage.get_original_column_names(dataset_name)) - self.assertEqual(expected_mapping, self.storage.get_aliased_column_names(dataset_name)) + retrieved_original_names = await self.storage.get_original_column_names(dataset_name) + retrieved_aliased_names = await self.storage.get_aliased_column_names(dataset_name) -def run_async_test(coro): - """Helper function to run async tests.""" - loop = asyncio.new_event_loop() - return loop.run_until_complete(coro) + self.assertEqual(column_names, retrieved_original_names) + self.assertEqual(expected_mapping, retrieved_aliased_names) + def test_retrieve_data(self): + run_async_test(self._test_retrieve_data()) -TestMariaDBStorage.test_retrieve_data = lambda self: run_async_test( - self._test_retrieve_data() -) -TestMariaDBStorage.test_name_mapping = lambda self: run_async_test( - self._test_name_mapping() -) + def test_name_mapping(self): + run_async_test(self._test_name_mapping()) -TestMariaDBStorage.test_big_insert = lambda self: run_async_test( - self._test_big_insert() -) + def test_big_insert(self): + run_async_test(self._test_big_insert()) -TestMariaDBStorage.test_single_row_insert = lambda self: run_async_test( - self._test_single_row_insert() -) + def test_single_row_insert(self): + run_async_test(self._test_single_row_insert()) -TestMariaDBStorage._test_single_row_retrieval = lambda self: run_async_test( - self._test_single_row_retrieval() -) + def test_single_row_retrieval(self): + run_async_test(self._test_vector_retrieval()) +def run_async_test(coro): + """Helper function to run async tests.""" + loop = asyncio.new_event_loop() + return loop.run_until_complete(coro) + if __name__ == "__main__": diff --git a/tests/service/data/test_payload_reconciliation_maria.py b/tests/service/data/test_payload_reconciliation_maria.py index 025a7c4..f1bdfee 100644 --- a/tests/service/data/test_payload_reconciliation_maria.py +++ b/tests/service/data/test_payload_reconciliation_maria.py @@ -18,10 +18,11 @@ from src.service.data.modelmesh_parser import ModelMeshPayloadParser, PartialPayload from src.service.data.storage.maria.maria import MariaDBStorage from src.service.data.storage.pvc import PVCStorage +from tests.service.data.test_payload_reconciliation_pvc import TestPayloadReconciliation from tests.service.data.test_utils import ModelMeshTestData -class TestMariaPayloadReconciliation(unittest.TestCase): +class TestMariaPayloadReconciliation(TestPayloadReconciliation): """ Test class for ModelMesh payload reconciliation. """ @@ -58,184 +59,6 @@ def tearDown(self): """Clean up after tests.""" self.storage.reset_database() - async def _test_persist_input_payload(self): - """Test persisting an input payload.""" - await self.storage.persist_modelmesh_payload( - self.input_payload, self.request_id, is_input=True - ) - - retrieved_payload = await self.storage.get_modelmesh_payload( - self.request_id, is_input=True - ) - - self.assertIsNotNone(retrieved_payload) - self.assertEqual(retrieved_payload.data, self.input_payload.data) - - output_payload = await self.storage.get_modelmesh_payload( - self.request_id, is_input=False - ) - self.assertIsNone(output_payload) - - async def _test_persist_output_payload(self): - """Test persisting an output payload.""" - await self.storage.persist_modelmesh_payload( - self.output_payload, self.request_id, is_input=False - ) - - retrieved_payload = await self.storage.get_modelmesh_payload( - self.request_id, is_input=False - ) - - self.assertIsNotNone(retrieved_payload) - self.assertEqual(retrieved_payload.data, self.output_payload.data) - - input_payload = await self.storage.get_modelmesh_payload( - self.request_id, is_input=True - ) - self.assertIsNone(input_payload) - - async def _test_full_reconciliation(self): - """Test the full payload reconciliation process.""" - await self.storage.persist_modelmesh_payload( - self.input_payload, self.request_id, is_input=True - ) - await self.storage.persist_modelmesh_payload( - self.output_payload, self.request_id, is_input=False - ) - - input_payload = await self.storage.get_modelmesh_payload( - self.request_id, is_input=True - ) - output_payload = await self.storage.get_modelmesh_payload( - self.request_id, is_input=False - ) - - self.assertIsNotNone(input_payload) - self.assertIsNotNone(output_payload) - - df = ModelMeshPayloadParser.payloads_to_dataframe( - input_payload, output_payload, self.request_id, self.model_name - ) - - self.assertIsInstance(df, pd.DataFrame) - self.assertIn("input", df.columns) - self.assertIn("output_output", df.columns) - self.assertEqual(len(df), 5) # Based on our test data with 5 rows - - self.assertIn("id", df.columns) - self.assertEqual(df["id"].iloc[0], self.request_id) - - self.assertIn("model_id", df.columns) - self.assertEqual(df["model_id"].iloc[0], self.model_name) - - # Clean up - await self.storage.delete_modelmesh_payload(self.request_id, is_input=True) - await self.storage.delete_modelmesh_payload(self.request_id, is_input=False) - - input_payload = await self.storage.get_modelmesh_payload( - self.request_id, is_input=True - ) - output_payload = await self.storage.get_modelmesh_payload( - self.request_id, is_input=False - ) - - self.assertIsNone(input_payload) - self.assertIsNone(output_payload) - - async def _test_reconciliation_with_real_data(self): - """Test reconciliation with sample b64 encoded data from files.""" - current_dir = os.path.dirname(os.path.abspath(__file__)) - test_data_dir = os.path.join( - os.path.dirname(os.path.dirname(current_dir)), "data" - ) - - with open(os.path.join(test_data_dir, "input-sample.b64"), "r") as f: - sample_input_data = f.read().strip() - - with open(os.path.join(test_data_dir, "output-sample.b64"), "r") as f: - sample_output_data = f.read().strip() - - input_payload = PartialPayload(data=sample_input_data) - output_payload = PartialPayload(data=sample_output_data) - - request_id = str(uuid.uuid4()) - model_id = "sample-model" - - await self.storage.persist_modelmesh_payload( - input_payload, request_id, is_input=True - ) - await self.storage.persist_modelmesh_payload( - output_payload, request_id, is_input=False - ) - - stored_input = await self.storage.get_modelmesh_payload( - request_id, is_input=True - ) - stored_output = await self.storage.get_modelmesh_payload( - request_id, is_input=False - ) - - self.assertIsNotNone(stored_input) - self.assertIsNotNone(stored_output) - self.assertEqual(stored_input.data, sample_input_data) - self.assertEqual(stored_output.data, sample_output_data) - - with mock.patch.object( - ModelMeshPayloadParser, "payloads_to_dataframe" - ) as mock_to_df: - sample_df = pd.DataFrame( - { - "input_feature": [1, 2, 3], - "output_output_feature": [4, 5, 6], - "id": [request_id] * 3, - "model_id": [model_id] * 3, - "synthetic": [False] * 3, - } - ) - mock_to_df.return_value = sample_df - - df = ModelMeshPayloadParser.payloads_to_dataframe( - stored_input, stored_output, request_id, model_id - ) - - self.assertIsInstance(df, pd.DataFrame) - self.assertEqual(len(df), 3) - - mock_to_df.assert_called_once_with( - stored_input, stored_output, request_id, model_id - ) - - # Clean up - await self.storage.delete_modelmesh_payload(request_id, is_input=True) - await self.storage.delete_modelmesh_payload(request_id, is_input=False) - - self.assertIsNone( - await self.storage.get_modelmesh_payload(request_id, is_input=True) - ) - self.assertIsNone( - await self.storage.get_modelmesh_payload(request_id, is_input=False) - ) - - -def run_async_test(coro): - """Helper function to run async tests.""" - loop = asyncio.new_event_loop() - return loop.run_until_complete(coro) - - -TestMariaPayloadReconciliation.test_persist_input_payload = lambda self: run_async_test( - self._test_persist_input_payload() -) -TestMariaPayloadReconciliation.test_persist_output_payload = lambda self: run_async_test( - self._test_persist_output_payload() -) -TestMariaPayloadReconciliation.test_full_reconciliation = lambda self: run_async_test( - self._test_full_reconciliation() -) -TestMariaPayloadReconciliation.test_reconciliation_with_real_data = ( - lambda self: run_async_test(self._test_reconciliation_with_real_data()) -) - if __name__ == "__main__": unittest.main() diff --git a/tests/service/data/test_payload_reconciliation_pvc.py b/tests/service/data/test_payload_reconciliation_pvc.py index 23c22bc..b7553af 100644 --- a/tests/service/data/test_payload_reconciliation_pvc.py +++ b/tests/service/data/test_payload_reconciliation_pvc.py @@ -6,14 +6,10 @@ import unittest import tempfile import os -import base64 -import time -from datetime import datetime from unittest import mock import uuid import pandas as pd -import numpy as np from src.service.data.modelmesh_parser import ModelMeshPayloadParser, PartialPayload from src.service.data.storage.pvc import PVCStorage @@ -53,54 +49,54 @@ def tearDown(self): async def _test_persist_input_payload(self): """Test persisting an input payload.""" - await self.storage.persist_modelmesh_payload( - self.input_payload, self.request_id, is_input=True + await self.storage.persist_partial_payload( + self.input_payload, payload_id=self.request_id, is_input=True ) - retrieved_payload = await self.storage.get_modelmesh_payload( - self.request_id, is_input=True + retrieved_payload = await self.storage.get_partial_payload( + self.request_id, is_input=True, is_modelmesh=True ) self.assertIsNotNone(retrieved_payload) self.assertEqual(retrieved_payload.data, self.input_payload.data) - output_payload = await self.storage.get_modelmesh_payload( - self.request_id, is_input=False + output_payload = await self.storage.get_partial_payload( + self.request_id, is_input=False, is_modelmesh=True ) self.assertIsNone(output_payload) async def _test_persist_output_payload(self): """Test persisting an output payload.""" - await self.storage.persist_modelmesh_payload( - self.output_payload, self.request_id, is_input=False + await self.storage.persist_partial_payload( + self.output_payload, payload_id=self.request_id, is_input=False ) - retrieved_payload = await self.storage.get_modelmesh_payload( - self.request_id, is_input=False + retrieved_payload = await self.storage.get_partial_payload( + self.request_id, is_input=False, is_modelmesh=True ) self.assertIsNotNone(retrieved_payload) self.assertEqual(retrieved_payload.data, self.output_payload.data) - input_payload = await self.storage.get_modelmesh_payload( - self.request_id, is_input=True + input_payload = await self.storage.get_partial_payload( + self.request_id, is_input=True, is_modelmesh=True ) self.assertIsNone(input_payload) async def _test_full_reconciliation(self): """Test the full payload reconciliation process.""" - await self.storage.persist_modelmesh_payload( + await self.storage.persist_partial_payload( self.input_payload, self.request_id, is_input=True ) - await self.storage.persist_modelmesh_payload( + await self.storage.persist_partial_payload( self.output_payload, self.request_id, is_input=False ) - input_payload = await self.storage.get_modelmesh_payload( - self.request_id, is_input=True + input_payload = await self.storage.get_partial_payload( + self.request_id, is_input=True, is_modelmesh=True ) - output_payload = await self.storage.get_modelmesh_payload( - self.request_id, is_input=False + output_payload = await self.storage.get_partial_payload( + self.request_id, is_input=False, is_modelmesh=True ) self.assertIsNotNone(input_payload) @@ -122,14 +118,14 @@ async def _test_full_reconciliation(self): self.assertEqual(df["model_id"].iloc[0], self.model_name) # Clean up - await self.storage.delete_modelmesh_payload(self.request_id, is_input=True) - await self.storage.delete_modelmesh_payload(self.request_id, is_input=False) + await self.storage.delete_partial_payload(self.request_id, is_input=True) + await self.storage.delete_partial_payload(self.request_id, is_input=False) - input_payload = await self.storage.get_modelmesh_payload( - self.request_id, is_input=True + input_payload = await self.storage.get_partial_payload( + self.request_id, is_input=True, is_modelmesh=True ) - output_payload = await self.storage.get_modelmesh_payload( - self.request_id, is_input=False + output_payload = await self.storage.get_partial_payload( + self.request_id, is_input=False, is_modelmesh=True ) self.assertIsNone(input_payload) @@ -154,18 +150,18 @@ async def _test_reconciliation_with_real_data(self): request_id = str(uuid.uuid4()) model_id = "sample-model" - await self.storage.persist_modelmesh_payload( + await self.storage.persist_partial_payload( input_payload, request_id, is_input=True ) - await self.storage.persist_modelmesh_payload( + await self.storage.persist_partial_payload( output_payload, request_id, is_input=False ) - stored_input = await self.storage.get_modelmesh_payload( - request_id, is_input=True + stored_input = await self.storage.get_partial_payload( + request_id, is_input=True, is_modelmesh=True ) - stored_output = await self.storage.get_modelmesh_payload( - request_id, is_input=False + stored_output = await self.storage.get_partial_payload( + request_id, is_input=False, is_modelmesh=True ) self.assertIsNotNone(stored_input) @@ -199,14 +195,14 @@ async def _test_reconciliation_with_real_data(self): ) # Clean up - await self.storage.delete_modelmesh_payload(request_id, is_input=True) - await self.storage.delete_modelmesh_payload(request_id, is_input=False) + await self.storage.delete_partial_payload(request_id, is_input=True) + await self.storage.delete_partial_payload(request_id, is_input=False) self.assertIsNone( - await self.storage.get_modelmesh_payload(request_id, is_input=True) + await self.storage.get_partial_payload(request_id, is_input=True, is_modelmesh=True) ) self.assertIsNone( - await self.storage.get_modelmesh_payload(request_id, is_input=False) + await self.storage.get_partial_payload(request_id, is_input=False, is_modelmesh=True) ) From b689e8e71e5c812a4377bcd1d235ca5faaf34e3d Mon Sep 17 00:00:00 2001 From: Rob Geada Date: Mon, 23 Jun 2025 10:28:02 +0100 Subject: [PATCH 03/10] Fix typo in pyproject --- pyproject.toml | 2 -- 1 file changed, 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 2e85516..bc8b2b7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,8 +30,6 @@ eval = ["lm-eval[api]==0.4.4", "fastapi-utils>=0.8.0", "typing-inspect==0.9.0"] protobuf = ["numpy>=1.24.0,<3", "grpcio>=1.62.1,<2", "grpcio-tools>=1.62.1,<2"] mariadb = ["mariadb>=1.1.12", "javaobj-py3==0.4.4"] ->>>>>>> 00aef587ef5ae523d62ec86c53c9c47428cb47bc - [tool.hatch.build.targets.sdist] include = ["src"] From 5ec2c5053cab8d12f6d45b957f91d80d7abe8f07 Mon Sep 17 00:00:00 2001 From: Rob Geada Date: Mon, 23 Jun 2025 11:17:57 +0100 Subject: [PATCH 04/10] Address review items --- src/endpoints/consumer/consumer_endpoint.py | 6 ++---- src/endpoints/data/data_upload.py | 23 +++++++++------------ src/service/data/model_data.py | 23 ++++++++++++++++++++- 3 files changed, 34 insertions(+), 18 deletions(-) diff --git a/src/endpoints/consumer/consumer_endpoint.py b/src/endpoints/consumer/consumer_endpoint.py index d1c45d1..c0e4a84 100644 --- a/src/endpoints/consumer/consumer_endpoint.py +++ b/src/endpoints/consumer/consumer_endpoint.py @@ -1,7 +1,7 @@ # endpoints/consumer.py import asyncio import time -from datetime import datetime +from datetime import datetime, timezone import numpy as np from fastapi import APIRouter, HTTPException, Header @@ -157,7 +157,7 @@ async def write_reconciled_data( model_id, tags, id_): storage_interface = get_global_storage_interface() - iso_time = datetime.isoformat(datetime.utcnow()) + iso_time = datetime.now(timezone.utc).isoformat() unix_timestamp = time.time() metadata = np.array( [[None, iso_time, unix_timestamp, tags]] * len(input_array), dtype="O" @@ -198,8 +198,6 @@ async def reconcile_modelmesh_payloads( request_id: str, model_id: str, ): - storage_interface = get_global_storage_interface() - """Reconcile the input and output ModelMesh payloads into dataset entries.""" df = ModelMeshPayloadParser.payloads_to_dataframe( input_payload, output_payload, request_id, model_id diff --git a/src/endpoints/data/data_upload.py b/src/endpoints/data/data_upload.py index 5e110a7..1920654 100644 --- a/src/endpoints/data/data_upload.py +++ b/src/endpoints/data/data_upload.py @@ -39,8 +39,7 @@ def validate_data_tag(tag: str) -> Optional[str]: @router.post("/data/upload") async def upload(payload: UploadPayload) -> Dict[str, str]: - """Upload model data - regular or ground truth.""" - error_msgs = [] + """Upload model data""" # validate tag tag_validation_msg = validate_data_tag(payload.data_tag) @@ -53,13 +52,15 @@ async def upload(payload: UploadPayload) -> Dict[str, str]: payload.response.model_name = payload.model_name req_id = str(uuid.uuid4()) - try: - model_data = ModelData(payload.model_name) + model_data = ModelData(payload.model_name) + datasets_exist = await model_data.datasets_exist() + + if all(datasets_exist): previous_data_points = (await model_data.row_counts())[0] - except: + else: previous_data_points = 0 - await consume_cloud_event(payload.response, req_id), + await consume_cloud_event(payload.response, req_id) await consume_cloud_event(payload.request, req_id, tag=payload.data_tag) model_data = ModelData(payload.model_name) @@ -74,12 +75,8 @@ async def upload(payload: UploadPayload) -> Dict[str, str]: except HTTPException as e: if "Could not reconcile_kserve KServe Inference" in str(e): - new_msg = e.detail.replace("Could not reconcile_kserve KServe Inference", - "Could not upload payload") - raise HTTPException(status_code=400, detail=new_msg) - else: - raise + raise HTTPException(status_code=400, detail=f"Could not upload payload for model {payload.model_name}.") from e + raise e except Exception as e: - traceback.print_exc() logger.error(f"Unexpected error in upload endpoint for model {payload.model_name}: {str(e)}", exc_info=True) - raise HTTPException(500, f"Internal server error: {str(e)}") + raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}") diff --git a/src/service/data/model_data.py b/src/service/data/model_data.py index d83c9f3..bc813b0 100644 --- a/src/service/data/model_data.py +++ b/src/service/data/model_data.py @@ -1,4 +1,4 @@ -import asyncio +import logging from typing import List, Optional import numpy as np @@ -7,6 +7,8 @@ from src.service.constants import * from src.service.data.storage import get_global_storage_interface +logger = logging.getLogger(__name__) + class ModelDataContainer: def __init__(self, model_name: str, input_data: np.ndarray, input_names: List[str], output_data: np.ndarray, output_names: List[str], metadata: np.ndarray, metadata_names: List[str]): @@ -32,6 +34,25 @@ def __init__(self, model_name): self.output_dataset = self.model_name+OUTPUT_SUFFIX self.metadata_dataset = self.model_name+METADATA_SUFFIX + async def datasets_exist(self) -> tuple[bool, bool, bool]: + """ + Checks if the requested model exists + """ + storage_interface = get_global_storage_interface() + input_exists = await storage_interface.dataset_exists(self.input_dataset) + output_exists = await storage_interface.dataset_exists(self.output_dataset) + metadata_exists = await storage_interface.dataset_exists(self.metadata_dataset) + + # warn if we're missing one of the expected datasets + dataset_checks = (input_exists, output_exists, metadata_exists) + if not all(dataset_checks): + expected_datasets = [self.input_dataset, self.output_dataset, self.metadata_dataset] + missing_datasets = [dataset for idx, dataset in enumerate(expected_datasets) if not dataset_checks[idx]] + logger.warning(f"Not all datasets present for model {self.model_name}: missing {missing_datasets}. This could be indicative of storage corruption or" + f"improper saving of previous model data.") + return dataset_checks + + async def row_counts(self) -> tuple[int, int, int]: """ Get the number of input, output, and metadata rows that exist in a model dataset From 737b7a49a9f476321d45e09d9b5a0de7fb973b45 Mon Sep 17 00:00:00 2001 From: Rob Geada Date: Mon, 23 Jun 2025 11:29:07 +0100 Subject: [PATCH 05/10] Clean unused imports --- src/endpoints/data/data_upload.py | 7 ++----- .../data/storage/maria/legacy_maria_reader.py | 1 - src/service/data/storage/maria/maria.py | 1 - tests/endpoints/test_upload_endpoint_pvc.py | 3 +-- tests/service/data/test_mariadb_migration.py | 2 -- tests/service/data/test_mariadb_storage.py | 2 -- tests/service/data/test_modelmesh_parser.py | 3 --- .../data/test_payload_reconciliation_maria.py | 12 +----------- 8 files changed, 4 insertions(+), 27 deletions(-) diff --git a/src/endpoints/data/data_upload.py b/src/endpoints/data/data_upload.py index 1920654..e423ec8 100644 --- a/src/endpoints/data/data_upload.py +++ b/src/endpoints/data/data_upload.py @@ -1,14 +1,11 @@ -import asyncio import logging -import traceback -from typing import Any, Dict, List, Optional +from typing import Dict, Optional -import numpy as np import uuid from fastapi import APIRouter, HTTPException from pydantic import BaseModel -from src.endpoints.consumer.consumer_endpoint import (reconcile_kserve, consume_cloud_event) +from src.endpoints.consumer.consumer_endpoint import consume_cloud_event from src.endpoints.consumer import KServeInferenceRequest, KServeInferenceResponse from src.service.constants import TRUSTYAI_TAG_PREFIX from src.service.data.model_data import ModelData diff --git a/src/service/data/storage/maria/legacy_maria_reader.py b/src/service/data/storage/maria/legacy_maria_reader.py index 0627d98..c9bc3df 100644 --- a/src/service/data/storage/maria/legacy_maria_reader.py +++ b/src/service/data/storage/maria/legacy_maria_reader.py @@ -1,4 +1,3 @@ -import asyncio import javaobj import logging import pandas as pd diff --git a/src/service/data/storage/maria/maria.py b/src/service/data/storage/maria/maria.py index 7f57025..d241052 100644 --- a/src/service/data/storage/maria/maria.py +++ b/src/service/data/storage/maria/maria.py @@ -2,7 +2,6 @@ import io import json import logging -import time import mariadb import numpy as np diff --git a/tests/endpoints/test_upload_endpoint_pvc.py b/tests/endpoints/test_upload_endpoint_pvc.py index 5024bb8..f4633cc 100644 --- a/tests/endpoints/test_upload_endpoint_pvc.py +++ b/tests/endpoints/test_upload_endpoint_pvc.py @@ -5,7 +5,6 @@ import tempfile import unittest import uuid -import time from fastapi.testclient import TestClient @@ -190,7 +189,7 @@ def post_test(self, payload, expected_status_code, check_msgs): print(f"Response JSON: {response.json()}") except: pass - print(f"==================") + print("==================") self.assertEqual(response.status_code, expected_status_code) return response diff --git a/tests/service/data/test_mariadb_migration.py b/tests/service/data/test_mariadb_migration.py index c6204aa..2029e2d 100644 --- a/tests/service/data/test_mariadb_migration.py +++ b/tests/service/data/test_mariadb_migration.py @@ -5,8 +5,6 @@ import asyncio import datetime import unittest -import os -import logging import numpy as np from src.service.data.storage.maria.maria import MariaDBStorage diff --git a/tests/service/data/test_mariadb_storage.py b/tests/service/data/test_mariadb_storage.py index 4d2b609..b762300 100644 --- a/tests/service/data/test_mariadb_storage.py +++ b/tests/service/data/test_mariadb_storage.py @@ -4,9 +4,7 @@ import asyncio import unittest -import os import numpy as np -from sympy import print_tree from src.service.data.storage.maria.maria import MariaDBStorage diff --git a/tests/service/data/test_modelmesh_parser.py b/tests/service/data/test_modelmesh_parser.py index 17eabfe..a77fb6f 100644 --- a/tests/service/data/test_modelmesh_parser.py +++ b/tests/service/data/test_modelmesh_parser.py @@ -3,10 +3,7 @@ """ import unittest -from typing import Dict, Optional - import pandas as pd -from pydantic import BaseModel from src.service.data.modelmesh_parser import ModelMeshPayloadParser, PartialPayload from tests.service.data.test_utils import ModelMeshTestData diff --git a/tests/service/data/test_payload_reconciliation_maria.py b/tests/service/data/test_payload_reconciliation_maria.py index f1bdfee..72f0964 100644 --- a/tests/service/data/test_payload_reconciliation_maria.py +++ b/tests/service/data/test_payload_reconciliation_maria.py @@ -2,22 +2,12 @@ Tests for ModelMesh payload reconciliation.MariaDBStorage("root", "root", "127.0.0.1", 3306, "trustyai_database_v2") """ -import asyncio import unittest -import tempfile -import os -import base64 -import time -from datetime import datetime -from unittest import mock import uuid -import pandas as pd -import numpy as np -from src.service.data.modelmesh_parser import ModelMeshPayloadParser, PartialPayload +from src.service.data.modelmesh_parser import PartialPayload from src.service.data.storage.maria.maria import MariaDBStorage -from src.service.data.storage.pvc import PVCStorage from tests.service.data.test_payload_reconciliation_pvc import TestPayloadReconciliation from tests.service.data.test_utils import ModelMeshTestData From 9393d36d2cc81230ae44c240882913bfead1fa55 Mon Sep 17 00:00:00 2001 From: Rob Geada Date: Mon, 23 Jun 2025 15:58:53 +0100 Subject: [PATCH 06/10] Fix failing tests --- src/endpoints/data/data_upload.py | 2 +- tests/endpoints/test_upload_endpoint_pvc.py | 1 + .../data/test_payload_reconciliation_maria.py | 4 +- .../test_consumer_endpoint_reconciliation.py | 55 ++++++++++--------- 4 files changed, 33 insertions(+), 29 deletions(-) diff --git a/src/endpoints/data/data_upload.py b/src/endpoints/data/data_upload.py index e423ec8..2bfc37f 100644 --- a/src/endpoints/data/data_upload.py +++ b/src/endpoints/data/data_upload.py @@ -72,7 +72,7 @@ async def upload(payload: UploadPayload) -> Dict[str, str]: except HTTPException as e: if "Could not reconcile_kserve KServe Inference" in str(e): - raise HTTPException(status_code=400, detail=f"Could not upload payload for model {payload.model_name}.") from e + raise HTTPException(status_code=400, detail=f"Could not upload payload for model {payload.model_name}: {str(e)}") from e raise e except Exception as e: logger.error(f"Unexpected error in upload endpoint for model {payload.model_name}: {str(e)}", exc_info=True) diff --git a/tests/endpoints/test_upload_endpoint_pvc.py b/tests/endpoints/test_upload_endpoint_pvc.py index f4633cc..8ee223a 100644 --- a/tests/endpoints/test_upload_endpoint_pvc.py +++ b/tests/endpoints/test_upload_endpoint_pvc.py @@ -276,6 +276,7 @@ def test_upload_multi_input_data_no_unique_name(self): payload = generate_mismatched_shape_no_unique_name_multi_input_payload(250, 4, 3, "FP64", "TRAINING") response = self.client.post("/data/upload", json=payload) self.assertEqual(response.status_code,400) + print(response.text) self.assertIn("input shapes were mismatched", response.text) self.assertIn("[250, 4]", response.text) diff --git a/tests/service/data/test_payload_reconciliation_maria.py b/tests/service/data/test_payload_reconciliation_maria.py index 72f0964..715ec28 100644 --- a/tests/service/data/test_payload_reconciliation_maria.py +++ b/tests/service/data/test_payload_reconciliation_maria.py @@ -1,7 +1,7 @@ """ Tests for ModelMesh payload reconciliation.MariaDBStorage("root", "root", "127.0.0.1", 3306, "trustyai_database_v2") """ - +import asyncio import unittest import uuid @@ -47,7 +47,7 @@ def setUp(self): def tearDown(self): """Clean up after tests.""" - self.storage.reset_database() + asyncio.run(self.storage.reset_database()) if __name__ == "__main__": diff --git a/tests/service/test_consumer_endpoint_reconciliation.py b/tests/service/test_consumer_endpoint_reconciliation.py index 5d79797..4b6ee9b 100644 --- a/tests/service/test_consumer_endpoint_reconciliation.py +++ b/tests/service/test_consumer_endpoint_reconciliation.py @@ -26,9 +26,11 @@ def setUp(self): self.temp_dir = tempfile.TemporaryDirectory() self.storage_patch = mock.patch( - "src.endpoints.consumer.consumer_endpoint.storage_interface" + "src.endpoints.consumer.consumer_endpoint.get_global_storage_interface" ) - self.mock_storage = self.storage_patch.start() + self.mock_get_storage = self.storage_patch.start() + self.mock_storage = mock.AsyncMock() + self.mock_get_storage.return_value = self.mock_storage self.parser_patch = mock.patch.object( ModelMeshPayloadParser, "parse_input_payload" @@ -92,8 +94,8 @@ def tearDown(self): async def _test_consume_input_payload(self): """Test consuming an input payload.""" - self.mock_storage.persist_modelmesh_payload = mock.AsyncMock() - self.mock_storage.get_modelmesh_payload = mock.AsyncMock(return_value=None) + self.mock_storage.persist_partial_payload = mock.AsyncMock() + self.mock_storage.get_partial_payload = mock.AsyncMock(return_value=None) self.mock_parse_input.return_value = True self.mock_parse_output.side_effect = ValueError("Not an output payload") @@ -104,6 +106,7 @@ async def _test_consume_input_payload(self): } response = self.client.post("/consumer/kserve/v2", json=inference_payload) + print(response.text) self.assertEqual(response.status_code, 200) self.assertEqual( @@ -114,15 +117,15 @@ async def _test_consume_input_payload(self): }, ) - self.mock_storage.persist_modelmesh_payload.assert_called_once() - call_args = self.mock_storage.persist_modelmesh_payload.call_args[0] - self.assertEqual(call_args[1], self.request_id) - self.assertTrue(call_args[2]) # is_input=True + self.mock_storage.persist_partial_payload.assert_called_once() + call_kwargs = self.mock_storage.persist_partial_payload.call_args[1] + self.assertEqual(call_kwargs["payload_id"], self.request_id) + self.assertTrue(call_kwargs["is_input"]) # is_input=True async def _test_consume_output_payload(self): """Test consuming an output payload.""" - self.mock_storage.persist_modelmesh_payload = mock.AsyncMock() - self.mock_storage.get_modelmesh_payload = mock.AsyncMock(return_value=None) + self.mock_storage.persist_partial_payload = mock.AsyncMock() + self.mock_storage.get_partial_payload = mock.AsyncMock(return_value=None) self.mock_parse_input.side_effect = ValueError("Not an input payload") self.mock_parse_output.return_value = True @@ -143,23 +146,23 @@ async def _test_consume_output_payload(self): }, ) - self.mock_storage.persist_modelmesh_payload.assert_called_once() - call_args = self.mock_storage.persist_modelmesh_payload.call_args[0] - self.assertEqual(call_args[1], self.request_id) - self.assertFalse(call_args[2]) # is_input=False + self.mock_storage.persist_partial_payload.assert_called_once() + call_kwargs = self.mock_storage.persist_partial_payload.call_args[1] + self.assertEqual(call_kwargs["payload_id"], self.request_id) + self.assertFalse(call_kwargs["is_input"]) # is_input=True async def _test_reconcile_payloads(self): """Test reconciling both input and output payloads.""" # Setup mocks for correct interactions - self.mock_storage.get_modelmesh_payload = mock.AsyncMock() - self.mock_storage.get_modelmesh_payload.side_effect = [ + self.mock_storage.get_partial_payload = mock.AsyncMock() + self.mock_storage.get_partial_payload.side_effect = [ None, self.input_payload, ] - self.mock_storage.persist_modelmesh_payload = mock.AsyncMock() + self.mock_storage.persist_partial_payload = mock.AsyncMock() self.mock_storage.write_data = mock.AsyncMock() - self.mock_storage.delete_modelmesh_payload = mock.AsyncMock() + self.mock_storage.delete_partial_payload = mock.AsyncMock() with ( mock.patch( @@ -213,13 +216,13 @@ async def mock_gather_impl(*args, **kwargs): }, ) - self.mock_storage.persist_modelmesh_payload.assert_called_once() + self.mock_storage.persist_partial_payload.assert_called_once() + call_kwargs = self.mock_storage.persist_partial_payload.call_args[1] - call_args = self.mock_storage.persist_modelmesh_payload.call_args[0] - self.assertEqual(call_args[1], self.request_id) - self.assertTrue(call_args[2]) # is_input=True + self.assertEqual(call_kwargs['payload_id'], self.request_id) + self.assertTrue(call_kwargs['is_input']) # is_input=True - self.mock_storage.persist_modelmesh_payload.reset_mock() + self.mock_storage.persist_partial_payload.reset_mock() mock_parse_input.side_effect = ValueError("Not an input") mock_parse_output.side_effect = lambda x: True @@ -246,9 +249,9 @@ async def mock_gather_impl(*args, **kwargs): }, ) - call_args = self.mock_storage.persist_modelmesh_payload.call_args[0] - self.assertEqual(call_args[1], self.request_id) - self.assertFalse(call_args[2]) # is_input=False + call_kwargs = self.mock_storage.persist_partial_payload.call_args[1] + self.assertEqual(call_kwargs["payload_id"], self.request_id) + self.assertFalse(call_kwargs["is_input"]) # is_input=False mock_reconcile.assert_called_once() reconcile_args = mock_reconcile.call_args[0] From bf371d16b307ccdd042385f031f4c3efa1606c03 Mon Sep 17 00:00:00 2001 From: Rob Geada Date: Mon, 23 Jun 2025 16:06:22 +0100 Subject: [PATCH 07/10] Fix failing migration test --- tests/service/data/test_mariadb_migration.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/service/data/test_mariadb_migration.py b/tests/service/data/test_mariadb_migration.py index 2029e2d..f85ffdc 100644 --- a/tests/service/data/test_mariadb_migration.py +++ b/tests/service/data/test_mariadb_migration.py @@ -34,7 +34,10 @@ def tearDown(self): async def _test_retrieve_data(self): # total data checks available_datasets = self.storage.list_all_datasets() - self.assertEqual(len(available_datasets), 12) + for i in [1,2,3,4]: + for split in ["inputs", "outputs", "metadata"]: + self.assertIn(f"model{i}_{split}", available_datasets) + self.assertGreaterEqual(len(available_datasets), 12) # model 1 checks model_1_inputs = await self.storage.read_data("model1_inputs", 0, 100) From 60909a67f89a673f0ceb14e2786993a7b7bd37db Mon Sep 17 00:00:00 2001 From: Rob Geada Date: Tue, 15 Jul 2025 09:25:07 +0100 Subject: [PATCH 08/10] Fix fastapi utils bug --- pyproject.toml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 63e70d5..2c6f60a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,6 +7,7 @@ requires-python = "~=3.11" readme = "README.md" dependencies = [ "fastapi>=0.115.9,<0.116", + "fastapi-utils>=0.8.0", "pandas>=2.2.3,<3", "prometheus-client>=0.21.1,<0.23", "pydantic>=2.4.2,<3", @@ -29,7 +30,7 @@ dev = [ "pytest-cov>=4.1.0,<7", "httpx>=0.25.0,<0.29", ] -eval = ["lm-eval[api]==0.4.4", "fastapi-utils>=0.8.0", "typing-inspect==0.9.0"] +eval = ["lm-eval[api]==0.4.4", "typing-inspect==0.9.0"] protobuf = ["numpy>=1.24.0,<3", "grpcio>=1.62.1,<2", "grpcio-tools>=1.62.1,<2"] mariadb = ["mariadb>=1.1.12", "javaobj-py3==0.4.4"] From 6529324d1d464c4ead11b91bb47c113240fb0c20 Mon Sep 17 00:00:00 2001 From: Neri Carcasci Date: Thu, 28 Aug 2025 16:14:49 +0100 Subject: [PATCH 09/10] Refactor data upload/storage: unify PVC and Maria backends, async interface, dtype/shape validation --- pyproject.toml | 1 + src/endpoints/consumer/__init__.py | 94 +++++++++++++++++-- src/endpoints/data/data_upload.py | 8 +- src/service/data/storage/maria/maria.py | 25 +++-- src/service/data/storage/pvc.py | 27 ++++-- src/service/data/storage/storage_interface.py | 2 +- tests/service/data/test_async_contract.py | 36 +++++++ 7 files changed, 167 insertions(+), 26 deletions(-) create mode 100644 tests/service/data/test_async_contract.py diff --git a/pyproject.toml b/pyproject.toml index 2c6f60a..d3c573a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,6 +18,7 @@ dependencies = [ "h5py>=3.13.0,<4", "scikit-learn", "aif360", + "typing_inspect>=0.9", ] [project.optional-dependencies] diff --git a/src/endpoints/consumer/__init__.py b/src/endpoints/consumer/__init__.py index 15ad2a9..2c4c891 100644 --- a/src/endpoints/consumer/__init__.py +++ b/src/endpoints/consumer/__init__.py @@ -1,6 +1,7 @@ -from typing import Optional, Dict, List, Literal - -from pydantic import BaseModel +from typing import Optional, Dict, List, Literal, Any +from enum import Enum +from pydantic import BaseModel, model_validator, ConfigDict +import numpy as np PartialKind = Literal["request", "response"] @@ -51,13 +52,94 @@ def set_model_id(self, model_id: str): self.modelid = model_id +class KServeDataType(str, Enum): + BOOL = "BOOL" + INT8 = "INT8" + INT16 = "INT16" + INT32 = "INT32" + INT64 = "INT64" + UINT8 = "UINT8" + UINT16 = "UINT16" + UINT32 = "UINT32" + UINT64 = "UINT64" + FP16 = "FP16" + FP32 = "FP32" + FP64 = "FP64" + BYTES = "BYTES" + +K_SERVE_NUMPY_DTYPES = { + KServeDataType.INT8: np.int8, + KServeDataType.INT16: np.int16, + KServeDataType.INT32: np.int32, + KServeDataType.INT64: np.int64, + KServeDataType.UINT8: np.uint8, + KServeDataType.UINT16: np.uint16, + KServeDataType.UINT32: np.uint32, + KServeDataType.UINT64: np.uint64, + KServeDataType.FP16: np.float16, + KServeDataType.FP32: np.float32, + KServeDataType.FP64: np.float64, +} + class KServeData(BaseModel): + + model_config = ConfigDict(use_enum_values=True) + name: str shape: List[int] - datatype: str + datatype: KServeDataType parameters: Optional[Dict[str, str]] = None - data: List - + data: List[Any] + + @model_validator(mode="after") + def _validate_shape(self) -> "KServeData": + raw = np.array(self.data, dtype=object) + actual = tuple(raw.shape) + declared = tuple(self.shape) + if declared != actual: + raise ValueError( + f"Declared shape {declared} does not match data shape {actual}" + ) + return self + + @model_validator(mode="after") + def validate_data_matches_type(self) -> "KServeData": + flat = np.array(self.data, dtype=object).flatten() + + if self.datatype == KServeDataType.BYTES: + for v in flat: + if not isinstance(v, str): + raise ValueError( + f"All values must be JSON strings for datatype {self.datatype}; " + f"found {type(v).__name__}: {v}" + ) + return self + + if self.datatype == KServeDataType.BOOL: + for v in flat: + if not (isinstance(v, (bool, int)) and v in (0, 1, True, False)): + raise ValueError( + f"All values must be bool or 0/1 for datatype {self.datatype}; found {v}" + ) + return self + + np_dtype = K_SERVE_NUMPY_DTYPES.get(self.datatype) + if np_dtype is None: + raise ValueError(f"Unsupported datatype: {self.datatype}") + + if np.dtype(np_dtype).kind == "u": + for v in flat: + if isinstance(v, (int, float)) and v < 0: + raise ValueError( + f"Negative value {v} not allowed for unsigned type {self.datatype}" + ) + + try: + np.array(flat, dtype=np_dtype) + except (ValueError, TypeError) as e: + raise ValueError(f"Data cannot be cast to {self.datatype}: {e}") + + return self class KServeInferenceRequest(BaseModel): id: Optional[str] = None diff --git a/src/endpoints/data/data_upload.py b/src/endpoints/data/data_upload.py index 2bfc37f..845012a 100644 --- a/src/endpoints/data/data_upload.py +++ b/src/endpoints/data/data_upload.py @@ -6,7 +6,7 @@ from pydantic import BaseModel from src.endpoints.consumer.consumer_endpoint import consume_cloud_event -from src.endpoints.consumer import KServeInferenceRequest, KServeInferenceResponse +from src.endpoints.consumer import KServeInferenceRequest, KServeInferenceResponse, KServeData from src.service.constants import TRUSTYAI_TAG_PREFIX from src.service.data.model_data import ModelData @@ -34,6 +34,8 @@ def validate_data_tag(tag: str) -> Optional[str]: ) return None + + @router.post("/data/upload") async def upload(payload: UploadPayload) -> Dict[str, str]: """Upload model data""" @@ -46,7 +48,9 @@ async def upload(payload: UploadPayload) -> Dict[str, str]: logger.info(f"Received upload request for model: {payload.model_name}") # overwrite response model name with provided model name - payload.response.model_name = payload.model_name + if payload.response.model_name != payload.model_name: + logger.warning(f"Response model name '{payload.response.model_name}' differs from request model name '{payload.model_name}'. Using '{payload.model_name}'.") + payload.response.model_name = payload.model_name req_id = str(uuid.uuid4()) model_data = ModelData(payload.model_name) diff --git a/src/service/data/storage/maria/maria.py b/src/service/data/storage/maria/maria.py index d241052..0bbdcc6 100644 --- a/src/service/data/storage/maria/maria.py +++ b/src/service/data/storage/maria/maria.py @@ -67,14 +67,17 @@ def __init__(self, user: str, password: str, host: str, port: int, database: str cursor.execute(f"CREATE TABLE IF NOT EXISTS `{self.partial_payload_table}` (payload_id varchar(255), is_input BOOLEAN, payload_data LONGBLOB)") if attempt_migration: - self._migrate_from_legacy_db() + # Schedule the migration to run asynchronously + import asyncio + loop = asyncio.get_event_loop() + loop.create_task(self._migrate_from_legacy_db()) # === MIGRATORS ================================================================================ - def _migrate_from_legacy_db(self): + async def _migrate_from_legacy_db(self): legacy_reader = LegacyMariaDBStorageReader(user=self.user, password=self.password, host=self.host, port=self.port, database=self.database) if legacy_reader.legacy_data_exists(): logger.info("Legacy TrustyAI v1 data exists in database, checking if a migration is necessary.") - asyncio.run(legacy_reader.migrate_data(self)) + await legacy_reader.migrate_data(self) # === INTERNAL HELPER FUNCTIONS ================================================================ @@ -115,14 +118,16 @@ async def dataset_exists(self, dataset_name: str) -> bool: except mariadb.ProgrammingError: return False - def list_all_datasets(self): - """ - List all available datasets in the database. - """ + def _list_all_datasets_sync(self): with self.connection_manager as (conn, cursor): cursor.execute(f"SELECT dataset_name FROM `{self.dataset_reference_table}`") - results = [x[0] for x in cursor.fetchall()] - return results + return [x[0] for x in cursor.fetchall()] + + async def list_all_datasets(self): + """ + List all datasets in the database. + """ + return await asyncio.to_thread(self._list_all_datasets_sync) @require_existing_dataset @@ -379,7 +384,7 @@ async def delete_dataset(self, dataset_name: str): conn.commit() async def delete_all_datasets(self): - for dataset_name in self.list_all_datasets(): + for dataset_name in await self.list_all_datasets(): logger.warning(f"Deleting dataset {dataset_name}") await self.delete_dataset(dataset_name) diff --git a/src/service/data/storage/pvc.py b/src/service/data/storage/pvc.py index d81b405..23dc24b 100644 --- a/src/service/data/storage/pvc.py +++ b/src/service/data/storage/pvc.py @@ -106,10 +106,15 @@ async def dataset_exists(self, dataset_name: str) -> bool: return False - def list_all_datasets(self) -> List[str]: - """ List all datasets known by the dataset """ - return [fname.replace(f"_{self.data_file}", "") for fname in os.listdir(self.data_directory) if self.data_file in fname] + def _list_all_datasets_sync(self) -> List[str]: + return [ + fname.replace(f"_{self.data_file}", "") + for fname in os.listdir(self.data_directory) + if self.data_file in fname + ] + async def list_all_datasets(self) -> List[str]: + return await asyncio.to_thread(self._list_all_datasets_sync) async def dataset_rows(self, dataset_name: str) -> int: """Number of data rows in dataset, returns a FileNotFoundError if the dataset does not exist""" @@ -207,14 +212,15 @@ async def _write_raw_data( ] # to-do: tune this value? - if isinstance(new_rows.data, np.dtypes.VoidDType): - new_rows = new_rows.astype("V400") + # if isinstance(new_rows.data, np.dtypes.VoidDType): + # new_rows = new_rows.astype("V400") dataset = db.create_dataset( allocated_dataset_name, data=new_rows, maxshape=max_shape, chunks=True, + dtype=new_rows.dtype # use the dtype of the data ) dataset.attrs[COLUMN_NAMES_ATTRIBUTE] = column_names dataset.attrs[BYTES_ATTRIBUTE] = is_bytes @@ -231,9 +237,16 @@ async def write_data(self, dataset_name: str, new_rows, column_names: List[str]) or not isinstance(new_rows, np.ndarray) and list_utils.contains_non_numeric(new_rows) ): + serialized = list_utils.serialize_rows(new_rows, MAX_VOID_TYPE_LENGTH) + arr = np.array(serialized) + if arr.ndim == 1: + arr = arr.reshape(-1, 1) await self._write_raw_data( - dataset_name, list_utils.serialize_rows(new_rows, MAX_VOID_TYPE_LENGTH), column_names - ) + dataset_name, + arr, + column_names, + is_bytes=True, +) else: await self._write_raw_data(dataset_name, np.array(new_rows), column_names) diff --git a/src/service/data/storage/storage_interface.py b/src/service/data/storage/storage_interface.py index 2b12182..84c005d 100644 --- a/src/service/data/storage/storage_interface.py +++ b/src/service/data/storage/storage_interface.py @@ -11,7 +11,7 @@ async def dataset_exists(self, dataset_name: str) -> bool: pass @abstractmethod - def list_all_datasets(self) -> List[str]: + async def list_all_datasets(self) -> List[str]: pass @abstractmethod diff --git a/tests/service/data/test_async_contract.py b/tests/service/data/test_async_contract.py new file mode 100644 index 0000000..99688df --- /dev/null +++ b/tests/service/data/test_async_contract.py @@ -0,0 +1,36 @@ +import inspect +import importlib +import pytest + + +CLASSES = [ + ("src.service.data.storage.maria.maria", "MariaDBStorage"), + ("src.service.data.storage.pvc", "PVCStorage"), +] + +SYNC_ALLOWED = { + "PVCStorage": {"allocate_valid_dataset_name", "get_lock"}, # no resource usage, non blocking +} + +def public_methods(cls): + for name, fn in inspect.getmembers(cls, predicate=inspect.isfunction): + if name.startswith("_"): + continue + if name in {"__init__", "__new__", "__class_getitem__"}: + continue + yield name, fn + +@pytest.mark.parametrize("module_path,class_name", CLASSES) +def test_storage_methods_are_async(module_path, class_name): + mod = importlib.import_module(module_path) + cls = getattr(mod, class_name) + allowed = SYNC_ALLOWED.get(class_name, set()) + + offenders = [] + for name, fn in public_methods(cls): + if name in allowed: + continue + if not inspect.iscoroutinefunction(fn): + offenders.append(name) + + assert not offenders, f"{class_name} has non-async methods: {offenders}" \ No newline at end of file From 24e2c0c8aece76723f50cdb0f1fe51fa0beb7887 Mon Sep 17 00:00:00 2001 From: Neri Carcasci Date: Thu, 28 Aug 2025 16:36:43 +0100 Subject: [PATCH 10/10] added bash script for e2e testing of upload & storage endpoint --- scripts/test_upload_endpoint.sh | 203 ++++++++++++++++++++++++++++++++ 1 file changed, 203 insertions(+) create mode 100644 scripts/test_upload_endpoint.sh diff --git a/scripts/test_upload_endpoint.sh b/scripts/test_upload_endpoint.sh new file mode 100644 index 0000000..ee814d1 --- /dev/null +++ b/scripts/test_upload_endpoint.sh @@ -0,0 +1,203 @@ +#!/usr/bin/env bash +# scripts/test_upload_endpoint.sh +# +# KServe-strict endpoint test for /data/upload +# +# +# Usage: +# ENDPOINT="https:///data/upload" \ +# MODEL="gaussian-credit-model" \ +# TAG="TRAINING" \ +# ./scripts/test_upload_endpoint.sh + +set -uo pipefail + +# --- Config via env vars (no secrets hardcoded) --- +: "${ENDPOINT:?ENDPOINT is required, e.g. https://.../data/upload}" +MODEL="${MODEL:-gaussian-credit-model}" +# Separate model for BYTES to avoid mixing with an existing numeric dataset +MODEL_BYTES="${MODEL_BYTES:-${MODEL}-bytes}" +TAG="${TAG:-TRAINING}" +AUTH_HEADER="${AUTH_HEADER:-}" # e.g. 'Authorization: Bearer ' + +CURL_OPTS=( --silent --show-error -H "Content-Type: application/json" ) +[[ -n "$AUTH_HEADER" ]] && CURL_OPTS+=( -H "$AUTH_HEADER" ) + +RED=$'\033[31m'; GREEN=$'\033[32m'; YELLOW=$'\033[33m'; CYAN=$'\033[36m'; RESET=$'\033[0m' +pass_cnt=0; fail_cnt=0; results=() +have_jq=1; command -v jq >/dev/null 2>&1 || have_jq=0 + +line(){ printf '%s\n' "--------------------------------------------------------------------------------"; } +snippet(){ if (( have_jq )); then echo "$1" | jq -r 'tostring' 2>/dev/null | head -c 240; else echo "$1" | head -c 240; fi; } + +# ---------- payload builders ---------- +mk_inputs_2x4_int32() { + cat < uses MODEL + local req="$1" out="$2" + cat < 0 )); then + echo "${YELLOW}Details for failures:${RESET}" + for r in "${results[@]}"; do + IFS='|' read -r status name http body <<<"$r" + if [[ "$status" == "FAIL" ]]; then + printf "%s[FAIL]%s %s (HTTP %s)\n" "$RED" "$RESET" "$name" "$http" + printf " body: %s\n" "$body" + line + fi + done +fi \ No newline at end of file