您的位置:首页 > 其它

用PaddlePaddle实现图像分类-ResNet(动态图版)

2020-04-14 12:08 731 查看

【推荐阅读】微服务还能火多久?>>>

ResNet

ResNet(Residual Network)是2015年ImageNet图像分类、图像物体定位和图像物体检测比赛的冠军。针对随着网络训练加深导致准确度下降的问题,ResNet提出了残差学习方法来减轻训练深层网络的困难。在已有设计思路(BN, 小卷积核,全卷积网络)的基础上,引入了残差模块。每个残差模块包含两条路径,其中一条路径是输入特征的直连通路,另一条路径对该特征做两到三次卷积操作得到该特征的残差,最后再将两条路径上的特征相加。

残差模块如图1所示,左边是基本模块连接方式,由两个输出通道数相同的3x3卷积组成。右边是瓶颈模块(Bottleneck)连接方式,之所以称为瓶颈,是因为上面的1x1卷积用来降维(图示例即256->64),下面的1x1卷积用来升维(图示例即64->256),这样中间3x3卷积的输入和输出通道数都较小(图示例即64->64)。


图1. 残差模块

图2展示了50、101、152层网络连接示意图,使用的是瓶颈模块。这三个模型的区别在于每组中残差模块的重复次数不同(见图右上角)。ResNet训练收敛较快,成功的训练了上百乃至近千层的卷积神经网络。


图2. 基于ImageNet的ResNet模型

ResNet解读博客https://blog.csdn.net/lanran2/article/details/79057994

In[1]
# 解压花朵数据集
!cd data/data2815 && unzip -q flower_photos.zip
In[2]
import codecs
import os
import random
import shutil
from PIL import Image

train_ratio = 4.0 / 5

all_file_dir = 'data/data2815'
class_list = [c for c in os.listdir(all_file_dir) if os.path.isdir(os.path.join(all_file_dir, c)) and not c.endswith('Set') and not c.startswith('.')]
class_list.sort()
print(class_list)
train_image_dir = os.path.join(all_file_dir, "trainImageSet")
if not os.path.exists(train_image_dir):
os.makedirs(train_image_dir)

eval_image_dir = os.path.join(all_file_dir, "evalImageSet")
if not os.path.exists(eval_image_dir):
os.makedirs(eval_image_dir)

train_file = codecs.open(os.path.join(all_file_dir, "train.txt"), 'w')
eval_file = codecs.open(os.path.join(all_file_dir, "eval.txt"), 'w')

with codecs.open(os.path.join(all_file_dir, "label_list.txt"), "w") as label_list:
label_id = 0
for class_dir in class_list:
label_list.write("{0}\t{1}\n".format(label_id, class_dir))
image_path_pre = os.path.join(all_file_dir, class_dir)
for file in os.listdir(image_path_pre):
try:
img = Image.open(os.path.join(image_path_pre, file))
if random.uniform(0, 1) <= train_ratio:
shutil.copyfile(os.path.join(image_path_pre, file), os.path.join(train_image_dir, file))
train_file.write("{0}\t{1}\n".format(os.path.join(train_image_dir, file), label_id))
else:
shutil.copyfile(os.path.join(image_path_pre, file), os.path.join(eval_image_dir, file))
eval_file.write("{0}\t{1}\n".format(os.path.join(eval_image_dir, file), label_id))
except Exception as e:
pass
# 存在一些文件打不开,此处需要稍作清洗
label_id += 1

train_file.close()
eval_file.close()
['daisy', 'dandelion', 'roses', 'sunflowers', 'tulips']
In[4]
#模型训练
!python train.py
2020-03-06 16:56:59,114-INFO: model saved at epoch 49, best accuracy is 0.9454861111111111
2020-03-06 16:56:59,114 - train.py[line:166] - INFO: model saved at epoch 49, best accuracy is 0.9454861111111111
2020-03-06 16:56:59,115-INFO: Final loss: [0.24703844]
2020-03-06 16:56:59,115 - train.py[line:167] - INFO: Final loss: [0.24703844]
In[5]
#训练50轮,在验证集上评估
!python eval.py
W0306 16:57:08.802810   231 device_context.cc:237] Please NOTE: device: 0, CUDA Capability: 70, Driver API Version: 10.1, Runtime API Version: 9.0
W0306 16:57:08.806915   231 device_context.cc:245] device: 0, cuDNN Version: 7.3.
0.82777035
In[6]
#模型预测
!python infer.py
W0306 16:59:22.282727   291 device_context.cc:237] Please NOTE: device: 0, CUDA Capability: 70, Driver API Version: 10.1, Runtime API Version: 9.0
W0306 16:59:22.286063   291 device_context.cc:245] device: 0, cuDNN Version: 7.3.
checkpoint loaded
image data/data2815/sunflowers/3840761441_7c648abf4d_n.jpg Infer result is: sunflowers

点击链接,使用AI Studio一键上手实践项目吧:https://aistudio.baidu.com/aistudio/projectdetail/204995 

内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签:  PaddlePaddle