CVAE(条件自编码) Condition GAN (条件GAN) 和 VAE-GAN模型之间的区别之CVAE
2019-03-20 19:23
351 查看
版权声明:版权归世界上所有无产阶级所有 https://blog.csdn.net/qq_41776781/article/details/88697174
# 使用CVAE(条件自编码) 训练fashion-mnist数据集
[code]import os import time import tensorflow as tf import numpy as np from ops import * from utils import * class CVAE(object): model_name = "CVAE" # name for checkpoint def __init__(self, sess, epoch, batch_size, z_dim, dataset_name, checkpoint_dir, result_dir, log_dir): self.sess = sess self.dataset_name = dataset_name self.checkpoint_dir = checkpoint_dir self.result_dir = result_dir self.log_dir = log_dir self.epoch = epoch self.batch_size = batch_size self.mean = 0 self.var =1 if dataset_name == 'mnist' or dataset_name == 'fashion-mnist': # parameters self.input_height = 28 self.input_width = 28 self.output_height = 28 self.output_width = 28 self.z_dim = z_dim # dimension of noise-vector self.y_dim = 10 # dimension of condition-vector (label) self.c_dim = 1 # train self.learning_rate = 0.0002 self.beta1 = 0.5 # test self.sample_num = 64 # number of generated images to be saved # load mnist self.data_X, self.data_y = load_mnist(self.dataset_name) # get number of batches for a single epoch self.num_batches = len(self.data_X) // self.batch_size else: print("********there is no other dataset to do *********") raise NotImplementedError # 编码器中输入的是真实图像和噪音向量 def encoder(self, x, y, is_training=True, reuse=False): with tf.variable_scope("encoder", reuse=reuse): y = tf.reshape(y, [self.batch_size, 1, 1, self.y_dim]) x = conv_cond_concat(x, y) net = lrelu(conv2d(x, 64, 4, 4, 2, 2, name='en_conv1')) net = lrelu(bn(conv2d(net, 128, 4, 4, 2, 2, name='en_conv2'), is_training=is_training, scope='en_bn2')) net = tf.reshape(net, [self.batch_size, -1]) net = lrelu(bn(linear(net, 1024, scope='en_fc3'), is_training=is_training, scope='en_bn3')) gaussian_params = linear(net, 2 * self.z_dim, scope='en_fc4') mean = gaussian_params[:, :self.z_dim] stddev = tf.nn.softplus(gaussian_params[:, self.z_dim:]) return mean, stddev # 定义解码器的相关操作 def decoder(self, z, y, is_training=True, reuse=False): with tf.variable_scope("decoder", reuse=reuse): z = concat([z, y], 1) net = tf.nn.relu(bn(linear(z, 1024, scope='de_fc1'), is_training=is_training, scope='de_bn1')) net = tf.nn.relu(bn(linear(net, 128 * 7 * 7, scope='de_fc2'), is_training=is_training, scope='de_bn2')) net = tf.reshape(net, [self.batch_size, 7, 7, 128]) net = tf.nn.relu( bn(deconv2d(net, [self.batch_size, 14, 14, 64], 4, 4, 2, 2, name='de_dc3'), is_training=is_training, scope='de_bn3')) out = tf.nn.sigmoid(deconv2d(net, [self.batch_size, 28, 28, 1], 4, 4, 2, 2, name='de_dc4')) return out def build_model(self): image_dims = [self.input_height, self.input_width, self.c_dim] bs = self.batch_size self.inputs = tf.placeholder(tf.float32, [bs] + image_dims, name='real_images') self.y = tf.placeholder(tf.float32, [bs, self.y_dim], name='y') self.z = tf.placeholder(tf.float32, [bs, self.z_dim], name='z') # 编码器中返回的数据是均值和方差 经过运算之后返回数据 mu, sigma = self.encoder(self.inputs, self.y, is_training=True, reuse=False) z = mu + sigma * tf.random_normal(tf.shape(mu), 0, 1, dtype=tf.float32) # 解码器输出真实图像 self.out = self.decoder(z, self.y, is_training=True, reuse=False) # 定义loss函数 marginal_likelihood = tf.reduce_sum(self.inputs * tf.log(self.out) + (1 - self.inputs) * tf.log(1 - self.out), [1, 2]) KL_divergence = 0.5 * tf.reduce_sum(tf.square(mu) + tf.square(sigma) - tf.log(1e-8 + tf.square(sigma)) - 1, [1]) self.neg_loglikelihood = -tf.reduce_mean(marginal_likelihood) self.KL_divergence = tf.reduce_mean(KL_divergence) # 这个损失函数不是很懂 生成结果好 就这样用吧 self.loss = self.neg_loglikelihood + self.KL_divergence # 定义优化器 t_vars = tf.trainable_variables() with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)): self.optim = tf.train.AdamOptimizer(self.learning_rate*5, beta1=self.beta1) \ .minimize(self.loss, var_list=t_vars) # is_training设置为false生成图像 但是不参与模型修改参数 self.fake_images = self.decoder(self.z, self.y, is_training=False, reuse=True) self.merged_summary_op = tf.summary.merge_all() def train(self): tf.global_variables_initializer().run() # 标签使用的是前batch_size个图像 self.sample_z = np.random.normal(self.mean, self.var, (self.batch_size, self.z_dim)).astype(np.float32) self.test_labels = self.data_y[0:self.batch_size] start_epoch = 0 start_batch_id = 0 counter = 1 start_time = time.time() for epoch in range(start_epoch, self.epoch): for idx in range(start_batch_id, self.num_batches): batch_images = self.data_X[idx*self.batch_size:(idx+1)*self.batch_size] batch_labels = self.data_y[idx * self.batch_size:(idx + 1) * self.batch_size] batch_z = np.random.uniform(-1, 1, [self.batch_size, self.z_dim]).astype(np.float32) _, summary_str, loss, nll_loss, kl_loss = self.sess.run([self.optim, self.merged_summary_op, self.loss, self.neg_loglikelihood, self.KL_divergence], feed_dict={self.inputs: batch_images, self.y: batch_labels, self.z: batch_z}) counter += 1 print("Epoch: [%2d] [%4d/%4d] time: %4.4f, loss: %.8f, nll: %.8f, kl: %.8f" \ % (epoch, idx, self.num_batches, time.time() - start_time, loss, nll_loss, kl_loss)) # save training results for every 300 steps if np.mod(counter, 300) == 0: samples = self.sess.run(self.fake_images, feed_dict={self.z: self.sample_z, self.y: self.test_labels}) tot_num_samples = min(self.sample_num, self.batch_size) manifold_h = int(np.floor(np.sqrt(tot_num_samples))) manifold_w = int(np.floor(np.sqrt(tot_num_samples))) save_images(samples[:manifold_h * manifold_w, :, :, :], [manifold_h, manifold_w], './' + check_folder(self.result_dir + '/' + self.model_dir) + '/' + self.model_name + '_train_{:02d}_{:04d}.png'.format( epoch, idx)) start_batch_id = 0 @property def model_dir(self): return "{}_{}_{}_{}".format( self.model_name, self.dataset_name, self.batch_size, self.z_dim)
训练的结果:
相关文章推荐
- CVAE(条件自编码) Condition GAN (条件GAN) 和 VAE-GAN模型之间的区别值 VAE-GAN
- CVAE(条件自编码) Condition GAN (条件GAN) 和 VAE-GAN模型之间的区别
- VAE、GAN、Info-GAN:全解深度学习三大生成模型
- Java虚拟机的理解与内存模型之间的区别
- 我的JavaScript回顾之路_01—0206—++在前在后区别/&&和||/条件判断语句/循环语句的区别/字符串类型数字和数字类型之间的转换
- Laravel的ORM模型的find(),findOrFail(),first(),firstOrFail(),get(),list(),toArray()之间的区别是什么?
- 各种编码之间的区别:ASCII、Unicode、UTF-8
- OSI和TCP/IP模型之间的区别-----无线网络通讯协议有哪些
- 字符编码之ASCII、Unicode以及utf-8之间的联系与区别
- 各种编码之间的区别 用法 总结
- 通过这几天的研究,终于明白了Unicode和UTF-8之间编码的区别。Unicode是一个字符集,而UTF-8是Unicode的其中一种,Unicode是定长的都为双字节,而UTF-8是可变的,对于
- 联合概率与联合分布、条件概率与条件分布、边缘概率与边缘分布、贝叶斯定理、生成模型(Generative Model)和判别模型(Discriminative Model)的区别
- 三大深度学习生成模型:VAE、GAN及其变种
- 三大深度学习生成模型:VAE、GAN及其变种
- 字符,字符集,编码之间的区别
- 编码问题,UTF,ISO8859-1,unicode,ACSii,GBK之间的区别
- JAVA 编码中文问题系统透彻讲解 UNICODE GBK UTF-8 ISO-8859-1 之间的区别
- 自己总结的 五种IO模型之间的关系与区别
- 卷积学习与传统稀疏编码、ICA模型学习区别(逐步补充)
- 深度学习的三大生成模型:VAE、GAN、GAN