-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathreinforcement_q_learning.py
More file actions
49 lines (35 loc) · 1.1 KB
/
reinforcement_q_learning.py
File metadata and controls
49 lines (35 loc) · 1.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
# -*- coding: utf-8 -*-
"""
Reinforcement Learning (DQN) tutorial
=====================================
This tutorial shows how to use PyTorch to train a Deep Q Learning (DQN) agent
on the CartPole-v0 task from the `OpenAI Gym <https://gym.openai.com/>`__.
**Task**
The agent has to decide between two actions - moving the cart left or
right - so that the pole attached to it stays upright.
"""
import math
import random
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from itertools import count
from copy import deepcopy
from PIL import Image
import torch
import torchvision.transforms as T
import DQN
import ReplayMemory
env = gym.make('CartPole-v0').unwrapped
# set up matplotlib
is_ipython = 'inline' in matplotlib.get_backend()
if is_ipython:
from IPython import display
plt.ion()
# if cuda is to be used
use_cuda = torch.cuda.is_available()
FloatTensor = torch.cuda.FloatTensor if use_cuda else torch.FloatTensor
LongTensor = torch.cuda.LongTensor if use_cuda else torch.LongTensor
ByteTensor = torch.cuda.ByteTensor if use_cuda else torch.ByteTensor
Tensor = FloatTensor
#