pytorch: 学习笔记9, pytorch实现全连接网络(多层感知机)
2020-08-05 21:28
260 查看
pytorch实现全连接网络(多层感知机)
网络模型为3层(含输入层):
输入(28×28)784个特征单元(神经元);
隐藏层:256个单元;
输出层:10 (比如sofamax的10分类)。
import torch from torch import nn from torch.nn import init import torchvision import torchvision.transforms as transforms import sys import time class FlattenLayer(torch.nn.Module): def __init__(self): super(FlattenLayer, self).__init__() def forward(self, x): # x shape: (batch, *, *, ...) return x.view(x.shape[0], -1) def load_data_fashion_mnist(batch_size, root='Datasets/FashionMNIST'): mnist_train = torchvision.datasets.FashionMNIST(root=root, train=True, download=False, transform=transforms.ToTensor()) mnist_test = torchvision.datasets.FashionMNIST(root=root, train=False, download=False, transform=transforms.ToTensor()) if sys.platform.startswith('win'): num_workers = 0 # 0表示不用额外的进程来加速读取数据 else: num_workers = 4 train_iter = torch.utils.data.DataLoader(mnist_train, batch_size=batch_size, shuffle=True, num_workers=num_workers) test_iter = torch.utils.data.DataLoader(mnist_test, batch_size=batch_size, shuffle=False, num_workers=num_workers) return train_iter, test_iter # 评估 def evaluate_accuracy(data_iter, net): acc_sum, n = 0.0, 0 for X, y in data_iter: acc_sum += (net(X).argmax(dim=1) == y).float().sum().item() n += y.shape[0] return acc_sum / n # 训练,随机梯度下降 def sgd(params, lr, batch_size): for param in params: param.data -= lr * param.grad / batch_size # 注意这里更改param时用的param.data # 训练 def train_ch3(net, train_iter, test_iter, loss, num_epochs, batch_size, params=None, lr=None, optimizer=None): for epoch in range(num_epochs): train_l_sum, train_acc_sum, n = 0.0, 0.0, 0 for X, y in train_iter: y_hat = net(X) l = loss(y_hat, y).sum() # 梯度清零 optimizer.zero_grad() l.backward() optimizer.step() train_l_sum += l.item() train_acc_sum += (y_hat.argmax(dim=1) == y).sum().item() n += y.shape[0] test_acc = evaluate_accuracy(test_iter, net) print('epoch %d, loss %.4f, train acc %.3f, test acc %.3f' % (epoch + 1, train_l_sum / n, train_acc_sum / n, test_acc)) if __name__ == '__main__': num_inputs, num_outputs, num_hiddens = 784, 10, 256 # 定义网络: 输入层(784) --> 隐藏层(256) --> 输出层(10) net = nn.Sequential( FlattenLayer(), nn.Linear(num_inputs, num_hiddens), nn.ReLU(), nn.Linear(num_hiddens, num_outputs), ) # print('len: ', len(list(net.parameters()))) # print('param: ', net.parameters()) # print('param list: ', list(net.parameters())) # 初始化网络参数 for params in net.parameters(): # net.parameters() 为各层的网络参数,可迭代 init.normal_(params, mean=0, std=0.01) print('len: ', len(list(net.parameters()))) print('param init: ', net.parameters()) # <generator object Module.parameters at 0x000001A0B1356620> print('param list init: ', list(net.parameters())) # 加载数据 batch_size = 256 train_iter, test_iter = load_data_fashion_mnist(batch_size) loss = torch.nn.CrossEntropyLoss() optimizer = torch.optim.SGD(net.parameters(), lr=0.5) # 训练 num_epochs = 5 train_ch3(net, train_iter, test_iter, loss, num_epochs, batch_size, None, None, optimizer)代码中几点说明:
① net.parameters()
net.parameters() 为:网络定义后,网络模型中的各层的参数 ( <generator object Module.parameters at 0x000001A0B1356620> )
可以转化为一个list查看其参数:list(net.parameters())
②torch.nn.init.normal_(params, mean=0, std=0.01)
对参数params进行均值为0方差为1的初始化
len: 4 param init: <generator object Module.parameters at 0x000001A0B1356620> param list init: [Parameter containing: tensor([[ 0.0096, 0.0091, 0.0036, ..., -0.0016, -0.0025, 0.0082], [ 0.0014, -0.0060, -0.0170, ..., -0.0087, -0.0124, 0.0029], [-0.0002, -0.0053, -0.0205, ..., 0.0140, 0.0015, -0.0047], ..., [-0.0067, -0.0192, -0.0026, ..., 0.0223, 0.0114, 0.0003], [ 0.0033, 0.0095, 0.0097, ..., 0.0015, 0.0030, -0.0138], [ 0.0139, -0.0073, 0.0012, ..., -0.0143, 0.0085, 0.0056]], requires_grad=True), Parameter containing: tensor([-3.3042e-03, 4.7892e-04, -5.6590e-03, 3.2377e-03, 3.5846e-03, 6.6989e-03, 2.3601e-03, -9.1927e-04, 2.4281e-02, -2.2155e-02, 5.3163e-03, 3.7543e-03, 4.3089e-03, 1.1811e-02, -4.1673e-03, -1.9667e-02, 2.6118e-03, -3.2978e-03, 1.3942e-02, -1.6289e-02, 9.8179e-03, -2.2531e-02, -1.4156e-02, -1.4382e-03, 1.7384e-02, -1.2549e-02, 9.4562e-03, -9.0459e-03, 8.4983e-03, -1.5124e-03, -1.4963e-02, -7.0390e-03, 1.0951e-02, -1.6487e-02, -5.2332e-03, -5.2680e-03, -1.7785e-03, -1.3423e-03, 1.9302e-03, -4.9111e-03, 1.7328e-03, -8.0625e-03, -3.9449e-03, 3.6381e-03, 1.1906e-02, 5.0710e-03, 5.1031e-03, 9.2445e-04, 2.6244e-02, -2.9451e-03, 9.6235e-03, -2.1532e-03, -1.3756e-02, -2.1489e-03, -1.3318e-02, 4.8365e-03, -1.0427e-02, 5.2636e-03, 8.1710e-03, -2.8734e-03, -4.0999e-03, -3.3395e-03, 9.2141e-03, 1.8420e-02, 2.3903e-03, 6.3389e-03, -7.1875e-03, 9.3982e-03, -1.6983e-02, -1.9021e-03, -6.3871e-03, 6.7952e-03, -1.2235e-02, -1.6785e-02, -6.6447e-03, 1.2196e-02, 7.3601e-03, -1.5027e-02, -2.6593e-03, -9.6182e-03, -8.4485e-03, 2.2411e-02, -7.5373e-03, 3.6415e-02, 2.6785e-03, 1.9647e-02, -1.4472e-03, -2.1426e-03, -1.0003e-02, -6.0945e-03, 6.1464e-04, 6.1757e-03, 1.2456e-02, 1.0664e-02, 8.7811e-03, -1.9107e-02, -8.5125e-03, -3.2865e-04, 1.0192e-02, -2.4412e-02, -2.1226e-02, 1.0242e-02, 4.0445e-03, -3.3238e-03, 4.4551e-04, 1.7880e-02, 1.4732e-02, 7.4244e-04, 1.5565e-02, 6.3838e-03, 4.2519e-03, 3.7454e-04, 6.0372e-03, 1.0598e-02, 6.6352e-03, 9.3732e-03, 7.1993e-03, -8.0230e-03, -2.0376e-02, 1.7323e-03, 1.5667e-02, -1.0637e-02, -1.9101e-02, -8.6477e-03, 4.6590e-03, -4.7290e-03, 1.2458e-02, 1.0215e-02, 1.4719e-02, -3.4490e-03, -4.6496e-03, 6.5331e-03, -3.9560e-03, -1.1488e-02, -8.5887e-03, 1.5083e-02, 1.0957e-02, 1.9015e-02, -2.1299e-03, -8.0287e-03, -1.4993e-02, -1.1674e-02, 7.0364e-03, -2.5001e-03, -1.0356e-03, 5.7498e-03, 5.7233e-04, 7.9161e-04, -6.0469e-03, -2.6913e-03, 6.7641e-03, 1.8129e-03, 1.5494e-03, -9.7351e-03, 6.8967e-05, 2.2971e-03, -9.1847e-03, -2.3717e-03, -6.4801e-03, 2.9549e-03, -7.2387e-03, -1.6071e-02, -1.1841e-02, -4.3262e-03, -7.4287e-04, -1.0381e-02, -1.9941e-02, 1.2515e-02, 1.1387e-02, -3.3133e-03, 1.3639e-02, -1.9078e-03, -1.5026e-02, 3.7264e-03, 1.2014e-02, -8.0367e-03, -3.5969e-02, 6.3780e-03, 3.4895e-03, 1.5735e-02, -5.6254e-04, -5.5807e-03, 5.4600e-04, -8.7495e-04, 7.8439e-03, -1.2823e-02, -1.4356e-02, 7.8702e-03, 4.3848e-04, 5.3145e-03, -6.1489e-03, 8.7027e-04, -1.0802e-03, 7.2241e-03, 5.0439e-03, 1.3031e-02, 7.4891e-03, -7.3666e-03, -6.0929e-03, -6.1948e-03, 8.1562e-03, -6.0273e-03, -1.0222e-02, -1.7376e-03, -1.2922e-02, 1.1247e-02, -1.0559e-02, -1.5887e-02, 1.0038e-02, -1.4515e-02, -9.5886e-03, 1.2830e-02, 8.8126e-03, -9.1111e-03, 6.2043e-03, 1.9829e-02, 1.5241e-02, 2.2486e-03, 9.0140e-03, 1.7259e-02, -5.6758e-03, 4.1752e-03, 4.8623e-04, 1.9457e-02, 8.3239e-03, -1.1590e-02, -5.5052e-03, -2.0561e-02, 2.8499e-03, 1.1046e-02, -7.4051e-03, 1.1231e-02, 1.4840e-02, 4.9973e-03, 1.3801e-02, -1.4826e-02, -7.4246e-03, -1.5146e-02, 1.2617e-02, 7.5188e-03, 1.9418e-02, -1.0118e-03, -8.8281e-03, -5.6416e-03, 1.8890e-04, -3.9850e-03, -4.7776e-03, 9.0903e-03, -3.2510e-02, 3.5589e-03, 3.9693e-03, 1.9995e-02, 2.7695e-03, 9.5730e-03, -9.2412e-03, 1.0012e-02], requires_grad=True), Parameter containing: tensor([[-0.0074, -0.0100, 0.0084, ..., 0.0004, -0.0123, -0.0015], [-0.0017, 0.0017, 0.0104, ..., -0.0067, -0.0016, -0.0096], [-0.0132, -0.0034, 0.0193, ..., 0.0191, 0.0004, 0.0105], ..., [ 0.0167, -0.0144, 0.0048, ..., 0.0061, -0.0083, 0.0072], [ 0.0018, 0.0048, 0.0050, ..., 0.0015, -0.0165, 0.0046], [-0.0039, 0.0027, -0.0014, ..., 0.0078, -0.0054, -0.0089]], requires_grad=True), Parameter containing: tensor([-0.0058, -0.0037, 0.0099, -0.0099, 0.0018, -0.0193, -0.0041, 0.0043, -0.0114, -0.0049], requires_grad=True)] epoch 1, loss 0.0031, train acc 0.699, test acc 0.763 epoch 2, loss 0.0019, train acc 0.817, test acc 0.785 epoch 3, loss 0.0017, train acc 0.844, test acc 0.844 epoch 4, loss 0.0015, train acc 0.856, test acc 0.798 epoch 5, loss 0.0014, train acc 0.865, test acc 0.845 Process finished with exit code 0
参考:
https://pytorch.org/docs/stable/nn.html#
https://pytorch.org/docs/stable/nn.init.html (torch.nn.init)
https://tangshusen.me/Dive-into-DL-PyTorch/#/chapter03_DL-basics/3.10_mlp-pytorch
相关文章推荐
- 【深度学习】多层感知机及其实现(pyTorch)
- 感知机模型学习笔记及Python实现
- 学习笔记TF026:多层感知机
- TensorFlow学习笔记(二):手写数字识别之多层感知机
- Tensorflow-多层感知机学习与实现
- Tensorflow 实战 笔记 (一)实现多层感知机
- 《统计学习方法》感知机学习笔记与Python实现
- 街景字符编码识别项目学习笔记(四)CNN介绍及字符识别模型的pytorch实现
- 神经网络学习笔记(十):多层感知机(中)--BP算法
- 感知机学习算法——统计学习方法笔记,代码实现
- Tensorflow学习之实现多层感知机
- 神经网络学习笔记(十一):多层感知机(下)
- 多层感知机进阶-基于keras的python学习笔记(八)
- HelloDNN,多层感知机MLP学习笔记
- pytorch 学习笔记 part 3 多层感知机
- 深度学习:多层感知机MLP数字识别的代码实现
- 深度学习笔记二:多层感知机(MLP)与神经网络结构
- 深度学习Deeplearning4j 入门实战(5):基于多层感知机的Mnist压缩以及在Spark实现
- 感知机模型学习笔记及Python实现
- 《tensorflow实战》学习2——实现多层感知机