在MNIST数据集上训练一个卷积网络自编码器
2018-07-31 16:30
204 查看
版权声明:转载请表明来源,多谢! https://blog.csdn.net/kingfoulin/article/details/81315937
首先你肯定知道了什么是自编码,一般我们常见的自编码是使用的多层感知机来实现的,也就是多层的全连接神经网络结构。本小记中我们使用CNN实现一个七层的卷积神经网络构成的自编码器。
自编码器使用很广泛,我觉得它的思想就是同一空间的数据操作的最优结果,数据经过压缩(数据向前传播到达中间的层所得到的结果),然后数据的解压缩过程(数据到达网络的输出),回到原来的空间。这个思想很关键,利用这个思想我们可以对元数据进行一些去噪操作,等等。
训练程序:
#!user/bin/python # _*_ coding:utf-8 _*_ import tensorflow as tf import matplotlib.pyplot as pt import numpy as np from tensorflow.examples.tutorials.mnist import input_data mnist = input_data.read_data_sets('../MNIST_data/', one_hot=True) batch_size = 64 sess = tf.InteractiveSession() def weight_varibales(shape): initial = tf.truncated_normal(shape=shape, stddev=0.1) return tf.Variable(initial) def bias_variables(shape): return tf.Variable(tf.constant(0.1, shape=shape)) def conv2d(x, w): return tf.nn.conv2d(x, w, strides=[1, 2, 2, 1], padding='SAME') # 四层结构的反卷积操作 inputdata = tf.placeholder(tf.float32, shape=[None, 28 * 28]) real_label = tf.placeholder(tf.float32, shape=[None, 28 * 28]) # 正确的标签 input_image = tf.reshape(inputdata, [-1, 28, 28, 1]) # layer1 w_conv1 = weight_varibales([5, 5, 1, 32]) b_conv1 = bias_variables([32]) h_conv1 = tf.nn.relu(conv2d(input_image, w_conv1) + b_conv1) # 14*14*32 # layer2 w_conv2 = weight_varibales([5, 5, 32, 64]) b_conv2 = bias_variables([64]) h_conv2 = tf.nn.relu(conv2d(h_conv1, w_conv2) + b_conv2) # 7*7*64 # layer3 w_conv3 = weight_varibales([3, 3, 64, 128]) b_conv3 = bias_variables([128]) h_conv3 = tf.nn.relu(conv2d(h_conv2, w_conv3) + b_conv3) # 4*4*128 # decode_layer 1 input 4*4*128 w_decode1 = weight_varibales([3, 3, 64, 128]) h_decode1 = tf.nn.conv2d_transpose(h_conv3, w_decode1, [batch_size, 7, 7, 64], [1, 2, 2, 1], padding="SAME") h_out1 = tf.nn.relu(h_decode1) # decode layer 2 input 7*7*64 w_decode2 = weight_varibales([3, 3, 32, 64]) h_decode2 = tf.nn.conv2d_transpose(h_out1, w_decode2, [batch_size, 14, 14, 32], [1, 2, 2, 1], padding='SAME') h_out2 = tf.nn.relu(h_decode2) # decode layer 3 input 14*14*32 w_decode3 = weight_varibales([5, 5, 1, 32]) h_decode3 = tf.nn.conv2d_transpose(h_out2, w_decode3, [batch_size, 28, 28, 1], [1, 2, 2, 1], padding='SAME') h_out = tf.nn.relu(h_decode3) # the network's output is 28*28*1 output = tf.reshape(h_out, [-1, 28 * 28]) loss = tf.reduce_mean(pow(output - inputdata, 2)) train_step = tf.train.AdamOptimizer(0.001).minimize(loss) tf.summary.scalar('loss', loss) saver = tf.train.Saver() write = tf.summary.FileWriter('pic/',sess.graph) # 注意这里的后面是需要加上sess.graph的,不然的话无法在tensorflow显示流程图 merge = tf.summary.merge_all() sess.run(tf.initialize_all_variables()) for i in range(5000): images, labels = mnist.train.next_batch(batch_size) result = sess.run(merge, {inputdata: images}) write.add_summary(result) if i % 100 == 0: print('loss: ', loss.eval({inputdata:images})) sess.run(train_step, {inputdata: images}) saver.save(sess, 'model/')
检测程序:
#!user/bin/python # _*_ coding:utf-8 _*_ import tensorflow as tf import matplotlib.pyplot as pt import numpy as np from tensorflow.examples.tutorials.mnist import input_data mnist = input_data.read_data_sets('../MNIST_data/', one_hot=True) batch_size = 2 sess = tf.InteractiveSession() def weight_varibales(shape): initial = tf.truncated_normal(shape=shape, stddev=0.1) return tf.Variable(initial) def bias_variables(shape): return tf.Variable(tf.constant(0.1, shape=shape)) def conv2d(x, w): return tf.nn.conv2d(x, w, strides=[1, 2, 2, 1], padding='SAME') # 四层结构的反卷积操作 inputdata = tf.placeholder(tf.float32, shape=[None, 28 * 28]) real_label = tf.placeholder(tf.float32, shape=[None, 28 * 28]) # 正确的标签 input_image = tf.reshape(inputdata, [-1, 28, 28, 1]) # layer1 w_conv1 = weight_varibales([5, 5, 1, 32]) b_conv1 = bias_variables([32]) h_conv1 = tf.nn.relu(conv2d(input_image, w_conv1) + b_conv1) # 14*14*32 # layer2 w_conv2 = weight_varibales([5, 5, 32, 64]) b_conv2 = bias_variables([64]) h_conv2 = tf.nn.relu(conv2d(h_conv1, w_conv2) + b_conv2) # 7*7*64 # layer3 w_conv3 = weight_varibales([3, 3, 64, 128]) b_conv3 = bias_variables([128]) h_conv3 = tf.nn.relu(conv2d(h_conv2, w_conv3) + b_conv3) # 4*4*128 # decode_layer 1 input 4*4*128 w_decode1 = weight_varibales([3, 3, 64, 128]) h_decode1 = tf.nn.conv2d_transpose(h_conv3, w_decode1, [batch_size, 7, 7, 64], [1, 2, 2, 1], padding="SAME") h_out1 = tf.nn.relu(h_decode1) # decode layer 2 input 7*7*64 w_decode2 = weight_varibales([3, 3, 32, 64]) h_decode2 = tf.nn.conv2d_transpose(h_out1, w_decode2, [batch_size, 14, 14, 32], [1, 2, 2, 1], padding='SAME') h_out2 = tf.nn.relu(h_decode2) # decode layer 3 input 14*14*32 w_decode3 = weight_varibales([5, 5, 1, 32]) h_decode3 = tf.nn.conv2d_transpose(h_out2, w_decode3, [batch_size, 28, 28, 1], [1, 2, 2, 1], padding='SAME') h_out = tf.nn.relu(h_decode3) # the network's output is 28*28*1 output = tf.reshape(h_out, [-1, 28 * 28]) loss = tf.reduce_mean(pow(output - inputdata, 2)) train_step = tf.train.AdamOptimizer(0.001).minimize(loss) saver = tf.train.Saver() saver.restore(sess, 'model/') images, labels = mnist.test.next_batch(batch_size) print(output.shape) pt.subplot(1, 2, 1) pt.imshow(images[0].reshape(28, 28), cmap='gray') pt.subplot(1, 2, 2) pt.imshow(output.eval({inputdata: images})[0].reshape(28, 28), cmap='gray') pt.show()
由上面的程序中,我们在测试的时候,使用的
batch_size=2,主要测试时候可视化数据基本要求不多。因此我们就使用batch_size很小的数据集进行测试。 阅读更多
相关文章推荐
- matconvnet环境下训练自己的数据集及模型测试-mnist网络结构-cifar10部分数据集
- UFLDL矢量化编程练习:用MNIST数据集的稀疏自编码器训练实现
- Mnist数据集下载、转换为lmdb,训练、测试、生成mean文件、生成label.txt、单张图片分类测试、可视化网络、可视化loss和accurate
- 基于pycaffe的网络训练和结果分析(mnist数据集)
- Tensorflow:Android调用Tensorflow Mobile版本API(1)-训练一个网络
- MNIST数据集的卷积神经网络训练代码具体实现示例--Tensorflow 框架
- Matconvnet学习——利用mnist网络训练自己的数据分辨左右手
- 使用TensorFlow训练MNIST数据集,SystemExit异常的解决方案
- TensorFlow训练mnist数据集(卷积神经网络lenet5)
- Caffe 之 使用非图片的鸢尾花(IRIS)数据集(hdf5格式) 训练网络模型
- Caffe使用——01 以LeNet训练Mnist数据集为例
- 卷积:如何成为一个很厉害的神经网络
- 利用DNN训练mnist数据集(1)
- 使用TensorFlow slim文件夹当中的inception_v4网络训练自己的分类数据集
- FCN网络的训练——以SIFT-Flow 数据集为例
- Caffe 之 使用非图片的鸢尾花(IRIS)数据集(hdf5格式) 训练网络模型
- 使用逻辑回归方法(softmax regression)识别MNIST手写体数字、使用CNN神经网络识别MNIST手写体数字、使用tensorboard可视化训练过程数据
- CIFAR-10数据集比MNIST训练难度高许多
- 神经网络 tensorflow教程 2.2 下载MNIST 数据集(保存所有图片)
- caffe学习笔记——用lenet网络及mnist数据集测试caffe