您的位置:首页 > 编程语言

代码备份:处理 SUN397 的代码,将其分为 80% 训练数据 以及 20% 的测试数据

2016-07-27 09:09 661 查看
处理SUN397 的代码,将其分为80% 训练数据以及20% 的测试数据

2016-07-27

1 %% Code for Process SUN397 Scene Classification
2 %     Just the a part : 24 kinds and 6169 images total
3 %     used for train a initial classifier and predict the additional dataset.
4 clc;
5 impath = '/home/wangxiao/Downloads/SUN397/SUN397/a/';
6 files = dir(impath);
7 label = -1 ;
8
9 train_fid = fopen('/home/wangxiao/Downloads/SUN397/selected_sun/train_list.txt', 'a');
10 test_fid   = fopen('/home/wangxiao/Downloads/SUN397/selected_sun/test_list.txt', 'a');
11
12 train_im_savePath = '/home/wangxiao/Downloads/SUN397/selected_sun/train_images/' ;
13 test_im_savePath = '/home/wangxiao/Downloads/SUN397/selected_sun/test_images/' ;
14
15 for i = 3:size(files, 1)
16 %     disp( [' ==> disp current ', num2str(i-2), '/', num2str(size(files, 1) - 2) , ' waiting . . . ' ]) ;
17     label = label + 1;
18     category = files(i).name ;
19     newPath = [impath, category, '/'] ;
20     images = dir([newPath, '*.jpg']) ;
21
22     for j = 1:size(images, 1)
23         disp( [' ==> deal with Class: ', num2str(i-2), '        ==> disp image:  ', num2str(j), '/', num2str(size(images, 1) - 2) , ' waiting . . . ' ]) ;
24         num_per_kind = size(images, 1) - 2 ;
25         random_num = randperm(size(images, 1)) ;
26
27         num_train = round( num_per_kind * 0.8 ) ;     %% number of train data
28         num_test   = round ( num_per_kind * 0.2 )  ;   %% number of test data
29
30         %% train data
31
32         if j <=  num_train
33
34             idx = random_num(j) ;
35             trainImage_name =  images(idx).name ;
36             im = imread([newPath, trainImage_name]);
37             im = imresize(im, [256, 256]) ;
38             imwrite( im, [train_im_savePath, trainImage_name]) ;
39             fprintf(train_fid, '%s ' ,  num2str(trainImage_name) ) ;
40             fprintf(train_fid, '%s ', ' ') ;
41             fprintf(train_fid, '%s \n', num2str(label)) ;
42         else
43             if j <  num_per_kind
44                 idx2 = random_num(j) ;
45                 testImage_name =  images(idx2).name ;
46                 im2 = imread([newPath, testImage_name]);
47                 im2 = imresize(im2, [227, 227]) ;
48                 imwrite( im2, [test_im_savePath, testImage_name]) ;
49                 fprintf(test_fid, '%s ' ,  num2str(testImage_name) ) ;
50                 fprintf(test_fid, '%s ', ' ') ;
51                 fprintf(test_fid, '%s \n', num2str(label)) ;
52             else
53                 break;
54             end
55         end
56
57
58
59
60
61
62     end
63
64 end


  

path = '/home/wangxiao/Downloads/SUN397/Sun-100/';
file1 = importdata([path, 'Sun_100_Labeled_Train_0.5_.txt' ]);
file2 = importdata([path, 'Sun_100_UnLabel_Train_0.5_.txt' ]);
file3 = importdata([path, 'Sun_100_Test_0.5_.txt' ]);

%% return the index of searched vector.
[C, ia, ic] = unique(file1.data) ;
labelMatrix = zeros(size(file1.data)) ;
for i = 1:size(ia, 1)
count = i-1;
index_1 = ia(i, 1) ; % start index
index_2 = ia(i+1, 1) ; % end index
labelMatrix(index_1:index_2, 1) = count ;
end
% select 80 classes.
select_labelMatrix = labelMatrix(1:9060) ;

%% return the index of searched vector.
[C, ia, ic] = unique(file2.data) ;
labelMatrix = zeros(size(file2.data)) ;
for i = 1:size(ia, 1)
count = i-1;
index_1 = ia(i, 1) ; % start index
index_2 = ia(i+1, 1) ; % end index
labelMatrix(index_1:index_2, 1) = count ;
end
% select 80 classes.
select_labelMatrix_2 = labelMatrix(1:9180) ;

%% return the index of searched vector.
[C, ia, ic] = unique(file3.data) ;
labelMatrix = zeros(size(file3.data)) ;
for i = 1:size(ia, 1)
count = i-1;
index_1 = ia(i, 1) ; % start index
index_2 = ia(i+1, 1) ; % end index
labelMatrix(index_1:index_2, 1) = count ;
end
% select 80 classes.
select_labelMatrix_3 = labelMatrix(1:4560) ;

%% save the selected 80 classes into txt files.
savePath = '/home/wangxiao/Downloads/SUN397/Sun-100/';
fid1 = fopen([savePath, 'Sun_80_50%_Labeled_data.txt'], 'a');
fid2 = fopen([savePath, 'Sun_80_50%_Unlabeled_data.txt'], 'a');
fid3 = fopen([savePath, 'Sun_80_50%_test_data.txt'], 'a');

for i = 1:size(select_labelMatrix, 1)
imageName = file1.textdata{i, 1} ;
imageLabel = select_labelMatrix(i, 1) ;
fprintf(fid1, '%s ', num2str(imageName)) ;
fprintf(fid1, '%s\n ', num2str(imageLabel)) ;
end

for i = 1:size(select_labelMatrix_2, 1)
imageName = file2.textdata{i, 1} ;
imageLabel = select_labelMatrix_2(i, 1) ;
fprintf(fid2, '%s ', num2str(imageName)) ;
fprintf(fid2, '%s\n ', num2str(imageLabel)) ;
end

for i = 1:size(select_labelMatrix_3, 1)
imageName = file3.textdata{i, 1} ;
imageLabel = select_labelMatrix_3(i, 1) ;
fprintf(fid3, '%s ', num2str(imageName)) ;
fprintf(fid3, '%s\n ', num2str(imageLabel)) ;
end


  
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: