您的位置:首页 > 其它

SVM处理鸢尾花数据集

2020-06-25 22:50 44 查看

数据介绍:
feature 共四个属性:‘sepal length (cm)’, ‘sepal width (cm)’, ‘petal length (cm)’, ‘petal width (cm)’
target 共三类结果:‘setosa’ ‘versicolor’ ‘virginica’
任务要求,按照四类属性将数据划分为三类结果中去,即判断每一组数据属于哪一类

解题思路:
1.直接通过svm将数据进行分类。
2.将数据做特征处理,变复杂,映射到高维空间,再分类。

from sklearn.datasets import load_iris
data=load_iris()

print(data["feature_names"])
print("****************************************")
print(data['filename'])
print("****************************************")
print(data["target_names"])
print("****************************************")

[‘sepal length (cm)’, ‘sepal width (cm)’, ‘petal length (cm)’, ‘petal width (cm)’]

E:\Anaconda3\lib\site-packages\sklearn\datasets\data\iris.csv

[‘setosa’ ‘versicolor’ ‘virginica’]

划分训练测试数据集

from sklearn.model_selection import train_test_split
from sklearn.svm import SVC
from sklearn.model_selection import GridSearchCV
import numpy as np
from sklearn.preprocessing import scale

feature=data["data"]
feature=scale(feature,axis=0)
test=data["target"]
train_x,train_y,test_x,test_y=train_test_split(feature,test,test_size=0.3)

1.直接处理

svm=SVC(kernel="linear",max_iter=-1,class_weight='balanced',tol=0.02)
params_grid=[{"C":np.linspace(0,100,2000)}]
grid_search=GridSearchCV(svm,params_grid,cv=2,scoring="neg_mean_squared_error",return_train_score=True)
grid_search.fit(train_x,test_x)

from sklearn.metrics import mean_squared_error
import matplotlib.pyplot as plt
%matplotlib inline

def evaluation_mdoel(train_y,test_y):
predict_curr=grid_search.predict(train_y)
for i in range(4):
plt.scatter(train_y[:,i],predict_curr,label="predict")
plt.scatter(train_y[:,i],test_y,label="true")
plt.legend()
plt.show()
print(mean_squared_error(predict_curr,test_y))
evaluation_mdoel(train_y,test_y)

0.06666666666666667
四维的图不会画,凑活着画了一个散点图。
接下来对原数据做多项式处理同直接计算做一下比较。

from sklearn.preprocessing import PolynomialFeatures
poly=PolynomialFeatures()
target_new=poly.fit_transform(feature)
target_new=poly.fit_transform(target_new)
target_new=poly.fit_transform(target_new)
train_x,train_y,test_x,test_y=train_test_split(target_new,test,test_size=0.3)
grid_search.fit(train_x,test_x)
grid_search.best_params_
grid_search.best_estimator_
evaluation_mdoel(train_y,test_y)

0.022222222222222223
可见对原数据做多项式处理后精度更加准确,但是计算量会增加。

内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: