您的位置:首页 > 其它

Pytorch的nn.module

2020-02-04 02:41 295 查看

包引入

import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable

定义类

class net(nn.Module):
def __init__(self):
super(net, self).__init__()
self.conv1 = nn.sequential(
nn.Conv2d(1, 16, kernal_size=5, padding = 1 ),
#输入深度1,输入深度16,卷积核5*5,0填充为1
nn.BatchNorm2d(16),     #归一化
nn.ReLU(),
nn.MaxPool2d(2)), #池化,核是2

self.conv2 = nn.sequential(
nn.Conv2d(16, 32, kernal_size=2, padding = 1 ),
#输入深度1,输入深度16,卷积核5*5,0填充为1
nn.BatchNorm2d(32),     #归一化
nn.ReLU(),
nn.MaxPool2d(3)), #池化,核是3

self.fc = nn.Linear(32,2)  #全连接层

def forward(self, x):
out = self.conv1(x)
out = self.conv2(out)
out = out.view(out.size(0), -1) #一维是out.size(0),剩下的在2维
out = self.fc(out)
return out

注意:view函数只能由于contiguous的张量上,具体而言,就是在内存中连续存储的张量。具体而言,可以参看link。所以,当tensor之前调用了transpose, permute函数就会是tensor内存中变得不再连续,就不能调用view函数。
所以,应该提前做tensor.contiguous()的操作!

训练

def main():
torch.set_default_tensor_type("torch.DoubleTensor")
trainx = np.array([[[[0.0, 0, 1, 0, 0],
[0, 1, 1, 0, 0],
[0, 0, 1, 0, 0],
[0, 0, 1, 0, 0],
[1, 1, 1, 1, 1]]],
[[[0, 1, 1, 0, 0],
[1, 1, 1, 1, 0],
[1, 0, 1, 1, 0],
[0, 1, 1, 0, 0],
[1, 1, 1, 1, 1]]]])
trainy = np.array([[1, 0.0], [0, 1]])
N = net()

trainx = torch.from_numpy(trainx)
trainy = torch.from_numpy(trainy)
print(trainx.shape)
input()
criterion = nn.MSELoss()  #平方损失函数
optimizer = optim.SGD(N.parameters(), lr=1e-4) #随机梯度下降
num_epochs = 10000

for epoch in range(num_epochs):
inputs = Variable(trainx)
target = Variable(trainy)

out = N(inputs)
loss = criterion(out, target)

optimizer.zero_grad()
loss.backward()
optimizer.step()

if (epoch) % 50 == 0:
print("loss:{}".format(loss.data))

N.eval() #转变为测试模式
predict = N(Variable(trainx))
predict = predict.data.numpy()
print(predict)
pass

main()

转自:https://www.jianshu.com/p/76fb6e8e59e6
https://blog.csdn.net/m0_37586991/article/details/88371251

  • 点赞
  • 收藏
  • 分享
  • 文章举报
yuki___ 发布了4 篇原创文章 · 获赞 0 · 访问量 469 私信 关注
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: