您的位置:首页 > 其它

Keras笔记——关于DCGAN的实现

2017-06-16 10:55 169 查看
前几天学习了一下GAN的相关知识,有NIPS 2016中的教程

还有知乎专栏的令人拍案叫绝的Wasserstein GAN,以及后续Wasserstein GAN最新进展:从weight clipping到gradient penalty,更加先进的Lipschitz限制手法

这两篇文章推导写的很好,有需要推荐直接看论文

还有深入浅出 GAN·原理篇文字版(完整)|干货

先是找到了这个https://github.com/jacobgil/keras-dcgan

这个对应的keras版本相对较旧,里面有些需要修改,稍微提一下

先是generator和discriminator里语法需要修改下,对应目前的2.0版本

def generator_model():
model = Sequential()
model.add(Dense(units=1024, input_dim=100))
model.add(Activation('tanh'))
model.add(Dense(128*7*7))
model.add(BatchNormalization())
model.add(Activation('tanh'))
model.add(Reshape((7, 7, 128), input_shape=(128*7*7,)))
model.add(UpSampling2D(size=(2, 2)))
model.add(Conv2D(64, (5, 5), padding='same'))
model.add(Activation('tanh'))
model.add(UpSampling2D(size=(2, 2)))
model.add(Conv2D(1, (5, 5), padding='same'))
model.add(Activation('tanh'))
return model


def discriminator_model():
model = Sequential()
model.add(Conv2D(64, (5, 5),
padding='same',
input_shape=(28, 28, 1)))
model.add(Activation('tanh'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Conv2D(128, (5, 5)))
model.add(Activation('tanh'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Flatten())
model.add(Dense(1024))
model.add(Activation('tanh'))
model.add(Dense(1))
model.add(Activation('sigmoid'))
return model


还有一处是def train(BATCH_SIZE)中的

X_train = X_train.reshape((X_train.shape[0], 1) + X_train.shape[1:])


改为

X_train = X_train.reshape(X_train.shape[0], 28, 28, 1)
X_test = X_test.reshape(X_test.shape[0], 28, 28, 1)


这里是因为我选择tenseoflow作为后台,同时生成图像的函数要修改如下

def combine_images(generated_images):
num = generated_images.shape[0]
width = int(math.sqrt(num))
height = int(math.ceil(float(num)/width))
shape = generated_images.shape[1:3]
image = np.zeros((height*shape[0], width*shape[1]),
dtype=generated_images.dtype)
for index, img in enumerate(generated_images):
i = int(index/width)
j = index % width
image[i*shape[0]:(i+1)*shape[0], j*shape[1]:(j+1)*shape[1]] = \
img[ :, :, 0]
return image


还有一点是如果下载数据集的时候网不好(比如我),手动下载后放到~/.keras/datasets里即可

后来又发现了这个dcgan_minst.py

这个是基于tensorflow 1.0 和 keras 2.0版的,和前面的区别是网络结构有所不同,前面是论文中的DCGAN结构,这个有所修改



最终效果嘛,我只迭代了1000次,次数太少效果不太好哈~



1000次结果
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签:  Keras DCGAN