您的位置:首页 > 其它

CVAE(条件自编码) Condition GAN (条件GAN) 和 VAE-GAN模型之间的区别

2019-03-20 19:44 393 查看
版权声明:版权归世界上所有无产阶级所有 https://blog.csdn.net/qq_41776781/article/details/88697627

这个文件主要负责读取模型中所需要的数据,以及保存图像等,相关信息可以直接参考DCGAN, 有啥不明白的再问吧

 

[code]import math
import random

import scipy.misc
import numpy as np
from time import gmtime, strftime
from six.moves import xrange
import matplotlib.pyplot as plt
import os, gzip

import tensorflow as tf
import tensorflow.contrib.slim as slim

def load_mnist():
# 2019 可以选择不同的数据集
# data_dir = "../Dataset/fashion-mnist/"
data_dir = "../Dataset/mnist_data/"

def extract_data(filename, num_data, head_size, data_size):
with gzip.open(filename) as bytestream:
bytestream.read(head_size)
buf = bytestream.read(data_size * num_data)
data = np.frombuffer(buf, dtype=np.uint8).astype(np.float)
return data

data = extract_data(data_dir + 'train-images-idx3-ubyte.gz', 60000, 16, 28 * 28)
trX = data.reshape((60000, 28, 28, 1))

data = extract_data(data_dir + 'train-labels-idx1-ubyte.gz', 60000, 8, 1)
trY = data.reshape((60000))

data = extract_data(data_dir + 't10k-images-idx3-ubyte.gz', 10000, 16, 28 * 28)
teX = data.reshape((10000, 28, 28, 1))

data = extract_data(data_dir + 't10k-labels-idx1-ubyte.gz', 10000, 8, 1)
teY = data.reshape((10000))

trY = np.asarray(trY)
teY = np.asarray(teY)

X = np.concatenate((trX, teX), axis=0)
y = np.concatenate((trY, teY), axis=0).astype(np.int)

data_index = np.arange(X.shape[0])
print("*****************dataX**************", len(X))
np.random.shuffle(data_index)
X = X[data_index, :, :, :]
y = y[data_index]
y_vec = np.zeros((len(y), 10), dtype=np.float)
for i, label in enumerate(y):
y_vec[i, y[i]] = 1.0

return X / 255., y_vec

def check_folder(log_dir):
if not os.path.exists(log_dir):
os.makedirs(log_dir)
return log_dir

def get_image(image_path, input_height, input_width, resize_height=64, resize_width=64, crop=True, grayscale=False):
image = imread(image_path, grayscale)
return transform(image, input_height, input_width, resize_height, resize_width, crop)

def save_images(images, size, image_path):
return imsave(inverse_transform(images), size, image_path)

def imread(path, grayscale = False):
if (grayscale):
return scipy.misc.imread(path, flatten = True).astype(np.float)
else:
return scipy.misc.imread(path).astype(np.float)

def merge_images(images, size):
return inverse_transform(images)

def merge(images, size):
h, w = images.shape[1], images.shape[2]
if (images.shape[3] in (3,4)):
c = images.shape[3]
img = np.zeros((h * size[0], w * size[1], c))
for idx, image in enumerate(images):
i = idx % size[1]
j = idx // size[1]
img[j * h:j * h + h, i * w:i * w + w, :] = image
return img
elif images.shape[3]==1:
img = np.zeros((h * size[0], w * size[1]))
for idx, image in enumerate(images):
i = idx % size[1]
j = idx // size[1]
img[j * h:j * h + h, i * w:i * w + w] = image[:,:,0]
return img
else:
raise ValueError('in merge(images,size) images parameter ''must have dimensions: HxW or HxWx3 or HxWx4')

def imsave(images, size, path):
image = np.squeeze(merge(images, size))
return scipy.misc.imsave(path, image)

def center_crop(x, crop_h, crop_w, resize_h=64, resize_w=64):
if crop_w is None:
crop_w = crop_h
h, w = x.shape[:2]
j = int(round((h - crop_h)/2.))
i = int(round((w - crop_w)/2.))
return scipy.misc.imresize(x[j:j+crop_h, i:i+crop_w], [resize_h, resize_w])

def transform(image, input_height, input_width, resize_height=64, resize_width=64, crop=True):
if crop:
cropped_image = center_crop(image, input_height, input_width, resize_height, resize_width)
else:
cropped_image = scipy.misc.imresize(image, [resize_height, resize_width])
return np.array(cropped_image)/127.5 - 1.

def inverse_transform(images):
return (images+1.)/2.

def save_scattered_image(z, id, z_range_x, z_range_y, name='scattered_image.jpg'):
N = 10
plt.figure(figsize=(8, 6))
plt.scatter(z[:, 0], z[:, 1], c=np.argmax(id, 1), marker='o', edgecolor='none', cmap=discrete_cmap(N, 'jet'))
plt.colorbar(ticks=range(N))
axes = plt.gca()
axes.set_xlim([-z_range_x, z_range_x])
axes.set_ylim([-z_range_y, z_range_y])
plt.grid(True)
plt.savefig(name)

 

 

 

内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: 
相关文章推荐