迁移学习CNN图像分类模型 - 花朵图片分类
2018-01-31 19:30
791 查看
训练一个好的卷积神经网络模型进行图像分类不仅需要计算资源还需要很长的时间。特别是模型比较复杂和数据量比较大的时候。普通的电脑动不动就需要训练几天的时间。为了能够快速地训练好自己的花朵图片分类器,我们可以使用别人已经训练好的模型参数,在此基础之上训练我们的模型。这个便属于迁移学习。本文提供训练数据集和代码下载。
原理:卷积神经网络模型总体上可以分为两部分,前面的卷积层和后面的全连接层。卷积层的作用是图片特征的提取,全连接层作用是特征的分类。我们的思路便是在inception-v3网络模型上,修改全连接层,保留卷积层。卷积层的参数使用的是别人已经训练好的,全连接层的参数需要我们初始化并使用我们自己的数据来训练和学习。
上面inception-v3模型图红色箭头前面部分是卷积层,后面是全连接层。我们需要修改修改全连接层,同时把模型的最终输出改为5。
由于这里使用了tensorflow框架,所以,我们需要获取上图红色箭头所在位置的张量
通过下面的链接下载inception-v3模型,其中包含已经训练好的参数。
模型下载链接:地址
训练数据花朵图片下载:地址
通过下面的代码加载模型,同时获取上面所述的两个张量。
由于我们模型的功能是对五种花进行分类,所以,我们需要修改全连接层,这里,我们只增加一个全连接层。全连接层的输入数据便是
最后便是定义交叉熵损失函数。模型使用反向传播训练,而训练的参数并不是模型的所有参数,仅仅是全连接层的参数,卷积层的参数是不变的。
那么接下来的是如何给我们的模型输入数据了,这里提供了几个操作数据的函数。由于训练数据集比较小,先把所有的图片通过
不到5分钟就可以训练好我们的模型,精确度还蛮高的。下图是本人运行的结果。
源码地址:https://github.com/liangyihuai/my_tensorflow/tree/master/com/huai/converlution/transfer_learning
原理:卷积神经网络模型总体上可以分为两部分,前面的卷积层和后面的全连接层。卷积层的作用是图片特征的提取,全连接层作用是特征的分类。我们的思路便是在inception-v3网络模型上,修改全连接层,保留卷积层。卷积层的参数使用的是别人已经训练好的,全连接层的参数需要我们初始化并使用我们自己的数据来训练和学习。
上面inception-v3模型图红色箭头前面部分是卷积层,后面是全连接层。我们需要修改修改全连接层,同时把模型的最终输出改为5。
由于这里使用了tensorflow框架,所以,我们需要获取上图红色箭头所在位置的张量
BOTTLENECK_TENSOR_NAME(最后一个卷积层激活函数的输出值,个数为2048)以及模型最开始的输入数据的张量
JPEG_DATA_TENSOR_NAME。获取这两个张量的作用是,图片训练数据通过
JPEG_DATA_TENSOR_NAME张量输入模型,通过
BOTTLENECK_TENSOR_NAME张量获取通过卷积层之后的图片特征。
BOTTLENECK_TENSOR_SIZE = 2048 BOTTLENECK_TENSOR_NAME = 'pool_3/_reshape:0' JPEG_DATA_TENSOR_NAME = 'DecodeJpeg/contents:0'
通过下面的链接下载inception-v3模型,其中包含已经训练好的参数。
模型下载链接:地址
训练数据花朵图片下载:地址
通过下面的代码加载模型,同时获取上面所述的两个张量。
# 读取已经训练好的Inception-v3模型。 with gfile.FastGFile(os.path.join(MODEL_DIR, MODEL_FILE), 'rb') as f: graph_def = tf.GraphDef() graph_def.ParseFromString(f.read()) bottleneck_tensor, jpeg_data_tensor = tf.import_graph_def( graph_def, return_elements=[BOTTLENECK_TENSOR_NAME, JPEG_DATA_TENSOR_NAME])
由于我们模型的功能是对五种花进行分类,所以,我们需要修改全连接层,这里,我们只增加一个全连接层。全连接层的输入数据便是
BOTTLENECK_TENSOR_NAME张量。
# 定义一层全链接层 with tf.name_scope('final_training_ops'): weights = tf.Variable(tf.truncated_normal([BOTTLENECK_TENSOR_SIZE, n_classes], stddev=0.001)) biases = tf.Variable(tf.zeros([n_classes])) logits = tf.matmul(bottleneck_input, weights) + biases final_tensor = tf.nn.softmax(logits)
最后便是定义交叉熵损失函数。模型使用反向传播训练,而训练的参数并不是模型的所有参数,仅仅是全连接层的参数,卷积层的参数是不变的。
# 定义交叉熵损失函数。 cross_entropy = tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=ground_truth_input) cross_entropy_mean = tf.reduce_mean(cross_entropy) train_step = tf.train.GradientDescentOptimizer(LEARNING_RATE).minimize(cross_entropy_mean)
那么接下来的是如何给我们的模型输入数据了,这里提供了几个操作数据的函数。由于训练数据集比较小,先把所有的图片通过
JPEG_DATA_TENSOR_NAME张量输入模型,然后获取
BOTTLENECK_TENSOR_NAME张量的值并保存到硬盘中。在模型训练的时候,从硬盘中读取所保存的
BOTTLENECK_TENSOR_NAME张量的值作为全连接层的输入数据。因为一张图片可能会被使用多次。
# 输入图片并获取`BOTTLENECK_TENSOR_NAME`张量的值 def get_or_create_bottleneck(sess, image_lists, label_name, index, category, jpeg_data_tensor, bottleneck_tensor) # 从硬盘中读取`BOTTLENECK_TENSOR_NAME`张量的值,用于训练 def get_or_create_bottleneck(sess, image_lists, label_name, index, category, jpeg_data_tensor, bottleneck_tensor): # 从硬盘中读取`BOTTLENECK_TENSOR_NAME`张量的值,用于测试。 def get_test_bottlenecks(sess, image_lists, n_classes, jpeg_data_tensor, bottleneck_tensor)
不到5分钟就可以训练好我们的模型,精确度还蛮高的。下图是本人运行的结果。
源码地址:https://github.com/liangyihuai/my_tensorflow/tree/master/com/huai/converlution/transfer_learning
相关文章推荐
- 第三阶段-tensorflow项目之图像image相关--图片分类训练效率之迁移学习
- tensorflow训练自己的数据集实现CNN图像分类2(保存模型&测试单张图片)
- Keras学习之三:用CNN实现cifar10图像分类模型
- [caffe]深度学习之图像分类模型Batch Normalization[BN-inception]解读
- [caffe]深度学习之图像分类模型AlexNet解读
- Deep Learning-TensorFlow (4) CNN卷积神经网络_CIFAR-10进阶图像分类模型(上)
- [caffe]深度学习之图像分类模型AlexNet解读
- 【机器学习PAI实践十】深度学习Caffe框架实现图像分类的模型训练
- TensorFlow之CNN图像分类及模型保存与调用
- Deep Learning_预训练CNN图片分类模型(AlexNet、VGG、GoogLeNet、Resnet.....)
- Deep Learning-TensorFlow (5) CNN卷积神经网络_CIFAR-10进阶图像分类模型(下)
- 【转】[caffe]深度学习之图像分类模型AlexNet解读
- [caffe]深度学习之图像分类模型AlexNet解读
- Caffe——python接口学习(6):用训练好的模型来分类新的图片
- Tensorflow实战学习(十六)【CNN实现、数据集、TFRecord、加载图像、模型、训练、调试】
- Tensorflow学习笔记--使用迁移学习做自己的图像分类器(Inception v3)
- 深度学习之图像分类模型AlexNet各层解读
- 利用CNN进行图像分类学习笔记
- [caffe]深度学习之图像分类模型AlexNet解读
- [caffe]深度学习之图像分类模型googlenet[inception v1]解读