Skip to content

Commit 5ac8b6b

Browse files
dependabot[bot]m-misiura
authored andcommitted
🚧 Working on upload and download endpoints post PR review
Signed-off-by: m-misiura <[email protected]>
1 parent ccb3218 commit 5ac8b6b

File tree

8 files changed

+689
-414
lines changed

8 files changed

+689
-414
lines changed

.github/workflows/python-tests.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ name: Python Tests
22

33
on:
44
push:
5-
branches: [main, download_and_upload_endpoints]
5+
branches: [main]
66
pull_request:
77
branches: [main]
88

pyproject.toml

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,16 +23,11 @@ dev = [
2323
"isort>=5.12.0,<7",
2424
"flake8>=6.1.0,<8",
2525
"mypy>=1.5.1,<2",
26-
<<<<<<< HEAD
2726
"pytest-cov>=4.1.0,<5",
2827
"httpx>=0.25.0,<0.29",
29-
=======
30-
"pytest-cov>=4.1.0,<7",
31-
"httpx>=0.25.0,<0.26",
32-
>>>>>>> 331eed9 (Update pytest-cov requirement from <5,>=4.1.0 to >=4.1.0,<7)
3328
]
3429
eval = ["lm-eval[api]==0.4.4", "fastapi-utils>=0.8.0", "typing-inspect==0.9.0"]
35-
protobuf = ["numpy>=1.24.0,<2", "grpcio>=1.62.1,<2", "grpcio-tools>=1.62.1,<2"]
30+
protobuf = ["numpy>=1.24.0,<3", "grpcio>=1.62.1,<2", "grpcio-tools>=1.62.1,<2"]
3631

3732
[tool.hatch.build.targets.sdist]
3833
include = ["src"]

src/endpoints/data/data_download.py

Lines changed: 2 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from src.service.utils.download import (
77
DataRequestPayload,
88
DataResponsePayload,
9-
apply_matcher,
9+
apply_filters, # ← New utility function
1010
load_model_dataframe,
1111
)
1212

@@ -19,35 +19,10 @@ async def download_data(payload: DataRequestPayload) -> DataResponsePayload:
1919
"""Download model data with filtering."""
2020
try:
2121
logger.info(f"Received data download request for model: {payload.modelId}")
22-
23-
# Load the dataframe
2422
df = await load_model_dataframe(payload.modelId)
25-
2623
if df.empty:
2724
return DataResponsePayload(dataCSV="")
28-
# Apply matchAll filters (AND logic)
29-
if payload.matchAll:
30-
for matcher in payload.matchAll:
31-
df = apply_matcher(df, matcher, negate=False)
32-
# Apply matchNone filters (NOT logic)
33-
if payload.matchNone:
34-
for matcher in payload.matchNone:
35-
df = apply_matcher(df, matcher, negate=True)
36-
base_df = df.copy()
37-
# Apply matchAny filters (OR logic)
38-
if payload.matchAny:
39-
matching_dfs = []
40-
for matcher in payload.matchAny:
41-
matched_df = apply_matcher(base_df, matcher, negate=False)
42-
if not matched_df.empty:
43-
matching_dfs.append(matched_df)
44-
# Union all results
45-
if matching_dfs:
46-
df = pd.concat(matching_dfs, ignore_index=True).drop_duplicates()
47-
else:
48-
# No matches found, return empty dataframe with same columns
49-
df = pd.DataFrame(columns=df.columns)
50-
# Convert to CSV
25+
df = apply_filters(df, payload)
5126
csv_data = df.to_csv(index=False)
5227
return DataResponsePayload(dataCSV=csv_data)
5328
except HTTPException:

src/endpoints/data/data_upload.py

Lines changed: 8 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -10,104 +10,30 @@
1010
from src.service.constants import INPUT_SUFFIX, METADATA_SUFFIX, OUTPUT_SUFFIX
1111
from src.service.data.modelmesh_parser import ModelMeshPayloadParser
1212
from src.service.data.storage import get_storage_interface
13-
from src.service.utils.upload import (
14-
handle_ground_truths,
15-
process_tensors,
16-
sanitize_id,
17-
save_model_data,
18-
validate_data_tag,
19-
validate_input_shapes,
20-
)
13+
from src.service.utils.upload import process_upload_request
2114

2215
router = APIRouter()
2316
logger = logging.getLogger(__name__)
2417

18+
2519
class UploadPayload(BaseModel):
2620
model_name: str
2721
data_tag: Optional[str] = None
2822
is_ground_truth: bool = False
2923
request: Dict[str, Any]
30-
response: Dict[str, Any]
24+
response: Optional[Dict[str, Any]] = None
3125

3226

3327
@router.post("/data/upload")
3428
async def upload(payload: UploadPayload) -> Dict[str, str]:
3529
"""Upload model data - regular or ground truth."""
3630
try:
37-
# Get fresh storage interface for each request
38-
storage = get_storage_interface()
39-
40-
model_name = ModelMeshPayloadParser.standardize_model_id(payload.model_name)
41-
if payload.data_tag and (error := validate_data_tag(payload.data_tag)):
42-
raise HTTPException(400, error)
43-
inputs = payload.request.get("inputs", [])
44-
outputs = payload.response.get("outputs", [])
45-
if not inputs or not outputs:
46-
raise HTTPException(400, "Missing input or output tensors")
47-
input_arrays, input_names, _, execution_ids = process_tensors(inputs)
48-
output_arrays, output_names, _, _ = process_tensors(outputs)
49-
if error := validate_input_shapes(input_arrays, input_names):
50-
raise HTTPException(400, f"One or more errors in input tensors: {error}")
51-
52-
if payload.is_ground_truth:
53-
if not execution_ids:
54-
raise HTTPException(400, "Ground truth requires execution IDs")
55-
result = await handle_ground_truths(
56-
model_name,
57-
input_arrays,
58-
input_names,
59-
output_arrays,
60-
output_names,
61-
[sanitize_id(id) for id in execution_ids],
62-
)
63-
if not result.success:
64-
raise HTTPException(400, result.message)
65-
result_data = result.data
66-
if result_data is None:
67-
raise HTTPException(500, "Ground truth processing failed")
68-
gt_name = f"{model_name}_ground_truth"
69-
await storage.write_data(gt_name + OUTPUT_SUFFIX, result_data["outputs"], result_data["output_names"])
70-
await storage.write_data(
71-
gt_name + METADATA_SUFFIX,
72-
result_data["metadata"],
73-
result_data["metadata_names"],
74-
)
75-
return {"message": result.message}
76-
else:
77-
n_rows = input_arrays[0].shape[0]
78-
exec_ids = execution_ids or [str(uuid.uuid4()) for _ in range(n_rows)]
79-
80-
def flatten(arrays: List[np.ndarray], row: int) -> List[Any]:
81-
return [x for arr in arrays for x in (arr[row].flatten() if arr.ndim > 1 else [arr[row]])]
82-
83-
input_data = [flatten(input_arrays, i) for i in range(n_rows)]
84-
output_data = [flatten(output_arrays, i) for i in range(n_rows)]
85-
cols = ["id", "model_id", "timestamp", "tag"]
86-
current_timestamp = datetime.now().isoformat()
87-
metadata_rows = [
88-
[
89-
str(eid),
90-
str(model_name),
91-
str(current_timestamp),
92-
str(payload.data_tag or ""),
93-
]
94-
for eid in exec_ids
95-
]
96-
metadata = np.array(metadata_rows, dtype="<U100")
97-
await save_model_data(
98-
model_name,
99-
np.array(input_data),
100-
input_names,
101-
np.array(output_data),
102-
output_names,
103-
metadata,
104-
cols,
105-
)
106-
return {"message": f"{n_rows} datapoints added to {model_name}"}
107-
31+
logger.info(f"Received upload request for model: {payload.model_name}")
32+
result = await process_upload_request(payload)
33+
logger.info(f"Upload completed for model: {payload.model_name}")
34+
return result
10835
except HTTPException:
109-
# Re-raise HTTP exceptions as-is
11036
raise
11137
except Exception as e:
11238
logger.error(f"Unexpected error in upload endpoint for model {payload.model_name}: {str(e)}", exc_info=True)
113-
raise HTTPException(500, f"Internal server error: {str(e)}")
39+
raise HTTPException(500, f"Internal server error: {str(e)}")

0 commit comments

Comments
 (0)