Skip to content

Commit 5323f67

Browse files
committed
A minimal implementation of staleness control
1 parent 73c81b7 commit 5323f67

File tree

2 files changed

+13
-0
lines changed

2 files changed

+13
-0
lines changed

trinity/buffer/storage/queue.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -302,6 +302,8 @@ def __init__(self, config: StorageConfig) -> None:
302302
self.logger = get_logger(f"queue_{config.name}", in_ray_actor=True)
303303
self.config = config
304304
self.capacity = config.capacity
305+
self.staleness_limit = config.max_staleness # Optional[int]
306+
self.max_model_version = 0 # max model version that queue has seen so far
305307
self.queue = QueueBuffer.get_queue(config)
306308
st_config = deepcopy(config)
307309
st_config.wrap_in_ray = False
@@ -351,6 +353,9 @@ async def put_batch(self, exp_list: List) -> None:
351353
await self.queue.put(exp_list)
352354
if self.writer is not None:
353355
self.writer.write(exp_list)
356+
for exp in exp_list:
357+
if exp.info["model_version"] > self.max_model_version:
358+
self.max_model_version = exp.info["model_version"]
354359

355360
async def get_batch(self, batch_size: int, timeout: float) -> List:
356361
"""Get batch of experience."""
@@ -361,6 +366,13 @@ async def get_batch(self, batch_size: int, timeout: float) -> List:
361366
raise StopAsyncIteration("Queue is closed and no more items to get.")
362367
try:
363368
exp_list = await asyncio.wait_for(self.queue.get(), timeout=1.0)
369+
if (self.staleness_limit is not None) and (self.staleness_limit > 0):
370+
exp_list = [
371+
exp
372+
for exp in exp_list
373+
if exp.info["model_version"]
374+
>= self.max_model_version - self.staleness_limit
375+
]
364376
self.exp_pool.extend(exp_list)
365377
except asyncio.TimeoutError:
366378
if time.time() - start_time > timeout:

trinity/common/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,7 @@ class StorageConfig:
159159

160160
# used for StorageType.QUEUE
161161
capacity: int = 10000
162+
staleness_limit: Optional[int] = None
162163
max_read_timeout: float = 1800
163164
replay_buffer: Optional[ReplayBufferConfig] = field(default_factory=ReplayBufferConfig)
164165

0 commit comments

Comments
 (0)