CVAE(条件自编码) Condition GAN (条件GAN) 和 VAE-GAN模型之间的区别
2019-03-20 19:44
393 查看
版权声明:版权归世界上所有无产阶级所有 https://blog.csdn.net/qq_41776781/article/details/88697627
这个文件主要负责读取模型中所需要的数据,以及保存图像等,相关信息可以直接参考DCGAN, 有啥不明白的再问吧
[code]import math import random import scipy.misc import numpy as np from time import gmtime, strftime from six.moves import xrange import matplotlib.pyplot as plt import os, gzip import tensorflow as tf import tensorflow.contrib.slim as slim def load_mnist(): # 2019 可以选择不同的数据集 # data_dir = "../Dataset/fashion-mnist/" data_dir = "../Dataset/mnist_data/" def extract_data(filename, num_data, head_size, data_size): with gzip.open(filename) as bytestream: bytestream.read(head_size) buf = bytestream.read(data_size * num_data) data = np.frombuffer(buf, dtype=np.uint8).astype(np.float) return data data = extract_data(data_dir + 'train-images-idx3-ubyte.gz', 60000, 16, 28 * 28) trX = data.reshape((60000, 28, 28, 1)) data = extract_data(data_dir + 'train-labels-idx1-ubyte.gz', 60000, 8, 1) trY = data.reshape((60000)) data = extract_data(data_dir + 't10k-images-idx3-ubyte.gz', 10000, 16, 28 * 28) teX = data.reshape((10000, 28, 28, 1)) data = extract_data(data_dir + 't10k-labels-idx1-ubyte.gz', 10000, 8, 1) teY = data.reshape((10000)) trY = np.asarray(trY) teY = np.asarray(teY) X = np.concatenate((trX, teX), axis=0) y = np.concatenate((trY, teY), axis=0).astype(np.int) data_index = np.arange(X.shape[0]) print("*****************dataX**************", len(X)) np.random.shuffle(data_index) X = X[data_index, :, :, :] y = y[data_index] y_vec = np.zeros((len(y), 10), dtype=np.float) for i, label in enumerate(y): y_vec[i, y[i]] = 1.0 return X / 255., y_vec def check_folder(log_dir): if not os.path.exists(log_dir): os.makedirs(log_dir) return log_dir def get_image(image_path, input_height, input_width, resize_height=64, resize_width=64, crop=True, grayscale=False): image = imread(image_path, grayscale) return transform(image, input_height, input_width, resize_height, resize_width, crop) def save_images(images, size, image_path): return imsave(inverse_transform(images), size, image_path) def imread(path, grayscale = False): if (grayscale): return scipy.misc.imread(path, flatten = True).astype(np.float) else: return scipy.misc.imread(path).astype(np.float) def merge_images(images, size): return inverse_transform(images) def merge(images, size): h, w = images.shape[1], images.shape[2] if (images.shape[3] in (3,4)): c = images.shape[3] img = np.zeros((h * size[0], w * size[1], c)) for idx, image in enumerate(images): i = idx % size[1] j = idx // size[1] img[j * h:j * h + h, i * w:i * w + w, :] = image return img elif images.shape[3]==1: img = np.zeros((h * size[0], w * size[1])) for idx, image in enumerate(images): i = idx % size[1] j = idx // size[1] img[j * h:j * h + h, i * w:i * w + w] = image[:,:,0] return img else: raise ValueError('in merge(images,size) images parameter ''must have dimensions: HxW or HxWx3 or HxWx4') def imsave(images, size, path): image = np.squeeze(merge(images, size)) return scipy.misc.imsave(path, image) def center_crop(x, crop_h, crop_w, resize_h=64, resize_w=64): if crop_w is None: crop_w = crop_h h, w = x.shape[:2] j = int(round((h - crop_h)/2.)) i = int(round((w - crop_w)/2.)) return scipy.misc.imresize(x[j:j+crop_h, i:i+crop_w], [resize_h, resize_w]) def transform(image, input_height, input_width, resize_height=64, resize_width=64, crop=True): if crop: cropped_image = center_crop(image, input_height, input_width, resize_height, resize_width) else: cropped_image = scipy.misc.imresize(image, [resize_height, resize_width]) return np.array(cropped_image)/127.5 - 1. def inverse_transform(images): return (images+1.)/2. def save_scattered_image(z, id, z_range_x, z_range_y, name='scattered_image.jpg'): N = 10 plt.figure(figsize=(8, 6)) plt.scatter(z[:, 0], z[:, 1], c=np.argmax(id, 1), marker='o', edgecolor='none', cmap=discrete_cmap(N, 'jet')) plt.colorbar(ticks=range(N)) axes = plt.gca() axes.set_xlim([-z_range_x, z_range_x]) axes.set_ylim([-z_range_y, z_range_y]) plt.grid(True) plt.savefig(name)
相关文章推荐
- VAE、GAN、Info-GAN:全解深度学习三大生成模型
- Java虚拟机的理解与内存模型之间的区别
- 我的JavaScript回顾之路_01—0206—++在前在后区别/&&和||/条件判断语句/循环语句的区别/字符串类型数字和数字类型之间的转换
- Laravel的ORM模型的find(),findOrFail(),first(),firstOrFail(),get(),list(),toArray()之间的区别是什么?
- 各种编码之间的区别:ASCII、Unicode、UTF-8
- OSI和TCP/IP模型之间的区别-----无线网络通讯协议有哪些
- 字符编码之ASCII、Unicode以及utf-8之间的联系与区别
- 各种编码之间的区别 用法 总结
- 通过这几天的研究,终于明白了Unicode和UTF-8之间编码的区别。Unicode是一个字符集,而UTF-8是Unicode的其中一种,Unicode是定长的都为双字节,而UTF-8是可变的,对于
- 联合概率与联合分布、条件概率与条件分布、边缘概率与边缘分布、贝叶斯定理、生成模型(Generative Model)和判别模型(Discriminative Model)的区别
- 三大深度学习生成模型:VAE、GAN及其变种
- 三大深度学习生成模型:VAE、GAN及其变种
- 字符,字符集,编码之间的区别
- 编码问题,UTF,ISO8859-1,unicode,ACSii,GBK之间的区别
- JAVA 编码中文问题系统透彻讲解 UNICODE GBK UTF-8 ISO-8859-1 之间的区别
- 自己总结的 五种IO模型之间的关系与区别
- 卷积学习与传统稀疏编码、ICA模型学习区别(逐步补充)
- 深度学习的三大生成模型:VAE、GAN、GAN
- 机器学习填坑:你知道模型参数和超参数之间的区别吗?
- hive join on和where条件之间的区别