您的位置:首页 > 编程语言 > Python开发

写个python脚本下载并解压 MNIST, CIFAR10, SVHN 数据集(2)

2017-06-13 15:42 459 查看
【UpdateTime:201706013】写个python脚本下载并解压 MNIST, CIFAR10 和 SVHN  数据集
一、本文目的MNIST之于机器学习&&深度学习,就相当于cout<<"hello world"之于编程(引用于tensorflow教程)。最近刚入门深度学习,当然也不忘学习机器学习,接触了各种MNIST相关的案例。本文的主要贡献是基于上一篇博文《写个python脚本下载并解压 MNIST 数据集(1)》的基础上,进一步扩展,编写一个能下载MNIST、Cifar10和SVHN数据集的代码,思想类似。

本文涉及的相关插件,请看脚本最前面的import相关内容。由于本文实验之前安装过多种深度学习的框架,所以一些相关的插件也都已经存在于系统中。倘若读者遇到什么问题,可以根据提示安装相关的插件(pip install xxx)
数据集简介如下:1、MNIST:== 简介:MNIST是一个手写数字数据库,它有60000个训练样本集和10000个测试样本集。它是NIST数据库的一个子集。== 本文使用的 MNIST 数据集格式:-- 'train-images-idx3-ubyte.gz'          // 60000个训练样本集 -- 'train-labels-idx1-ubyte.gz'           // 60000个训练样本集对应的标签-- 't10k-images-idx3-ubyte.gz'         // 10000个测试样本集-- 't10k-labels-idx1-ubyte.gz'           // 10000个测试样本集对应的标签== 参考网址:1)基于openCV,用C++编写程序转化MNIST(附源码):http://blog.csdn.net/fengbingchun/article/details/49611549
2、Cifar10:== 简介:Cifar-10是由Hinton的两个大弟子Alex Krizhevsky、Ilya Sutskever收集的一个用于普适物体识别的数据集。由60000张32*32的RGB彩色图片构成,共10个分类。50000张训练,10000张测试(交叉验证)。这个数据集最大的特点在于将识别迁移到了普适物体,而且应用于多分类。此外还有Cifar100等。== 本文使用的 Cifar10 数据集格式:-- cifar-10-python.tar.gz    // 解压前-- cifar-10-batches-py       // 解压后
3、SVHN:== 简介:一个真实世界(谷歌街景)的街道门牌号数字识别数据集。超过600000张带有标签的数字图像。== 本文使用的 SVHN 数据集格式:-- test_32x32.mat  -- train_32x32.mat官网:http://ufldl.stanford.edu/housenumbers/

