您的位置:首页 > 理论基础 > 计算机网络

快速上手生成对抗生成网络生成手写数字集(直接上代码以及详细注释,亲测可用)

2018-11-02 19:10 281 查看

GAN的原理其实很简单,就是生成网络G, 和判别网络D的对抗过程, 生成网络努力使得生成的虚假物品更加真实,而判别网络努力分别出哪些是G生成的,哪些是真实的,在这样一个对抗的过程中两个网络的能力不断得到提升。最终达到一个相对平衡的结果:理想状态下, G学习到真实数据的分布,而D无法分别出真实数据和生成数据, 即D(真) = D(假) = 0.5.
这里插入几张模型生成的图片,从左到右分别是随机生成的图片,100轮之后的图片,2000轮之后的图片,8000轮之后的图片。 代码虽然有100多行,但注释大概占了一般左右。一起交流,一起进步!

import tensorflow as tf
from matplotlib import gridspec
from tensorflow.examples.tutorials.mnist import input_data
import numpy as np
import os
import cv2
"""
定义各个超参数,包括输入图片的大小, 隐藏层大小, 学习率,batch_size, 迭代次数等
"""
# D 代表判别网络, 定义判别网络的参数
D_input_size = 784
D_H_layer1 = 200
D_output_size = 1

# G 代表生成网络,定义生成网络的参数
G_input_size = 100
G_H_layer1 = 300
G_output_size = 784

Learning_rate = 1e-3
iterations = 50000
batch_size = 16

"""
:param D_input_size: 判别网络的第一层输入(样本图片(生成图片)的大小),需要输入给判别网络做判断的图片的大小
:param D_H_layer1: 判别网络的第一个隐藏层的大小,暂时设置成默认值200
:param D_output: 判别网络的第输出层的大小, 因为输出是一个概率值所以只有一个神经元
所以判别网络的模型是 784--》200——》1
"""
D_W1 = tf.Variable(tf.truncated_normal([D_input_size, D_H_layer1], stddev=0.1), name="D_W1",dtype=tf.float32)
D_b1 = tf.Variable(tf.zeros([D_H_layer1]), name="D_b1", dtype=tf.float32)
D_W2 = tf.Variable(tf.truncated_normal([D_H_layer1, D_output_size], stddev=0.1), name="D_W2",dtype=tf.float32)
D_b2 = tf.Variable(tf.zeros([D_output_size]), name="D_b2", dtype=tf.float32)

"""
:param G_input_size:生成网络的输入,默认设置成100,是一些随机噪声,从这些噪声中 逐步的生成理想的图片
:param G_H_layer1: 生成网络的第一层隐藏层的大小,默认设置成300
:param G_output_size: 生成网络的输出层大小,即要生成的图片大小,显然应该是与生成网络的输入一致
所以生成网络的模型是100——》300——》784
"""
G_W1 = tf.Variable(tf.truncated_normal([G_input_size, G_H_layer1], stddev=0.1), name="G_W1",dtype=tf.float32)
G_b1 = tf.Variable(tf.zeros([G_H_layer1]), name="G_b1", dtype=tf.float32)
G_W2 = tf.Variable(tf.truncated_normal([G_H_layer1, G_output_size], stddev=0.1), name="G_W2",dtype=tf.float32)
G_b2 = tf.Variable(tf.zeros([G_output_size]), name="G_b2", dtype=tf.float32)

# 用列表保存,后面构建网络,和更新参数的时候会用到
G_variables = [G_W1, G_b1, G_W2, G_b2]
D_variables = [D_W1, D_b1, D_W2, D_b2]

# 判别网络
def discriminator(D_input):
D_A1 = tf.nn.relu(tf.matmul(D_input, D_W1) + D_b1)
D_output = tf.nn.sigmoid(tf.matmul(D_A1, D_W2) + D_b2)
return D_output

# 生成网络
def generator(G_input):
G_A1 = tf.nn.relu(tf.matmul(G_input, G_W1) + G_b1)
G_output = tf.sigmoid(tf.matmul(G_A1, G_W2) + G_b2)
return G_output

