Skip to content

Commit 095a331

Browse files
committed
2 parents e55a95f + 94e39ab commit 095a331

File tree

9 files changed

+957
-11
lines changed

9 files changed

+957
-11
lines changed

README.md

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,23 @@
1-
# Popular Model-free Reinforcement Learning Algorithms [![Tweet](https://img.shields.io/twitter/url/http/shields.io.svg?style=social)](https://twitter.com/intent/tweet?text=State-of-the-art-Model-free-Reinforcement-Learning-Algorithms%20&url=hhttps://github.com/quantumiracle/STOA-RL-Algorithms&hashtags=RL)
1+
# Popular Model-free Reinforcement Learning Algorithms
2+
<!-- [![Tweet](https://img.shields.io/twitter/url/http/shields.io.svg?style=social)](https://twitter.com/intent/tweet?text=State-of-the-art-Model-free-Reinforcement-Learning-Algorithms%20&url=hhttps://github.com/quantumiracle/STOA-RL-Algorithms&hashtags=RL) -->
23

34

45
**PyTorch** and **Tensorflow 2.0** implementation of state-of-the-art model-free reinforcement learning algorithms on both Openai gym environments and a self-implemented Reacher environment.
56

6-
Algorithms include **Soft Actor-Critic (SAC), Deep Deterministic Policy Gradient (DDPG), Twin Delayed DDPG (TD3), Actor-Critic (AC/A2C), Proximal Policy Optimization (PPO), QT-Opt (including Cross-entropy (CE) Method)**, **PointNet**, **Transporter**, **Recurrent Policy Gradient**, **Soft Decision Tree**, **Probabilistic Mixture-of-Experts**, etc.
7+
Algorithms include:
8+
* **Actor-Critic (AC/A2C)**;
9+
* **Soft Actor-Critic (SAC)**;
10+
* **Deep Deterministic Policy Gradient (DDPG)**;
11+
* **Twin Delayed DDPG (TD3)**;
12+
* **Proximal Policy Optimization (PPO)**;
13+
* **QT-Opt (including Cross-entropy (CE) Method)**;
14+
* **PointNet**;
15+
* **Transporter**;
16+
* **Recurrent Policy Gradient**;
17+
* **Soft Decision Tree**;
18+
* **Probabilistic Mixture-of-Experts**;
19+
* **QMIX**
20+
* etc.
721

822
Please note that this repo is more of a personal collection of algorithms I implemented and tested during my research and study period, rather than an official open-source library/package for usage. However, I think it could be helpful to share it with others and I'm expecting useful discussions on my implementations. But I didn't spend much time on cleaning or structuring the code. As you may notice that there may be several versions of implementation for each algorithm, I intentionally show all of them here for you to refer and compare. Also, this repo contains only **PyTorch** Implementation.
923

@@ -36,6 +50,10 @@ Since Tensorflow 2.0 has already incorporated the dynamic graph construction ins
3650
`sac_discrete.py`: for discrete action space.
3751

3852
paper (the author is actually one of my classmates at IC): https://arxiv.org/abs/1910.07207
53+
54+
**SAC Discrete PER**
55+
56+
`sac_discrete_per.py`: for discrete action space, and with prioritized experience replay (PER).
3957

4058
* **Deep Deterministic Policy Gradient (DDPG)**:
4159

@@ -86,6 +104,8 @@ Since Tensorflow 2.0 has already incorporated the dynamic graph construction ins
86104
`td3_lstm.py`: TD3 with LSTM policy.
87105

88106
`sac_v2_lstm.py`: SAC with LSTM policy.
107+
108+
`sac_v2_gru.py`: SAC with GRU policy.
89109

90110
References:
91111

@@ -107,6 +127,18 @@ Since Tensorflow 2.0 has already incorporated the dynamic graph construction ins
107127
paper: [Probabilistic Mixture-of-Experts for Efficient Deep Reinforcement Learning
108128
](https://arxiv.org/pdf/2104.09122)
109129

130+
* **QMIX**:
131+
132+
`qmix.py`: a fully cooperative multi-agent RL algorithm, demo environment using [pettingzoo](https://www.pettingzoo.ml/atari/entombed_cooperative).
133+
134+
paper: http://proceedings.mlr.press/v80/rashid18a.html
135+
136+
* **Phasic Policy Gradient (PPG)**:
137+
138+
todo
139+
140+
paper: [Phasic Policy Gradient](http://proceedings.mlr.press/v139/cobbe21a.html)
141+
110142

111143
* **Maximum a Posteriori Policy Optimisation (MPO)**:
112144

1.77 KB
Binary file not shown.

common/buffers.py

Lines changed: 50 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,56 @@ def __len__(
3636
def get_length(self):
3737
return len(self.buffer)
3838

39+
class ReplayBufferPER:
40+
"""
41+
Replay buffer with Prioritized Experience Replay (PER),
42+
TD error as sampling weights. This is a simple version without sumtree.
43+
44+
Reference:
45+
https://github.com/Felhof/DiscreteSAC/blob/main/utilities/ReplayBuffer.py
46+
"""
47+
def __init__(self, capacity):
48+
self.capacity = capacity
49+
self.buffer = []
50+
self.position = 0
51+
self.weights = np.zeros(int(capacity))
52+
self.max_weight = 10**-2
53+
self.delta = 10**-4
54+
55+
def push(self, state, action, reward, next_state, done):
56+
if len(self.buffer) < self.capacity:
57+
self.buffer.append(None)
58+
self.buffer[self.position] = (state, action, reward, next_state, done)
59+
self.weights[self.position] = self.max_weight # new sample has max weights
60+
61+
self.position = int((self.position + 1) % self.capacity) # as a ring buffer
62+
63+
def sample(self, batch_size):
64+
set_weights = self.weights[:self.position] + self.delta
65+
probabilities = set_weights / sum(set_weights)
66+
self.indices = np.random.choice(range(self.position), batch_size, p=probabilities, replace=False)
67+
batch = np.array(self.buffer)[list(self.indices)]
68+
state, action, reward, next_state, done = map(np.stack,
69+
zip(*batch)) # stack for each element
70+
'''
71+
the * serves as unpack: sum(a,b) <=> batch=(a,b), sum(*batch) ;
72+
zip: a=[1,2], b=[2,3], zip(a,b) => [(1, 2), (2, 3)] ;
73+
the map serves as mapping the function on each list element: map(square, [2,3]) => [4,9] ;
74+
np.stack((1,2)) => array([1, 2])
75+
'''
76+
return state, action, reward, next_state, done
77+
78+
def update_weights(self, prediction_errors):
79+
max_error = max(prediction_errors)
80+
self.max_weight = max(self.max_weight, max_error)
81+
self.weights[self.indices] = prediction_errors
82+
83+
def __len__(
84+
self): # this is a stupid func! cannot work in multiprocessing case, len(replay_buffer) is not available in proxy of manager!
85+
return len(self.buffer)
86+
87+
def get_length(self):
88+
return len(self.buffer)
3989

4090
class ReplayBufferLSTM:
4191
"""
@@ -73,9 +123,6 @@ def __len__(
73123
def get_length(self):
74124
return len(self.buffer)
75125

76-
77-
78-
79126
class ReplayBufferLSTM2:
80127
"""
81128
Replay buffer for agent with LSTM network additionally storing previous action,
@@ -128,8 +175,6 @@ def get_length(self):
128175
return len(self.buffer)
129176

130177

131-
132-
133178
class ReplayBufferGRU:
134179
"""
135180
Replay buffer for agent with GRU network additionally storing previous action,

common/wrappers.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
2+
import numpy as np
3+
import gym
4+
5+
6+
class Dict2TupleWrapper():
7+
""" Wrap the PettingZoo envs to have a similar style as LaserFrame in NFSP """
8+
def __init__(self, env, keep_info=False):
9+
super(Dict2TupleWrapper, self).__init__()
10+
self.env = env
11+
self.num_agents = env.num_agents
12+
self.keep_info = keep_info # if True keep info as dict
13+
if len(env.observation_space.shape) > 1: # image
14+
old_shape = env.observation_space.shape
15+
self.observation_space = gym.spaces.Box(low=0.0, high=1.0, shape=(old_shape[-1], old_shape[0], old_shape[1]), dtype=np.uint8)
16+
self.obs_type = 'rgb_image'
17+
else:
18+
self.observation_space = env.observation_space
19+
self.obs_type = 'ram'
20+
self.action_space = env.action_space
21+
self.observation_spaces = env.observation_spaces
22+
self.action_spaces = env.action_spaces
23+
try: # both pettingzoo and slimevolley can work with this
24+
self.agents = env.agents
25+
except:
26+
self.agents = env.unwrapped.agents
27+
28+
@property
29+
def unwrapped(self,):
30+
return self.env
31+
32+
@property
33+
def spec(self):
34+
return self.env.spec
35+
36+
def observation_swapaxis(self, observation):
37+
return (np.swapaxes(observation[0], 2, 0), np.swapaxes(observation[1], 2, 0))
38+
39+
def reset(self):
40+
obs_dict = self.env.reset()
41+
if self.obs_type == 'ram':
42+
return tuple(obs_dict.values())
43+
else:
44+
return self.observation_swapaxis(tuple(obs_dict.values()))
45+
46+
def step(self, actions):
47+
actions = {agent_name: action for agent_name, action in zip(self.agents, actions)}
48+
obs, rewards, dones, infos = self.env.step(actions)
49+
if self.obs_type == 'ram':
50+
o = tuple(obs.values())
51+
else:
52+
o = self.observation_swapaxis(tuple(obs.values()))
53+
r = list(rewards.values())
54+
d = list(dones.values())
55+
if self.keep_info: # a special case for VectorEnv
56+
info = infos
57+
else:
58+
info = list(infos.values())
59+
del obs,rewards, dones, infos
60+
# r = self._zerosum_filter(r)
61+
62+
return o, r, d, info
63+
64+
def _zerosum_filter(self, r):
65+
## zero-sum filter:
66+
# added for making non-zero sum game to be zero-sum, e.g. tennis_v2
67+
if np.sum(r) != 0:
68+
nonzero_idx = np.nonzero(r)[0][0]
69+
r[1-nonzero_idx] = -r[nonzero_idx]
70+
return r
71+
72+
def seed(self, seed):
73+
self.env.seed(seed)
74+
np.random.seed(seed)
75+
76+
def render(self,):
77+
self.env.render()
78+
79+
def close(self):
80+
self.env.close()

0 commit comments

Comments
 (0)