forked from Harold-lkk/d2lzh_pytorch_lkk
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy path3.6_softmax-regression-scratch.py
More file actions
84 lines (66 loc) · 2.47 KB
/
3.6_softmax-regression-scratch.py
File metadata and controls
84 lines (66 loc) · 2.47 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
# -*- coding:utf-8 –*-
'''
@Author: lkk
@Date: 2019-12-16 19:33:08
@LastEditTime: 2019-12-17 14:56:08
@LastEditors: lkk
@Description:
'''
import torch
import torchvision
import numpy as np
import d2lzh_pytorch as d2l
def softmax(x):
x_exp = x.exp()
partition = x_exp.sum(dim=1, keepdim=True)
return x_exp / partition
def net(x):
return softmax(torch.mm(x.view((-1, num_inputs)), w) + b)
def cross_entropy(y_hat, y):
return -torch.log(y_hat.gather(1, y.view(-1, 1)))
def accuracy(y_hat, y):
return (y_hat.argmax(dim=1) == y).float().mean().item()
def evaluate_accuracy(data_iter, net):
acc_sum, n = 0.0, 0
for x, y in data_iter:
acc_sum += (net(x).argmax(dim=1)==y).float().sum().item()
n += y.shape[0]
return acc_sum / n
def train_ch3(net, train_iter, test_iter, loss, num_epochs, batch_size, params=None, lr=None, optimizer=None):
for epoch in range(num_epochs):
train_l_sum, train_acc_sum, n = 0.0, 0.0, 0
for x, y in train_iter:
y_hat = net(x)
l = loss(y_hat, y).sum()
if optimizer is not None:
optimizer.zero_grad()
elif params is not None and params[0].grad is not None:
for param in params:
param.grad.data.zero_()
l.backward()
if optimizer is None:
d2l.sgd(params, lr, batch_size)
else:
optimizer.step()
train_l_sum += l.item()
train_acc_sum += (y_hat.argmax(dim=1) == y).sum().item()
n += y.shape[0]
test_acc = evaluate_accuracy(test_iter, net)
print('epoch %d, loss %.4f, train acc %.3f, test acc %.3f'
% (epoch + 1, train_l_sum / n, train_acc_sum / n, test_acc))
batch_size = 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)
num_inputs = 784
num_outputs = 10
w = torch.normal(0,
0.01, (num_inputs, num_outputs),
dtype=torch.float,
requires_grad=True)
b = torch.zeros(num_outputs, dtype=torch.float, requires_grad=True)
num_epochs, lr = 5, 0.1
train_ch3(net, train_iter, test_iter, cross_entropy, num_epochs, batch_size, [w, b], lr)
X, y = iter(test_iter).next()
true_labels = d2l.get_fashion_mnist_labels(y.numpy())
pred_labels = d2l.get_fashion_mnist_labels(net(X).argmax(dim=1).numpy())
titles = [true + '\n' + pred for true, pred in zip(true_labels, pred_labels)]
d2l.show_fashion_mnist(X[0:9], titles[0:9])