您的位置:首页 > 其它

模型实战:titannic (分类)(xgboost)(超参数搜索)

2019-03-13 20:46 381 查看

import pandas as pd
#本地读取训练集和测试集
train=pd.read_csv('train.csv')
test=pd.read_csv('test.csv')
#输出基本信息
print(train.info())
print(test.info())
selected_features=['Pclass','Sex','Age','Emarked','SibSip','Parch','Fare']
X_train=train[selected_features]
X_test=test[selected_features]
y_train=train['Survived']

一、填充缺失值

使用特征的众数或者平均值填充

#Emarked特征存在缺失值,需要补充
#某个特征,水平分类计数
print(X_train['Emarked'].value_counts())
print(X_test['Emarked'].value_counts())
#使用出现频率最高的水平填充缺失值
X_train['Emarked'].fillna('S',inplace=True)
X_test['Emarked'].fillna('S',inplace=True)

X_train['Age'].fillna(X_train['Age'].mean(),inplace=True)
X_test['Age'].fillna(X_test['Age'].mean(),inplace=True)
X_test['Fare'].fillna(X_train['Fare'].mean(),inplace=True)

X_train.info()
X_test.info()

二、处理类别特征,类别特征向量化
#特征向量化
from sklearn.feature_extraction import DictVectorizer
dict_vec=DictVectorizer(sparse=False)
X_train=dict_vec.fit_transform(X_train.to_dict(orient='record'))
dict_vec.feature_names_
X_test=dict_vec.transform(X_test.to_dict(orient='record'))

三、五折交叉验证,模型得分,可以用于选择模型

from xgboost import XGBClassifier
xgbc=XGBClassifier()
#使用五折交叉验证在训练集上对模型预测性能进行评估,五次得分的平均分
from sklearn.cross_validation import cross_val_score
cross_val_score(xgbc,X_train,y_train,cv=5).mean()

#xgboost 进行预测操作
xgbc.fit(X_train,y_train)
xgbc_y_predict=xgbc.predict(X_test)
xgbc_submission=pd.DataFrame({'PassengerId':test['PassengerId'],'Survived':xgbc_y_predict})
xgbc_submission.to_csv('/titanic/xgbc_submission.csv',index=False)

四、通过搜索最佳参数,对模型进行优化,网格搜索

#超参数搜索
from sklearn.grid_search import GridSearchCV
params={'max_depth':range(2,7),'n_estimators':range(100,1100,200),'learning_rate':[0.05,0.1,0.25,0.5,1.0]}
xgbc_best=XGBClassifier()
gs=GridSearchCV(xgbc_best,params,verbose=1,cv=5,n_jobs=-1)

#执行多线程并行网格搜索
gs.fit(X_train,y_train)
#所确定的最佳参数,以及模型的accuracy
print(gs.best_params_,gs.best_score_)
#输出最佳模型在测试集上的准确性
xgbc_best_y_predict=gs.predict(X_test)
xgbc_best_submission=pd.DataFrame({'PassengerId':test['PassengerId'],'Survived':xgbc_best_y_predict})
xgbc_submission.to_csv('/titanic/xgbc_best_submission.csv',index=False)

 

 

 

 

 

 

 

 

 

 

 


 

内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: