您的位置:首页 > 数据库

解析mnist数据库

2016-06-01 16:41 447 查看

Python解析数据库成图片

import struct
import numpy as np
#import matplotlib.pyplot as plt
import Image
import sys

input_path = sys.argv[1] #mnist数据库解压后的所在路径
output_path = sys.argv[2] #生成的图片所在的路径

# =====read labels=====
label_file = input_path + '/train-labels.idx1-ubyte'
label_fp = open(label_file, 'rb')
label_buf = label_fp.read()

label_index=0
label_magic, label_numImages = struct.unpack_from('>II', label_buf, label_index)
label_index += struct.calcsize('>II')
labels = struct.unpack_from('>60000B', label_buf, label_index)

# =====read train images=====
label_map = {}
train_file = input_path + '/train-images.idx3-ubyte'
train_fp = open(train_file, 'rb')
buf = train_fp.read()

index=0
magic,numImages,numRows,numColumns=struct.unpack_from('>IIII',buf,index)
index+=struct.calcsize('>IIII')
k = 0
for image in range(0,numImages):
label = labels[k]
if(label_map.has_key(label)):
ids = label_map[label] + 1
label_map[label] += 1

else:
label_map[label] = 0
ids = 0
k += 1
if(label_map[label] > 50):
continue
im=struct.unpack_from('>784B',buf,index)
index+=struct.calcsize('>784B')

im=np.array(im,dtype='uint8')
im=im.reshape(28,28)
#fig=plt.figure()
#plotwindow=fig.add_subplot(111)
#plt.imshow(im,cmap='gray')
#plt.show()
im=Image.fromarray(im)
im.save(output_path + '/%s_%s.bmp'%(label, ids),'bmp')


Matlab解析

引自:http://blog.csdn.net/wangyuquanliuli/article/details/17378317

主程序

trainImages = loadMNISTImages('train-images.idx3-ubyte');
trainLabels = loadMNISTLabels('train-labels.idx1-ubyte');
N = 784;
K = 100;% can be any other value
testImages = loadMNISTImages('t10k-images.idx3-ubyte');
testLabels = loadMNISTLabels('t10k-labels.idx1-ubyte');
trainLength = length(trainImages);
testLength = length(testImages);
testResults = linspace(0,0,length(testImages));
compLabel = linspace(0,0,K);
tic;
for i=1:testLength
curImage = repmat(testImages(:,i),1,trainLength);
curImage = abs(trainImages-curImage);
comp=sum(curImage);
[sortedComp,ind] = sort(comp);
for j = 1:K
compLabel(j) = trainLabels(ind(j));
end
table = tabulate(compLabel);
[maxCount,idx] = max(table(:,2));
testResults(i) = table(idx);

disp(testResults(i));
disp(testLabels(i));
end
% Compute the error on the test set
error=0;
for i=1:testLength
if (testResults(i) ~= testLabels(i))
error=error+1;
end
end

%Print out the classification error on the test set
error/testLength
toc;
disp(toc-tic);


两个子程序

function images = loadMNISTImages(filename)
%loadMNISTImages returns a 28x28x[number of MNIST images] matrix containing
%the raw MNIST images
fp = fopen(filename, 'rb');
assert(fp ~= -1, ['Could not open ', filename, '']);
magic = fread(fp, 1, 'int32', 0, 'ieee-be');
assert(magic == 2051, ['Bad magic number in ', filename, '']);
numImages = fread(fp, 1, 'int32', 0, 'ieee-be');
numRows = fread(fp, 1, 'int32', 0, 'ieee-be');
numCols = fread(fp, 1, 'int32', 0, 'ieee-be');
images = fread(fp, inf, 'unsigned char');
images = reshape(images, numCols, numRows, numImages);
images = permute(images,[2 1 3]);
fclose(fp);
% Reshape to #pixels x #examples
images = reshape(images, size(images, 1) * size(images, 2), size(images, 3));
% Convert to double and rescale to [0,1]
images = double(images) / 255;
end


function labels = loadMNISTLabels(filename)
%loadMNISTLabels returns a [number of MNIST images]x1 matrix containing
%the labels for the MNIST images
fp = fopen(filename, 'rb');
assert(fp ~= -1, ['Could not open ', filename, '']);
magic = fread(fp, 1, 'int32', 0, 'ieee-be');
assert(magic == 2049, ['Bad magic number in ', filename, '']);
numLabels = fread(fp, 1, 'int32', 0, 'ieee-be');
labels = fread(fp, inf, 'unsigned char');
assert(size(labels,1) == numLabels, 'Mismatch in label count');
fclose(fp);
end


C++解析

引自:http://blog.csdn.net/fengbingchun/article/details/49611549

#include <iostream>
#include <fstream>

#include "opencv2/core/core.hpp"
#include "opencv2/highgui/highgui.hpp"
#include "opencv2/imgproc/imgproc.hpp"

using namespace std;

int ReverseInt(int i)
{
unsigned char ch1, ch2, ch3, ch4;
ch1 = i & 255;
ch2 = (i >> 8) & 255;
ch3 = (i >> 16) & 255;
ch4 = (i >> 24) & 255;
return((int) ch1 << 24) + ((int)ch2 << 16) + ((int)ch3 << 8) + ch4;
}

void read_Mnist(string filename, vector<cv::Mat> &vec)
{
ifstream file (filename, ios::binary);
if (file.is_open()) {
int magic_number = 0;
int number_of_images = 0;
int n_rows = 0;
int n_cols = 0;
file.read((char*) &magic_number, sizeof(magic_number));
magic_number = ReverseInt(magic_number);
file.read((char*) &number_of_images,sizeof(number_of_images));
number_of_images = ReverseInt(number_of_images);
file.read((char*) &n_rows, sizeof(n_rows));
n_rows = ReverseInt(n_rows);
file.read((char*) &n_cols, sizeof(n_cols));
n_cols = ReverseInt(n_cols);

for(int i = 0; i < number_of_images; ++i) {
cv::Mat tp = cv::Mat::zeros(n_rows, n_cols, CV_8UC1);
for(int r = 0; r < n_rows; ++r) {
for(int c = 0; c < n_cols; ++c) {
unsigned char temp = 0;
file.read((char*) &temp, sizeof(temp));
tp.at<uchar>(r, c) = (int) temp;
}
}
vec.push_back(tp);
}
}
}

void read_Mnist_Label(string filename, vector<int> &vec)
{
ifstream file (filename, ios::binary);
if (file.is_open()) {
int magic_number = 0;
int number_of_images = 0;
int n_rows = 0;
int n_cols = 0;
file.read((char*) &magic_number, sizeof(magic_number));
magic_number = ReverseInt(magic_number);
file.read((char*) &number_of_images,sizeof(number_of_images));
number_of_images = ReverseInt(number_of_images);

for(int i = 0; i < number_of_images; ++i) {
unsigned char temp = 0;
file.read((char*) &temp, sizeof(temp));
vec[i]= (int)temp;
}
}
}

string GetImageName(int number, int arr[])
{
string str1, str2;

for (int i = 0; i < 10; i++) {
if (number == i) {
arr[i]++;
char ch1[10];
sprintf(ch1, "%d", arr[i]);
str1 = std::string(ch1);

if (arr[i] < 10) {
str1 = "0000" + str1;
} else if (arr[i] < 100) {
str1 = "000" + str1;
} else if (arr[i] < 1000) {
str1 = "00" + str1;
} else if (arr[i] < 10000) {
str1 = "0" + str1;
}

break;
}
}

char ch2[10];
sprintf(ch2, "%d", number);
str2 = std::string(ch2);

str2 = str2 + "_" + str1;

return str2;
}

int main()
{
//reference: http://eric-yuan.me/cpp-read-mnist/ //test images and test labels
//read MNIST image into OpenCV Mat vector
string filename_test_images = "D:/Download/t10k-images-idx3-ubyte/t10k-images.idx3-ubyte";
int number_of_test_images = 10000;
vector<cv::Mat> vec_test_images;

read_Mnist(filename_test_images, vec_test_images);

//read MNIST label into int vector
string filename_test_labels = "D:/Download/t10k-labels-idx1-ubyte/t10k-labels.idx1-ubyte";
vector<int> vec_test_labels(number_of_test_images);

read_Mnist_Label(filename_test_labels, vec_test_labels);

if (vec_test_images.size() != vec_test_labels.size()) {
cout<<"parse MNIST test file error"<<endl;
return -1;
}

//save test images
int count_digits[10];
for (int i = 0; i < 10; i++)
count_digits[i] = 0;

string save_test_images_path = "D:/Download/MNIST/test_images/";

for (int i = 0; i < vec_test_images.size(); i++) {
int number = vec_test_labels[i];
string image_name = GetImageName(number, count_digits);
image_name = save_test_images_path + image_name + ".jpg";

cv::imwrite(image_name, vec_test_images[i]);
}

//train images and train labels
//read MNIST image into OpenCV Mat vector
string filename_train_images = "D:/Download/train-images-idx3-ubyte/train-images.idx3-ubyte";
int number_of_train_images = 60000;
vector<cv::Mat> vec_train_images;

read_Mnist(filename_train_images, vec_train_images);

//read MNIST label into int vector
string filename_train_labels = "D:/Download/train-labels-idx1-ubyte/train-labels.idx1-ubyte";
vector<int> vec_train_labels(number_of_train_images);

read_Mnist_Label(filename_train_labels, vec_train_labels);

if (vec_train_images.size() != vec_train_labels.size()) {
cout<<"parse MNIST train file error"<<endl;
return -1;
}

//save train images
for (int i = 0; i < 10; i++)
count_digits[i] = 0;

string save_train_images_path = "D:/Download/MNIST/train_images/";

for (int i = 0; i < vec_train_images.size(); i++) {
int number = vec_train_labels[i];
string image_name = GetImageName(number, count_digits);
image_name = save_train_images_path + image_name + ".jpg";

cv::imwrite(image_name, vec_train_images[i]);
}

return 0;
}
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签:  深度学习