您的位置:首页 > 理论基础 > 计算机网络

解决Tensorflow读取MNIST数据集时网络超时问题

2018-02-10 18:12 701 查看
最近在学习TensorFlow,比较烦人的是使用
tensorflow.examples.tutorials.mnist.input_data
读取数据

from tensorflow.examples.tutorials.mnist import input_data

mnist = input_data.read_data_sets('/temp/mnist_data/')
X = mnist.test.images.reshape(-1, n_steps, n_inputs)
y = mnist.test.labels


时,经常出现网络连接错误



解决方法其实很简单,这里我们可以看一下
input_data.py
的源代码(这里截取关键部分)

def maybe_download(filename, work_directory):
"""Download the data from Yann's website, unless it's already here."""
if not os.path.exists(work_directory):
os.mkdir(work_directory)
filepath = os.path.join(work_directory, filename)
if not os.path.exists(filepath):
filepath, _ = urllib.request.urlretrieve(SOURCE_URL + filename, filepath)
statinfo = os.stat(filepath)
print('Successfully downloaded', filename, statinfo.st_size, 'bytes.')
return filepath


可以看到,代码会先检查文件是否存在,如果不存在再进行下载,那么我是不是自己下载数据不就行了?

MNIST的数据集是从Yann LeCun教授的官网下载,下载完成之后修改一下我们读取数据的代码,加上我们下载的路径即可

from tensorflow.examples.tutorials.mnist import input_data
import os

data_path = os.path.join('.', 'temp', 'data')
mnist = input_data.read_data_sets(datapath)
X = mnist.test.images.reshape(-1, n_steps, n_inputs)
y = mnist.test.labels


测试一下



成功!
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: