@@ -38,17 +38,17 @@ def __init__(self) -> None:
3838
3939 # MODEL TRACKING OPERATIONS
4040
41- def get_known_models (self ) -> Set [str ]:
41+ async def get_known_models (self ) -> Set [str ]:
4242 """Get the set of known model IDs."""
4343 return self .known_models .copy ()
4444
45- def add_model_to_known (self , model_id : str ) -> None :
45+ async def add_model_to_known (self , model_id : str ) -> None :
4646 """Add a model to the known models set."""
4747 self .known_models .add (model_id )
4848
4949 # DATAFRAME READS
5050
51- def get_dataframe (self , model_id : str ) -> pd .DataFrame :
51+ async def get_dataframe (self , model_id : str ) -> pd .DataFrame :
5252 """
5353 Get a dataframe for the given model ID using the default batch size.
5454
@@ -62,9 +62,9 @@ def get_dataframe(self, model_id: str) -> pd.DataFrame:
6262 DataframeCreateException: If the dataframe cannot be created
6363 """
6464 batch_size = int (os .environ .get ("SERVICE_BATCH_SIZE" , "100" ))
65- return self .get_dataframe_with_batch_size (model_id , batch_size )
65+ return await self .get_dataframe_with_batch_size (model_id , batch_size )
6666
67- def get_dataframe_with_batch_size (
67+ async def get_dataframe_with_batch_size (
6868 self , model_id : str , batch_size : int
6969 ) -> pd .DataFrame :
7070 """
@@ -83,16 +83,16 @@ def get_dataframe_with_batch_size(
8383 try :
8484 model_data = ModelData (model_id )
8585
86- input_rows , output_rows , metadata_rows = asyncio . run ( model_data .row_counts () )
86+ input_rows , output_rows , metadata_rows = await model_data .row_counts ()
8787
8888 available_rows = min (input_rows , output_rows , metadata_rows )
8989
9090 start_row = max (0 , available_rows - batch_size )
9191 n_rows = min (batch_size , available_rows )
9292
93- input_data , output_data , metadata = asyncio . run ( model_data .data (start_row = start_row , n_rows = n_rows ) )
93+ input_data , output_data , metadata = await model_data .data (start_row = start_row , n_rows = n_rows )
9494
95- input_names , output_names , metadata_names = asyncio . run ( model_data .column_names () )
95+ input_names , output_names , metadata_names = await model_data .column_names ()
9696
9797
9898 # Combine the data into a single dataframe
@@ -121,7 +121,7 @@ def get_dataframe_with_batch_size(
121121 f"Error creating dataframe for model={ model_id } : { str (e )} "
122122 )
123123
124- def get_organic_dataframe (self , model_id : str , batch_size : int ) -> pd .DataFrame :
124+ async def get_organic_dataframe (self , model_id : str , batch_size : int ) -> pd .DataFrame :
125125 """
126126 Get a dataframe with only organic data (not synthetic).
127127
@@ -135,7 +135,7 @@ def get_organic_dataframe(self, model_id: str, batch_size: int) -> pd.DataFrame:
135135 Raises:
136136 DataframeCreateException: If the dataframe cannot be created
137137 """
138- df = self .get_dataframe_with_batch_size (model_id , batch_size )
138+ df = await self .get_dataframe_with_batch_size (model_id , batch_size )
139139
140140 # Filter out any rows with the unlabeled tag (synthetic data)
141141 if UNLABELED_TAG in df .columns :
@@ -145,7 +145,7 @@ def get_organic_dataframe(self, model_id: str, batch_size: int) -> pd.DataFrame:
145145
146146 # METADATA READS
147147
148- def get_metadata (self , model_id : str ) -> StorageMetadata :
148+ async def get_metadata (self , model_id : str ) -> StorageMetadata :
149149 """
150150 Get metadata for the given model ID.
151151
@@ -164,8 +164,8 @@ def get_metadata(self, model_id: str) -> StorageMetadata:
164164 try :
165165 model_data = ModelData (model_id )
166166
167- input_rows , output_rows , metadata_rows = asyncio . run ( model_data .row_counts () )
168- input_names , output_names , metadata_names = asyncio . run ( model_data .column_names () )
167+ input_rows , output_rows , metadata_rows = await model_data .row_counts ()
168+ input_names , output_names , metadata_names = await model_data .column_names ()
169169
170170 input_items = {}
171171 for i , name in enumerate (input_names ):
@@ -195,7 +195,7 @@ def get_metadata(self, model_id: str) -> StorageMetadata:
195195 f"Error getting metadata for model={ model_id } : { str (e )} "
196196 )
197197
198- def has_metadata (self , model_id : str ) -> bool :
198+ async def has_metadata (self , model_id : str ) -> bool :
199199 """
200200 Check if metadata exists for the given model ID.
201201
@@ -206,14 +206,14 @@ def has_metadata(self, model_id: str) -> bool:
206206 True if metadata exists, False otherwise
207207 """
208208 try :
209- return self .get_metadata (model_id ) is not None
209+ return await self .get_metadata (model_id ) is not None
210210 except Exception as e :
211211 logger .error (f"Error checking if metadata exists for model={ model_id } : { str (e )} " )
212212 return False
213213
214214 # DATAFRAME QUERIES
215215
216- def get_num_observations (self , model_id : str ) -> int :
216+ async def get_num_observations (self , model_id : str ) -> int :
217217 """
218218 Get the number of observations for the corresponding model.
219219
@@ -223,10 +223,10 @@ def get_num_observations(self, model_id: str) -> int:
223223 Returns:
224224 The number of observations
225225 """
226- metadata = self .get_metadata (model_id )
226+ metadata : StorageMetadata = await self .get_metadata (model_id )
227227 return metadata .get_observations ()
228228
229- def has_recorded_inferences (self , model_id : str ) -> bool :
229+ async def has_recorded_inferences (self , model_id : str ) -> bool :
230230 """
231231 Check to see if a particular model has recorded inferences.
232232
@@ -236,10 +236,10 @@ def has_recorded_inferences(self, model_id: str) -> bool:
236236 Returns:
237237 True if the model has received inference data
238238 """
239- metadata = self .get_metadata (model_id )
239+ metadata : StorageMetadata = await self .get_metadata (model_id )
240240 return metadata .is_recorded_inferences ()
241241
242- def get_verified_models (self ) -> List [str ]:
242+ async def get_verified_models (self ) -> List [str ]:
243243 """
244244 Get the list of model IDs that are confirmed to have metadata in storage.
245245
@@ -250,19 +250,19 @@ def get_verified_models(self) -> List[str]:
250250
251251 # Check all known models for metadata
252252 for model_id in self .known_models :
253- if self .has_metadata (model_id ):
253+ if await self .has_metadata (model_id ):
254254 verified_models .append (model_id )
255255
256256 if not verified_models :
257- discovered_models = self ._discover_models_from_storage ()
257+ discovered_models = await self ._discover_models_from_storage ()
258258 for model_id in discovered_models :
259- if self .has_metadata (model_id ):
260- self .add_model_to_known (model_id )
259+ if await self .has_metadata (model_id ):
260+ await self .add_model_to_known (model_id )
261261 verified_models .append (model_id )
262262
263263 return verified_models
264264
265- def _discover_models_from_storage (self ) -> List [str ]:
265+ async def _discover_models_from_storage (self ) -> List [str ]:
266266 """
267267 Discover model IDs from storage.
268268
@@ -294,7 +294,7 @@ def get_ground_truth_name(model_id: str) -> str:
294294 """
295295 return model_id + DataSource .GROUND_TRUTH_SUFFIX
296296
297- def has_ground_truths (self , model_id : str ) -> bool :
297+ async def has_ground_truths (self , model_id : str ) -> bool :
298298 """
299299 Check if ground truths exist for a model.
300300
@@ -304,9 +304,9 @@ def has_ground_truths(self, model_id: str) -> bool:
304304 Returns:
305305 True if ground truths exist, False otherwise
306306 """
307- return self .has_metadata (self .get_ground_truth_name (model_id ))
307+ return await self .has_metadata (self .get_ground_truth_name (model_id ))
308308
309- def get_ground_truths (self , model_id : str ) -> pd .DataFrame :
309+ async def get_ground_truths (self , model_id : str ) -> pd .DataFrame :
310310 """
311311 Get ground-truth dataframe for this particular model.
312312
@@ -316,11 +316,11 @@ def get_ground_truths(self, model_id: str) -> pd.DataFrame:
316316 Returns:
317317 The ground-truth dataframe
318318 """
319- return self .get_dataframe (self .get_ground_truth_name (model_id ))
319+ return await self .get_dataframe (self .get_ground_truth_name (model_id ))
320320
321321 # UTILITY METHODS
322322
323- def save_dataframe (
323+ async def save_dataframe (
324324 self , dataframe : pd .DataFrame , model_id : str , overwrite : bool = False
325325 ) -> None :
326326 """
@@ -332,7 +332,7 @@ def save_dataframe(
332332 overwrite: If true, overwrite existing data. Otherwise, append.
333333 """
334334 # Add to known models
335- self .add_model_to_known (model_id )
335+ await self .add_model_to_known (model_id )
336336
337337 # TODO: In a full implementation, this would save the dataframe to storage
338338 logger .info (f"Saving dataframe for model { model_id } (overwrite={ overwrite } )" )
0 commit comments