@@ -302,6 +302,8 @@ def __init__(self, config: StorageConfig) -> None:
302302 self .logger = get_logger (f"queue_{ config .name } " , in_ray_actor = True )
303303 self .config = config
304304 self .capacity = config .capacity
305+ self .staleness_limit = config .max_staleness # Optional[int]
306+ self .max_model_version = 0 # max model version that queue has seen so far
305307 self .queue = QueueBuffer .get_queue (config )
306308 st_config = deepcopy (config )
307309 st_config .wrap_in_ray = False
@@ -351,6 +353,9 @@ async def put_batch(self, exp_list: List) -> None:
351353 await self .queue .put (exp_list )
352354 if self .writer is not None :
353355 self .writer .write (exp_list )
356+ for exp in exp_list :
357+ if exp .info ["model_version" ] > self .max_model_version :
358+ self .max_model_version = exp .info ["model_version" ]
354359
355360 async def get_batch (self , batch_size : int , timeout : float ) -> List :
356361 """Get batch of experience."""
@@ -361,6 +366,13 @@ async def get_batch(self, batch_size: int, timeout: float) -> List:
361366 raise StopAsyncIteration ("Queue is closed and no more items to get." )
362367 try :
363368 exp_list = await asyncio .wait_for (self .queue .get (), timeout = 1.0 )
369+ if (self .staleness_limit is not None ) and (self .staleness_limit > 0 ):
370+ exp_list = [
371+ exp
372+ for exp in exp_list
373+ if exp .info ["model_version" ]
374+ >= self .max_model_version - self .staleness_limit
375+ ]
364376 self .exp_pool .extend (exp_list )
365377 except asyncio .TimeoutError :
366378 if time .time () - start_time > timeout :
0 commit comments