您的位置:首页 > 其它

基于Tensorflow和DCGAN生成动漫头像实践(一)

2018-03-30 18:48 309 查看
前言:学习tensorflow和深度学习有一段时间了,一直停留在运行别人的代码和跑mnsit和cifar10数据集上,决定从简单的动漫头像生成着手代码,经过无数的debug后终于完成大概,此间主要参考的有以下两个代码,一个是别人写的DCGAN动漫头像生成,另一个是pix2pix的tensorflow实现代码。
动漫头像生成:https://blog.csdn.net/sinat_33741547/article/details/77871170?locationNum=5&fps=1阿城
pix2pix代码:https://github.com/affinelayer/pix2pix-tensorflow/blob/master/pix2pix.py

说明:本部分是数据是数据处理部分,采用的数据是别人提取好的动漫头像,共50000多张,将这些图片转化为tensorflow官方的标准数据TFrecord格式,这个格式的在tensorflow处理的时侯读取速度会快不少
数据来源

百度网盘  密码:g5qa

代码#!/usr/bin/env python2
# -*- coding: utf-8 -*-
'''
读取图片数据并转化为tensorflow官方的TFrecord格式
'''
import tensorflow as tf
import os
import sys
import time

def _int64_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

def _bytes_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

def get_TF():
train_dir = "./faces/" #定义读取图片的路径
data = []
for file in os.listdir(train_dir): #将图片的路径存储到data list中
data.append(train_dir+file)

stdi,stdo,stde=sys.stdin,sys.stdout,sys.stderr #如果没有这部分会提示编码错误
reload(sys) #python3的reload在其他包中
sys.setdefaultencoding('utf-8')
sys.stdin,sys.stdout,sys.stderr=stdi,stdo,stde #改正reload之后print输出不了的问题

sess=tf.Session()
file_at = 0
start_time = time.time()
for i in range(len(data)):

image_path = data[i] #枚举每个图片的路径
image_raw_data = tf.gfile.FastGFile(image_path,'r').read()
img_data = tf.image.decode_jpeg(image_raw_data,channels=3) #将读取到的图片按照jpeg的格式解压成tensor的形式
img_data = img_data.eval(session=sess)
image_raw = img_data.tobytes() #将图片的tensor变成字符串

example = tf.train.Example(features=tf.train.Features(feature={ #构造TFrecord形式的example
'height':_int64_feature(img_data.shape[0]),
'width':_int64_feature(img_data.shape[1]),
'channel':_int64_feature(img_data.shape[2]),
'image_raw':tf.train.Feature(bytes_list=tf.train.BytesList(value=[image_raw])) #之后需要的只有'image_raw',其他可以不定义
}))
if i % 500 == 0: #500个example存储为一个TFrecord文件
file_at += 1
filename = ("./TFrecord/data-tfrecords-%.5d" % file_at)
if i>0:
writer.close()
writer = tf.python_io.TFRecordWriter(filename)
print("%d steps,using time %f" % (i,time.time()-start_time))
start_time =time.time()
writer.write(example.SerializeToString()) #将examples写入TFrecord文件

writer.close()

get_TF()
在程序实际运行的时候,一开始处理很快,但是后来生成一个TFrecord文件就越运行越慢,查了资料没发现其他人有出现这个问题,没有解决。当然,也可以直接读取原图片训练。
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签:  深度学习 GAN TFrecord