55from abc import ABC , abstractmethod
66from collections import deque
77from copy import deepcopy
8- from functools import partial
9- from typing import List , Optional , Tuple
8+ from typing import Dict , List , Optional , Tuple
109
1110import numpy as np
1211import ray
@@ -28,48 +27,82 @@ def is_json_file(path: str) -> bool:
2827
2928
3029PRIORITY_FUNC = Registry ("priority_fn" )
31- """
32- Each priority_fn,
33- Args:
34- item: List[Experience], assume that all experiences in it have the same model_version and use_count
35- kwargs: storage_config.replay_buffer_kwargs (except priority_fn)
36- Returns:
37- priority: float
38- put_into_queue: bool, decide whether to put item into queue
39- Note that put_into_queue takes effect both for new item from the explorer and for item sampled from the buffer.
40- """
30+
31+
32+ class PriorityFunction (ABC ):
33+ """
34+ Each priority_fn,
35+ Args:
36+ item: List[Experience], assume that all experiences in it have the same model_version and use_count
37+ priority_fn_args: Dict, the arguments for priority_fn
38+
39+ Returns:
40+ priority: float
41+ put_into_queue: bool, decide whether to put item into queue
42+
43+ Note that put_into_queue takes effect both for new item from the explorer and for item sampled from the buffer.
44+ """
45+
46+ @abstractmethod
47+ def __call__ (self , item : List [Experience ]) -> Tuple [float , bool ]:
48+ """Calculate the priority of item."""
49+
50+ @classmethod
51+ @abstractmethod
52+ def default_config (cls ) -> Dict :
53+ """Return the default config."""
4154
4255
4356@PRIORITY_FUNC .register_module ("linear_decay" )
44- def linear_decay_priority (
45- item : List [Experience ],
46- decay : float = 2.0 ,
47- ) -> Tuple [float , bool ]:
57+ class LinearDecayPriority (PriorityFunction ):
4858 """Calculate priority by linear decay.
4959
5060 Priority is calculated as `model_version - decay * use_count. The item is always put back into the queue for reuse (as long as `reuse_cooldown_time` is not None).
5161 """
52- priority = float (item [0 ].info ["model_version" ] - decay * item [0 ].info ["use_count" ])
53- put_into_queue = True
54- return priority , put_into_queue
55-
56-
57- @PRIORITY_FUNC .register_module ("linear_decay_use_count_control_randomization" )
58- def linear_decay_use_count_control_priority (
59- item : List [Experience ],
60- decay : float = 2.0 ,
61- use_count_limit : int = 3 ,
62- sigma : float = 0.0 ,
63- ) -> Tuple [float , bool ]:
62+
63+ def __init__ (self , decay : float = 2.0 ):
64+ self .decay = decay
65+
66+ def __call__ (self , item : List [Experience ]) -> Tuple [float , bool ]:
67+ priority = float (item [0 ].info ["model_version" ] - self .decay * item [0 ].info ["use_count" ])
68+ put_into_queue = True
69+ return priority , put_into_queue
70+
71+ @classmethod
72+ def default_config (cls ) -> Dict :
73+ return {
74+ "decay" : 2.0 ,
75+ }
76+
77+
78+ @PRIORITY_FUNC .register_module ("decay_limit_randomization" )
79+ class LinearDecayUseCountControlPriority (PriorityFunction ):
6480 """Calculate priority by linear decay, use count control, and randomization.
6581
6682 Priority is calculated as `model_version - decay * use_count`; if `sigma` is non-zero, priority is further perturbed by random Gaussian noise with standard deviation `sigma`. The item will be put back into the queue only if use count does not exceed `use_count_limit`.
6783 """
68- priority = float (item [0 ].info ["model_version" ] - decay * item [0 ].info ["use_count" ])
69- if sigma > 0.0 :
70- priority += float (np .random .randn () * sigma )
71- put_into_queue = item [0 ].info ["use_count" ] < use_count_limit if use_count_limit > 0 else True
72- return priority , put_into_queue
84+
85+ def __init__ (self , decay : float = 2.0 , use_count_limit : int = 3 , sigma : float = 0.0 ):
86+ self .decay = decay
87+ self .use_count_limit = use_count_limit
88+ self .sigma = sigma
89+
90+ def __call__ (self , item : List [Experience ]) -> Tuple [float , bool ]:
91+ priority = float (item [0 ].info ["model_version" ] - self .decay * item [0 ].info ["use_count" ])
92+ if self .sigma > 0.0 :
93+ priority += float (np .random .randn () * self .sigma )
94+ put_into_queue = (
95+ item [0 ].info ["use_count" ] < self .use_count_limit if self .use_count_limit > 0 else True
96+ )
97+ return priority , put_into_queue
98+
99+ @classmethod
100+ def default_config (cls ) -> Dict :
101+ return {
102+ "decay" : 2.0 ,
103+ "use_count_limit" : 3 ,
104+ "sigma" : 0.0 ,
105+ }
73106
74107
75108class QueueBuffer (ABC ):
@@ -168,7 +201,10 @@ def __init__(
168201 self .capacity = capacity
169202 self .item_count = 0
170203 self .priority_groups = SortedDict () # Maps priority -> deque of items
171- self .priority_fn = partial (PRIORITY_FUNC .get (priority_fn ), ** (priority_fn_args or {}))
204+ priority_fn_cls = PRIORITY_FUNC .get (priority_fn )
205+ kwargs = priority_fn_cls .default_config ()
206+ kwargs .update (priority_fn_args or {})
207+ self .priority_fn = priority_fn_cls (** kwargs )
172208 self .reuse_cooldown_time = reuse_cooldown_time
173209 self ._condition = asyncio .Condition () # For thread-safe operations
174210 self ._closed = False
0 commit comments