二、环境
1、Ubuntu环境:http://blog.csdn.net/houchaoqun_xmu/article/details/724531872、Anaconda2:http://blog.csdn.net/houchaoqun_xmu/article/details/72461592
三、代码
# Copyright 20170611 . All Rights Reserved.
# author: Chaoqun Hou
# Prerequisites:
# Python 2.7
# gzip, subprocess, numpy
#
# ==============================================================================
"""Functions for downloading and uzip MNIST data."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import gzip
import tarfile
import subprocess
import os
import numpy
from six.moves import urllib

def maybe_download(filename, data_dir, SOURCE_URL):
"""Download the data from Yann's website, unless it's already here."""
filepath = os.path.join(data_dir, filename)
if not os.path.exists(filepath):
filepath, _ = urllib.request.urlretrieve(SOURCE_URL + filename, filepath)
statinfo = os.stat(filepath)
print('Successfully downloaded', filename, statinfo.st_size, 'bytes.')

def check_file(data_dir):
if os.path.exists(data_dir):
return True
else:
os.mkdir(data_dir)
return False

def uzip_data(decompression_command, decompression_optional, target_path):
# uzip mnist data
cmd = [decompression_command, decompression_optional, target_path]
print('decompression', target_path)
subprocess.call(cmd)

def mnist_download(data_dir):
if check_file(data_dir):
print(data_dir)
print('dir mnist already exist.')

# delete the dir mnist
# if mnist is existed in your dir, it would delete it and download it again.
cmd = ['rm', '-rf', data_dir]
print('delete the dir', data_dir)
subprocess.call(cmd)
os.mkdir(data_dir)

SOURCE_URL = 'http://yann.lecun.com/exdb/mnist/'
data_keys = ['train-images-idx3-ubyte.gz', 'train-labels-idx1-ubyte.gz', 't10k-images-idx3-ubyte.gz', 't10k-labels-idx1-ubyte.gz']
for filename in data_keys:
if os.path.isfile(os.path.join(data_dir, filename)):
print("[warning...]", filename, "already exist.")
else:
maybe_download(filename, data_dir, SOURCE_URL)

# uzip the mnist data.
uziped_data_keys = ['train-images-idx3-ubyte', 'train-labels-idx1-ubyte', 't10k-images-idx3-ubyte', 't10k-labels-idx1-ubyte']
for filename in uziped_data_keys:
if os.path.isfile(os.path.join(data_dir, filename)):
print("[warning...]", filename, "already exist.")
else:
target_path = os.path.join(data_dir, filename)
uzip_data('gzip', '-d', target_path)

def cifar10_download(data_dir):
if check_file(data_dir):
print
c59c
(data_dir)
print('dir mnist already exist.')

# delete the dir mnist
cmd = ['rm', '-rf', data_dir]
print('delete the dir', data_dir)
subprocess.call(cmd)
os.mkdir(data_dir)

SOURCE_URL = 'https://www.cs.toronto.edu/~kriz/'
filename = 'cifar-10-python.tar.gz'

if os.path.isfile(os.path.join(data_dir, filename)):
print("[warning...]", filename, "already exist.")
else:
target_path = os.path.join(data_dir, filename)
target_url = os.path.join(SOURCE_URL, filename)
# maybe_download(filename, data_dir, SOURCE_URL)
cmd = ['curl', target_url, '-o', target_path]
print('Downloading CIFAR10')
subprocess.call(cmd)

decompressioned_name = 'cifar-10-batches-py'
if os.path.isfile(os.path.join(data_dir, decompressioned_name)):
print("[warning...]", decompressioned_name, "already exist.")
else:
print("data_dir = ",data_dir)
target_path = os.path.join(data_dir, filename)
print("target_path = ",target_path)
tarfile.open(target_path, 'r:gz').extractall(data_dir)

# The Street View House Numbers (SVHN) Dataset
# SVHN download
def svhn_download(data_dir):
import scipy.io as sio
# svhn file loader
def svhn_loader(url, path):
cmd = ['curl', url, '-o', path]
subprocess.call(cmd)
m = sio.loadmat(path)
return m['X'], m['y']

if check_file(data_dir):
print('SVHN was downloaded.')
else:
data_url = 'http://ufldl.stanford.edu/housenumbers/train_32x32.mat'
train_image, train_label = svhn_loader(data_url, os.path.join(data_dir, 'train_32x32.mat'))

data_url = 'http://ufldl.stanford.edu/housenumbers/test_32x32.mat'
test_image, test_label = svhn_loader(data_url, os.path.join(data_dir, 'test_32x32.mat'))

if __name__ == '__main__':
print("===== running - input_data() script =====")
print("Please input [ 1 ] to download the mnist database.")
print("Please input [ 2 ] to download the cifar10 database.")
print("Please input [ 3 ] to download the svhn database.")
print("Please input [ all ] to download the mnist, cifar10 and svhn database.")
user_input = raw_input()
print("your input is", user_input)

# There is no switch function in python.
#
# switch(user_input){
# case '1': print("Download MNIST.");
# case '2': print("Download Cifar10.");
# case '3': print("Download SVHN.");
# case 'all': print("Download MNIST, Cifar10 and SVHN.");
# }

if(user_input == '1'):
mnist_download("./mnist")
elif(user_input == '2'):
cifar10_download("./cifar10")
elif(user_input == '3'):
svhn_download("./svhn")
elif(user_input == 'all'):
mnist_download("./mnist")
cifar10_download("./cifar10")
svhn_download("./svhn")
else:
print("your input is not correct.")

print("============= =============")执行如下脚本即可,命令如下所示:
python get_mnist.py
执行结果如下所示(下载过程需要一些时间):

内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: