Skip to content

Commit cd82bfc

Browse files
authored
Refactor Priority Function (agentscope-ai#344)
1 parent 4576dab commit cd82bfc

File tree

4 files changed

+104
-45
lines changed

4 files changed

+104
-45
lines changed

tests/buffer/queue_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -326,7 +326,7 @@ async def test_priority_queue_reuse_count_control(self):
326326
path=BUFFER_FILE_PATH,
327327
replay_buffer=ReplayBufferConfig(
328328
enable=True,
329-
priority_fn="linear_decay_use_count_control_randomization",
329+
priority_fn="decay_limit_randomization",
330330
reuse_cooldown_time=0.5,
331331
priority_fn_args={"decay": 1.2, "use_count_limit": 2, "sigma": 0.0},
332332
),

trinity/buffer/storage/queue.py

Lines changed: 70 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,7 @@
55
from abc import ABC, abstractmethod
66
from collections import deque
77
from copy import deepcopy
8-
from functools import partial
9-
from typing import List, Optional, Tuple
8+
from typing import Dict, List, Optional, Tuple
109

1110
import numpy as np
1211
import ray
@@ -28,48 +27,82 @@ def is_json_file(path: str) -> bool:
2827

2928

3029
PRIORITY_FUNC = Registry("priority_fn")
31-
"""
32-
Each priority_fn,
33-
Args:
34-
item: List[Experience], assume that all experiences in it have the same model_version and use_count
35-
kwargs: storage_config.replay_buffer_kwargs (except priority_fn)
36-
Returns:
37-
priority: float
38-
put_into_queue: bool, decide whether to put item into queue
39-
Note that put_into_queue takes effect both for new item from the explorer and for item sampled from the buffer.
40-
"""
30+
31+
32+
class PriorityFunction(ABC):
33+
"""
34+
Each priority_fn,
35+
Args:
36+
item: List[Experience], assume that all experiences in it have the same model_version and use_count
37+
priority_fn_args: Dict, the arguments for priority_fn
38+
39+
Returns:
40+
priority: float
41+
put_into_queue: bool, decide whether to put item into queue
42+
43+
Note that put_into_queue takes effect both for new item from the explorer and for item sampled from the buffer.
44+
"""
45+
46+
@abstractmethod
47+
def __call__(self, item: List[Experience]) -> Tuple[float, bool]:
48+
"""Calculate the priority of item."""
49+
50+
@classmethod
51+
@abstractmethod
52+
def default_config(cls) -> Dict:
53+
"""Return the default config."""
4154

4255

4356
@PRIORITY_FUNC.register_module("linear_decay")
44-
def linear_decay_priority(
45-
item: List[Experience],
46-
decay: float = 2.0,
47-
) -> Tuple[float, bool]:
57+
class LinearDecayPriority(PriorityFunction):
4858
"""Calculate priority by linear decay.
4959
5060
Priority is calculated as `model_version - decay * use_count. The item is always put back into the queue for reuse (as long as `reuse_cooldown_time` is not None).
5161
"""
52-
priority = float(item[0].info["model_version"] - decay * item[0].info["use_count"])
53-
put_into_queue = True
54-
return priority, put_into_queue
55-
56-
57-
@PRIORITY_FUNC.register_module("linear_decay_use_count_control_randomization")
58-
def linear_decay_use_count_control_priority(
59-
item: List[Experience],
60-
decay: float = 2.0,
61-
use_count_limit: int = 3,
62-
sigma: float = 0.0,
63-
) -> Tuple[float, bool]:
62+
63+
def __init__(self, decay: float = 2.0):
64+
self.decay = decay
65+
66+
def __call__(self, item: List[Experience]) -> Tuple[float, bool]:
67+
priority = float(item[0].info["model_version"] - self.decay * item[0].info["use_count"])
68+
put_into_queue = True
69+
return priority, put_into_queue
70+
71+
@classmethod
72+
def default_config(cls) -> Dict:
73+
return {
74+
"decay": 2.0,
75+
}
76+
77+
78+
@PRIORITY_FUNC.register_module("decay_limit_randomization")
79+
class LinearDecayUseCountControlPriority(PriorityFunction):
6480
"""Calculate priority by linear decay, use count control, and randomization.
6581
6682
Priority is calculated as `model_version - decay * use_count`; if `sigma` is non-zero, priority is further perturbed by random Gaussian noise with standard deviation `sigma`. The item will be put back into the queue only if use count does not exceed `use_count_limit`.
6783
"""
68-
priority = float(item[0].info["model_version"] - decay * item[0].info["use_count"])
69-
if sigma > 0.0:
70-
priority += float(np.random.randn() * sigma)
71-
put_into_queue = item[0].info["use_count"] < use_count_limit if use_count_limit > 0 else True
72-
return priority, put_into_queue
84+
85+
def __init__(self, decay: float = 2.0, use_count_limit: int = 3, sigma: float = 0.0):
86+
self.decay = decay
87+
self.use_count_limit = use_count_limit
88+
self.sigma = sigma
89+
90+
def __call__(self, item: List[Experience]) -> Tuple[float, bool]:
91+
priority = float(item[0].info["model_version"] - self.decay * item[0].info["use_count"])
92+
if self.sigma > 0.0:
93+
priority += float(np.random.randn() * self.sigma)
94+
put_into_queue = (
95+
item[0].info["use_count"] < self.use_count_limit if self.use_count_limit > 0 else True
96+
)
97+
return priority, put_into_queue
98+
99+
@classmethod
100+
def default_config(cls) -> Dict:
101+
return {
102+
"decay": 2.0,
103+
"use_count_limit": 3,
104+
"sigma": 0.0,
105+
}
73106

74107

75108
class QueueBuffer(ABC):
@@ -168,7 +201,10 @@ def __init__(
168201
self.capacity = capacity
169202
self.item_count = 0
170203
self.priority_groups = SortedDict() # Maps priority -> deque of items
171-
self.priority_fn = partial(PRIORITY_FUNC.get(priority_fn), **(priority_fn_args or {}))
204+
priority_fn_cls = PRIORITY_FUNC.get(priority_fn)
205+
kwargs = priority_fn_cls.default_config()
206+
kwargs.update(priority_fn_args or {})
207+
self.priority_fn = priority_fn_cls(**kwargs)
172208
self.reuse_cooldown_time = reuse_cooldown_time
173209
self._condition = asyncio.Condition() # For thread-safe operations
174210
self._closed = False

trinity/manager/config_manager.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,10 @@
1414
from trinity.algorithm.policy_loss_fn.policy_loss_fn import POLICY_LOSS_FN
1515
from trinity.algorithm.sample_strategy.sample_strategy import SAMPLE_STRATEGY
1616
from trinity.common.constants import StorageType
17-
from trinity.manager.config_registry.buffer_config_manager import get_train_batch_size
17+
from trinity.manager.config_registry.buffer_config_manager import (
18+
get_train_batch_size,
19+
parse_priority_fn_args,
20+
)
1821
from trinity.manager.config_registry.config_registry import CONFIG_GENERATORS
1922
from trinity.manager.config_registry.trainer_config_manager import use_critic
2023
from trinity.utils.plugin_loader import load_plugins
@@ -190,7 +193,8 @@ def _expert_buffer_part(self):
190193
self.get_configs("storage_type")
191194
self.get_configs("experience_buffer_path")
192195
self.get_configs("enable_replay_buffer")
193-
self.get_configs("reuse_cooldown_time", "priority_fn", "priority_decay")
196+
self.get_configs("reuse_cooldown_time", "priority_fn")
197+
self.get_configs("priority_fn_args")
194198

195199
# TODO: used for SQL storage
196200
# self.buffer_advanced_tab = st.expander("Advanced Config")
@@ -592,9 +596,7 @@ def _gen_buffer_config(self):
592596
"enable": st.session_state["enable_replay_buffer"],
593597
"priority_fn": st.session_state["priority_fn"],
594598
"reuse_cooldown_time": st.session_state["reuse_cooldown_time"],
595-
"priority_fn_args": {
596-
"decay": st.session_state["priority_decay"],
597-
},
599+
"priority_fn_args": parse_priority_fn_args(st.session_state["priority_fn_args"]),
598600
}
599601

600602
if st.session_state["mode"] != "train":

trinity/manager/config_registry/buffer_config_manager.py

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
import json
2+
3+
import pandas as pd
14
import streamlit as st
25

36
from trinity.buffer.storage.queue import PRIORITY_FUNC
@@ -328,13 +331,31 @@ def set_priority_fn(**kwargs):
328331
)
329332

