diff --git a/deep_q_network.py b/deep_q_network.py index 1294f96..50ebd11 100755 --- a/deep_q_network.py +++ b/deep_q_network.py @@ -9,7 +9,8 @@ import random import numpy as np from collections import deque - +import tensorflow.compat.v1 as tf +tf.disable_v2_behavior() GAME = 'bird' # the name of the game being played for log files ACTIONS = 2 # number of valid actions GAMMA = 0.99 # decay rate of past observations @@ -21,8 +22,10 @@ BATCH = 32 # size of minibatch FRAME_PER_ACTION = 1 + def weight_variable(shape): - initial = tf.truncated_normal(shape, stddev = 0.01) + initial = tf.random.truncated_normal(shape, stddev = 0.01) + return tf.Variable(initial) def bias_variable(shape): @@ -204,7 +207,9 @@ def trainNetwork(s, readout, h_fc1, sess): ''' def playGame(): - sess = tf.InteractiveSession() + #sess = tf.InteractiveSession() + sess=tf.compat.v1.InteractiveSession() + s, readout, h_fc1 = createNetwork() trainNetwork(s, readout, h_fc1, sess)