Skip to content

Commit 958d7f1

Browse files
resolve nested async loop issue
1 parent 58a2639 commit 958d7f1

File tree

8 files changed

+169
-143
lines changed

8 files changed

+169
-143
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ dependencies = [
2020
[project.optional-dependencies]
2121
dev = [
2222
"pytest>=7.4.2,<9",
23+
"pytest-asyncio>=0.26.0,<2",
2324
"isort>=5.12.0,<6",
2425
"flake8>=6.1.0,<7",
2526
"mypy>=1.5.1,<2",

src/main.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@
5656
raise_exceptions=False
5757
)
5858
async def schedule_metrics_calculation():
59-
prometheus_scheduler.calculate()
59+
await prometheus_scheduler.calculate()
6060

6161
@asynccontextmanager
6262
async def lifespan(app: FastAPI):

src/service/data/datasources/data_source.py

Lines changed: 31 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -38,17 +38,17 @@ def __init__(self) -> None:
3838

3939
# MODEL TRACKING OPERATIONS
4040

41-
def get_known_models(self) -> Set[str]:
41+
async def get_known_models(self) -> Set[str]:
4242
"""Get the set of known model IDs."""
4343
return self.known_models.copy()
4444

45-
def add_model_to_known(self, model_id: str) -> None:
45+
async def add_model_to_known(self, model_id: str) -> None:
4646
"""Add a model to the known models set."""
4747
self.known_models.add(model_id)
4848

4949
# DATAFRAME READS
5050

51-
def get_dataframe(self, model_id: str) -> pd.DataFrame:
51+
async def get_dataframe(self, model_id: str) -> pd.DataFrame:
5252
"""
5353
Get a dataframe for the given model ID using the default batch size.
5454
@@ -62,9 +62,9 @@ def get_dataframe(self, model_id: str) -> pd.DataFrame:
6262
DataframeCreateException: If the dataframe cannot be created
6363
"""
6464
batch_size = int(os.environ.get("SERVICE_BATCH_SIZE", "100"))
65-
return self.get_dataframe_with_batch_size(model_id, batch_size)
65+
return await self.get_dataframe_with_batch_size(model_id, batch_size)
6666

67-
def get_dataframe_with_batch_size(
67+
async def get_dataframe_with_batch_size(
6868
self, model_id: str, batch_size: int
6969
) -> pd.DataFrame:
7070
"""
@@ -83,16 +83,16 @@ def get_dataframe_with_batch_size(
8383
try:
8484
model_data = ModelData(model_id)
8585

86-
input_rows, output_rows, metadata_rows = asyncio.run(model_data.row_counts())
86+
input_rows, output_rows, metadata_rows = await model_data.row_counts()
8787

8888
available_rows = min(input_rows, output_rows, metadata_rows)
8989

9090
start_row = max(0, available_rows - batch_size)
9191
n_rows = min(batch_size, available_rows)
9292

93-
input_data, output_data, metadata = asyncio.run(model_data.data(start_row=start_row, n_rows=n_rows))
93+
input_data, output_data, metadata = await model_data.data(start_row=start_row, n_rows=n_rows)
9494

95-
input_names, output_names, metadata_names = asyncio.run(model_data.column_names())
95+
input_names, output_names, metadata_names = await model_data.column_names()
9696

9797

9898
# Combine the data into a single dataframe
@@ -121,7 +121,7 @@ def get_dataframe_with_batch_size(
121121
f"Error creating dataframe for model={model_id}: {str(e)}"
122122
)
123123

124-
def get_organic_dataframe(self, model_id: str, batch_size: int) -> pd.DataFrame:
124+
async def get_organic_dataframe(self, model_id: str, batch_size: int) -> pd.DataFrame:
125125
"""
126126
Get a dataframe with only organic data (not synthetic).
127127
@@ -135,7 +135,7 @@ def get_organic_dataframe(self, model_id: str, batch_size: int) -> pd.DataFrame:
135135
Raises:
136136
DataframeCreateException: If the dataframe cannot be created
137137
"""
138-
df = self.get_dataframe_with_batch_size(model_id, batch_size)
138+
df = await self.get_dataframe_with_batch_size(model_id, batch_size)
139139

140140
# Filter out any rows with the unlabeled tag (synthetic data)
141141
if UNLABELED_TAG in df.columns:
@@ -145,7 +145,7 @@ def get_organic_dataframe(self, model_id: str, batch_size: int) -> pd.DataFrame:
145145

146146
# METADATA READS
147147

148-
def get_metadata(self, model_id: str) -> StorageMetadata:
148+
async def get_metadata(self, model_id: str) -> StorageMetadata:
149149
"""
150150
Get metadata for the given model ID.
151151
@@ -164,8 +164,8 @@ def get_metadata(self, model_id: str) -> StorageMetadata:
164164
try:
165165
model_data = ModelData(model_id)
166166

167-
input_rows, output_rows, metadata_rows = asyncio.run(model_data.row_counts())
168-
input_names, output_names, metadata_names = asyncio.run(model_data.column_names())
167+
input_rows, output_rows, metadata_rows = await model_data.row_counts()
168+
input_names, output_names, metadata_names = await model_data.column_names()
169169

170170
input_items = {}
171171
for i, name in enumerate(input_names):
@@ -195,7 +195,7 @@ def get_metadata(self, model_id: str) -> StorageMetadata:
195195
f"Error getting metadata for model={model_id}: {str(e)}"
196196
)
197197

198-
def has_metadata(self, model_id: str) -> bool:
198+
async def has_metadata(self, model_id: str) -> bool:
199199
"""
200200
Check if metadata exists for the given model ID.
201201
@@ -206,14 +206,14 @@ def has_metadata(self, model_id: str) -> bool:
206206
True if metadata exists, False otherwise
207207
"""
208208
try:
209-
return self.get_metadata(model_id) is not None
209+
return await self.get_metadata(model_id) is not None
210210
except Exception as e:
211211
logger.error(f"Error checking if metadata exists for model={model_id}: {str(e)}")
212212
return False
213213

214214
# DATAFRAME QUERIES
215215

216-
def get_num_observations(self, model_id: str) -> int:
216+
async def get_num_observations(self, model_id: str) -> int:
217217
"""
218218
Get the number of observations for the corresponding model.
219219
@@ -223,10 +223,10 @@ def get_num_observations(self, model_id: str) -> int:
223223
Returns:
224224
The number of observations
225225
"""
226-
metadata = self.get_metadata(model_id)
226+
metadata: StorageMetadata = await self.get_metadata(model_id)
227227
return metadata.get_observations()
228228

229-
def has_recorded_inferences(self, model_id: str) -> bool:
229+
async def has_recorded_inferences(self, model_id: str) -> bool:
230230
"""
231231
Check to see if a particular model has recorded inferences.
232232
@@ -236,10 +236,10 @@ def has_recorded_inferences(self, model_id: str) -> bool:
236236
Returns:
237237
True if the model has received inference data
238238
"""
239-
metadata = self.get_metadata(model_id)
239+
metadata: StorageMetadata = await self.get_metadata(model_id)
240240
return metadata.is_recorded_inferences()
241241

242-
def get_verified_models(self) -> List[str]:
242+
async def get_verified_models(self) -> List[str]:
243243
"""
244244
Get the list of model IDs that are confirmed to have metadata in storage.
245245
@@ -250,19 +250,19 @@ def get_verified_models(self) -> List[str]:
250250

251251
# Check all known models for metadata
252252
for model_id in self.known_models:
253-
if self.has_metadata(model_id):
253+
if await self.has_metadata(model_id):
254254
verified_models.append(model_id)
255255

256256
if not verified_models:
257-
discovered_models = self._discover_models_from_storage()
257+
discovered_models = await self._discover_models_from_storage()
258258
for model_id in discovered_models:
259-
if self.has_metadata(model_id):
260-
self.add_model_to_known(model_id)
259+
if await self.has_metadata(model_id):
260+
await self.add_model_to_known(model_id)
261261
verified_models.append(model_id)
262262

263263
return verified_models
264264

265-
def _discover_models_from_storage(self) -> List[str]:
265+
async def _discover_models_from_storage(self) -> List[str]:
266266
"""
267267
Discover model IDs from storage.
268268
@@ -294,7 +294,7 @@ def get_ground_truth_name(model_id: str) -> str:
294294
"""
295295
return model_id + DataSource.GROUND_TRUTH_SUFFIX
296296

297-
def has_ground_truths(self, model_id: str) -> bool:
297+
async def has_ground_truths(self, model_id: str) -> bool:
298298
"""
299299
Check if ground truths exist for a model.
300300
@@ -304,9 +304,9 @@ def has_ground_truths(self, model_id: str) -> bool:
304304
Returns:
305305
True if ground truths exist, False otherwise
306306
"""
307-
return self.has_metadata(self.get_ground_truth_name(model_id))
307+
return await self.has_metadata(self.get_ground_truth_name(model_id))
308308

309-
def get_ground_truths(self, model_id: str) -> pd.DataFrame:
309+
async def get_ground_truths(self, model_id: str) -> pd.DataFrame:
310310
"""
311311
Get ground-truth dataframe for this particular model.
312312
@@ -316,11 +316,11 @@ def get_ground_truths(self, model_id: str) -> pd.DataFrame:
316316
Returns:
317317
The ground-truth dataframe
318318
"""
319-
return self.get_dataframe(self.get_ground_truth_name(model_id))
319+
return await self.get_dataframe(self.get_ground_truth_name(model_id))
320320

321321
# UTILITY METHODS
322322

323-
def save_dataframe(
323+
async def save_dataframe(
324324
self, dataframe: pd.DataFrame, model_id: str, overwrite: bool = False
325325
) -> None:
326326
"""
@@ -332,7 +332,7 @@ def save_dataframe(
332332
overwrite: If true, overwrite existing data. Otherwise, append.
333333
"""
334334
# Add to known models
335-
self.add_model_to_known(model_id)
335+
await self.add_model_to_known(model_id)
336336

337337
# TODO: In a full implementation, this would save the dataframe to storage
338338
logger.info(f"Saving dataframe for model {model_id} (overwrite={overwrite})")

src/service/payloads/metrics/request_reconciler.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,15 @@
1616

1717
class RequestReconciler:
1818
@staticmethod
19-
def reconcile(request: BaseMetricRequest, data_source: DataSource) -> None:
19+
async def reconcile(request: BaseMetricRequest, data_source: DataSource) -> None:
2020
"""
2121
Reconcile a metric request with the data source.
2222
2323
Args:
2424
request: The metric request to reconcile
2525
data_source: The data source to use for reconciliation
2626
"""
27-
storage_metadata: StorageMetadata = data_source.get_metadata(request.model_id)
27+
storage_metadata: StorageMetadata = await data_source.get_metadata(request.model_id)
2828
RequestReconciler.reconcile_with_metadata(request, storage_metadata)
2929

3030
@staticmethod

src/service/prometheus/prometheus_scheduler.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -61,19 +61,19 @@ def get_all_requests_flat(self) -> Dict[uuid.UUID, BaseMetricRequest]:
6161
result.update(metric_dict)
6262
return result
6363

64-
def calculate(self) -> None:
64+
async def calculate(self) -> None:
6565
"""Calculate scheduled metrics."""
66-
self.calculate_manual(False)
66+
await self.calculate_manual(False)
6767

68-
def calculate_manual(self, throw_errors: bool = True) -> None:
68+
async def calculate_manual(self, throw_errors: bool = True) -> None:
6969
"""
7070
Calculate scheduled metrics.
7171
7272
Args:
7373
throw_errors: If True, errors will be thrown. If False, they will just be logged.
7474
"""
7575
try:
76-
verified_models = self.data_source.get_verified_models()
76+
verified_models = await self.data_source.get_verified_models()
7777

7878
# Global service statistic
7979
self.publisher.gauge(
@@ -86,15 +86,15 @@ def calculate_manual(self, throw_errors: bool = True) -> None:
8686

8787
for model_id in verified_models:
8888
# Global model statistics
89-
total_observations = self.data_source.get_num_observations(model_id)
89+
total_observations = await self.data_source.get_num_observations(model_id)
9090
self.publisher.gauge(
9191
model_name=model_id,
9292
id=PrometheusPublisher.generate_uuid(model_id),
9393
metric_name="MODEL_OBSERVATIONS_TOTAL",
9494
value=total_observations,
9595
)
9696

97-
has_recorded_inferences = self.data_source.has_recorded_inferences(
97+
has_recorded_inferences = await self.data_source.has_recorded_inferences(
9898
model_id
9999
)
100100

@@ -122,7 +122,7 @@ def calculate_manual(self, throw_errors: bool = True) -> None:
122122
default=self.service_config.get("batch_size", 100),
123123
)
124124

125-
df = self.data_source.get_organic_dataframe(
125+
df = await self.data_source.get_organic_dataframe(
126126
model_id, max_batch_size
127127
)
128128

@@ -161,17 +161,17 @@ def calculate_manual(self, throw_errors: bool = True) -> None:
161161
else:
162162
logger.error(f"Error calculating metrics: {e}")
163163

164-
def register(
164+
async def register(
165165
self, metric_name: str, id: uuid.UUID, request: BaseMetricRequest
166166
) -> None:
167167
"""Register a metric request."""
168-
RequestReconciler.reconcile(request, self.data_source)
168+
await RequestReconciler.reconcile(request, self.data_source)
169169
with self._requests_lock:
170170
if metric_name not in self.requests:
171171
self.requests[metric_name] = {}
172172
self.requests[metric_name][id] = request
173173

174-
def delete(self, metric_name: str, id: uuid.UUID) -> None:
174+
async def delete(self, metric_name: str, id: uuid.UUID) -> None:
175175
"""Delete a metric request."""
176176
with self._requests_lock:
177177
if metric_name in self.requests and id in self.requests[metric_name]:

0 commit comments

Comments
 (0)