Skip to content

Commit e66cf17

Browse files
committed
Update doc string
1 parent bfa2671 commit e66cf17

File tree

1 file changed

+11
-7
lines changed

1 file changed

+11
-7
lines changed

trinity/buffer/storage/queue.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,17 @@ def is_json_file(path: str) -> bool:
2626
return path.endswith(".json") or path.endswith(".jsonl")
2727

2828

29-
# Each priority_fn,
30-
# Args: item, kwargs
31-
# Returns: priority (float), put_into_queue (bool, decide whether to put item into queue)
32-
# Note that put_into_queue takes effect both for new item from the explorer
33-
# and for item sampled from the buffer.
3429
PRIORITY_FUNC = Registry("priority_fn")
30+
"""
31+
Each priority_fn,
32+
Args:
33+
item: List[Experience], assume that all experiences in it have the same model_version and use_count
34+
kwargs: storage_config.replay_buffer_kwargs (except priority_fn)
35+
Returns:
36+
priority: float
37+
put_into_queue: bool, decide whether to put item into queue
38+
Note that put_into_queue takes effect both for new item from the explorer and for item sampled from the buffer.
39+
"""
3540

3641

3742
@PRIORITY_FUNC.register_module("linear_decay")
@@ -62,11 +67,10 @@ def linear_decay_use_count_control_priority(
6267
priority = float(item[0].info["model_version"] - decay * item[0].info["use_count"])
6368
if sigma > 0.0:
6469
priority += float(np.random.randn() * sigma)
65-
put_into_queue = item[0].info["use_count"] < use_count_limit
70+
put_into_queue = item[0].info["use_count"] < use_count_limit if use_count_limit > 0 else True
6671
return priority, put_into_queue
6772

6873

69-
7074
class QueueBuffer(ABC):
7175
@abstractmethod
7276
async def put(self, exps: List[Experience]) -> None:

0 commit comments

Comments
 (0)