330333

334+
def parse_priority_fn_args(raw_data: str):
335+
try:
336+
data = json.loads(raw_data)
337+
if data["priority_fn"] != st.session_state["priority_fn"]:
338+
raise ValueError
339+
return data["fn_args"]
340+
except (json.JSONDecodeError, KeyError, ValueError):
341+
print(f"Use `default_config` for {st.session_state['priority_fn']}")
342+
return PRIORITY_FUNC.get(st.session_state["priority_fn"]).default_config()
343+
344+
331345
@CONFIG_GENERATORS.register_config(
332-
default_value=0.1, visible=lambda: st.session_state["enable_replay_buffer"]
346+
default_value="", visible=lambda: st.session_state["enable_replay_buffer"]
333347
)
334-
def set_priority_decay(**kwargs):
335-
st.number_input(
336-
"Priority Decay",
337-
**kwargs,
348+
def set_priority_fn_args(**kwargs):
349+
key = kwargs.get("key")
350+
df = pd.DataFrame([parse_priority_fn_args(st.session_state[key])])
351+
df.index = [st.session_state["priority_fn"]]
352+
st.caption("Priority Function Args")
353+
df = st.data_editor(df)
354+
st.session_state[key] = json.dumps(
355+
{
356+
"fn_args": df.to_dict(orient="records")[0],
357+
"priority_fn": st.session_state["priority_fn"],
358+
}
338359
)
339360

340361

0 commit comments

Comments
 (0)