Skip to content

Commit c7c6c77

Browse files
committed
✨ added upload and download endpoints together with tests akin to the Java ones
Signed-off-by: m-misiura <[email protected]>
1 parent aa250e3 commit c7c6c77

File tree

6 files changed

+1488
-33
lines changed

6 files changed

+1488
-33
lines changed
Lines changed: 45 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,57 @@
1-
from fastapi import APIRouter, HTTPException
2-
from pydantic import BaseModel
3-
from typing import List, Any, Optional
41
import logging
52

6-
router = APIRouter()
7-
logger = logging.getLogger(__name__)
8-
9-
10-
class RowMatcher(BaseModel):
11-
columnName: str
12-
operation: str
13-
values: List[Any]
3+
import pandas as pd
4+
from fastapi import APIRouter, HTTPException
145

6+
from src.service.utils.download import (
7+
DataRequestPayload,
8+
DataResponsePayload,
9+
apply_matcher,
10+
load_model_dataframe,
11+
)
1512

16-
class DataRequestPayload(BaseModel):
17-
modelId: str
18-
matchAny: Optional[List[RowMatcher]] = None
19-
matchAll: Optional[List[RowMatcher]] = None
20-
matchNone: Optional[List[RowMatcher]] = None
13+
router = APIRouter()
14+
logger = logging.getLogger(__name__)
2115

2216

2317
@router.post("/data/download")
24-
async def download_data(payload: DataRequestPayload):
25-
"""Download model data."""
18+
async def download_data(payload: DataRequestPayload) -> DataResponsePayload:
19+
"""Download model data with filtering."""
2620
try:
2721
logger.info(f"Received data download request for model: {payload.modelId}")
28-
# TODO: Implement
29-
return {"status": "success", "data": []}
22+
23+
# Load the dataframe
24+
df = await load_model_dataframe(payload.modelId)
25+
26+
if df.empty:
27+
return DataResponsePayload(dataCSV="")
28+
# Apply matchAll filters (AND logic)
29+
if payload.matchAll:
30+
for matcher in payload.matchAll:
31+
df = apply_matcher(df, matcher, negate=False)
32+
# Apply matchNone filters (NOT logic)
33+
if payload.matchNone:
34+
for matcher in payload.matchNone:
35+
df = apply_matcher(df, matcher, negate=True)
36+
base_df = df.copy()
37+
# Apply matchAny filters (OR logic)
38+
if payload.matchAny:
39+
matching_dfs = []
40+
for matcher in payload.matchAny:
41+
matched_df = apply_matcher(base_df, matcher, negate=False)
42+
if not matched_df.empty:
43+
matching_dfs.append(matched_df)
44+
# Union all results
45+
if matching_dfs:
46+
df = pd.concat(matching_dfs, ignore_index=True).drop_duplicates()
47+
else:
48+
# No matches found, return empty dataframe with same columns
49+
df = pd.DataFrame(columns=df.columns)
50+
# Convert to CSV
51+
csv_data = df.to_csv(index=False)
52+
return DataResponsePayload(dataCSV=csv_data)
53+
except HTTPException:
54+
raise
3055
except Exception as e:
3156
logger.error(f"Error downloading data: {str(e)}")
3257
raise HTTPException(status_code=500, detail=f"Error downloading data: {str(e)}")

src/endpoints/data/data_upload.py

Lines changed: 89 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,103 @@
1+
import logging
2+
import uuid
3+
from datetime import datetime
4+
from typing import Any, Dict, List, Optional
5+
6+
import numpy as np
17
from fastapi import APIRouter, HTTPException
28
from pydantic import BaseModel
3-
from typing import Dict, Any
4-
import logging
9+
10+
from src.service.constants import METADATA_SUFFIX, OUTPUT_SUFFIX
11+
from src.service.data.modelmesh_parser import ModelMeshPayloadParser
12+
from src.service.data.storage import get_storage_interface
13+
from src.service.utils.upload import (
14+
handle_ground_truths,
15+
process_tensors,
16+
sanitize_id,
17+
save_model_data,
18+
validate_data_tag,
19+
validate_input_shapes,
20+
)
521

622
router = APIRouter()
723
logger = logging.getLogger(__name__)
24+
storage = get_storage_interface()
825

926

10-
class ModelInferJointPayload(BaseModel):
27+
class UploadPayload(BaseModel):
1128
model_name: str
12-
data_tag: str = None
29+
data_tag: Optional[str] = None
1330
is_ground_truth: bool = False
1431
request: Dict[str, Any]
1532
response: Dict[str, Any]
1633

1734

1835
@router.post("/data/upload")
19-
async def upload_data(payload: ModelInferJointPayload):
20-
"""Upload a batch of model data to TrustyAI."""
21-
try:
22-
logger.info(f"Received data upload for model: {payload.model_name}")
23-
# TODO: Implement
24-
return {"status": "success", "message": "Data uploaded successfully"}
25-
except Exception as e:
26-
logger.error(f"Error uploading data: {str(e)}")
27-
raise HTTPException(status_code=500, detail=f"Error uploading data: {str(e)}")
36+
async def upload(payload: UploadPayload) -> Dict[str, str]:
37+
"""Upload model data - regular or ground truth."""
38+
model_name = ModelMeshPayloadParser.standardize_model_id(payload.model_name)
39+
if payload.data_tag and (error := validate_data_tag(payload.data_tag)):
40+
raise HTTPException(400, error)
41+
inputs = payload.request.get("inputs", [])
42+
outputs = payload.response.get("outputs", [])
43+
if not inputs or not outputs:
44+
raise HTTPException(400, "Missing input or output tensors")
45+
input_arrays, input_names, _, execution_ids = process_tensors(inputs)
46+
output_arrays, output_names, _, _ = process_tensors(outputs)
47+
if error := validate_input_shapes(input_arrays, input_names):
48+
raise HTTPException(400, f"One or more errors in input tensors: {error}")
49+
if payload.is_ground_truth:
50+
if not execution_ids:
51+
raise HTTPException(400, "Ground truth requires execution IDs")
52+
result = await handle_ground_truths(
53+
model_name,
54+
input_arrays,
55+
input_names,
56+
output_arrays,
57+
output_names,
58+
[sanitize_id(id) for id in execution_ids],
59+
)
60+
if not result.success:
61+
raise HTTPException(400, result.message)
62+
result_data = result.data
63+
if result_data is None:
64+
raise HTTPException(500, "Ground truth processing failed")
65+
gt_name = f"{model_name}_ground_truth"
66+
await storage.write_data(gt_name + OUTPUT_SUFFIX, result_data["outputs"], result_data["output_names"])
67+
await storage.write_data(
68+
gt_name + METADATA_SUFFIX,
69+
result_data["metadata"],
70+
result_data["metadata_names"],
71+
)
72+
return {"message": result.message}
73+
else:
74+
n_rows = input_arrays[0].shape[0]
75+
exec_ids = execution_ids or [str(uuid.uuid4()) for _ in range(n_rows)]
76+
77+
def flatten(arrays: List[np.ndarray], row: int) -> List[Any]:
78+
return [x for arr in arrays for x in (arr[row].flatten() if arr.ndim > 1 else [arr[row]])]
79+
80+
input_data = [flatten(input_arrays, i) for i in range(n_rows)]
81+
output_data = [flatten(output_arrays, i) for i in range(n_rows)]
82+
cols = ["id", "model_id", "timestamp", "tag"]
83+
current_timestamp = datetime.now().isoformat()
84+
metadata_rows = [
85+
[
86+
str(eid),
87+
str(model_name),
88+
str(current_timestamp),
89+
str(payload.data_tag or ""),
90+
]
91+
for eid in exec_ids
92+
]
93+
metadata = np.array(metadata_rows, dtype="<U100")
94+
await save_model_data(
95+
model_name,
96+
np.array(input_data),
97+
input_names,
98+
np.array(output_data),
99+
output_names,
100+
metadata,
101+
cols,
102+
)
103+
return {"message": f"{n_rows} datapoints added to {model_name}"}

