您的位置:首页 > 其它

Scikit-learn中使用SVM对文本进行分类

2017-09-14 20:59 846 查看

(一)背景

  本人最近弄了两个和svm算法有关的大作业,一个是处理手写数字识别的,另外一个是文本分类的。最开始,我用libsvm提供的包进行分类。但是总是会出现分成一类的情况。有时候数据归一化之后,就不会分成一类,能够正常分类;但有时候源数据能正常分类,而归一化数据之后却会分成只有一类的情况。

  这让我感到很苦恼,并且了解了一下SVM算法,但是仍然不知道问题所在。后来在大神指点下,使用机器学习框架scikit-learn中自带的svm算法,能够有效的避免这种分成一类的现象,下面以某个文本分类为例子,对scikit-learn中的svm算法用法进行讲解。

(二)系统环境及scikit-learn安装

  本人使用的ubuntu系统,在安装skicit-learn之前,请保证安装了Python3(尽量最新版),以及pip3(便于Python模块的安装)。在安装了Python3和pip3的前提下,要安装skicit-learn,得先安装numpy和scipy模块。

  1.安装numpy模块

sudo pip3 install numpy


  2.安装scipy模块

sudo pip3 install scipy


  3.安装scikit-learn

sudo pip3 install scikit-learn


  也可以先下载好相应的whl文件,然后用pip3命令进行安装

(三)文本分类

  先给出文本分类的数据,包括两个文件,一个是train_data.txt 和predict_data.txt文件,这两个文件的格式是标准的libsvm数据文件的格式,即:

lable1 dimension1:value1 dimension2:value2......
lable2 dimension1:value1 dimension2:value2......


数据网盘地址为:http://pan.baidu.com/s/1hrLoHMk

提取码为:yjxx

1.scikit-learn中svm介绍

  scikit-learn中用于svm分类的类包括三个类,SVC,NuSVC和LinearSVC这三个类。SVC和NuSVC两种方法类似,但是接受的参数有细微不同,而且底层数学原理不一样。LinearSVC是指核函数为线性核的SVM。

2.SVC用法介绍

  要使用SVC需要将数据的标签类别和维向量分开保存,例如文本类别用y数组包括,文本内容向量用x数组保存。然后再使用numpy.array对数据进行规整,编程标准的数组格式,再调用SVC的fit()函数训练出模型。训练好模型之后用SVC.predict()函数根据文本内容向量预测出文本的类别。

其中SVC.fit(train_y, train_x)中参数train_y为文本类别,是一个一维数组;train_x为文本内容向量,是一个二维数组。经过fit()函数处
理后,模型就会保存在SVC中(也许SVC类中有一个model变量)。fit()函数无返回值。

SVC.predict(predict_x),参数predict_x是需要分类的文本内容的向量,是一个二维数组。predict()函数的返回值是一个文本类别数组。


3.文本分类代码:

__author__ = 'liuwei'

import os
import pickle
import numpy as np
from sklearn.svm import SVC

class DataProc(object):
'''get datas from the file'''

def __init__(self, fileName):
self.fileName = fileName
self.readData()

def readData(self):
self.main_data = []

with open(self.fileName, 'r') as file:
data = file.readlines()

for item in data:
item = item.rstrip('\n')                                   #去除尾部换行符
item = item.rstrip(' ')                                    #去除尾部空格
items = item.split(' ')

f_items = []

for son_item in items:
f_items.append(float(son_item))                        #将字符类型转化为浮点类型

self.main_data.append(f_items)

self.main_data = np.array(self.main_data)                          #转换为二维数组

def getY(self):
'''get all labels,is the first column of the array'''

self.y_data = self.main_data[:,0]                                  #获取所有文本的类别便签,即第一列

return self.y_data

def getX(self):
'''get the data of text except the first column'''

self.x_data = self.main_data[:,1:]                                 #获取所有文本的内容向量,除第一列以外的列

return self.x_data

class SVM(object):
'''use the svm algorithm in the skicit-learn to predict the class of text'''

def __init__(self, train_dataproc, predict_dataproc):
self.__train_data = train_dataproc
self.__pred_data = predict_dataproc
self.__svc = SVC()

def train(self):
'''train data,and general a model'''

train_y = self.__train_data.getY()                                  #所有训练文本类别
train_x = self.__train_data.getX()                                  #所有训练文本内容向量

self.__svc.fit(train_x, train_y)                                    #开始训练,

#we can save the model in .pkl file
self.model_presistence()                                            #将SVC对象持久化,相当于将模型持久化

def predict(self):
'''predict the data'''

predict_y = self.__pred_data.getY()                                 #
predict_x = self.__pred_data.getX()

test_data = np.array(predict_x)

res = self.__svc.predict(test_data)                                 #开始预测,结果在res中

accu = 0

for i in range(len(predict_y)):
if predict_y[i] == res[i]:                                      #统计正确率
accu += 1

accu = accu / len(predict_y)

print('the accuracy is %f' %accu)                                   #输出正确率

def model_presistence(self):
'''save a model in .pkl file'''

fileObject = open('SVM.pkl', 'wb')
pickle.dump(self.__svc, fileObject)                                 #将SVC持久化
fileObject.close()

def read_model(self):
'''load a model from a .pkl file'''

fileName = 'SVM.pkl'
fileObject = open(fileName, 'rb')
self.__svc = pickle.load(fileObject)                                #读取SVC

if __name__ == '__main__':
train_data = DataProc('/home/liuwei/python-space/datas/train_data.txt')
predict_data = DataProc('/home/liuwei/python-space/datas/predict_data.txt')

svm = SVM(train_data, predict_data)

#开始训练
svm.train()

#开始预测
svm.predict()


4.运行结果:



可以看出,分类的正确率为67.8%
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: