您的位置:首页 > 其它

PyTorch(三)——使用训练好的模型测试自己图片

2017-05-31 09:24 751 查看
目录连接

(1) 数据处理

(2) 搭建和自定义网络

(3) 使用训练好的模型测试自己图片

(4) 视频数据的处理

(5) PyTorch源码修改之增加ConvLSTM层

(6) 梯度反向传递(BackPropogate)的理解

(总) PyTorch遇到令人迷人的BUG

PyTorch的学
4000
习和使用(三)

上一篇文章中实现了如何增加一个自定义的Loss,以Siamese network为例。现在实现使用训练好的该网络对自己手写的数字图片进行测试。

首先需要对训练时的权重进行保存,然后在测试时直接加载即可。

torch.save(net, path)
torch.load(path)


即可。

然后自己手写的图片进行处理。

把需要测试的图片放入一个文件夹中,然后使用然后对图片数据进行加载,对图片数据进行归一化处理,并且调整大小为(B,C,H,W)。

transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((1.1618,), (1.1180,))])

def get_files(directory):
return [os.path.join(directory, f) for f in sorted(list(os.listdir(directory)))
if os.path.isfile(os.path.join(directory, f))]
images = np.array([])
file = get_files('./data/figure')
for i, item in enumerate(file):
print('Processing %i of %i (%s)' % (i+1, len(file), item))
image = transform(Image.open(item).convert('L'))
images = np.append(images, image.numpy())

img = images.reshape(-1, 1, 28, 28)
img = torch.from_numpy(img).float()
label = torch.ones(5,1).long()


其加载后的数据可视化为:



最后加载模型并测试。

torch.load('./saveT.pt')
def test(data, label):
net.eval()

data, label = Variable(data, volatile=True), Variable(label)
output = net(data)
out = output.view(-1, 4)
test_loss = criterion(out[:, 0:2], out[:, 2:4], label).data[0]
pred = classify(out.data[:, 0:2], out.data[:, 2:4])
correct = pred.eq(label.data).sum()
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: