Skip to content

Commit 8a0eb28

Browse files
committed
🚧 fixed storage interface issue by adding fresh storage per request
1 parent bc0cf51 commit 8a0eb28

File tree

3 files changed

+122
-106
lines changed

3 files changed

+122
-106
lines changed

src/endpoints/data/data_upload.py

Lines changed: 76 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from fastapi import APIRouter, HTTPException
88
from pydantic import BaseModel
99

10-
from src.service.constants import METADATA_SUFFIX, OUTPUT_SUFFIX
10+
from src.service.constants import INPUT_SUFFIX, METADATA_SUFFIX, OUTPUT_SUFFIX
1111
from src.service.data.modelmesh_parser import ModelMeshPayloadParser
1212
from src.service.data.storage import get_storage_interface
1313
from src.service.utils.upload import (
@@ -21,8 +21,6 @@
2121

2222
router = APIRouter()
2323
logger = logging.getLogger(__name__)
24-
storage = get_storage_interface()
25-
2624

2725
class UploadPayload(BaseModel):
2826
model_name: str
@@ -35,69 +33,81 @@ class UploadPayload(BaseModel):
3533
@router.post("/data/upload")
3634
async def upload(payload: UploadPayload) -> Dict[str, str]:
3735
"""Upload model data - regular or ground truth."""
38-
model_name = ModelMeshPayloadParser.standardize_model_id(payload.model_name)
39-
if payload.data_tag and (error := validate_data_tag(payload.data_tag)):
40-
raise HTTPException(400, error)
41-
inputs = payload.request.get("inputs", [])
42-
outputs = payload.response.get("outputs", [])
43-
if not inputs or not outputs:
44-
raise HTTPException(400, "Missing input or output tensors")
45-
input_arrays, input_names, _, execution_ids = process_tensors(inputs)
46-
output_arrays, output_names, _, _ = process_tensors(outputs)
47-
if error := validate_input_shapes(input_arrays, input_names):
48-
raise HTTPException(400, f"One or more errors in input tensors: {error}")
49-
if payload.is_ground_truth:
50-
if not execution_ids:
51-
raise HTTPException(400, "Ground truth requires execution IDs")
52-
result = await handle_ground_truths(
53-
model_name,
54-
input_arrays,
55-
input_names,
56-
output_arrays,
57-
output_names,
58-
[sanitize_id(id) for id in execution_ids],
59-
)
60-
if not result.success:
61-
raise HTTPException(400, result.message)
62-
result_data = result.data
63-
if result_data is None:
64-
raise HTTPException(500, "Ground truth processing failed")
65-
gt_name = f"{model_name}_ground_truth"
66-
await storage.write_data(gt_name + OUTPUT_SUFFIX, result_data["outputs"], result_data["output_names"])
67-
await storage.write_data(
68-
gt_name + METADATA_SUFFIX,
69-
result_data["metadata"],
70-
result_data["metadata_names"],
71-
)
72-
return {"message": result.message}
73-
else:
74-
n_rows = input_arrays[0].shape[0]
75-
exec_ids = execution_ids or [str(uuid.uuid4()) for _ in range(n_rows)]
36+
try:
37+
# Get fresh storage interface for each request
38+
storage = get_storage_interface()
39+
40+
model_name = ModelMeshPayloadParser.standardize_model_id(payload.model_name)
41+
if payload.data_tag and (error := validate_data_tag(payload.data_tag)):
42+
raise HTTPException(400, error)
43+
inputs = payload.request.get("inputs", [])
44+
outputs = payload.response.get("outputs", [])
45+
if not inputs or not outputs:
46+
raise HTTPException(400, "Missing input or output tensors")
47+
input_arrays, input_names, _, execution_ids = process_tensors(inputs)
48+
output_arrays, output_names, _, _ = process_tensors(outputs)
49+
if error := validate_input_shapes(input_arrays, input_names):
50+
raise HTTPException(400, f"One or more errors in input tensors: {error}")
51+
52+
if payload.is_ground_truth:
53+
if not execution_ids:
54+
raise HTTPException(400, "Ground truth requires execution IDs")
55+
result = await handle_ground_truths(
56+
model_name,
57+
input_arrays,
58+
input_names,
59+
output_arrays,
60+
output_names,
61+
[sanitize_id(id) for id in execution_ids],
62+
)
63+
if not result.success:
64+
raise HTTPException(400, result.message)
65+
result_data = result.data
66+
if result_data is None:
67+
raise HTTPException(500, "Ground truth processing failed")
68+
gt_name = f"{model_name}_ground_truth"
69+
await storage.write_data(gt_name + OUTPUT_SUFFIX, result_data["outputs"], result_data["output_names"])
70+
await storage.write_data(
71+
gt_name + METADATA_SUFFIX,
72+
result_data["metadata"],
73+
result_data["metadata_names"],
74+
)
75+
return {"message": result.message}
76+
else:
77+
n_rows = input_arrays[0].shape[0]
78+
exec_ids = execution_ids or [str(uuid.uuid4()) for _ in range(n_rows)]
7679

77-
def flatten(arrays: List[np.ndarray], row: int) -> List[Any]:
78-
return [x for arr in arrays for x in (arr[row].flatten() if arr.ndim > 1 else [arr[row]])]
80+
def flatten(arrays: List[np.ndarray], row: int) -> List[Any]:
81+
return [x for arr in arrays for x in (arr[row].flatten() if arr.ndim > 1 else [arr[row]])]
7982

80-
input_data = [flatten(input_arrays, i) for i in range(n_rows)]
81-
output_data = [flatten(output_arrays, i) for i in range(n_rows)]
82-
cols = ["id", "model_id", "timestamp", "tag"]
83-
current_timestamp = datetime.now().isoformat()
84-
metadata_rows = [
85-
[
86-
str(eid),
87-
str(model_name),
88-
str(current_timestamp),
89-
str(payload.data_tag or ""),
83+
input_data = [flatten(input_arrays, i) for i in range(n_rows)]
84+
output_data = [flatten(output_arrays, i) for i in range(n_rows)]
85+
cols = ["id", "model_id", "timestamp", "tag"]
86+
current_timestamp = datetime.now().isoformat()
87+
metadata_rows = [
88+
[
89+
str(eid),
90+
str(model_name),
91+
str(current_timestamp),
92+
str(payload.data_tag or ""),
93+
]
94+
for eid in exec_ids
9095
]
91-
for eid in exec_ids
92-
]
93-
metadata = np.array(metadata_rows, dtype="<U100")
94-
await save_model_data(
95-
model_name,
96-
np.array(input_data),
97-
input_names,
98-
np.array(output_data),
99-
output_names,
100-
metadata,
101-
cols,
102-
)
103-
return {"message": f"{n_rows} datapoints added to {model_name}"}
96+
metadata = np.array(metadata_rows, dtype="<U100")
97+
await save_model_data(
98+
model_name,
99+
np.array(input_data),
100+
input_names,
101+
np.array(output_data),
102+
output_names,
103+
metadata,
104+
cols,
105+
)
106+
return {"message": f"{n_rows} datapoints added to {model_name}"}
107+
108+
except HTTPException:
109+
# Re-raise HTTP exceptions as-is
110+
raise
111+
except Exception as e:
112+
logger.error(f"Unexpected error in upload endpoint for model {payload.model_name}: {str(e)}", exc_info=True)
113+
raise HTTPException(500, f"Internal server error: {str(e)}")

src/service/utils/upload.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,8 @@ def __init__(self, model_name: str):
170170

171171
async def initialize(self) -> None:
172172
"""Load existing data."""
173+
# Get fresh storage interface for each call
174+
storage_interface = get_storage_interface()
173175
self.inputs, _ = await storage_interface.read_data(self.model_name + INPUT_SUFFIX)
174176
self.outputs, _ = await storage_interface.read_data(self.model_name + OUTPUT_SUFFIX)
175177
self.metadata, _ = await storage_interface.read_data(self.model_name + METADATA_SUFFIX)
@@ -216,6 +218,8 @@ async def validate_data(
216218
return f"ID={exec_id} output type mismatch at position {i + 1}: Class={existing_type} != Class={uploaded_type}"
217219
if output_names:
218220
try:
221+
# Get fresh storage interface for each call
222+
storage_interface = get_storage_interface()
219223
stored_output_names = await storage_interface.get_original_column_names(self.model_name + OUTPUT_SUFFIX)
220224
if len(stored_output_names) != len(output_names):
221225
return (
@@ -233,6 +237,8 @@ async def validate_data(
233237
logger.warning(f"Could not validate output names for {exec_id}: {e}")
234238
if input_names:
235239
try:
240+
# Get fresh storage interface for each call
241+
storage_interface = get_storage_interface()
236242
stored_input_names = await storage_interface.get_original_column_names(self.model_name + INPUT_SUFFIX)
237243
if len(stored_input_names) != len(input_names):
238244
return (
@@ -262,6 +268,8 @@ async def handle_ground_truths(
262268
"""Handle ground truth validation."""
263269
if not execution_ids:
264270
return GroundTruthValidationResult(success=False, message="No execution IDs provided.")
271+
# Get fresh storage interface for each call
272+
storage_interface = get_storage_interface()
265273
if not await storage_interface.dataset_exists(model_name + INPUT_SUFFIX):
266274
return GroundTruthValidationResult(success=False, message=f"Model {model_name} not found.")
267275
validator = GroundTruthValidator(model_name)
@@ -316,11 +324,13 @@ async def save_model_data(
316324
metadata_data: np.ndarray,
317325
metadata_names: List[str],
318326
) -> Dict[str, Any]:
327+
# Get fresh storage interface for each call
328+
storage_interface = get_storage_interface()
319329
"""Save model data to storage."""
320330
await storage_interface.write_data(model_name + INPUT_SUFFIX, input_data, input_names)
321331
await storage_interface.write_data(model_name + OUTPUT_SUFFIX, output_data, output_names)
322332
await storage_interface.write_data(model_name + METADATA_SUFFIX, metadata_data, metadata_names)
323333
return {
324334
"model_name": model_name,
325335
"rows": len(input_data),
326-
}
336+
}

tests/endpoints/test_upload_endpoint.py

Lines changed: 35 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
import sys
77
import tempfile
88
import uuid
9-
import asyncio
109

1110
import h5py
1211
import numpy as np
@@ -25,13 +24,7 @@
2524
)
2625
from src.service.data.storage import get_storage_interface
2726

28-
@pytest.fixture(autouse=True)
29-
def reset_storage():
30-
"""Reset storage interface for each test."""
31-
import src.service.data.storage
32-
src.service.data.storage._storage_interface = None
33-
yield
34-
27+
3528
def pytest_sessionfinish(session, exitstatus):
3629
"""Clean up the temporary directory after all tests are done."""
3730
if os.path.exists(TEMP_DIR):
@@ -136,43 +129,46 @@ def generate_mismatched_shape_no_unique_name_multi_input_payload(n_rows, n_input
136129

137130
def get_data_from_storage(model_name, suffix):
138131
"""Get data from storage file."""
139-
async def _get_data():
140-
storage = get_storage_interface()
141-
try:
142-
data, column_names = await storage.read_data(model_name + suffix)
143-
return {"data": data, "column_names": column_names}
144-
except Exception as e:
145-
print(f"Error reading {model_name + suffix}: {e}")
146-
return None
147-
148-
try:
149-
return asyncio.run(_get_data())
150-
except Exception as e:
151-
print(f"Async error for {model_name + suffix}: {e}")
132+
storage = get_storage_interface()
133+
filename = storage._get_filename(model_name + suffix)
134+
if not os.path.exists(filename):
152135
return None
136+
with h5py.File(filename, "r") as f:
137+
if model_name + suffix in f:
138+
data = f[model_name + suffix][:]
139+
column_names = f[model_name + suffix].attrs.get("column_names", [])
140+
return {"data": data, "column_names": column_names}
153141

154142

155143
def get_metadata_ids(model_name):
156144
"""Extract actual IDs from metadata storage."""
157-
async def _get_ids():
158-
storage = get_storage_interface()
159-
try:
160-
metadata, column_names = await storage.read_data(model_name + METADATA_SUFFIX)
145+
storage = get_storage_interface()
146+
filename = storage._get_filename(model_name + METADATA_SUFFIX)
147+
if not os.path.exists(filename):
148+
return []
149+
ids = []
150+
with h5py.File(filename, "r") as f:
151+
if model_name + METADATA_SUFFIX in f:
152+
metadata = f[model_name + METADATA_SUFFIX][:]
153+
column_names = f[model_name + METADATA_SUFFIX].attrs.get("column_names", [])
161154
id_idx = next((i for i, name in enumerate(column_names) if name.lower() == "id"), None)
162-
if id_idx is not None and metadata is not None:
163-
ids = []
155+
if id_idx is not None:
164156
for row in metadata:
165-
if hasattr(row, "__len__") and len(row) > id_idx:
166-
ids.append(str(row[id_idx]))
167-
return ids
168-
except Exception as e:
169-
print(f"Error getting metadata: {e}")
170-
return []
171-
172-
try:
173-
return asyncio.run(_get_ids())
174-
except Exception:
175-
return []
157+
try:
158+
if hasattr(row, "__getitem__") and len(row) > id_idx:
159+
id_val = row[id_idx]
160+
else:
161+
row_data = pickle.loads(row.tobytes())
162+
id_val = row_data[id_idx]
163+
if isinstance(id_val, np.ndarray):
164+
ids.append(str(id_val))
165+
else:
166+
ids.append(str(id_val))
167+
except Exception as e:
168+
print(f"Error processing ID from row {len(ids)}: {e}")
169+
continue
170+
print(f"Successfully extracted {len(ids)} IDs: {ids}")
171+
return ids
176172

177173

178174
def get_metadata_from_storage(model_name):
@@ -443,4 +439,4 @@ def test_upload_gaussian_data():
443439
],
444440
},
445441
}
446-
post_test(payload, 200, ["2 datapoints"])
442+
post_test(payload, 200, ["2 datapoints"])

0 commit comments

Comments
 (0)