src/service/utils/download.py

Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
import logging
2+
import pickle
3+
from datetime import datetime
4+
from typing import Any, List, Optional
5+
6+
import pandas as pd
7+
from fastapi import HTTPException
8+
from pydantic import BaseModel
9+
10+
from src.service.data.storage import get_storage_interface
11+
12+
logger = logging.getLogger(__name__)
13+
14+
15+
class RowMatcher(BaseModel):
16+
columnName: str
17+
operation: str
18+
values: List[Any]
19+
20+
21+
class DataRequestPayload(BaseModel):
22+
modelId: str
23+
matchAny: Optional[List[RowMatcher]] = []
24+
matchAll: Optional[List[RowMatcher]] = []
25+
matchNone: Optional[List[RowMatcher]] = []
26+
27+
28+
class DataResponsePayload(BaseModel):
29+
dataCSV: str
30+
31+
32+
def get_storage() -> Any:
33+
"""Get storage instance"""
34+
return get_storage_interface()
35+
36+
37+
def apply_matcher(df: pd.DataFrame, matcher: RowMatcher, negate: bool = False) -> pd.DataFrame:
38+
"""Apply a single matcher to the dataframe."""
39+
if matcher.operation not in ["EQUALS", "BETWEEN"]:
40+
raise HTTPException(
41+
status_code=400,
42+
detail="RowMatch operation must be one of [BETWEEN, EQUALS]",
43+
)
44+
if matcher.operation == "EQUALS":
45+
return apply_equals_matcher(df, matcher, negate)
46+
elif matcher.operation == "BETWEEN":
47+
return apply_between_matcher(df, matcher, negate)
48+
49+
50+
def apply_equals_matcher(df: pd.DataFrame, matcher: RowMatcher, negate: bool = False) -> pd.DataFrame:
51+
"""Apply EQUALS matcher to dataframe."""
52+
column_name = matcher.columnName
53+
values = matcher.values
54+
if column_name not in df.columns:
55+
raise HTTPException(
56+
status_code=400,
57+
detail=f"No feature or output found with name={column_name}",
58+
)
59+
mask = df[column_name].isin(values)
60+
if negate:
61+
mask = ~mask
62+
return df[mask]
63+
64+
65+
def apply_between_matcher(df: pd.DataFrame, matcher: RowMatcher, negate: bool = False) -> pd.DataFrame:
66+
"""Apply BETWEEN matcher to dataframe."""
67+
column_name = matcher.columnName
68+
values = matcher.values
69+
70+
if column_name not in df.columns:
71+
raise HTTPException(
72+
status_code=400,
73+
detail=f"No feature or output found with name={column_name}",
74+
)
75+
errors = []
76+
if len(values) != 2:
77+
errors.append(
78+
f"BETWEEN operation must contain exactly two values, describing the lower and upper bounds of the desired range. Received {len(values)} values"
79+
)
80+
if column_name == "trustyai.TIMESTAMP":
81+
if errors:
82+
combined_error = ", ".join(errors)
83+
raise HTTPException(status_code=400, detail=combined_error)
84+
try:
85+
start_time = pd.to_datetime(str(values[0]))
86+
end_time = pd.to_datetime(str(values[1]))
87+
df_times = pd.to_datetime(df[column_name])
88+
mask = (df_times >= start_time) & (df_times < end_time)
89+
except Exception as e:
90+
raise HTTPException(
91+
status_code=400,
92+
detail=f"Timestamp value is unparseable as an ISO_LOCAL_DATE_TIME: {str(e)}",
93+
)
94+
elif column_name == "trustyai.INDEX":
95+
if errors:
96+
combined_error = ", ".join(errors)
97+
raise HTTPException(status_code=400, detail=combined_error)
98+
min_val, max_val = sorted([int(v) for v in values])
99+
mask = (df[column_name] >= min_val) & (df[column_name] < max_val)
100+
else:
101+
if not all(isinstance(v, (int, float)) for v in values):
102+
errors.append(
103+
"BETWEEN operation must only contain numbers, describing the lower and upper bounds of the desired range. Received non-numeric values"
104+
)
105+
if errors:
106+
combined_error = ", ".join(errors)
107+
raise HTTPException(status_code=400, detail=combined_error)
108+
min_val, max_val = sorted(values)
109+
try:
110+
mask = (pd.to_numeric(df[column_name], errors="coerce") >= min_val) & (
111+
pd.to_numeric(df[column_name], errors="coerce") < max_val
112+
)
113+
except:
114+
mask = (df[column_name].astype(str) >= str(min_val)) & (df[column_name].astype(str) < str(max_val))
115+
if negate:
116+
mask = ~mask
117+
return df[mask]
118+
119+
120+
async def load_model_dataframe(model_id: str) -> pd.DataFrame:
121+
"""Load model dataframe from storage."""
122+
storage = get_storage()
123+
try:
124+
input_data, input_cols = await storage.read_data(f"{model_id}_inputs")
125+
output_data, output_cols = await storage.read_data(f"{model_id}_outputs")
126+
metadata_data, metadata_cols = await storage.read_data(f"{model_id}_metadata")
127+
if input_data is None or output_data is None or metadata_data is None:
128+
raise HTTPException(status_code=404, detail=f"Model {model_id} not found")
129+
df = pd.DataFrame()
130+
if len(input_data) > 0:
131+
if input_data.ndim == 2 and len(input_cols) == 1 and input_data.shape[1] > 1:
132+
col_name = input_cols[0]
133+
for j in range(input_data.shape[1]):
134+
df[f"{col_name}_{j}"] = input_data[:, j]
135+
else:
136+
input_df = pd.DataFrame(input_data, columns=input_cols)
137+
for col in input_cols:
138+
df[col] = input_df[col]
139+
if len(output_data) > 0:
140+
if output_data.ndim == 2 and len(output_cols) == 1 and output_data.shape[1] > 1:
141+
col_name = output_cols[0]
142+
for j in range(output_data.shape[1]):
143+
df[f"{col_name}_{j}"] = output_data[:, j]
144+
else:
145+
if output_data.ndim == 2:
146+
output_data = output_data.flatten()
147+
output_df = pd.DataFrame({output_cols[0]: output_data})
148+
for col in output_cols:
149+
df[col] = output_df[col]
150+
if len(metadata_data) > 0 and isinstance(metadata_data[0], bytes):
151+
deserialized_metadata = []
152+
for row in metadata_data:
153+
deserialized_row = pickle.loads(row)
154+
deserialized_metadata.append(deserialized_row)
155+
metadata_df = pd.DataFrame(deserialized_metadata, columns=metadata_cols)
156+
else:
157+
metadata_df = pd.DataFrame(metadata_data, columns=metadata_cols)
158+
trusty_mapping = {
159+
"id": "trustyai.ID",
160+
"model_id": "trustyai.MODEL_ID",
161+
"timestamp": "trustyai.TIMESTAMP",
162+
"tag": "trustyai.TAG",
163+
}
164+
for orig_col in metadata_cols:
165+
trusty_col = trusty_mapping.get(orig_col.lower(), orig_col)
166+
df[trusty_col] = metadata_df[orig_col]
167+
df["trustyai.INDEX"] = range(len(df))
168+
return df
169+
except Exception as e:
170+
if "not found" in str(e).lower() or "MissingH5PYDataException" in str(type(e).__name__):
171+
raise HTTPException(status_code=404, detail=f"Model {model_id} not found")
172+
raise HTTPException(status_code=500, detail=f"Error loading model data: {str(e)}")

0 commit comments

Comments
 (0)