您的位置:首页 > 理论基础 > 计算机网络

pytorch: 学习笔记9, pytorch实现全连接网络(多层感知机)

2020-08-05 21:28 260 查看
pytorch实现全连接网络(多层感知机)

网络模型为3层(含输入层):
输入(28×28)784个特征单元(神经元);
隐藏层:256个单元;
输出层:10 (比如sofamax的10分类)。

代码:
import torch
from torch import nn
from torch.nn import init
import torchvision
import torchvision.transforms as transforms
import sys
import time

class FlattenLayer(torch.nn.Module):
def __init__(self):
super(FlattenLayer, self).__init__()
def forward(self, x): # x shape: (batch, *, *, ...)
return x.view(x.shape[0], -1)

def load_data_fashion_mnist(batch_size, root='Datasets/FashionMNIST'):
mnist_train = torchvision.datasets.FashionMNIST(root=root, train=True, download=False, transform=transforms.ToTensor())
mnist_test = torchvision.datasets.FashionMNIST(root=root, train=False, download=False, transform=transforms.ToTensor())
if sys.platform.startswith('win'):
num_workers = 0  # 0表示不用额外的进程来加速读取数据
else:
num_workers = 4
train_iter = torch.utils.data.DataLoader(mnist_train, batch_size=batch_size, shuffle=True, num_workers=num_workers)
test_iter = torch.utils.data.DataLoader(mnist_test, batch_size=batch_size, shuffle=False, num_workers=num_workers)

return train_iter, test_iter

# 评估
def evaluate_accuracy(data_iter, net):
acc_sum, n = 0.0, 0
for X, y in data_iter:
acc_sum += (net(X).argmax(dim=1) == y).float().sum().item()
n += y.shape[0]
return acc_sum / n

# 训练,随机梯度下降
def sgd(params, lr, batch_size):
for param in params:
param.data -= lr * param.grad / batch_size # 注意这里更改param时用的param.data

# 训练
def train_ch3(net, train_iter, test_iter, loss, num_epochs, batch_size,
params=None, lr=None, optimizer=None):
for epoch in range(num_epochs):
train_l_sum, train_acc_sum, n = 0.0, 0.0, 0
for X, y in train_iter:
y_hat = net(X)
l = loss(y_hat, y).sum()

# 梯度清零
optimizer.zero_grad()

l.backward()
optimizer.step()

train_l_sum += l.item()
train_acc_sum += (y_hat.argmax(dim=1) == y).sum().item()
n += y.shape[0]
test_acc = evaluate_accuracy(test_iter, net)
print('epoch %d, loss %.4f, train acc %.3f, test acc %.3f'
% (epoch + 1, train_l_sum / n, train_acc_sum / n, test_acc))

if __name__ == '__main__':
num_inputs, num_outputs, num_hiddens = 784, 10, 256

# 定义网络: 输入层(784)  -->  隐藏层(256)  -->  输出层(10)
net = nn.Sequential(
FlattenLayer(),
nn.Linear(num_inputs, num_hiddens),
nn.ReLU(),
nn.Linear(num_hiddens, num_outputs),
)

# print('len: ', len(list(net.parameters())))
# print('param: ', net.parameters())
# print('param list: ', list(net.parameters()))

# 初始化网络参数
for params in net.parameters():  # net.parameters() 为各层的网络参数,可迭代
init.normal_(params, mean=0, std=0.01)

print('len: ', len(list(net.parameters())))
print('param init: ', net.parameters())  # <generator object Module.parameters at 0x000001A0B1356620>
print('param list init: ', list(net.parameters()))

# 加载数据
batch_size = 256
train_iter, test_iter = load_data_fashion_mnist(batch_size)

loss = torch.nn.CrossEntropyLoss()

optimizer = torch.optim.SGD(net.parameters(), lr=0.5)

# 训练
num_epochs = 5
train_ch3(net, train_iter, test_iter, loss, num_epochs, batch_size, None, None, optimizer)
代码中几点说明:

net.parameters()
net.parameters() 为:网络定义后,网络模型中的各层的参数 ( <generator object Module.parameters at 0x000001A0B1356620> )
可以转化为一个list查看其参数:list(net.parameters())
torch.nn.init.normal_(params, mean=0, std=0.01)
对参数params进行均值为0方差为1的初始化

结果:
len:  4
param init:  <generator object Module.parameters at 0x000001A0B1356620>
param list init:  [Parameter containing:
tensor([[ 0.0096,  0.0091,  0.0036,  ..., -0.0016, -0.0025,  0.0082],
[ 0.0014, -0.0060, -0.0170,  ..., -0.0087, -0.0124,  0.0029],
[-0.0002, -0.0053, -0.0205,  ...,  0.0140,  0.0015, -0.0047],
...,
[-0.0067, -0.0192, -0.0026,  ...,  0.0223,  0.0114,  0.0003],
[ 0.0033,  0.0095,  0.0097,  ...,  0.0015,  0.0030, -0.0138],
[ 0.0139, -0.0073,  0.0012,  ..., -0.0143,  0.0085,  0.0056]],
requires_grad=True), Parameter containing:
tensor([-3.3042e-03,  4.7892e-04, -5.6590e-03,  3.2377e-03,  3.5846e-03,
6.6989e-03,  2.3601e-03, -9.1927e-04,  2.4281e-02, -2.2155e-02,
5.3163e-03,  3.7543e-03,  4.3089e-03,  1.1811e-02, -4.1673e-03,
-1.9667e-02,  2.6118e-03, -3.2978e-03,  1.3942e-02, -1.6289e-02,
9.8179e-03, -2.2531e-02, -1.4156e-02, -1.4382e-03,  1.7384e-02,
-1.2549e-02,  9.4562e-03, -9.0459e-03,  8.4983e-03, -1.5124e-03,
-1.4963e-02, -7.0390e-03,  1.0951e-02, -1.6487e-02, -5.2332e-03,
-5.2680e-03, -1.7785e-03, -1.3423e-03,  1.9302e-03, -4.9111e-03,
1.7328e-03, -8.0625e-03, -3.9449e-03,  3.6381e-03,  1.1906e-02,
5.0710e-03,  5.1031e-03,  9.2445e-04,  2.6244e-02, -2.9451e-03,
9.6235e-03, -2.1532e-03, -1.3756e-02, -2.1489e-03, -1.3318e-02,
4.8365e-03, -1.0427e-02,  5.2636e-03,  8.1710e-03, -2.8734e-03,
-4.0999e-03, -3.3395e-03,  9.2141e-03,  1.8420e-02,  2.3903e-03,
6.3389e-03, -7.1875e-03,  9.3982e-03, -1.6983e-02, -1.9021e-03,
-6.3871e-03,  6.7952e-03, -1.2235e-02, -1.6785e-02, -6.6447e-03,
1.2196e-02,  7.3601e-03, -1.5027e-02, -2.6593e-03, -9.6182e-03,
-8.4485e-03,  2.2411e-02, -7.5373e-03,  3.6415e-02,  2.6785e-03,
1.9647e-02, -1.4472e-03, -2.1426e-03, -1.0003e-02, -6.0945e-03,
6.1464e-04,  6.1757e-03,  1.2456e-02,  1.0664e-02,  8.7811e-03,
-1.9107e-02, -8.5125e-03, -3.2865e-04,  1.0192e-02, -2.4412e-02,
-2.1226e-02,  1.0242e-02,  4.0445e-03, -3.3238e-03,  4.4551e-04,
1.7880e-02,  1.4732e-02,  7.4244e-04,  1.5565e-02,  6.3838e-03,
4.2519e-03,  3.7454e-04,  6.0372e-03,  1.0598e-02,  6.6352e-03,
9.3732e-03,  7.1993e-03, -8.0230e-03, -2.0376e-02,  1.7323e-03,
1.5667e-02, -1.0637e-02, -1.9101e-02, -8.6477e-03,  4.6590e-03,
-4.7290e-03,  1.2458e-02,  1.0215e-02,  1.4719e-02, -3.4490e-03,
-4.6496e-03,  6.5331e-03, -3.9560e-03, -1.1488e-02, -8.5887e-03,
1.5083e-02,  1.0957e-02,  1.9015e-02, -2.1299e-03, -8.0287e-03,
-1.4993e-02, -1.1674e-02,  7.0364e-03, -2.5001e-03, -1.0356e-03,
5.7498e-03,  5.7233e-04,  7.9161e-04, -6.0469e-03, -2.6913e-03,
6.7641e-03,  1.8129e-03,  1.5494e-03, -9.7351e-03,  6.8967e-05,
2.2971e-03, -9.1847e-03, -2.3717e-03, -6.4801e-03,  2.9549e-03,
-7.2387e-03, -1.6071e-02, -1.1841e-02, -4.3262e-03, -7.4287e-04,
-1.0381e-02, -1.9941e-02,  1.2515e-02,  1.1387e-02, -3.3133e-03,
1.3639e-02, -1.9078e-03, -1.5026e-02,  3.7264e-03,  1.2014e-02,
-8.0367e-03, -3.5969e-02,  6.3780e-03,  3.4895e-03,  1.5735e-02,
-5.6254e-04, -5.5807e-03,  5.4600e-04, -8.7495e-04,  7.8439e-03,
-1.2823e-02, -1.4356e-02,  7.8702e-03,  4.3848e-04,  5.3145e-03,
-6.1489e-03,  8.7027e-04, -1.0802e-03,  7.2241e-03,  5.0439e-03,
1.3031e-02,  7.4891e-03, -7.3666e-03, -6.0929e-03, -6.1948e-03,
8.1562e-03, -6.0273e-03, -1.0222e-02, -1.7376e-03, -1.2922e-02,
1.1247e-02, -1.0559e-02, -1.5887e-02,  1.0038e-02, -1.4515e-02,
-9.5886e-03,  1.2830e-02,  8.8126e-03, -9.1111e-03,  6.2043e-03,
1.9829e-02,  1.5241e-02,  2.2486e-03,  9.0140e-03,  1.7259e-02,
-5.6758e-03,  4.1752e-03,  4.8623e-04,  1.9457e-02,  8.3239e-03,
-1.1590e-02, -5.5052e-03, -2.0561e-02,  2.8499e-03,  1.1046e-02,
-7.4051e-03,  1.1231e-02,  1.4840e-02,  4.9973e-03,  1.3801e-02,
-1.4826e-02, -7.4246e-03, -1.5146e-02,  1.2617e-02,  7.5188e-03,
1.9418e-02, -1.0118e-03, -8.8281e-03, -5.6416e-03,  1.8890e-04,
-3.9850e-03, -4.7776e-03,  9.0903e-03, -3.2510e-02,  3.5589e-03,
3.9693e-03,  1.9995e-02,  2.7695e-03,  9.5730e-03, -9.2412e-03,
1.0012e-02], requires_grad=True), Parameter containing:
tensor([[-0.0074, -0.0100,  0.0084,  ...,  0.0004, -0.0123, -0.0015],
[-0.0017,  0.0017,  0.0104,  ..., -0.0067, -0.0016, -0.0096],
[-0.0132, -0.0034,  0.0193,  ...,  0.0191,  0.0004,  0.0105],
...,
[ 0.0167, -0.0144,  0.0048,  ...,  0.0061, -0.0083,  0.0072],
[ 0.0018,  0.0048,  0.0050,  ...,  0.0015, -0.0165,  0.0046],
[-0.0039,  0.0027, -0.0014,  ...,  0.0078, -0.0054, -0.0089]],
requires_grad=True), Parameter containing:
tensor([-0.0058, -0.0037,  0.0099, -0.0099,  0.0018, -0.0193, -0.0041,  0.0043,
-0.0114, -0.0049], requires_grad=True)]
epoch 1, loss 0.0031, train acc 0.699, test acc 0.763
epoch 2, loss 0.0019, train acc 0.817, test acc 0.785
epoch 3, loss 0.0017, train acc 0.844, test acc 0.844
epoch 4, loss 0.0015, train acc 0.856, test acc 0.798
epoch 5, loss 0.0014, train acc 0.865, test acc 0.845

Process finished with exit code 0

参考:
https://pytorch.org/docs/stable/nn.html#
https://pytorch.org/docs/stable/nn.init.html (torch.nn.init)
https://tangshusen.me/Dive-into-DL-PyTorch/#/chapter03_DL-basics/3.10_mlp-pytorch

内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: