diff --git a/src/endpoints/data/data_download.py b/src/endpoints/data/data_download.py index eb65111..4e39858 100644 --- a/src/endpoints/data/data_download.py +++ b/src/endpoints/data/data_download.py @@ -1,32 +1,32 @@ -from fastapi import APIRouter, HTTPException -from pydantic import BaseModel -from typing import List, Any, Optional import logging -router = APIRouter() -logger = logging.getLogger(__name__) - - -class RowMatcher(BaseModel): - columnName: str - operation: str - values: List[Any] +import pandas as pd +from fastapi import APIRouter, HTTPException +from src.service.utils.download import ( + DataRequestPayload, + DataResponsePayload, + apply_filters, # ← New utility function + load_model_dataframe, +) -class DataRequestPayload(BaseModel): - modelId: str - matchAny: Optional[List[RowMatcher]] = None - matchAll: Optional[List[RowMatcher]] = None - matchNone: Optional[List[RowMatcher]] = None +router = APIRouter() +logger = logging.getLogger(__name__) @router.post("/data/download") -async def download_data(payload: DataRequestPayload): - """Download model data.""" +async def download_data(payload: DataRequestPayload) -> DataResponsePayload: + """Download model data with filtering.""" try: logger.info(f"Received data download request for model: {payload.modelId}") - # TODO: Implement - return {"status": "success", "data": []} + df = await load_model_dataframe(payload.modelId) + if df.empty: + return DataResponsePayload(dataCSV="") + df = apply_filters(df, payload) + csv_data = df.to_csv(index=False) + return DataResponsePayload(dataCSV=csv_data) + except HTTPException: + raise except Exception as e: logger.error(f"Error downloading data: {str(e)}") - raise HTTPException(status_code=500, detail=f"Error downloading data: {str(e)}") + raise HTTPException(status_code=500, detail=f"Error downloading data: {str(e)}") \ No newline at end of file diff --git a/src/endpoints/data/data_upload.py b/src/endpoints/data/data_upload.py index 49889db..6f20bb2 100644 --- a/src/endpoints/data/data_upload.py +++ b/src/endpoints/data/data_upload.py @@ -1,27 +1,39 @@ +import logging +import uuid +from datetime import datetime +from typing import Any, Dict, List, Optional + +import numpy as np from fastapi import APIRouter, HTTPException from pydantic import BaseModel -from typing import Dict, Any -import logging + +from src.service.constants import INPUT_SUFFIX, METADATA_SUFFIX, OUTPUT_SUFFIX +from src.service.data.modelmesh_parser import ModelMeshPayloadParser +from src.service.data.storage import get_storage_interface +from src.service.utils.upload import process_upload_request router = APIRouter() logger = logging.getLogger(__name__) -class ModelInferJointPayload(BaseModel): +class UploadPayload(BaseModel): model_name: str - data_tag: str = None + data_tag: Optional[str] = None is_ground_truth: bool = False request: Dict[str, Any] - response: Dict[str, Any] + response: Optional[Dict[str, Any]] = None @router.post("/data/upload") -async def upload_data(payload: ModelInferJointPayload): - """Upload a batch of model data to TrustyAI.""" +async def upload(payload: UploadPayload) -> Dict[str, str]: + """Upload model data - regular or ground truth.""" try: - logger.info(f"Received data upload for model: {payload.model_name}") - # TODO: Implement - return {"status": "success", "message": "Data uploaded successfully"} + logger.info(f"Received upload request for model: {payload.model_name}") + result = await process_upload_request(payload) + logger.info(f"Upload completed for model: {payload.model_name}") + return result + except HTTPException: + raise except Exception as e: - logger.error(f"Error uploading data: {str(e)}") - raise HTTPException(status_code=500, detail=f"Error uploading data: {str(e)}") + logger.error(f"Unexpected error in upload endpoint for model {payload.model_name}: {str(e)}", exc_info=True) + raise HTTPException(500, f"Internal server error: {str(e)}") \ No newline at end of file diff --git a/src/service/utils/download.py b/src/service/utils/download.py new file mode 100644 index 0000000..a6deea3 --- /dev/null +++ b/src/service/utils/download.py @@ -0,0 +1,310 @@ +import logging +import numbers +import pickle +from typing import Any, List + +import pandas as pd +from fastapi import HTTPException +from pydantic import BaseModel, Field + +from src.service.data.storage import get_storage_interface + +logger = logging.getLogger(__name__) + + +class RowMatcher(BaseModel): + """Represents a row matching condition for data filtering.""" + + columnName: str + operation: str # "EQUALS" or "BETWEEN" + values: List[Any] + + +class DataRequestPayload(BaseModel): + """Request payload for data download operations.""" + + modelId: str + matchAny: List[RowMatcher] = Field(default_factory=list) + matchAll: List[RowMatcher] = Field(default_factory=list) + matchNone: List[RowMatcher] = Field(default_factory=list) + + +class DataResponsePayload(BaseModel): + """Response payload containing filtered data as CSV.""" + + dataCSV: str + + +def get_storage() -> Any: + """Get storage interface instance.""" + return get_storage_interface() + + +async def load_model_dataframe(model_id: str) -> pd.DataFrame: + try: + storage = get_storage_interface() + print(f"DEBUG: storage type = {type(storage)}") + input_data, input_cols = await storage.read_data(f"{model_id}_inputs") + output_data, output_cols = await storage.read_data(f"{model_id}_outputs") + metadata_data, metadata_cols = await storage.read_data(f"{model_id}_metadata") + if input_data is None or output_data is None or metadata_data is None: + raise HTTPException(404, f"Model {model_id} not found") + df = pd.DataFrame() + if len(input_data) > 0: + input_df = pd.DataFrame(input_data, columns=input_cols) + df = pd.concat([df, input_df], axis=1) + if len(output_data) > 0: + output_df = pd.DataFrame(output_data, columns=output_cols) + df = pd.concat([df, output_df], axis=1) + if len(metadata_data) > 0: + logger.debug(f"Metadata data type: {type(metadata_data)}") + logger.debug(f"First row type: {type(metadata_data[0]) if len(metadata_data) > 0 else 'empty'}") + logger.debug( + f"First row dtype: {metadata_data[0].dtype if hasattr(metadata_data[0], 'dtype') else 'no dtype'}" + ) + metadata_df = pd.DataFrame(metadata_data, columns=metadata_cols) + trusty_mapping = { + "ID": "trustyai.ID", + "MODEL_ID": "trustyai.MODEL_ID", + "TIMESTAMP": "trustyai.TIMESTAMP", + "TAG": "trustyai.TAG", + "INDEX": "trustyai.INDEX", + } + for orig_col in metadata_cols: + trusty_col = trusty_mapping.get(orig_col, orig_col) + df[trusty_col] = metadata_df[orig_col] + return df + except HTTPException: + raise + except Exception as e: + logger.error(f"Error loading model dataframe: {e}") + raise HTTPException(500, f"Error loading model data: {str(e)}") + + +def apply_filters(df: pd.DataFrame, payload: DataRequestPayload) -> pd.DataFrame: + """ + Apply all filters to DataFrame with performance optimization. + """ + if not any([payload.matchAll, payload.matchAny, payload.matchNone]): + return df + has_timestamp_filter = _has_timestamp_filters(payload) + if has_timestamp_filter: + logger.debug("Using boolean mask approach for timestamp filters") + return _apply_filters_with_boolean_masks(df, payload) + else: + logger.debug("Using query approach for non-timestamp filters") + return _apply_filters_with_query(df, payload) + + +def _has_timestamp_filters(payload: DataRequestPayload) -> bool: + """Check if payload contains any timestamp filters.""" + for matcher_list in [payload.matchAll or [], payload.matchAny or [], payload.matchNone or []]: + for matcher in matcher_list: + if matcher.columnName == "trustyai.TIMESTAMP": + return True + return False + + +def _apply_filters_with_query(df: pd.DataFrame, payload: DataRequestPayload) -> pd.DataFrame: + """Apply filters using pandas query (optimized for non-timestamp filters).""" + query_expr = _build_query_expression(df, payload) + if query_expr: + logger.debug(f"Executing query: {query_expr}") + try: + df = df.query(query_expr) + except Exception as e: + logger.error(f"Query execution failed: {query_expr}") + raise HTTPException(status_code=400, detail=f"Filter execution failed: {str(e)}") + return df + + +def _apply_filters_with_boolean_masks(df: pd.DataFrame, payload: DataRequestPayload) -> pd.DataFrame: + """Apply filters using boolean masks (optimized for timestamp filters).""" + final_mask = pd.Series(True, index=df.index) + if payload.matchAll: + for matcher in payload.matchAll: + matcher_mask = _get_matcher_mask(df, matcher, negate=False) + final_mask &= matcher_mask + if payload.matchNone: + for matcher in payload.matchNone: + matcher_mask = _get_matcher_mask(df, matcher, negate=True) + final_mask &= matcher_mask + if payload.matchAny: + any_mask = pd.Series(False, index=df.index) + for matcher in payload.matchAny: + matcher_mask = _get_matcher_mask(df, matcher, negate=False) + any_mask |= matcher_mask + final_mask &= any_mask + return df[final_mask] + + +def _get_matcher_mask(df: pd.DataFrame, matcher: RowMatcher, negate: bool = False) -> pd.Series: + """ + Get boolean mask for a single matcher with comprehensive validation. + """ + column_name = matcher.columnName + values = matcher.values + if matcher.operation not in ["EQUALS", "BETWEEN"]: + raise HTTPException(status_code=400, detail="RowMatch operation must be one of [BETWEEN, EQUALS]") + if column_name not in df.columns: + raise HTTPException(status_code=400, detail=f"No feature or output found with name={column_name}") + if matcher.operation == "EQUALS": + mask = df[column_name].isin(values) + elif matcher.operation == "BETWEEN": + mask = _create_between_mask(df, column_name, values) + if negate: + mask = ~mask + + return mask + + +def _create_between_mask(df: pd.DataFrame, column_name: str, values: List[Any]) -> pd.Series: + """Create boolean mask for BETWEEN operation with type-specific handling.""" + errors = [] + if len(values) != 2: + errors.append( + f"BETWEEN operation must contain exactly two values, describing the lower and upper bounds of the desired range. Received {len(values)} values" + ) + if column_name == "trustyai.TIMESTAMP": + if errors: + raise HTTPException(status_code=400, detail=", ".join(errors)) + try: + start_time = pd.to_datetime(str(values[0])) + end_time = pd.to_datetime(str(values[1])) + df_times = pd.to_datetime(df[column_name]) + return (df_times >= start_time) & (df_times < end_time) + except Exception as e: + raise HTTPException( + status_code=400, detail=f"Timestamp value is unparseable as an ISO_LOCAL_DATE_TIME: {str(e)}" + ) + elif column_name == "trustyai.INDEX": + if errors: + raise HTTPException(status_code=400, detail=", ".join(errors)) + min_val, max_val = sorted([int(v) for v in values]) + return (df[column_name] >= min_val) & (df[column_name] < max_val) + else: + if not all(isinstance(v, numbers.Number) for v in values): + errors.append( + "BETWEEN operation must only contain numbers, describing the lower and upper bounds of the desired range. Received non-numeric values" + ) + if errors: + raise HTTPException(status_code=400, detail=", ".join(errors)) + min_val, max_val = sorted(values) + try: + if df[column_name].dtype in ["int64", "float64", "int32", "float32"]: + return (df[column_name] >= min_val) & (df[column_name] < max_val) + else: + numeric_column = pd.to_numeric(df[column_name], errors="raise") + return (numeric_column >= min_val) & (numeric_column < max_val) + except (ValueError, TypeError): + raise HTTPException( + status_code=400, + detail=f"Column '{column_name}' contains non-numeric values that cannot be compared with BETWEEN operation.", + ) + + +def _build_query_expression(df: pd.DataFrame, payload: DataRequestPayload) -> str: + """Build optimized pandas query expression for all filters.""" + conditions = [] + if payload.matchAll: + for matcher in payload.matchAll: + condition = _build_condition(df, matcher, negate=False) + if condition: + conditions.append(condition) + if payload.matchNone: + for matcher in payload.matchNone: + condition = _build_condition(df, matcher, negate=True) + if condition: + conditions.append(condition) + if payload.matchAny: + any_conditions = [] + for matcher in payload.matchAny: + condition = _build_condition(df, matcher, negate=False) + if condition: + any_conditions.append(condition) + if any_conditions: + any_expr = " | ".join(f"({cond})" for cond in any_conditions) + conditions.append(f"({any_expr})") + return " & ".join(f"({cond})" for cond in conditions) if conditions else "" + + +def _build_condition(df: pd.DataFrame, matcher: RowMatcher, negate: bool = False) -> str: + """Build a single condition for pandas query.""" + column_name = matcher.columnName + values = matcher.values + if matcher.operation not in ["EQUALS", "BETWEEN"]: + raise HTTPException(status_code=400, detail="RowMatch operation must be one of [BETWEEN, EQUALS]") + if column_name not in df.columns: + raise HTTPException(status_code=400, detail=f"No feature or output found with name={column_name}") + safe_column = _sanitize_column_name(column_name) + if matcher.operation == "EQUALS": + condition = _build_equals_condition(safe_column, values, df[column_name].dtype) + elif matcher.operation == "BETWEEN": + condition = _build_between_condition(safe_column, values, column_name, df[column_name].dtype) + if negate: + condition = f"~({condition})" + return condition + + +def _sanitize_column_name(column_name: str) -> str: + """Sanitize column name for pandas query syntax.""" + if "." in column_name or column_name.startswith("trustyai"): + return f"`{column_name}`" + return column_name + + +def _build_equals_condition(safe_column: str, values: List[Any], dtype) -> str: + """Build EQUALS condition for query with optimization.""" + if len(values) == 1: + val = _format_value_for_query(values[0], dtype) + return f"{safe_column} == {val}" + else: + formatted_values = [_format_value_for_query(v, dtype) for v in values] + values_str = "[" + ", ".join(formatted_values) + "]" + return f"{safe_column}.isin({values_str})" + + +def _build_between_condition(safe_column: str, values: List[Any], original_column: str, dtype) -> str: + """Build BETWEEN condition for query with comprehensive validation.""" + errors = [] + if len(values) != 2: + errors.append( + f"BETWEEN operation must contain exactly two values, describing the lower and upper bounds of the desired range. Received {len(values)} values" + ) + if original_column == "trustyai.TIMESTAMP": + if errors: + raise HTTPException(status_code=400, detail=", ".join(errors)) + try: + start_time = pd.to_datetime(str(values[0])) + end_time = pd.to_datetime(str(values[1])) + return f"'{start_time}' <= {safe_column} < '{end_time}'" + except Exception as e: + raise HTTPException( + status_code=400, detail=f"Timestamp value is unparseable as an ISO_LOCAL_DATE_TIME: {str(e)}" + ) + elif original_column == "trustyai.INDEX": + if errors: + raise HTTPException(status_code=400, detail=", ".join(errors)) + min_val, max_val = sorted([int(v) for v in values]) + return f"{min_val} <= {safe_column} < {max_val}" + else: + if not all(isinstance(v, numbers.Number) for v in values): + errors.append( + "BETWEEN operation must only contain numbers, describing the lower and upper bounds of the desired range. Received non-numeric values" + ) + if errors: + raise HTTPException(status_code=400, detail=", ".join(errors)) + min_val, max_val = sorted(values) + return f"{min_val} <= {safe_column} < {max_val}" + + +def _format_value_for_query(value: Any, dtype) -> str: + """Format value appropriately for pandas query syntax.""" + if isinstance(value, str): + escaped = value.replace("'", "\\'") + return f"'{escaped}'" + elif isinstance(value, (int, float)): + return str(value) + else: + escaped = str(value).replace("'", "\\'") + return f"'{escaped}'" \ No newline at end of file diff --git a/src/service/utils/upload.py b/src/service/utils/upload.py new file mode 100644 index 0000000..58f3942 --- /dev/null +++ b/src/service/utils/upload.py @@ -0,0 +1,545 @@ +import logging +import uuid +from datetime import datetime +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional, Tuple + +import numpy as np +from fastapi import HTTPException + +from src.service.constants import ( + INPUT_SUFFIX, + METADATA_SUFFIX, + OUTPUT_SUFFIX, + TRUSTYAI_TAG_PREFIX, +) +from src.service.data.modelmesh_parser import ModelMeshPayloadParser +from src.service.data.storage import get_storage_interface +from src.service.utils import list_utils +from src.endpoints.consumer.consumer_endpoint import process_payload + + +logger = logging.getLogger(__name__) + + +METADATA_STRING_MAX_LENGTH = 100 + + +class KServeDataAdapter: + """ + Convert upload tensors to consumer endpoint format. + """ + + def __init__(self, tensor_dict: Dict[str, Any], numpy_array: np.ndarray): + """Initialize adapter with validated data.""" + self._name = tensor_dict.get("name", "unknown") + self._shape = tensor_dict.get("shape", []) + self._datatype = tensor_dict.get("datatype", "FP64") + self._data = numpy_array # Keep numpy array intact + + @property + def name(self) -> str: + return self._name + + @property + def shape(self) -> List[int]: + return self._shape + + @property + def datatype(self) -> str: + return self._datatype + + @property + def data(self) -> np.ndarray: + """Returns numpy array with .shape attribute as expected by consumer endpoint.""" + return self._data + + +class ConsumerEndpointAdapter: + """ + Consumer endpoint's expected structure. + """ + + def __init__(self, adapted_tensors: List[KServeDataAdapter]): + self.tensors = adapted_tensors + self.id = f"upload_request_{uuid.uuid4().hex[:8]}" + + +async def process_upload_request(payload: Any) -> Dict[str, str]: + """ + Process complete upload request with validation and data handling. + """ + try: + model_name = ModelMeshPayloadParser.standardize_model_id(payload.model_name) + if payload.data_tag: + error = validate_data_tag(payload.data_tag) + if error: + raise HTTPException(400, error) + inputs = payload.request.get("inputs", []) + outputs = payload.response.get("outputs", []) if payload.response else [] + if not inputs: + raise HTTPException(400, "Missing input tensors") + if payload.is_ground_truth and not outputs: + raise HTTPException(400, "Ground truth uploads require output tensors") + + input_arrays, input_names, _, execution_ids = process_tensors_using_kserve_logic(inputs) + if outputs: + output_arrays, output_names, _, _ = process_tensors_using_kserve_logic(outputs) + else: + output_arrays, output_names = [], [] + error = validate_input_shapes(input_arrays, input_names) + if error: + raise HTTPException(400, f"One or more errors in input tensors: {error}") + if payload.is_ground_truth: + return await _process_ground_truth_data( + model_name, input_arrays, input_names, output_arrays, output_names, execution_ids + ) + else: + return await _process_regular_data( + model_name, input_arrays, input_names, output_arrays, output_names, execution_ids, payload.data_tag + ) + except ProcessingError as e: + raise HTTPException(400, str(e)) + except ValidationError as e: + raise HTTPException(400, str(e)) + + +async def _process_ground_truth_data( + model_name: str, + input_arrays: List[np.ndarray], + input_names: List[str], + output_arrays: List[np.ndarray], + output_names: List[str], + execution_ids: Optional[List[str]], +) -> Dict[str, str]: + """Process ground truth data upload.""" + if not execution_ids: + raise HTTPException(400, "Ground truth requires execution IDs") + result = await handle_ground_truths( + model_name, + input_arrays, + input_names, + output_arrays, + output_names, + [sanitize_id(id) for id in execution_ids], + ) + if not result.success: + raise HTTPException(400, result.message) + result_data = result.data + if result_data is None: + raise HTTPException(500, "Ground truth processing failed") + gt_name = f"{model_name}_ground_truth" + storage_interface = get_storage_interface() + await storage_interface.write_data(gt_name + OUTPUT_SUFFIX, result_data["outputs"], result_data["output_names"]) + await storage_interface.write_data( + gt_name + METADATA_SUFFIX, + result_data["metadata"], + result_data["metadata_names"], + ) + logger.info(f"Ground truth data saved for model: {model_name}") + return {"message": result.message} + + +async def _process_regular_data( + model_name: str, + input_arrays: List[np.ndarray], + input_names: List[str], + output_arrays: List[np.ndarray], + output_names: List[str], + execution_ids: Optional[List[str]], + data_tag: Optional[str], +) -> Dict[str, str]: + """Process regular model data upload.""" + n_rows = input_arrays[0].shape[0] + exec_ids = execution_ids or [str(uuid.uuid4()) for _ in range(n_rows)] + input_data = _flatten_tensor_data(input_arrays, n_rows) + output_data = _flatten_tensor_data(output_arrays, n_rows) + metadata, metadata_cols = _create_metadata(exec_ids, model_name, data_tag) + await save_model_data( + model_name, + np.array(input_data), + input_names, + np.array(output_data), + output_names, + metadata, + metadata_cols, + ) + logger.info(f"Regular data saved for model: {model_name}, rows: {n_rows}") + return {"message": f"{n_rows} datapoints added to {model_name}"} + + +def _flatten_tensor_data(arrays: List[np.ndarray], n_rows: int) -> List[List[Any]]: + """ + Flatten tensor arrays into row-based format for storage. + """ + + def flatten_row(arrays: List[np.ndarray], row: int) -> List[Any]: + """Flatten arrays for a single row.""" + return [x for arr in arrays for x in (arr[row].flatten() if arr.ndim > 1 else [arr[row]])] + + return [flatten_row(arrays, i) for i in range(n_rows)] + + +def _create_metadata( + execution_ids: List[str], model_name: str, data_tag: Optional[str] +) -> Tuple[np.ndarray, List[str]]: + """ + Create metadata array for model data storage. + """ + current_timestamp = datetime.now().isoformat() + metadata_cols = ["ID", "MODEL_ID", "TIMESTAMP", "TAG"] + metadata_rows = [ + [ + str(eid), + str(model_name), + str(current_timestamp), + str(data_tag or ""), + ] + for eid in execution_ids + ] + _validate_metadata_lengths(metadata_rows, metadata_cols) + metadata = np.array(metadata_rows, dtype=f" None: + """ + Validate that all metadata values fit within the defined string length limit. + """ + for row_idx, row in enumerate(metadata_rows): + for col_idx, value in enumerate(row): + value_str = str(value) + if len(value_str) > METADATA_STRING_MAX_LENGTH: + col_name = column_names[col_idx] if col_idx < len(column_names) else f"column_{col_idx}" + raise ValidationError( + f"Metadata field '{col_name}' in row {row_idx} exceeds maximum length " + f"of {METADATA_STRING_MAX_LENGTH} characters (got {len(value_str)} chars): " + f"'{value_str[:50]}{'...' if len(value_str) > 50 else ''}'" + ) + + +class ValidationError(Exception): + """Validation errors.""" + + pass + + +class ProcessingError(Exception): + """Processing errors.""" + + pass + + +@dataclass +class GroundTruthValidationResult: + """Result of ground truth validation.""" + + success: bool + message: str + data: Optional[Dict[str, Any]] = None + errors: List[str] = field(default_factory=list) + + +TYPE_MAP = { + np.int64: "Long", + np.int32: "Integer", + np.float32: "Float", + np.float64: "Double", + np.bool_: "Boolean", + int: "Long", + float: "Double", + bool: "Boolean", + str: "String", +} + + +def get_type_name(val: Any) -> str: + """Get Java-style type name for a value (used in ground truth validation).""" + if hasattr(val, "dtype"): + return TYPE_MAP.get(val.dtype.type, "String") + return TYPE_MAP.get(type(val), "String") + + +def sanitize_id(execution_id: str) -> str: + """Sanitize execution ID.""" + return str(execution_id).strip() + + +def extract_row_data(arrays: List[np.ndarray], row_index: int) -> List[Any]: + """Extract data from arrays for a specific row.""" + row_data = [] + for arr in arrays: + if arr.ndim > 1: + row_data.extend(arr[row_index].flatten()) + else: + row_data.append(arr[row_index]) + return row_data + + +def process_tensors_using_kserve_logic( + tensors: List[Dict[str, Any]], +) -> Tuple[List[np.ndarray], List[str], List[str], Optional[List[str]]]: + """ + Process tensor data using consumer endpoint logic via clean adapter pattern. + """ + if not tensors: + return [], [], [], None + validation_errors = _validate_tensor_inputs(tensors) + if validation_errors: + error_message = "One or more errors occurred: " + ". ".join(validation_errors) + raise HTTPException(400, error_message) + adapted_tensors = [] + execution_ids = None + datatypes = [] + for tensor in tensors: + if execution_ids is None: + execution_ids = tensor.get("execution_ids") + numpy_array = _convert_tensor_to_numpy(tensor) + adapter = KServeDataAdapter(tensor, numpy_array) + adapted_tensors.append(adapter) + datatypes.append(adapter.datatype) + try: + adapter_payload = ConsumerEndpointAdapter(adapted_tensors) + tensor_array, column_names = process_payload(adapter_payload, lambda payload: payload.tensors) + arrays, all_names = _convert_consumer_results_to_upload_format(tensor_array, column_names, adapted_tensors) + return arrays, all_names, datatypes, execution_ids + except Exception as e: + logger.error(f"Consumer endpoint processing failed: {e}") + raise HTTPException(400, f"Tensor processing error: {str(e)}") + + +def _validate_tensor_inputs(tensors: List[Dict[str, Any]]) -> List[str]: + """Validate tensor inputs and return list of error messages.""" + errors = [] + tensor_names = [tensor.get("name", f"tensor_{i}") for i, tensor in enumerate(tensors)] + if len(tensor_names) != len(set(tensor_names)): + errors.append("Input tensors must have unique names") + shapes = [tensor.get("shape", []) for tensor in tensors] + if len(shapes) > 1: + first_dims = [shape[0] if shape else 0 for shape in shapes] + if len(set(first_dims)) > 1: + errors.append(f"Input tensors must have consistent first dimension. Found: {first_dims}") + return errors + + +def _convert_tensor_to_numpy(tensor: Dict[str, Any]) -> np.ndarray: + """Convert tensor dictionary to numpy array with proper dtype.""" + raw_data = tensor.get("data", []) + + if list_utils.contains_non_numeric(raw_data): + return np.array(raw_data, dtype="O") + dtype_map = {"INT64": np.int64, "INT32": np.int32, "FP32": np.float32, "FP64": np.float64, "BOOL": np.bool_} + datatype = tensor.get("datatype", "FP64") + np_dtype = dtype_map.get(datatype, np.float64) + return np.array(raw_data, dtype=np_dtype) + + +def _convert_consumer_results_to_upload_format( + tensor_array: np.ndarray, column_names: List[str], adapted_tensors: List[KServeDataAdapter] +) -> Tuple[List[np.ndarray], List[str]]: + """Convert consumer endpoint results back to upload format.""" + if len(adapted_tensors) == 1: + # Single tensor case + return [tensor_array], column_names + arrays = [] + all_names = [] + col_start = 0 + for adapter in adapted_tensors: + if len(adapter.shape) > 1: + n_cols = adapter.shape[1] + tensor_names = [f"{adapter.name}-{i}" for i in range(n_cols)] + else: + n_cols = 1 + tensor_names = [adapter.name] + if tensor_array.ndim == 2: + tensor_data = tensor_array[:, col_start : col_start + n_cols] + else: + tensor_data = tensor_array[col_start : col_start + n_cols] + arrays.append(tensor_data) + all_names.extend(tensor_names) + col_start += n_cols + return arrays, all_names + + +def validate_input_shapes(input_arrays: List[np.ndarray], input_names: List[str]) -> Optional[str]: + """Validate input array shapes and names - collect ALL errors.""" + if not input_arrays: + return None + errors = [] + if len(set(input_names)) != len(input_names): + errors.append("Input tensors must have unique names") + first_dim = input_arrays[0].shape[0] + for i, arr in enumerate(input_arrays[1:], 1): + if arr.shape[0] != first_dim: + errors.append( + f"Input tensor '{input_names[i]}' has first dimension {arr.shape[0]}, " + f"which doesn't match the first dimension {first_dim} of '{input_names[0]}'" + ) + if errors: + return ". ".join(errors) + "." + return None + + +def validate_data_tag(tag: str) -> Optional[str]: + """Validate data tag format and content.""" + if not tag: + return None + if tag.startswith(TRUSTYAI_TAG_PREFIX): + return ( + f"The tag prefix '{TRUSTYAI_TAG_PREFIX}' is reserved for internal TrustyAI use only. " + f"Provided tag '{tag}' violates this restriction." + ) + return None + + +class GroundTruthValidator: + """Ground truth validator.""" + + def __init__(self, model_name: str): + self.model_name = model_name + self.id_to_row: Dict[str, int] = {} + self.inputs: Optional[np.ndarray] = None + self.outputs: Optional[np.ndarray] = None + self.metadata: Optional[np.ndarray] = None + + async def initialize(self) -> None: + """Load existing data.""" + storage_interface = get_storage_interface() + self.inputs, _ = await storage_interface.read_data(self.model_name + INPUT_SUFFIX) + self.outputs, _ = await storage_interface.read_data(self.model_name + OUTPUT_SUFFIX) + self.metadata, _ = await storage_interface.read_data(self.model_name + METADATA_SUFFIX) + metadata_cols = await storage_interface.get_original_column_names(self.model_name + METADATA_SUFFIX) + id_col = next((i for i, name in enumerate(metadata_cols) if name.upper() == "ID"), 0) + if self.metadata is not None: + for j, row in enumerate(self.metadata): + id_val = row[id_col] + self.id_to_row[str(id_val)] = j + + def find_row(self, exec_id: str) -> Optional[int]: + """Find row index for execution ID.""" + return self.id_to_row.get(str(exec_id)) + + async def validate_data( + self, + exec_id: str, + uploaded_inputs: List[Any], + uploaded_outputs: List[Any], + row_idx: int, + input_names: Optional[List[str]] = None, + output_names: Optional[List[str]] = None, + ) -> Optional[str]: + """Validate inputs and outputs.""" + if self.inputs is None or self.outputs is None: + return f"ID={exec_id} no existing data found" + existing_inputs = self.inputs[row_idx] + existing_outputs = self.outputs[row_idx] + for i, (existing, uploaded) in enumerate(zip(existing_inputs[:3], uploaded_inputs[:3])): + if hasattr(existing, "dtype"): + print( + f" Input {i}: existing.dtype={existing.dtype}, uploaded.dtype={getattr(uploaded, 'dtype', 'no dtype')}" + ) + print(f" Input {i}: existing={existing}, uploaded={uploaded}") + for i, (existing, uploaded) in enumerate(zip(existing_outputs[:2], uploaded_outputs[:2])): + if hasattr(existing, "dtype"): + print( + f" Output {i}: existing.dtype={existing.dtype}, uploaded.dtype={getattr(uploaded, 'dtype', 'no dtype')}" + ) + print(f" Output {i}: existing={existing}, uploaded={uploaded}") + if len(existing_inputs) != len(uploaded_inputs): + return f"ID={exec_id} input shapes do not match. Observed inputs have length={len(existing_inputs)} while uploaded inputs have length={len(uploaded_inputs)}" + for i, (existing, uploaded) in enumerate(zip(existing_inputs, uploaded_inputs)): + existing_type = get_type_name(existing) + uploaded_type = get_type_name(uploaded) + print(f" Input {i}: existing_type='{existing_type}', uploaded_type='{uploaded_type}'") + if existing_type != uploaded_type: + return f"ID={exec_id} input type mismatch at position {i + 1}: Class={existing_type} != Class={uploaded_type}" + if existing != uploaded: + return f"ID={exec_id} inputs are not identical: value mismatch at position {i + 1}" + if len(existing_outputs) != len(uploaded_outputs): + return f"ID={exec_id} output shapes do not match. Observed outputs have length={len(existing_outputs)} while uploaded ground-truths have length={len(uploaded_outputs)}" + for i, (existing, uploaded) in enumerate(zip(existing_outputs, uploaded_outputs)): + existing_type = get_type_name(existing) + uploaded_type = get_type_name(uploaded) + print(f" Output {i}: existing_type='{existing_type}', uploaded_type='{uploaded_type}'") + if existing_type != uploaded_type: + return f"ID={exec_id} output type mismatch at position {i + 1}: Class={existing_type} != Class={uploaded_type}" + return None + + +async def handle_ground_truths( + model_name: str, + input_arrays: List[np.ndarray], + input_names: List[str], + output_arrays: List[np.ndarray], + output_names: List[str], + execution_ids: List[str], + config: Optional[Any] = None, +) -> GroundTruthValidationResult: + """Handle ground truth validation.""" + if not execution_ids: + return GroundTruthValidationResult(success=False, message="No execution IDs provided.") + storage_interface = get_storage_interface() + if not await storage_interface.dataset_exists(model_name + INPUT_SUFFIX): + return GroundTruthValidationResult(success=False, message=f"Model {model_name} not found.") + validator = GroundTruthValidator(model_name) + await validator.initialize() + errors = [] + valid_outputs = [] + valid_metadata = [] + n_rows = input_arrays[0].shape[0] if input_arrays else 0 + for i, exec_id in enumerate(execution_ids): + if i >= n_rows: + errors.append(f"ID={exec_id} index out of bounds") + continue + row_idx = validator.find_row(exec_id) + if row_idx is None: + errors.append(f"ID={exec_id} not found") + continue + uploaded_inputs = extract_row_data(input_arrays, i) + uploaded_outputs = extract_row_data(output_arrays, i) + error = await validator.validate_data(exec_id, uploaded_inputs, uploaded_outputs, row_idx) + if error: + errors.append(error) + continue + valid_outputs.append(uploaded_outputs) + valid_metadata.append([exec_id]) + if errors: + return GroundTruthValidationResult( + success=False, + message="Found fatal mismatches between uploaded data and recorded inference data:\n" + + "\n".join(errors[:5]), + errors=errors, + ) + if not valid_outputs: + return GroundTruthValidationResult(success=False, message="No valid ground truths found.") + return GroundTruthValidationResult( + success=True, + message=f"{len(valid_outputs)} ground truths added.", + data={ + "outputs": np.array(valid_outputs), + "output_names": output_names, + "metadata": np.array(valid_metadata), + "metadata_names": ["ID"], + }, + ) + + +async def save_model_data( + model_name: str, + input_data: np.ndarray, + input_names: List[str], + output_data: np.ndarray, + output_names: List[str], + metadata_data: np.ndarray, + metadata_names: List[str], +) -> Dict[str, Any]: + """Save model data to storage.""" + storage_interface = get_storage_interface() + await storage_interface.write_data(model_name + INPUT_SUFFIX, input_data, input_names) + await storage_interface.write_data(model_name + OUTPUT_SUFFIX, output_data, output_names) + await storage_interface.write_data(model_name + METADATA_SUFFIX, metadata_data, metadata_names) + logger.info(f"Saved model data for {model_name}: {len(input_data)} rows") + return { + "model_name": model_name, + "rows": len(input_data), + } \ No newline at end of file diff --git a/tests/endpoints/test_download_endpoint.py b/tests/endpoints/test_download_endpoint.py new file mode 100644 index 0000000..647f694 --- /dev/null +++ b/tests/endpoints/test_download_endpoint.py @@ -0,0 +1,404 @@ +import uuid +from datetime import datetime, timedelta +from io import StringIO +from unittest.mock import patch + +import numpy as np +import pandas as pd +import pytest +from fastapi.testclient import TestClient + +from src.main import app + +client = TestClient(app) + + +class DataframeGenerators: + """Python equivalent of Java DataframeGenerators""" + + @staticmethod + def generate_random_dataframe(observations: int, feature_diversity: int = 100) -> pd.DataFrame: + random = np.random.RandomState(0) + data = { + "age": [], + "gender": [], + "race": [], + "income": [], + "trustyai.ID": [], + "trustyai.MODEL_ID": [], + "trustyai.TIMESTAMP": [], + "trustyai.TAG": [], + "trustyai.INDEX": [], + } + for i in range(observations): + data["age"].append(i % feature_diversity) + data["gender"].append(1 if random.choice([True, False]) else 0) + data["race"].append(1 if random.choice([True, False]) else 0) + data["income"].append(1 if random.choice([True, False]) else 0) + data["trustyai.ID"].append(str(uuid.uuid4())) + data["trustyai.MODEL_ID"].append("example1") + data["trustyai.TIMESTAMP"].append((datetime.now() - timedelta(seconds=i)).isoformat()) + data["trustyai.TAG"].append("") + data["trustyai.INDEX"].append(i) + return pd.DataFrame(data) + + @staticmethod + def generate_random_text_dataframe(observations: int, seed: int = 0) -> pd.DataFrame: + if seed < 0: + random = np.random.RandomState(0) + else: + random = np.random.RandomState(seed) + makes = ["Ford", "Chevy", "Dodge", "GMC", "Buick"] + colors = ["Red", "Blue", "White", "Black", "Purple", "Green", "Yellow"] + data = { + "year": [], + "make": [], + "color": [], + "value": [], + "trustyai.ID": [], + "trustyai.MODEL_ID": [], + "trustyai.TIMESTAMP": [], + "trustyai.TAG": [], + "trustyai.INDEX": [], + } + for i in range(observations): + data["year"].append(1970 + i % 50) + data["make"].append(makes[i % len(makes)]) + data["color"].append(colors[i % len(colors)]) + data["value"].append(random.random() * 50) + data["trustyai.ID"].append(str(uuid.uuid4())) + data["trustyai.MODEL_ID"].append("example1") + data["trustyai.TIMESTAMP"].append((datetime.now() - timedelta(seconds=i)).isoformat()) + data["trustyai.TAG"].append("") + data["trustyai.INDEX"].append(i) + return pd.DataFrame(data) + + +# Mock storage for testing +class MockStorage: + def __init__(self): + self.data = {} + + async def read_data(self, dataset_name: str): + if dataset_name.endswith("_outputs"): + model_id = dataset_name.replace("_outputs", "") + if model_id not in self.data: + raise Exception(f"Model {model_id} not found") + output_data = self.data[model_id].get("output") + output_cols = self.data[model_id].get("output_cols", []) + return output_data, output_cols + elif dataset_name.endswith("_metadata"): + model_id = dataset_name.replace("_metadata", "") + if model_id not in self.data: + raise Exception(f"Model {model_id} not found") + metadata_data = self.data[model_id].get("metadata") + metadata_cols = ["ID", "MODEL_ID", "TIMESTAMP", "TAG", "INDEX"] + return metadata_data, metadata_cols + elif dataset_name.endswith("_inputs"): + model_id = dataset_name.replace("_inputs", "") + if model_id not in self.data: + raise Exception(f"Model {model_id} not found") + input_data = self.data[model_id].get("input") + input_cols = self.data[model_id].get("input_cols", []) + return input_data, input_cols + else: + raise Exception(f"Unknown dataset: {dataset_name}") + + def save_dataframe(self, df: pd.DataFrame, model_id: str): + input_cols = [col for col in df.columns if not col.startswith("trustyai.") and col not in ["income", "value"]] + output_cols = [col for col in df.columns if col in ["income", "value"]] + metadata_cols = [col for col in df.columns if col.startswith("trustyai.")] + input_data = df[input_cols].values if input_cols else np.array([]) + output_data = df[output_cols].values if output_cols else np.array([]) + metadata_data_cols = ["ID", "MODEL_ID", "TIMESTAMP", "TAG", "INDEX"] + metadata_values = [] + for _, row in df.iterrows(): + row_data = [] + for col in metadata_data_cols: + trusty_col = f"trustyai.{col}" + if trusty_col in df.columns: + value = row[trusty_col] + if col == "INDEX": + row_data.append(int(value)) + else: + row_data.append(str(value)) + else: + row_data.append("" if col != "INDEX" else 0) + metadata_values.append(row_data) + metadata_data = np.array(metadata_values, dtype=object) + self.data[model_id] = { + "dataframe": df, + "input": input_data, + "input_cols": input_cols, + "output": output_data, + "output_cols": output_cols, + "metadata": metadata_data, + } + + def reset(self): + self.data.clear() + + +mock_storage = MockStorage() + + +@pytest.fixture(autouse=True) +def setup_storage(): + """Setup mock storage for all tests""" + with patch("src.service.utils.download.get_storage_interface", return_value=mock_storage): + yield + + +@pytest.fixture(autouse=True) +def reset_storage(): + """Reset storage before each test""" + mock_storage.reset() + yield + + +# Test constants +MODEL_ID = "example1" + + +def test_download_data(): + """equivalent of Java downloadData() test""" + dataframe = DataframeGenerators.generate_random_dataframe(1000) + mock_storage.save_dataframe(dataframe, MODEL_ID) + + payload = { + "modelId": MODEL_ID, + "matchAll": [ + {"columnName": "gender", "operation": "EQUALS", "values": [0]}, + {"columnName": "race", "operation": "EQUALS", "values": [0]}, + {"columnName": "income", "operation": "EQUALS", "values": [0]}, + ], + "matchAny": [ + {"columnName": "age", "operation": "BETWEEN", "values": [5, 10]}, + {"columnName": "age", "operation": "BETWEEN", "values": [50, 70]}, + ], + "matchNone": [{"columnName": "age", "operation": "BETWEEN", "values": [55, 65]}], + } + response = client.post("/data/download", json=payload) + assert response.status_code == 200 + result = response.json() + result_df = pd.read_csv(StringIO(result["dataCSV"])) + assert len(result_df[(result_df["age"] > 55) & (result_df["age"] < 65)]) == 0 + assert len(result_df[result_df["gender"] == 1]) == 0 + assert len(result_df[result_df["race"] == 1]) == 0 + assert len(result_df[result_df["income"] == 1]) == 0 + assert len(result_df[(result_df["age"] >= 10) & (result_df["age"] < 50)]) == 0 + assert len(result_df[result_df["age"] > 70]) == 0 + + +def test_download_text_data(): + """equivalent of Java downloadTextData() test""" + dataframe = DataframeGenerators.generate_random_text_dataframe(1000) + mock_storage.save_dataframe(dataframe, MODEL_ID) + + payload = { + "modelId": MODEL_ID, + "matchAll": [ + { + "columnName": "make", + "operation": "EQUALS", + "values": ["Chevy", "Ford", "Dodge"], + }, + { + "columnName": "year", + "operation": "BETWEEN", + "values": [1990, 2050], + }, + ], + } + + response = client.post("/data/download", json=payload) + assert response.status_code == 200 + result = response.json() + result_df = pd.read_csv(StringIO(result["dataCSV"])) + assert len(result_df[result_df["year"] < 1990]) == 0 + assert len(result_df[result_df["make"] == "GMC"]) == 0 + assert len(result_df[result_df["make"] == "Buick"]) == 0 + + +def test_download_text_data_between_error(): + """equivalent of Java downloadTextDataBetweenError() test""" + dataframe = DataframeGenerators.generate_random_text_dataframe(1000) + mock_storage.save_dataframe(dataframe, MODEL_ID) + payload = { + "modelId": MODEL_ID, + "matchAll": [ + { + "columnName": "make", + "operation": "BETWEEN", + "values": ["Chevy", "Ford", "Dodge"], + } + ], + } + response = client.post("/data/download", json=payload) + assert response.status_code == 400 + assert ( + "BETWEEN operation must contain exactly two values, describing the lower and upper bounds of the desired range. Received 3 values" + in response.text + ) + assert ( + "BETWEEN operation must only contain numbers, describing the lower and upper bounds of the desired range. Received non-numeric values" + in response.text + ) + + +def test_download_text_data_invalid_column_error(): + """equivalent of Java downloadTextDataInvalidColumnError() test""" + dataframe = DataframeGenerators.generate_random_text_dataframe(1000) + mock_storage.save_dataframe(dataframe, MODEL_ID) + payload = { + "modelId": MODEL_ID, + "matchAll": [ + { + "columnName": "mak123e", + "operation": "EQUALS", + "values": ["Chevy", "Ford"], + } + ], + } + + response = client.post("/data/download", json=payload) + assert response.status_code == 400 + assert "No feature or output found with name=" in response.text + + +def test_download_text_data_invalid_operation_error(): + """equivalent of Java downloadTextDataInvalidOperationError() test""" + dataframe = DataframeGenerators.generate_random_text_dataframe(1000) + mock_storage.save_dataframe(dataframe, MODEL_ID) + payload = { + "modelId": MODEL_ID, + "matchAll": [ + { + "columnName": "mak123e", + "operation": "DOESNOTEXIST", + "values": ["Chevy", "Ford"], + } + ], + } + response = client.post("/data/download", json=payload) + assert response.status_code == 400 + assert "RowMatch operation must be one of [BETWEEN, EQUALS]" in response.text + + +def test_download_text_data_internal_column(): + """equivalent of Java downloadTextDataInternalColumn() test""" + dataframe = DataframeGenerators.generate_random_text_dataframe(1000) + dataframe.loc[0:499, "trustyai.TAG"] = "TRAINING" + mock_storage.save_dataframe(dataframe, MODEL_ID) + payload = { + "modelId": MODEL_ID, + "matchAll": [ + { + "columnName": "trustyai.TAG", + "operation": "EQUALS", + "values": ["TRAINING"], + } + ], + } + response = client.post("/data/download", json=payload) + assert response.status_code == 200 + result = response.json() + result_df = pd.read_csv(StringIO(result["dataCSV"])) + assert len(result_df) == 500 + + +def test_download_text_data_internal_column_index(): + """equivalent of Java downloadTextDataInternalColumnIndex() test""" + dataframe = DataframeGenerators.generate_random_text_dataframe(1000) + mock_storage.save_dataframe(dataframe, MODEL_ID) + expected_rows = dataframe.iloc[0:10].copy() + payload = { + "modelId": MODEL_ID, + "matchAll": [ + { + "columnName": "trustyai.INDEX", + "operation": "BETWEEN", + "values": [0, 10], + } + ], + } + response = client.post("/data/download", json=payload) + print(f"Response status: {response.status_code}") + print(f"Response text: {response.text}") + assert response.status_code == 200 + result = response.json() + result_df = pd.read_csv(StringIO(result["dataCSV"])) + assert len(result_df) == 10 + input_cols = ["year", "make", "color"] + for i in range(10): + for col in input_cols: + assert result_df.iloc[i][col] == expected_rows.iloc[i][col], f"Row {i}, column {col} mismatch" + + +def test_download_text_data_internal_column_timestamp(): + """equivalent of Java downloadTextDataInternalColumnTimestamp() test""" + dataframe = DataframeGenerators.generate_random_text_dataframe(1, -1) + base_time = datetime.now() + for i in range(100): + new_row = DataframeGenerators.generate_random_text_dataframe(1, i) + # Use milliseconds to simulate Thread.sleep(1) and ensure ascending order + timestamp = (base_time + timedelta(milliseconds=i + 1)).isoformat() + # Fix this line - change to UPPERCASE + new_row["trustyai.TIMESTAMP"] = [timestamp] + dataframe = pd.concat([dataframe, new_row], ignore_index=True) + mock_storage.save_dataframe(dataframe, MODEL_ID) + extract_idx = 50 + n_to_get = 10 + expected_rows = dataframe.iloc[extract_idx : extract_idx + n_to_get].copy() + start_time = dataframe.iloc[extract_idx]["trustyai.TIMESTAMP"] + end_time = dataframe.iloc[extract_idx + n_to_get]["trustyai.TIMESTAMP"] + payload = { + "modelId": MODEL_ID, + "matchAny": [ + { + "columnName": "trustyai.TIMESTAMP", + "operation": "BETWEEN", + "values": [start_time, end_time], + } + ], + } + response = client.post("/data/download", json=payload) + assert response.status_code == 200 + result = response.json() + result_df = pd.read_csv(StringIO(result["dataCSV"])) + assert len(result_df) == 10 + input_cols = ["year", "make", "color"] + for i in range(10): + for col in input_cols: + assert result_df.iloc[i][col] == expected_rows.iloc[i][col], f"Row {i}, column {col} mismatch" + + +def test_download_text_data_internal_column_timestamp_unparseable(): + """equivalent of Java downloadTextDataInternalColumnTimestampUnparseable() test""" + dataframe = DataframeGenerators.generate_random_text_dataframe(1000) + mock_storage.save_dataframe(dataframe, MODEL_ID) + payload = { + "modelId": MODEL_ID, + "matchAny": [ + { + "columnName": "trustyai.TIMESTAMP", + "operation": "BETWEEN", + "values": ["not a timestamp", "also not a timestamp"], + } + ], + } + response = client.post("/data/download", json=payload) + assert response.status_code == 400 + assert "unparseable as an ISO_LOCAL_DATE_TIME" in response.text + + +def test_download_text_data_null_request(): + """equivalent of Java downloadTextDataNullRequest() test""" + dataframe = DataframeGenerators.generate_random_text_dataframe(1000) + mock_storage.save_dataframe(dataframe, MODEL_ID) + payload = {"modelId": MODEL_ID} + response = client.post("/data/download", json=payload) + assert response.status_code == 200 + result = response.json() + result_df = pd.read_csv(StringIO(result["dataCSV"])) + assert len(result_df) == 1000 \ No newline at end of file diff --git a/tests/endpoints/test_upload_endpoint.py b/tests/endpoints/test_upload_endpoint.py new file mode 100644 index 0000000..12ccf09 --- /dev/null +++ b/tests/endpoints/test_upload_endpoint.py @@ -0,0 +1,484 @@ +import copy +import json +import os +import pickle +import shutil +import sys +import tempfile +import uuid + +import h5py +import numpy as np +import pytest + +TEMP_DIR = tempfile.mkdtemp() +os.environ["STORAGE_DATA_FOLDER"] = TEMP_DIR +from fastapi.testclient import TestClient + +from src.main import app +from src.service.constants import ( + INPUT_SUFFIX, + METADATA_SUFFIX, + OUTPUT_SUFFIX, + TRUSTYAI_TAG_PREFIX, +) +from src.service.data.storage import get_storage_interface + + +def pytest_sessionfinish(session, exitstatus): + """Clean up the temporary directory after all tests are done.""" + if os.path.exists(TEMP_DIR): + shutil.rmtree(TEMP_DIR) + + +pytest.hookimpl(pytest_sessionfinish) +client = TestClient(app) +MODEL_ID = "example1" + + +def generate_payload(n_rows, n_input_cols, n_output_cols, datatype, tag, input_offset=0, output_offset=0): + """Generate a test payload with specific dimensions and data types.""" + model_name = f"{MODEL_ID}_{uuid.uuid4().hex[:8]}" + input_data = [] + for i in range(n_rows): + if n_input_cols == 1: + input_data.append(i + input_offset) + else: + row = [i + j + input_offset for j in range(n_input_cols)] + input_data.append(row) + output_data = [] + for i in range(n_rows): + if n_output_cols == 1: + output_data.append(i * 2 + output_offset) + else: + row = [i * 2 + j + output_offset for j in range(n_output_cols)] + output_data.append(row) + payload = { + "model_name": model_name, + "data_tag": tag, + "is_ground_truth": False, + "request": { + "inputs": [ + { + "name": "input", + "shape": [n_rows, n_input_cols] if n_input_cols > 1 else [n_rows], + "datatype": datatype, + "data": input_data, + } + ] + }, + "response": { + "outputs": [ + { + "name": "output", + "shape": [n_rows, n_output_cols] if n_output_cols > 1 else [n_rows], + "datatype": datatype, + "data": output_data, + } + ] + }, + } + return payload + + +def generate_multi_input_payload(n_rows, n_input_cols, n_output_cols, datatype, tag): + """Generate a test payload with multi-dimensional tensors like real data.""" + model_name = f"{MODEL_ID}_{uuid.uuid4().hex[:8]}" + input_data = [] + for row_idx in range(n_rows): + row = [row_idx + col_idx * 10 for col_idx in range(n_input_cols)] + input_data.append(row) + output_data = [] + for row_idx in range(n_rows): + row = [row_idx * 2 + col_idx for col_idx in range(n_output_cols)] + output_data.append(row) + payload = { + "model_name": model_name, + "data_tag": tag, + "is_ground_truth": False, + "request": { + "inputs": [ + { + "name": "multi_input", + "shape": [n_rows, n_input_cols], + "datatype": datatype, + "data": input_data, + } + ] + }, + "response": { + "outputs": [ + { + "name": "multi_output", + "shape": [n_rows, n_output_cols], + "datatype": datatype, + "data": output_data, + } + ] + }, + } + return payload + + +def generate_mismatched_shape_no_unique_name_multi_input_payload(n_rows, n_input_cols, n_output_cols, datatype, tag): + """Generate a payload with mismatched shapes and non-unique names.""" + model_name = f"{MODEL_ID}_{uuid.uuid4().hex[:8]}" + input_data_1 = [[row_idx + col_idx * 10 for col_idx in range(n_input_cols)] for row_idx in range(n_rows)] + mismatched_rows = n_rows - 1 if n_rows > 1 else 1 + input_data_2 = [[row_idx + col_idx * 20 for col_idx in range(n_input_cols)] for row_idx in range(mismatched_rows)] + output_data = [[row_idx * 2 + col_idx for col_idx in range(n_output_cols)] for row_idx in range(n_rows)] + payload = { + "model_name": model_name, + "data_tag": tag, + "is_ground_truth": False, + "request": { + "inputs": [ + { + "name": "same_name", + "shape": [n_rows, n_input_cols], + "datatype": datatype, + "data": input_data_1, + }, + { + "name": "same_name", + "shape": [mismatched_rows, n_input_cols], + "datatype": datatype, + "data": input_data_2, + }, + ] + }, + "response": { + "outputs": [ + { + "name": "multi_output", + "shape": [n_rows, n_output_cols], + "datatype": datatype, + "data": output_data, + } + ] + }, + } + return payload + + +def get_data_from_storage(model_name, suffix): + """Get data from storage file.""" + storage = get_storage_interface() + filename = storage._get_filename(model_name + suffix) + if not os.path.exists(filename): + return None + with h5py.File(filename, "r") as f: + if model_name + suffix in f: + data = f[model_name + suffix][:] + column_names = f[model_name + suffix].attrs.get("column_names", []) + return {"data": data, "column_names": column_names} + + +def get_metadata_ids(model_name): + """Extract actual IDs from metadata storage.""" + storage = get_storage_interface() + filename = storage._get_filename(model_name + METADATA_SUFFIX) + if not os.path.exists(filename): + return [] + ids = [] + with h5py.File(filename, "r") as f: + if model_name + METADATA_SUFFIX in f: + metadata = f[model_name + METADATA_SUFFIX][:] + column_names = f[model_name + METADATA_SUFFIX].attrs.get("column_names", []) + id_idx = next((i for i, name in enumerate(column_names) if name.lower() == "id"), None) + if id_idx is not None: + for row in metadata: + try: + if hasattr(row, "__getitem__") and len(row) > id_idx: + id_val = row[id_idx] + else: + row_data = pickle.loads(row.tobytes()) + id_val = row_data[id_idx] + if isinstance(id_val, np.ndarray): + ids.append(str(id_val)) + else: + ids.append(str(id_val)) + except Exception as e: + print(f"Error processing ID from row {len(ids)}: {e}") + continue + print(f"Successfully extracted {len(ids)} IDs: {ids}") + return ids + + +def get_metadata_from_storage(model_name): + """Get metadata directly from storage file.""" + storage = get_storage_interface() + filename = storage._get_filename(model_name + METADATA_SUFFIX) + if not os.path.exists(filename): + return {"data": [], "column_names": []} + with h5py.File(filename, "r") as f: + if model_name + METADATA_SUFFIX in f: + metadata = f[model_name + METADATA_SUFFIX][:] + column_names = f[model_name + METADATA_SUFFIX].attrs.get("column_names", []) + parsed_rows = [] + for row in metadata: + try: + row_data = pickle.loads(row.tobytes()) + parsed_rows.append(row_data) + except Exception as e: + print(f"Error unpickling metadata row: {e}") + + return {"data": parsed_rows, "column_names": column_names} + return {"data": [], "column_names": []} + + +def count_rows_with_tag(model_name, tag): + """Count rows with a specific tag in metadata.""" + storage = get_storage_interface() + filename = storage._get_filename(model_name + METADATA_SUFFIX) + if not os.path.exists(filename): + return 0 + count = 0 + with h5py.File(filename, "r") as f: + if model_name + METADATA_SUFFIX in f: + metadata = f[model_name + METADATA_SUFFIX][:] + column_names = f[model_name + METADATA_SUFFIX].attrs.get("column_names", []) + tag_idx = next( + (i for i, name in enumerate(column_names) if name.lower() == "tag"), + None, + ) + if tag_idx is not None: + for row in metadata: + try: + row_data = pickle.loads(row.tobytes()) + if tag_idx < len(row_data) and row_data[tag_idx] == tag: + count += 1 + except Exception as e: + print(f"Error unpickling tag: {e}") + return count + + +def post_test(payload, expected_status_code, check_msgs): + """Post a payload and check the response.""" + response = client.post("/data/upload", json=payload) + if response.status_code != expected_status_code: + print(f"\n=== DEBUG INFO ===") + print(f"Expected status: {expected_status_code}") + print(f"Actual status: {response.status_code}") + print(f"Response text: {response.text}") + print(f"Response headers: {dict(response.headers)}") + if hasattr(response, "json"): + try: + print(f"Response JSON: {response.json()}") + except: + pass + print(f"==================") + + assert response.status_code == expected_status_code + return response + + +# data upload tests +@pytest.mark.parametrize("n_input_rows", [1, 5, 250]) +@pytest.mark.parametrize("n_input_cols", [1, 4]) +@pytest.mark.parametrize("n_output_cols", [1, 2]) +@pytest.mark.parametrize("datatype", ["INT64", "INT32", "FP32", "FP64", "BOOL"]) +def test_upload_data(n_input_rows, n_input_cols, n_output_cols, datatype): + """Test uploading data with various dimensions and datatypes.""" + data_tag = "TRAINING" + payload = generate_payload(n_input_rows, n_input_cols, n_output_cols, datatype, data_tag) + response = post_test(payload, 200, [f"{n_input_rows} datapoints"]) + inputs = get_data_from_storage(payload["model_name"], INPUT_SUFFIX) + outputs = get_data_from_storage(payload["model_name"], OUTPUT_SUFFIX) + assert inputs is not None, "Input data not found in storage" + assert outputs is not None, "Output data not found in storage" + assert len(inputs["data"]) == n_input_rows, "Incorrect number of input rows" + assert len(outputs["data"]) == n_input_rows, "Incorrect number of output rows" + tag_count = count_rows_with_tag(payload["model_name"], data_tag) + assert tag_count == n_input_rows, "Not all rows have the correct tag" + + +@pytest.mark.parametrize("n_rows", [1, 3, 5, 250]) +@pytest.mark.parametrize("n_input_cols", [2, 6]) +@pytest.mark.parametrize("n_output_cols", [4]) +@pytest.mark.parametrize("datatype", ["INT64", "INT32", "FP32", "FP64", "BOOL"]) +def test_upload_multi_input_data(n_rows, n_input_cols, n_output_cols, datatype): + """Test uploading data with multiple input tensors.""" + data_tag = "TRAINING" + payload = generate_multi_input_payload(n_rows, n_input_cols, n_output_cols, datatype, data_tag) + response = post_test(payload, 200, [f"{n_rows} datapoints"]) + inputs = get_data_from_storage(payload["model_name"], INPUT_SUFFIX) + outputs = get_data_from_storage(payload["model_name"], OUTPUT_SUFFIX) + assert inputs is not None, "Input data not found in storage" + assert outputs is not None, "Output data not found in storage" + assert len(inputs["data"]) == n_rows, "Incorrect number of input rows" + assert len(outputs["data"]) == n_rows, "Incorrect number of output rows" + assert len(inputs["column_names"]) == n_input_cols, "Incorrect number of input columns" + assert len(outputs["column_names"]) == n_output_cols, "Incorrect number of output columns" + assert len(inputs["column_names"]) >= 2, "Should have at least 2 input column names" + tag_count = count_rows_with_tag(payload["model_name"], data_tag) + assert tag_count == n_rows, "Not all rows have the correct tag" + + +def test_upload_multi_input_data_no_unique_name(): + """Test error case for non-unique tensor names.""" + payload = generate_mismatched_shape_no_unique_name_multi_input_payload(250, 4, 3, "FP64", "TRAINING") + response = client.post("/data/upload", json=payload) + assert response.status_code == 400 + assert "One or more errors" in response.text + assert "unique names" in response.text + assert "first dimension" in response.text + + +def test_upload_multiple_tagging(): + """Test uploading data with multiple tags.""" + n_payload1 = 50 + n_payload2 = 51 + tag1 = "TRAINING" + tag2 = "NOT TRAINING" + model_name = f"{MODEL_ID}_{uuid.uuid4().hex[:8]}" + payload1 = generate_payload(n_payload1, 10, 1, "INT64", tag1) + payload1["model_name"] = model_name + post_test(payload1, 200, [f"{n_payload1} datapoints"]) + payload2 = generate_payload(n_payload2, 10, 1, "INT64", tag2) + payload2["model_name"] = model_name + post_test(payload2, 200, [f"{n_payload2} datapoints"]) + tag1_count = count_rows_with_tag(model_name, tag1) + tag2_count = count_rows_with_tag(model_name, tag2) + assert tag1_count == n_payload1, f"Expected {n_payload1} rows with tag {tag1}" + assert tag2_count == n_payload2, f"Expected {n_payload2} rows with tag {tag2}" + inputs = get_data_from_storage(model_name, INPUT_SUFFIX) + assert len(inputs["data"]) == n_payload1 + n_payload2, "Incorrect total number of rows" + + +def test_upload_tag_that_uses_protected_name(): + """Test error when using a protected tag name.""" + invalid_tag = f"{TRUSTYAI_TAG_PREFIX}_something" + payload = generate_payload(5, 10, 1, "INT64", invalid_tag) + response = post_test(payload, 400, ["reserved for internal TrustyAI use only"]) + expected_msg = f"The tag prefix '{TRUSTYAI_TAG_PREFIX}' is reserved for internal TrustyAI use only. Provided tag '{invalid_tag}' violates this restriction." + assert expected_msg in response.text + + +@pytest.mark.parametrize("n_input_rows", [1, 5, 250]) +@pytest.mark.parametrize("n_input_cols", [1, 4]) +@pytest.mark.parametrize("n_output_cols", [1, 2]) +@pytest.mark.parametrize("datatype", ["INT64", "INT32", "FP32", "FP64", "BOOL"]) +def test_upload_data_and_ground_truth(n_input_rows, n_input_cols, n_output_cols, datatype): + """Test uploading model data and corresponding ground truth data.""" + model_name = f"{MODEL_ID}_{uuid.uuid4().hex[:8]}" + payload = generate_payload(n_input_rows, n_input_cols, n_output_cols, datatype, "TRAINING") + payload["model_name"] = model_name + payload["is_ground_truth"] = False + post_test(payload, 200, [f"{n_input_rows} datapoints"]) + ids = get_metadata_ids(model_name) + payload_gt = generate_payload(n_input_rows, n_input_cols, n_output_cols, datatype, "TRAINING", 0, 1) + payload_gt["model_name"] = model_name + payload_gt["is_ground_truth"] = True + payload_gt["request"] = payload["request"] + payload_gt["request"]["inputs"][0]["execution_ids"] = ids + post_test(payload_gt, 200, [f"{n_input_rows} ground truths"]) + original_data = get_data_from_storage(model_name, OUTPUT_SUFFIX) + gt_data = get_data_from_storage(f"{model_name}_ground_truth", OUTPUT_SUFFIX) + assert len(original_data["data"]) == len(gt_data["data"]), "Row dimensions don't match" + assert len(original_data["column_names"]) == len(gt_data["column_names"]), "Column dimensions don't match" + original_ids = get_metadata_ids(model_name) + gt_ids = get_metadata_ids(f"{model_name}_ground_truth") + assert original_ids == gt_ids, "Ground truth IDs don't match original IDs" + + +def test_upload_mismatch_input_values(): + """Test error when ground truth inputs don't match original data.""" + n_input_rows = 5 + model_name = f"{MODEL_ID}_{uuid.uuid4().hex[:8]}" + payload0 = generate_payload(n_input_rows, 10, 1, "INT64", "TRAINING") + payload0["model_name"] = model_name + post_test(payload0, 200, [f"{n_input_rows} datapoints"]) + ids = get_metadata_ids(model_name) + payload1 = generate_payload(n_input_rows, 10, 1, "INT64", "TRAINING", 1, 0) + payload1["model_name"] = model_name + payload1["is_ground_truth"] = True + payload1["request"]["inputs"][0]["execution_ids"] = ids + response = client.post("/data/upload", json=payload1) + assert response.status_code == 400 + assert "Found fatal mismatches" in response.text or "inputs are not identical" in response.text + + +def test_upload_mismatch_input_lengths(): + """Test error when ground truth has different input lengths.""" + n_input_rows = 5 + model_name = f"{MODEL_ID}_{uuid.uuid4().hex[:8]}" + payload0 = generate_payload(n_input_rows, 10, 1, "INT64", "TRAINING") + payload0["model_name"] = model_name + post_test(payload0, 200, [f"{n_input_rows} datapoints"]) + ids = get_metadata_ids(model_name) + payload1 = generate_payload(n_input_rows, 11, 1, "INT64", "TRAINING") + payload1["model_name"] = model_name + payload1["is_ground_truth"] = True + payload1["request"]["inputs"][0]["execution_ids"] = ids + response = client.post("/data/upload", json=payload1) + assert response.status_code == 400 + assert "Found fatal mismatches" in response.text + assert ( + "input shapes do not match. Observed inputs have length=10 while uploaded inputs have length=11" + in response.text + ) + + +def test_upload_mismatch_input_and_output_types(): + """Test error when ground truth has different data types.""" + n_input_rows = 5 + model_name = f"{MODEL_ID}_{uuid.uuid4().hex[:8]}" + payload0 = generate_payload(n_input_rows, 10, 2, "INT64", "TRAINING") + payload0["model_name"] = model_name + post_test(payload0, 200, [f"{n_input_rows} datapoints"]) + ids = get_metadata_ids(model_name) + payload1 = generate_payload(n_input_rows, 10, 2, "FP32", "TRAINING", 0, 1) + payload1["model_name"] = model_name + payload1["is_ground_truth"] = True + payload1["request"]["inputs"][0]["execution_ids"] = ids + response = client.post("/data/upload", json=payload1) + print(f"Response status: {response.status_code}") + print(f"Response text: {response.text}") + assert response.status_code == 400 + assert "Found fatal mismatches" in response.text + assert "Class=Long != Class=Float" in response.text or "inputs are not identical" in response.text + + +def test_upload_gaussian_data(): + """Test uploading realistic Gaussian data.""" + payload = { + "model_name": "gaussian-credit-model", + "data_tag": "TRAINING", + "request": { + "inputs": [ + { + "name": "credit_inputs", + "shape": [2, 4], + "datatype": "FP64", + "data": [ + [ + 47.45380690750797, + 478.6846214843319, + 13.462184703540503, + 20.764525303373535, + ], + [ + 47.468246185717554, + 575.6911203538863, + 10.844143722475575, + 14.81343667761101, + ], + ], + } + ] + }, + "response": { + "model_name": "gaussian-credit-model__isvc-d79a7d395d", + "model_version": "1", + "outputs": [ + { + "name": "predict", + "datatype": "FP32", + "shape": [2, 1], + "data": [0.19013395683309373, 0.2754730253205645], + } + ], + }, + } + post_test(payload, 200, ["2 datapoints"]) \ No newline at end of file