快速上手生成对抗生成网络生成手写数字集(直接上代码以及详细注释,亲测可用)
2018-11-02 19:10
281 查看
GAN的原理其实很简单,就是生成网络G, 和判别网络D的对抗过程, 生成网络努力使得生成的虚假物品更加真实,而判别网络努力分别出哪些是G生成的,哪些是真实的,在这样一个对抗的过程中两个网络的能力不断得到提升。最终达到一个相对平衡的结果:理想状态下, G学习到真实数据的分布,而D无法分别出真实数据和生成数据, 即D(真) = D(假) = 0.5.
这里插入几张模型生成的图片,从左到右分别是随机生成的图片,100轮之后的图片,2000轮之后的图片,8000轮之后的图片。 代码虽然有100多行,但注释大概占了一般左右。一起交流,一起进步!
import tensorflow as tf from matplotlib import gridspec from tensorflow.examples.tutorials.mnist import input_data import numpy as np import os import cv2 """ 定义各个超参数,包括输入图片的大小, 隐藏层大小, 学习率,batch_size, 迭代次数等 """ # D 代表判别网络, 定义判别网络的参数 D_input_size = 784 D_H_layer1 = 200 D_output_size = 1 # G 代表生成网络,定义生成网络的参数 G_input_size = 100 G_H_layer1 = 300 G_output_size = 784 Learning_rate = 1e-3 iterations = 50000 batch_size = 16 """ :param D_input_size: 判别网络的第一层输入(样本图片(生成图片)的大小),需要输入给判别网络做判断的图片的大小 :param D_H_layer1: 判别网络的第一个隐藏层的大小,暂时设置成默认值200 :param D_output: 判别网络的第输出层的大小, 因为输出是一个概率值所以只有一个神经元 所以判别网络的模型是 784--》200——》1 """ D_W1 = tf.Variable(tf.truncated_normal([D_input_size, D_H_layer1], stddev=0.1), name="D_W1",dtype=tf.float32) D_b1 = tf.Variable(tf.zeros([D_H_layer1]), name="D_b1", dtype=tf.float32) D_W2 = tf.Variable(tf.truncated_normal([D_H_layer1, D_output_size], stddev=0.1), name="D_W2",dtype=tf.float32) D_b2 = tf.Variable(tf.zeros([D_output_size]), name="D_b2", dtype=tf.float32) """ :param G_input_size:生成网络的输入,默认设置成100,是一些随机噪声,从这些噪声中 逐步的生成理想的图片 :param G_H_layer1: 生成网络的第一层隐藏层的大小,默认设置成300 :param G_output_size: 生成网络的输出层大小,即要生成的图片大小,显然应该是与生成网络的输入一致 所以生成网络的模型是100——》300——》784 """ G_W1 = tf.Variable(tf.truncated_normal([G_input_size, G_H_layer1], stddev=0.1), name="G_W1",dtype=tf.float32) G_b1 = tf.Variable(tf.zeros([G_H_layer1]), name="G_b1", dtype=tf.float32) G_W2 = tf.Variable(tf.truncated_normal([G_H_layer1, G_output_size], stddev=0.1), name="G_W2",dtype=tf.float32) G_b2 = tf.Variable(tf.zeros([G_output_size]), name="G_b2", dtype=tf.float32) # 用列表保存,后面构建网络,和更新参数的时候会用到 G_variables = [G_W1, G_b1, G_W2, G_b2] D_variables = [D_W1, D_b1, D_W2, D_b2] # 判别网络 def discriminator(D_input): D_A1 = tf.nn.relu(tf.matmul(D_input, D_W1) + D_b1) D_output = tf.nn.sigmoid(tf.matmul(D_A1, D_W2) + D_b2) return D_output # 生成网络 def generator(G_input): G_A1 = tf.nn.relu(tf.matmul(G_input, G_W1) + G_b1) G_output = tf.sigmoid(tf.matmul(G_A1, G_W2) + G_b2) return G_output # 使用placeholder来定义生成网络的输入, G_image 表示生成的图片, real_image 表示真实的图片 G_input = tf.placeholder(tf.float32, shape=[None, 100]) G_image = generator(G_input) real_image = tf.placeholder(tf.float32, shape=[None, 784]) # 判别网络对生成图片的判别概率 D_fake = discriminator(G_image) # 判别网络对真实图片的判别概率 D_real = discriminator(real_image) # 至此就到了最重要的一步了! 定义生成网络和判别网络的损失函数 # 首先是判别网络的损失函数:判别网络的目的是能很好的区分真实图片和生成图片, # D_real 是判别网络对真实图片的判别概率,当然是越大越好,D_fake是判别网络对 # 生成图片的判别概率,判别器越好则越能是D_fake 变小,即(1.0 - D_fake)越大 # 越好,由此可得我们的目的是最小化目标函数D_loss D_loss = -tf.reduce_mean(tf.log(D_real) + tf.log(1. - D_fake)) # 其次来看生成网络的损失函数,生成的目的是尽可能的骗过判别网络,使得生成的图片和 # 样本中的片面尽量相同,故对于生成网络来说D_fake 越大越能体现其优越。即G_loss越小 # 越好 G_loss = -tf.reduce_mean(tf.log(D_fake)) # 使用反向传播算法来更新参数,注意var_list表示需要更新的参数列表 D_train = tf.train.AdamOptimizer(Learning_rate).minimize(D_loss, var_list=D_variables) G_train = tf.train.AdamOptimizer(Learning_rate).minimize(G_loss, var_list=G_variables) # 读取mnist数据 mnist = input_data.read_data_sets("MNIST_DATA", one_hot=True) # 画图函数 def plot(generate_pictures): fig = plt.figure(figsize=(4, 4)) gs = gridspec.GridSpec(4, 4) gs.update(wspace=0.05, hspace=0.05) for i, sample in enumerate(generate_pictures): ax = plt.subplot(gs[i]) plt.axis('off') ax.set_xticklabels([]) ax.set_yticklabels([]) ax.set_aspect('equal') plt.imshow(sample.reshape(28, 28), cmap='Greys_r') return fig # 打开一个会话 with tf.Session() as sess: sess.run(tf.global_variables_initializer()) # 构建生成图片的路径, 若路径不存在则新建一个 if not os.path.exists('output_pictures/'): os.makedirs('output_pictures/') for i in range(iterations): # 从mnist 中读取真实的图片 real_picture, _ = mnist.train.next_batch(batch_size) # 运行判别网络和生成网络, 注意要feed两个placeholder current_generate_loss, _ = sess.run([G_loss, G_train], feed_dict={ G_input: np.random.uniform(-1., 1., size=[batch_size, 100])}) current_discriminator_loss, _ = sess.run([D_loss, D_train], feed_dict={G_input: np.random.uniform(-1., 1., size=[batch_size, 100]), real_image: real_picture}) # 每隔100轮保存一下生成器生成的图片 if i % 100 == 0: # np.random.uniform()生成大小为【self.batch_size, 100】的均匀分布,作为生成网络的输入 generate_pictures = sess.run(G_image, feed_dict={G_input: np.random.uniform(-0.5, 0.5, size=[batch_size, 100])}) # generate_pictures为16 x 784 的矩阵,每一行表示一张图片,从中随机抽取一张保存下来 import matplotlib.pyplot as plt fig = plot(generate_pictures) plt.savefig('output_pictures/image{}.jpg'.format(str(i // 100)), bbox_inches='tight') plt.close(fig) # 显示单张图 single_picture = generate_pictures[0] # 这里因为生成图片的最后一层采用的是signoid函数,输出值为0-1,而像素值是0-255,所以乘以255 single_picture = np.reshape(single_picture, (28, 28)) * 255 cv2.imwrite("output_pictures/A{}.jpg".format(str(i // 100)), single_picture) # 输入连个网络当前的损失 print("Iterations: " + str(i) + ",the D_loss is %.4f, and the G_loss is %.4f" % ( current_discriminator_loss, current_generate_loss))阅读更多
相关文章推荐
- BP神经网络以及在手写数字分类中python代码的详细注释
- 优化版本对生成对抗网络生成手写数字集(附代码详解)
- 生成对抗网络介绍(附TensorFlow代码)
- 50行代码实现对抗生成网络GAN
- 七种常见的排序算法--c++直接上代码,注释详细
- CLIP PATH (MASK) GENERATOR是一款在线制作生成clip-path路径的工具,可以直接生成SVG代码以及配合Mask制作蒙板。
- 对抗神经网络学习(一)——GAN实现mnist手写数字生成(tensorflow实现)
- OGRESE 地形Tile材质的生成 源码以及详细注释 所有源码出自OGRESE,注释部分出自OGRESE
- wp7检测网络是否可用以及网络开启简单代码段
- pix2code:从UI截图直接生成代码的神经网络工具
- 不到 200 行代码,教你如何用 Keras 搭建生成对抗网络(GAN)【转】
- tensorflow 1.01中GAN(生成对抗网络)手写字体生成例子(MINST)的测试
- 生成对抗网络GANs理解(附代码)
- java代码如何快速添加作者描述的注释最好能有详细的图解
- sublime text 3:创建可重复用的代码片段php文件头部注释信息快速生成
- 对抗生成网络及代码实例
- PyTorch快速入门教程十(GANs以及对抗网络)
- 学习笔记:生成对抗网络(Generative Adversarial Nets)(附代码)
- Quicksort 快速排序—注意点以及代码实现(笔试手写代码)
- Eclipse中的快捷键快速生成常用代码(例如无参、带参构造,set、get方法),以及Java中重要的内存分析(栈、堆、方法区、常量池)