# 使用placeholder来定义生成网络的输入, G_image 表示生成的图片, real_image 表示真实的图片
G_input = tf.placeholder(tf.float32, shape=[None, 100])
G_image = generator(G_input)
real_image = tf.placeholder(tf.float32, shape=[None, 784])
# 判别网络对生成图片的判别概率
D_fake = discriminator(G_image)
# 判别网络对真实图片的判别概率
D_real = discriminator(real_image)

# 至此就到了最重要的一步了! 定义生成网络和判别网络的损失函数
# 首先是判别网络的损失函数:判别网络的目的是能很好的区分真实图片和生成图片,
# D_real 是判别网络对真实图片的判别概率,当然是越大越好,D_fake是判别网络对
# 生成图片的判别概率,判别器越好则越能是D_fake 变小,即(1.0 - D_fake)越大
# 越好,由此可得我们的目的是最小化目标函数D_loss
D_loss = -tf.reduce_mean(tf.log(D_real) + tf.log(1. - D_fake))
# 其次来看生成网络的损失函数,生成的目的是尽可能的骗过判别网络,使得生成的图片和
# 样本中的片面尽量相同,故对于生成网络来说D_fake 越大越能体现其优越。即G_loss越小
# 越好
G_loss = -tf.reduce_mean(tf.log(D_fake))

# 使用反向传播算法来更新参数,注意var_list表示需要更新的参数列表
D_train = tf.train.AdamOptimizer(Learning_rate).minimize(D_loss, var_list=D_variables)
G_train = tf.train.AdamOptimizer(Learning_rate).minimize(G_loss, var_list=G_variables)
# 读取mnist数据
mnist = input_data.read_data_sets("MNIST_DATA", one_hot=True)

# 画图函数
def plot(generate_pictures):
fig = plt.figure(figsize=(4, 4))
gs = gridspec.GridSpec(4, 4)
gs.update(wspace=0.05, hspace=0.05)

for i, sample in enumerate(generate_pictures):
ax = plt.subplot(gs[i])
plt.axis('off')
ax.set_xticklabels([])
ax.set_yticklabels([])
ax.set_aspect('equal')
plt.imshow(sample.reshape(28, 28), cmap='Greys_r')
return fig

# 打开一个会话

with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
# 构建生成图片的路径, 若路径不存在则新建一个
if not os.path.exists('output_pictures/'):
os.makedirs('output_pictures/')

for i in range(iterations):
# 从mnist 中读取真实的图片
real_picture, _ = mnist.train.next_batch(batch_size)
# 运行判别网络和生成网络, 注意要feed两个placeholder
current_generate_loss, _ = sess.run([G_loss, G_train], feed_dict={
G_input: np.random.uniform(-1., 1., size=[batch_size, 100])})
current_discriminator_loss, _ = sess.run([D_loss, D_train],
feed_dict={G_input: np.random.uniform(-1., 1.,
size=[batch_size,
100]),
real_image: real_picture})
# 每隔100轮保存一下生成器生成的图片
if i % 100 == 0:
# np.random.uniform()生成大小为【self.batch_size, 100】的均匀分布,作为生成网络的输入
generate_pictures = sess.run(G_image,
feed_dict={G_input: np.random.uniform(-0.5, 0.5,
size=[batch_size, 100])})
# generate_pictures为16 x 784 的矩阵,每一行表示一张图片,从中随机抽取一张保存下来
import matplotlib.pyplot as plt

fig = plot(generate_pictures)
plt.savefig('output_pictures/image{}.jpg'.format(str(i // 100)), bbox_inches='tight')
plt.close(fig)
# 显示单张图
single_picture = generate_pictures[0]
# 这里因为生成图片的最后一层采用的是signoid函数,输出值为0-1,而像素值是0-255,所以乘以255
single_picture = np.reshape(single_picture, (28, 28)) * 255
cv2.imwrite("output_pictures/A{}.jpg".format(str(i // 100)), single_picture)

# 输入连个网络当前的损失
print("Iterations: " + str(i) + ",the D_loss is %.4f,  and the G_loss is %.4f" % (
current_discriminator_loss, current_generate_loss))
阅读更多
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: 
相关文章推荐