基于DCGAN的手写数字生成
2018-03-08 14:59
423 查看
生成对抗模型,简单的可以理解为一个生成模型G,一个判别模型D,判别模型用于判断一个给定的图片是否真实图片(从数据集里获取的图片),生成模型的任务是去创造一个看起来像真的图片一样的图片,一开始的时候这两个模型都是没有经过训练的,这两个模型一起对抗训练,生成模型产生一张图片去欺骗判别模型,然后判别模型去判断这张图片是真是假,最终在这两个模型训练的过程中,两个模型的能力越来越强,最终达到稳态。
图中,右边的生成模型G,我们输入一个噪声z,通过生成样本x,由于这个样本是假的,判别模型G应该尽力将该模型x样本判别为0,即为假,而生成模型G会努力将x样本进行改进让判别模型将x判断为1,即为真。而在左边,由于输入的是真实样本,所以判别模型应当将输出1,即为真。
当输入的是从数据集中取出的real Iamge 数据时,我们只需要考虑第二部分,D(x)为判别模型的输出,表示输入x为real 数据的概率,我们的目的是让判别模型的输出D(x)的输出尽量靠近1。
当输入的为fake数据时,我们只计算第一部分,G(z)是生成模型的输出,输出的是一张Fake Image。我们要做的是让D(G(z))的输出尽可能趋向于0。这样才能表示判别模型是有区分力的。
相对判别模型来说,这个损失函数其实就是交叉熵损失函数。计算loss,进行梯度反传。这里的梯度反传可以使用任何一种梯度修正的方法。
当更新完判别模型的参数后,我们再去更新生成模型的参数。
对于生成模型来说,我们要做的是让G(z)产生的数据尽可能的和数据集中的数据一样。就是所谓的同样的数据分布。那么我们要做的就是最小化生成模型的误差,即只将由G(z)产生的误差传给生成模型。
但是针对判别模型的预测结果,要对梯度变化的方向进行改变。当判别模型认为G(z)输出为真实数据集的时候和认为输出为噪声数据的时候,梯度更新方向要进行改变。
即最终的损失函数为:
其中
表示判别模型的预测类别,对预测概率取整,为0或者1.用于更改梯度方向,阈值可以自己设置,或者正常的话就是0.5。
DCGAN中的架构
DCGAN即使深度卷积生成对抗网络。
可以看到,其实就是讲一个noise通过反卷积还原成一张图片。这个就是生成模型的工作。而判别模型就是一个普通的CNN,卷积神经网络,输出值为0或1.
这次来实现利用DCGAN,从一个噪音向量生成手写数字,样本来自MNIST,其实基于这个思想可以做很多事。
核心代码如下:# 28 x 28 的图片
img_rows, img_cols = 28, 28
(X_train, y_train), (X_test, y_test) = mnist.load_data()
X_train = X_train.reshape(X_train.shape[0], 1, img_rows, img_cols)
X_test = X_test.reshape(X_test.shape[0], 1, img_rows, img_cols)
X_train = X_train.astype('float32')
X_test = X_test.astype('float32')
X_train /= 255
X_test /= 255
跑起来以后,结果如下
可以大致看的出来,有几个图片有那么些手写数字的样子了。
后序还有挺多优化的余地,可以尝试不同的如wgan,或者ebgan这样的基于能量的,使用encoder,decoder的架构。应该可以得到更好的结果。如果计算机性能允许,并且有数据集,做人脸生成之类的也都是可以的。
图中,右边的生成模型G,我们输入一个噪声z,通过生成样本x,由于这个样本是假的,判别模型G应该尽力将该模型x样本判别为0,即为假,而生成模型G会努力将x样本进行改进让判别模型将x判断为1,即为真。而在左边,由于输入的是真实样本,所以判别模型应当将输出1,即为真。
判别模型的损失函数:
当输入的是从数据集中取出的real Iamge 数据时,我们只需要考虑第二部分,D(x)为判别模型的输出,表示输入x为real 数据的概率,我们的目的是让判别模型的输出D(x)的输出尽量靠近1。
当输入的为fake数据时,我们只计算第一部分,G(z)是生成模型的输出,输出的是一张Fake Image。我们要做的是让D(G(z))的输出尽可能趋向于0。这样才能表示判别模型是有区分力的。
相对判别模型来说,这个损失函数其实就是交叉熵损失函数。计算loss,进行梯度反传。这里的梯度反传可以使用任何一种梯度修正的方法。
当更新完判别模型的参数后,我们再去更新生成模型的参数。
给出生成模型的损失函数:
对于生成模型来说,我们要做的是让G(z)产生的数据尽可能的和数据集中的数据一样。就是所谓的同样的数据分布。那么我们要做的就是最小化生成模型的误差,即只将由G(z)产生的误差传给生成模型。
但是针对判别模型的预测结果,要对梯度变化的方向进行改变。当判别模型认为G(z)输出为真实数据集的时候和认为输出为噪声数据的时候,梯度更新方向要进行改变。
即最终的损失函数为:
其中
表示判别模型的预测类别,对预测概率取整,为0或者1.用于更改梯度方向,阈值可以自己设置,或者正常的话就是0.5。
DCGAN中的架构
DCGAN即使深度卷积生成对抗网络。
可以看到,其实就是讲一个noise通过反卷积还原成一张图片。这个就是生成模型的工作。而判别模型就是一个普通的CNN,卷积神经网络,输出值为0或1.
这次来实现利用DCGAN,从一个噪音向量生成手写数字,样本来自MNIST,其实基于这个思想可以做很多事。
核心代码如下:# 28 x 28 的图片
img_rows, img_cols = 28, 28
(X_train, y_train), (X_test, y_test) = mnist.load_data()
X_train = X_train.reshape(X_train.shape[0], 1, img_rows, img_cols)
X_test = X_test.reshape(X_test.shape[0], 1, img_rows, img_cols)
X_train = X_train.astype('float32')
X_test = X_test.astype('float32')
X_train /= 255
X_test /= 255
# 做个指示器,告诉算法,现在这个net(要么是dis要么是gen),能不能被继续train def make_trainable(net, val): net.trainable = val for l in net.layers: l.trainable = val
dropout_rate = 0.25 # 设置gen和dis的opt opt = Adam(lr=1e-3) dopt = Adam(lr=1e-4) nch = 200 # 造个GEN nch = 200 g_input = Input(shape=[100]) # 倒过来的CNN第一层(也就是普通CNN那个flatten那一层) H = Dense(nch*14*14, init='glorot_normal')(g_input) H = BatchNormGAN()(H) H = Activation('relu')(H) H = Reshape( [nch, 14, 14] )(H) # upscale上去2倍大。也就是从14x14 到 28x28 H = UpSampling2D(size=(2, 2))(H) # CNN滤镜 H = Convolution2D(int(nch/2), 3, 3, border_mode='same', init='glorot_uniform')(H) H = BatchNormGAN()(H) H = Activation('relu')(H) # CNN滤镜 H = Convolution2D(int(nch/4), 3, 3, border_mode='same', init='glorot_uniform')(H) H = BatchNormGAN()(H) H = Activation('relu')(H) # 合成一个大图片 H = Convolution2D(1, 1, 1, border_mode='same', init='glorot_uniform')(H) g_V = Activation('sigmoid')(H) generator = Model(g_input,g_V) generator.compile(loss='binary_crossentropy', optimizer=opt) generator.summary() # 造个DIS # 这就是一个正常的CNN d_input = Input(shape=shp) # 滤镜 H = Convolution2D(256, 5, 5, subsample=(2, 2), border_mode = 'same', activation='relu')(d_input) #H = LeakyReLU(0.2)(H) H = Dropout(dropout_rate)(H) # 滤镜 H = Convolution2D(512, 5, 5, subsample=(2, 2), border_mode = 'same', activation='relu')(H) #H = LeakyReLU(0.2)(H) H = Dropout(dropout_rate)(H) H = Flatten()(H) # flatten之后,接MLP H = Dense(256)(H) #H = LeakyReLU(0.2)(H) H = Dropout(dropout_rate)(H) # 出一个结果,『是』或者『不是』 d_V = Dense(2,activation='softmax')(H) discriminator = Model(d_input,d_V) discriminator.compile(loss='categorical_crossentropy', optimizer=dopt) discriminator.summary() make_trainable(discriminator, False) # 为stacked GAN做准备 # 然后合成一个GAN的构架 gan_input = Input(shape=[100]) H = generator(gan_input) gan_V = discriminator(H) GAN = Model(gan_input, gan_V) GAN.compile(loss='categorical_crossentropy', optimizer=opt) GAN.summary()
ntrain = 10000 trainidx = random.sample(range(0,X_train.shape[0]), ntrain) XT = X_train[trainidx,:,:,:] noise_gen = np.random.uniform(0,1,size=[XT.shape[0],100]) generated_images = generator.predict(noise_gen) X = np.concatenate((XT, generated_images)) n = XT.shape[0] y = np.zeros([2*n,2]) y[:n,1] = 1 y[n:,0] = 1 #提取训练discriminator,让它可以识别对错 make_trainable(discriminator,True) discriminator.fit(X,y, nb_epoch=1, batch_size=32) y_hat = discriminator.predict(X)
def train_for_n(nb_epoch=5000, plt_frq=25,BATCH_SIZE=32): for e in tqdm(range(nb_epoch)): # 生成图片 image_batch = X_train[np.random.randint(0,X_train.shape[0],size=BATCH_SIZE),:,:,:] noise_gen = np.random.uniform(0,1,size=[BATCH_SIZE,100]) generated_images = generator.predict(noise_gen) # 训练DIS X = np.concatenate((image_batch, generated_images)) y = np.zeros([2*BATCH_SIZE,2]) y[0:BATCH_SIZE,1] = 1 y[BATCH_SIZE:,0] = 1 # 当然,要让DIS可以被训练 make_trainable(discriminator,True) d_loss = discriminator.train_on_batch(X,y) losses["d"].append(d_loss) # 训练 Generator-Discriminator stack noise_tr = np.random.uniform(0,1,size=[BATCH_SIZE,100]) y2 = np.zeros([BATCH_SIZE,2]) y2[:,1] = 1 # 这个时候,让DIS不能被变化。保证判断结果一致性 make_trainable(discriminator,False) g_loss = GAN.train_on_batch(noise_tr, y2 ) losses["g"].append(g_loss) # 生成图片 if e%plt_frq==plt_frq-1: plot_loss(losses) plot_gen()
跑起来以后,结果如下
可以大致看的出来,有几个图片有那么些手写数字的样子了。
后序还有挺多优化的余地,可以尝试不同的如wgan,或者ebgan这样的基于能量的,使用encoder,decoder的架构。应该可以得到更好的结果。如果计算机性能允许,并且有数据集,做人脸生成之类的也都是可以的。
相关文章推荐
- 基于Tensorflow和DCGAN生成动漫头像实践(二)
- 基于opencv的手写数字识别(MFC,HOG,SVM)
- 从一到二:利用mnist训练集生成的caffemodel对mnist测试集与自己手写的数字进行测试
- [DL]3.基于CNN的手写数字识别
- DCGAN例子学习-MNIST 手写体数字生成
- [置顶] java实现基于Mnist数据集的手写数字识别
- OpenCV手写数字字符识别(基于k近邻算法)
- opencv 基于KNN的手写数字字符识别
- 利用Pornzilla扩展的小书签功能基于URL中数字自动批量生成URL
- python tensorflow 基于cnn实现手写数字识别
- 基于tensorflow的MNIST手写数字识别(二)--入门篇
- TensorFlow小试牛刀(2):GAN生成手写数字
- 基于opencv的手写数字识别(MFC,HOG,SVM)
- 【webAI】基于deeplearn.js的Mnist手写数字识别
- 基于机器学习多种方法的kaggle竞赛入门之手写数字的图像识别预测
- 基于tensorflow的MNIST手写数字识别
- 深度学习与神经网络实战:快速构建一个基于神经网络的手写数字识别系统
- 如何使用TensorFlow和VAE模型生成手写数字
- 基于Fisher线性判别分析的手写数字识别
- [Python]基于CNN的MNIST手写数字识别