DCGAN论文笔记+源码解析
2017-09-11 16:45
330 查看
DCGAN论文笔记+源码解析
作者:wspba论文地址:UNSUPERVISED REPRESENTATION LEARNING WITH DEEP CONVOLUTIONAL GENERATIVE ADVERSARIAL NETWORKS
源码地址:DCGAN in TensorFlow
DCGAN,Deep Convolutional Generative Adversarial Networks是生成对抗网络(Generative Adversarial Networks)的一种延伸,将卷积网络引入到生成式模型当中来做无监督的训练,利用卷积网络强大的特征提取能力来提高生成网络的学习效果。
DCGAN有以下特点:
1.在判别器模型中使用strided convolutions来替代空间池化(pooling),而在生成器模型中使用fractional strided convolutions,即deconv,反卷积层。
2.除了生成器模型的输出层和判别器模型的输入层,在网络其它层上都使用了Batch Normalization,使用BN可以稳定学习,有助于处理初始化不良导致的训练问题。
3.去除了全连接层,而直接使用卷积层连接生成器和判别器的输入层以及输出层。
4.在生成器的输出层使用Tanh激活函数,而在其它层使用ReLU;在判别器上使用leaky ReLU。
原论文中只给出了在LSUN实验上的生成器模型的结构图如下:
但是对于实验细节以及方法的介绍并不是很详细,于是便从源码入手来理解DCGAN的工作原理。
先看main.py:
with tf.Session(config=run_config) as sess: if FLAGS.dataset == 'mnist': dcgan = DCGAN( sess, input_width=FLAGS.input_width, input_height=FLAGS.input_height, output_width=FLAGS.output_width, output_height=FLAGS.output_height, batch_size=FLAGS.batch_size, y_dim=10, c_dim=1, dataset_name=FLAGS.dataset, input_fname_pattern=FLAGS.input_fname_pattern, is_crop=FLAGS.is_crop, checkpoint_dir=FLAGS.checkpoint_dir, sample_dir=FLAGS.sample_dir)因为我们使用DCGAN来生成MNIST数字手写体图像,注意这里的y_dim=10,表示0到9这10个类别,c_dim=1,表示灰度图像。
再看model.py:
def discriminator(self, image, y=None, reuse=False): with tf.variable_scope("discriminator") as scope: if reuse: scope.reuse_variables() yb = tf.reshape(y, [self.batch_size, 1, 1, self.y_dim]) x = conv_cond_concat(image, yb) h0 = lrelu(conv2d(x, self.c_dim + self.y_dim, name='d_h0_conv')) h0 = conv_cond_concat(h0, yb) h1 = lrelu(self.d_bn1(conv2d(h0, self.df_dim + self.y_dim, name='d_h1_conv'))) h1 = tf.reshape(h1, [self.batch_size, -1]) h1 = tf.concat_v2([h1, y], 1) h2 = lrelu(self.d_bn2(linear(h1, self.dfc_dim, 'd_h2_lin'))) h2 = tf.concat_v2([h2, y], 1) h3 = linear(h2, 1, 'd_h3_lin') return tf.nn.sigmoid(h3), h3这里batch_size=64,image的维度为[64 28 28 1],y的维度是[64 10],yb的维度[64 1 1 10],x将image和yb连接起来,这相当于是使用了Conditional GAN,为图像提供标签作为条件信息,于是x的维度是[64 28 28 11],将x输入到卷积层conv2d,conv2d的代码如下:
def conv2d(input_, output_dim, k_h b31e =5, k_w=5, d_h=2, d_w=2, stddev=0.02, name="conv2d"): with tf.variable_scope(name): w = tf.get_variable('w', [k_h, k_w, input_.get_shape()[-1], output_dim], initializer=tf.truncated_normal_initializer(stddev=stddev)) conv = tf.nn.conv2d(input_, w, strides=[1, d_h, d_w, 1], padding='SAME') biases = tf.get_variable('biases', [output_dim], initializer=tf.constant_initializer(0.0)) conv = tf.reshape(tf.nn.bias_add(conv, biases), conv.get_shape()) return conv卷积核的大小为5*5,stride为[1 2 2 1],通过2的卷积步长可以替代pooling进行降维,padding=‘SAME’,则卷积的输出维度为[64 14 14 11]。然后使用batch normalization及leaky ReLU的激活层,输出与yb再进行concat,得到h0,维度为[64 14 14 21]。同理,h1的维度为[64 7*7*74+10],h2的维度为[64 1024+10],然后连接一个线性输出,得到h3,维度为[64 1],由于我们希望判别器的输出代表概率,所以最终使用一个sigmoid的激活。
def generator(self, z, y=None): with tf.variable_scope("generator") as scope: s_h, s_w = self.output_height, self.output_width s_h2, s_h4 = int(s_h/2), int(s_h/4) s_w2, s_w4 = int(s_w/2), int(s_w/4) # yb = tf.expand_dims(tf.expand_dims(y, 1),2) yb = tf.reshape(y, [self.batch_size, 1, 1, self.y_dim]) z = tf.concat_v2([z, y], 1) h0 = tf.nn.relu( self.g_bn0(linear(z, self.gfc_dim, 'g_h0_lin'))) h0 = tf.concat_v2([h0, y], 1) h1 = tf.nn.relu(self.g_bn1( linear(h0, self.gf_dim*2*s_h4*s_w4, 'g_h1_lin'))) h1 = tf.reshape(h1, [self.batch_size, s_h4, s_w4, self.gf_dim * 2]) h1 = conv_cond_concat(h1, yb) h2 = tf.nn.relu(self.g_bn2(deconv2d(h1, [self.batch_size, s_h2, s_w2, self.gf_dim * 2], name='g_h2'))) h2 = conv_cond_concat(h2, yb) return tf.nn.sigmoid( deconv2d(h2, [self.batch_size, s_h, s_w, self.c_dim], name='g_h3'))output_height和output_width为28,因此s_h和s_w为28,s_h2和s_w2为14,s_h4和s_w4为7。在这里z为平均分布的随机分布数,维度为[64 100],y的维度为[64 10],yb的维度是[64 1 1 10],z与y进行一个concat得到[64 110]的tensor,输入到一个线性层,输出维度是[64 1024],再经过batch normalization以及ReLU激活,并与y进行concat,输出h0的维度是[64 1034],同样的再经过一个线性层输出维度为[64
128*7*7],再进行reshape并与yb进行concat,得到h1,维度为[64 7 7 138],然后输入到一个deconv2d,做一个反卷积,也就是文中说的fractional strided convolutions,再经过batch normalization以及ReLU激活,并与yb进行concat,输出h2的维度是[64 14 14 138],最后再输入到deconv2d层以及sigmoid激活,得到生成器的输出,维度为[64 28 28 1]。
生成器以及判别器的输出:
self.G = self.generator(self.z, self.y) self.D, self.D_logits = \ self.discriminator(inputs, self.y, reuse=False) self.D_, self.D_logits_ = \ self.discriminator(self.G, self.y, reuse=True)其中D表示真实数据的判别器输出,D_表示生成数据的判别器输出。
再看损失函数:
self.d_loss_real = tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits( logits=self.D_logits, targets=tf.ones_like(self.D))) self.d_loss_fake = tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits( logits=self.D_logits_, targets=tf.zeros_like(self.D_))) self.g_loss = tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits( logits=self.D_logits_, targets=tf.ones_like(self.D_)))即对于真实数据,判别器的损失函数d_loss_real为判别器输出与1的交叉熵,而对于生成数据,判别器的损失函数d_loss_fake为输出与0的交叉熵,因此判别器的损失函数d_loss=d_loss_real+d_loss_fake;生成器的损失函数是g_loss判别器对于生成数据的输出与1的交叉熵。
优化器:
d_optim = tf.train.AdamOptimizer(config.learning_rate, beta1=config.beta1) \ .minimize(self.d_loss, var_list=self.d_vars) g_optim = tf.train.AdamOptimizer(config.learning_rate, beta1=config.beta1) \ .minimize(self.g_loss, var_list=self.g_vars)训练阶段:
for epoch in xrange(config.epoch): batch_idxs = min(len(data_X), config.train_size) // config.batch_size for idx in xrange(0, batch_idxs): batch_images = data_X[idx*config.batch_size:(idx+1)*config.batch_size] batch_labels = data_y[idx*config.batch_size:(idx+1)*config.batch_size] batch_images = np.array(batch).astype(np.float32)[:, :, :, None] batch_z = np.random.uniform(-1, 1, [config.batch_size, self.z_dim]) \ .astype(np.float32) # Update D network _, summary_str = self.sess.run([d_optim, self.d_sum], feed_dict={ self.inputs: batch_images, self.z: batch_z, self.y:batch_labels, }) self.writer.add_summary(summary_str, counter) # Update G network _, summary_str = self.sess.run([g_optim, self.g_sum], feed_dict={ self.z: batch_z, self.y:batch_labels, }) self.writer.add_summary(summary_str, counter) # Run g_optim twice to make sure that d_loss does not go to zero (different from paper) _, summary_str = self.sess.run([g_optim, self.g_sum], feed_dict={ self.z: batch_z, self.y:batch_labels }) self.writer.add_summary(summary_str, counter) errD_fake = self.d_loss_fake.eval({ self.z: batch_z, self.y:batch_labels }) errD_real = self.d_loss_real.eval({ self.inputs: batch_images, self.y:batch_labels }) errG = self.g_loss.eval({ self.z: batch_z, self.y: batch_labels }) counter += 1与论文中不同的是,这里在一个batch中,更新两次生成器,更新一次判别器。
实验结果:
由于自己的笔记本配置有限,仅使用CPU来运行速度较慢,因此epoch仅设置为2,对于MNIST手写数字数据集的生成情况如下
相关文章推荐
- DCGAN论文笔记+源码解析
- DCGAN论文笔记+源码解析
- DCGAN论文笔记+源码解析
- Neural Style论文笔记+源码解析
- InfoGAN论文笔记+源码解析
- Scala中类型变量Bounds代码实战及其在Spark中的应用源码解析之Scala学习笔记-34
- Android入门笔记之源码解析一
- 学习笔记:springmvc4.3源码学习:spring解析配置文件过程
- ClassTag 、Manifest、ClassManifest、TypeTag代码实战及其在Spark中的应用源码解析之Scala学习笔记-37
- jQuery 源码解析笔记(一)
- 第67讲:Scala并发编程匿名Actor、消息传递、偏函数实战解析及其在Spark源码中的应用解析学习笔记
- ZooKeeper源码学习笔记(1)--client端解析
- 论文阅读笔记:图像分割方法deeplab以及Hole算法解析(diliation)
- Scala中隐式转换内幕操作规则揭秘、最佳实践及其在Spark中的应用源码解析之Scala学习笔记-55
- 【jbpm4.4源码阅读笔记】engine的解析与生成
- OpenJDK源码研究笔记(八)-详细解析如何读取Java字节码文件(.class)
- Spring 源码深度解析笔记 - Spring 模块划分
- [置顶] 【安卓笔记】Volley全方位解析,带你从源码的角度彻底理解
- Retrofit源码学习笔记(2)-CallAdapter解析