您的位置:首页 > 其它

TensorFlow 下 mnist 数据集的操作及可视化

2017-03-16 15:12 417 查看
from tensorflow.examples.tutorials.mnist import input_data


首先需要连网下载数据集:

mnsit = input_data.read_data_sets(train_dir='./MNIST_DATA', one_hot=True)
# 如果当前文件夹下没有 MNIST_DATA,会首先创建该文件夹,然后下载 mnist 数据集


训练集与测试集的划分:

X_train, y_train = mnist.train.images, mnist.train.labels
# 返回的 X_train 是 numpy 下的 多维数组,(55000, 784)
X_test, y_test = mnist.test.images, mnist.test.labels
# (10000, 784)
X_valid, y_valid = mnist.valid.images, mnist.valid.labels
# (5000, 784)


当然可以通过迭代的形式以一定 batch_size 读取数据:

mnist.train.next_batch(100)


mnist.train.next_batch() ⇒ 返回两个值,一个是图像数据,一个是图像数据对应的类别信息。

>> X_batch, y_batch = mnist.train.next_batch(100)
>> X_batch.shape
(100, 784)
>> y_batch.shape
(100, 10)                 # one hot 编码


1. 可视化

# images:9*(28*28) 的 numpy.ndarray
# y_ 表示其真实的标签信息
def plot_mnist_3_3(images, y_, y=None):
assert images.shape[0] == len(y_)
fig, axes = plt.subplots(3, 3)
for i, ax in enumerate(axes.flat):
ax.imshow(images[i].reshape(image_shp), cmap='binary')
if y is None:
xlabel = 'True: {}'.format(y_[i])
else:
xlabel = 'True: {0}, Pred: {1}'.format(y_[i], y[i])
ax.set_xlabel(xlabel)
ax.set_xticks([])
ax.set_yticks([])
plt.show()
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: 
相关文章推荐