diff --git a/deep_q_network.py b/deep_q_network.py index 1294f96..f6a9ee8 100755 --- a/deep_q_network.py +++ b/deep_q_network.py @@ -123,7 +123,7 @@ def trainNetwork(s, readout, h_fc1, sess): if random.random() <= epsilon: print("----------Random Action----------") action_index = random.randrange(ACTIONS) - a_t[random.randrange(ACTIONS)] = 1 + a_t[action_index] = 1 else: action_index = np.argmax(readout_t) a_t[action_index] = 1