-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathbuffer.py
More file actions
57 lines (46 loc) · 2.1 KB
/
buffer.py
File metadata and controls
57 lines (46 loc) · 2.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import os.path
import numpy as np
class ReplayBuffer:
def __init__(self, max_size, input_shape, n_actions):
self.mem_size = max_size
self.mem_cntr = 0
self.state_memory = np.zeros((self.mem_size, *input_shape))
self.new_state_memory = np.zeros((self.mem_size, *input_shape))
self.action_memory = np.zeros((self.mem_size, n_actions))
self.reward_memory = np.zeros(self.mem_size)
self.terminal_memory = np.zeros(self.mem_size, dtype=np.bool_)
def store_transition(self, state, action, reward, state_, done):
index = self.mem_cntr % self.mem_size
self.state_memory[index] = state
self.new_state_memory[index] = state_
self.action_memory[index] = action
self.reward_memory[index] = reward
self.terminal_memory[index] = done
self.mem_cntr += 1
def sample_buffer(self, batch_size):
max_mem = min(self.mem_cntr, self.mem_size)
batch = np.random.choice(max_mem, batch_size)
states = self.state_memory[batch]
states_ = self.new_state_memory[batch]
actions = self.action_memory[batch]
rewards = self.reward_memory[batch]
dones = self.terminal_memory[batch]
return states, actions, rewards, states_, dones
def save_to_file(self, filename):
np.savez(filename,
mem_size=self.mem_size,
mem_cntr=self.mem_cntr,
state_memory=self.state_memory,
new_state_memory=self.new_state_memory,
action_memory=self.action_memory,
reward_memory=self.reward_memory,
terminal_memory=self.terminal_memory)
def load_from_file(self, filename):
data = np.load(filename)
self.mem_size = data['mem_size']
self.mem_cntr = data['mem_cntr']
self.state_memory = data['state_memory']
self.new_state_memory = data['new_state_memory']
self.action_memory = data['action_memory']
self.reward_memory = data['reward_memory']
self.terminal_memory = data['terminal_memory']