您的位置:首页 > 其它

利用随机森林,xgboost,logistic回归,预测泰坦尼克号上面的乘客的获救概率

2017-04-25 14:29 691 查看
数据示例:

,PassengerId,Survived,Pclass,Name,Sex,Age,SibSp,Parch,Ticket,Fare,Cabin,Embarked,Embarked_C,Embarked_Q,Embarked_S,Embarked_U
0,1,0,3,"Braund, Mr. Owen Harris",1,22.0,1,0,A/5 21171,7.25,,S,0,0,1,0
1,2,1,1,"Cumings, Mrs. John Bradley (Florence Briggs Thayer)",0,38.0,1,0,PC 17599,71.2833,C85,C,1,0,0,0
2,3,1,3,"Heikkinen, Miss. Laina",0,26.0,0,0,STON/O2. 3101282,7.925,,S,0,0,1,0
3,4,1,1,"Futrelle, Mrs. Jacques Heath (Lily May Peel)",0,35.0,1,0,113803,53.1,C123,S,0,0,1,0
4,5,0,3,"Allen, Mr. William Henry",1,35.0,0,0,373450,8.05,,S,0,0,1,0
5,6,0,3,"Moran, Mr. James",1,23.8011805916,0,0,330877,8.4583,,Q,0,1,0,0
6,7,0,1,"McCarthy, Mr. Timothy J",1,54.0,0,0,17463,51.8625,E46,S,0,0,1,0
7,8,0,3,"Palsson, Master. Gosta Leonard",1,2.0,3,1,349909,21.075,,S,0,0,1,0
8,9,1,3,"Johnson, Mrs. Oscar W (Elisabeth Vilhelmina Berg)",0,27.0,0,2,347742,11.1333,,S,0,0,1,0
9,10,1,2,"Nasser, Mrs. Nicholas (Adele Achem)",0,14.0,1,0,237736,30.0708,,C,1,0,0,0
10,11,1,3,"Sandstrom, Miss. Marguerite Rut",0,4.0,1,1,PP 9549,16.7,G6,S,0,0,1,0
11,12,1,1,"Bonnell, Miss. Elizabeth",0,58.0,0,0,113783,26.55,C103,S,0,0,1,0
12,13,0,3,"Saundercock, Mr. William Henry",1,20.0,0,0,A/5. 2151,8.05,,S,0,0,1,0
13,14,0,3,"Andersson, Mr. Anders Johan",1,39.0,1,5,347082,31.275,,S,0,0,1,0
14,15,0,3,"Vestrom, Miss. Hulda Amanda Adolfina",0,14.0,0,0,350406,7.8542,,S,0,0,1,0
15,16,1,2,"Hewlett, Mrs. (Mary D Kingcome) ",0,55.0,0,0,248706,16.0,,S,0,0,1,0
16,17,0,3,"Rice, Master. Eugene",1,2.0,4,1,382652,29.125,,Q,0,1,0,0
17,18,1,2,"Williams, Mr. Charles Eugene",1,33.478692644,0,0,244373,13.0,,S,0,0,1,0
18,19,0,3,"Vander Planke, Mrs. Julius (Emelia Maria Vandemoortele)",0,31.0,1,0,345763,18.0,,S,0,0,1,0
19,20,1,3,"Masselmani, Mrs. Fatima",0,18.4510583333,0,0,2649,7.225,,C,1,0,0,0
20,21,0,2,"Fynney, Mr. Joseph J",1,35.0,0,0,239865,26.0,,S,0,0,1,0
21,22,1,2,"Beesley, Mr. Lawrence",1,34.0,0,0,248698,13.0,D56,S,0,0,1,0
22,23,1,3,"McGowan, Miss. Anna ""Annie""",0,15.0,0,0,330923,8.0292,,Q,0,1,0,0
23,24,1,1,"Sloper, Mr. William Thompson",1,28.0,0,0,113788,35.5,A6,S,0,0,1,0
24,25,0,3,"Palsson, Miss. Torborg Danira",0,8.0,3,1,349909,21.075,,S,0,0,1,0
25,26,1,3,"Asplund, Mrs. Carl Oscar (Selma Augusta Emilia Johansson)",0,38.0,1,5,347077,31.3875,,S,0,0,1,0
26,27,0,3,"Emir, Mr. Farred Chehab",1,34.8922936994,0,0,2631,7.225,,C,1,0,0,0
27,28,0,1,"Fortune, Mr. Charles Alexander",1,19.0,3,2,19950,263.0,C23 C25 C27,S,0,0,1,0
28,29,1,3,"O'Dwyer, Miss. Ellen ""Nellie""",0,22.8110194444,0,0,330959,7.8792,,Q,0,1,0,0
29,30,0,3,"Todoroff, Mr. Lalio",1,27.8541556913,0,0,349216,7.8958,,S,0,0,1,0
30,31,0,1,"Uruchurtu, Don. Manuel E",1,40.0,0,0,PC 17601,27.7208,,C,1,0,0,0
31,32,1,1,"Spencer, Mrs. William Augustus (Marie Eugenie)",0,38.0680685714,1,0,PC 17569,146.5208,B78,C,1,0,0,0
32,33,1,3,"Glynn, Miss. Mary Agatha",0,22.2371852543,0,0,335677,7.75,,Q,0,1,0,0
33,34,0,2,"Wheadon, Mr. Edward H",1,66.0,0,0,C.A. 24579,10.5,,S,0,0,1,0
34,35,0,1,"Meyer, Mr. Edgar Joseph",1,28.0,1,0,PC 17604,82.1708,,C,1,0,0,0
35,36,0,1,"Holverson, Mr. Alexander Oskar",1,42.0,1,0,113789,52.0,,S,0,0,1,0
# /usr/bin/python
# -*- encoding:utf-8 -*-

import xgboost as xgb
import numpy as np
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestRegressor
from sklearn.ensemble import RandomForestClassifier
import pandas as pd
import csv

def show_accuracy(a, b, tip):
acc = a.ravel() == b.ravel()
acc_rate = 100 * float(acc.sum()) / a.size
return acc_rate

def load_data(file_name, is_train):
data = pd.read_csv(file_name)  # 数据文件路径
# print 'data.describe() = \n', data.describe()

# 性别 将性别字段Sex中的值 female用0,male用1代替,类型 int
data['Sex'] = data['Sex'].map({'female': 0, 'male': 1}).astype(int)

# 补齐船票价格缺失值
if len(data.Fare[data.Fare.isnull()]) > 0:
fare = np.zeros(3)
for f in range(0, 3):
fare[f] = data[data.Pclass == f + 1]['Fare'].dropna().median()
for f in range(0, 3):  # loop 0 to 2
data.loc[(data.Fare.isnull()) & (data.Pclass == f + 1), 'Fare'] = fare[f]

# 年龄:使用均值代替缺失值
# mean_age = data['Age'].dropna().mean()
# data.loc[(data.Age.isnull()), 'Age'] = mean_age
if is_train:
# 年龄:使用随机森林预测年龄缺失值
print '随机森林预测缺失年龄:--start--'
data_for_age = data[['Age', 'Survived', 'Fare', 'Parch', 'SibSp', 'Pclass']]
age_exist = data_for_age.loc[(data.Age.notnull())]   # 年龄不缺失的数据
age_null = data_for_age.loc[(data.Age.isnull())]
# print 'data_for_age=\n', data_for_age
# print 'age_exis=\n', age_exist
# print 'age_null=\n',age_null
# print age_exist
x = age_exist.values[:, 1:]
y = age_exist.values[:, 0]
# print 'x = age_exist.values[:, 1:] 中 x=',x
# print 'y = age_exist.values[:, 0] 中 y=',y
#n_estimators 决策树的个数,越多越好,值越大,性能就会越差,但至少100
rfr = RandomForestRegressor(n_estimators=1000)
rfr.fit(x, y)
age_hat = rfr.predict(age_null.values[:, 1:])
# print age_hat
# print 'age_hat',age_hat
#填充年龄字段中值为空的
data.loc[(data.Age.isnull()), 'Age'] = age_hat
print '随机森林预测缺失年龄:--over--'
else:
print '随机森林预测缺失年龄2:--start--'
data_for_age = data[['Age', 'Fare', 'Parch', 'SibSp', 'Pclass']]
age_exist = data_for_age.loc[(data.Age.notnull())]  # 年龄不缺失的数据
age_null = data_for_age.loc[(data.Age.isnull())]
# print age_exist
x = age_exist.values[:, 1:]
y = age_exist.values[:, 0]
rfr = RandomForestRegressor(n_estimators=1000)
rfr.fit(x, y)
age_hat = rfr.predict(age_null.values[:, 1:])
# print age_hat
data.loc[(data.Age.isnull()), 'Age'] = age_hat
print '随机森林预测缺失年龄2:--over--'

# 起始城市
data.loc[(data.Embarked.isnull()), 'Embarked'] = 'S'  # 保留缺失出发城市
# print data['Embarked']
embarked_data = pd.get_dummies(data.Embarked)
# print embarked_data
embarked_data = embarked_data.rename(columns=lambda x: 'Embarked_' + str(x))
data = pd.concat([data, embarked_data], axis=1)
# print data.describe()
data.to_csv('New_Data.csv')

x = data[['Pclass', 'Sex', 'Age', 'SibSp', 'Parch', 'Fare', 'Embarked_C', 'Embarked_Q', 'Embarked_S']]
# x = data[['Pclass', 'Sex', 'Age', 'SibSp', 'Parch', 'Fare', 'Embarked']]
y = None
if 'Survived' in data:
y = data['Survived']

x = np.array(x)
y = np.array(y)

x = np.tile(x, (5, 1))
y = np.tile(y, (5, ))
if is_train:
return x, y
return x, data['PassengerId']

def write_result(c, c_type):
file_name = '14.Titanic.test.csv'
x, passenger_id = load_data(file_name, False)

if type == 3:
x = xgb.DMatrix(x)
y = c.predict(x)
y[y > 0.5] = 1
y[~(y > 0.5)] = 0

predictions_file = open("Prediction_%d.csv" % c_type, "wb")
open_file_object = csv.writer(predictions_file)
open_file_object.writerow(["PassengerId", "Survived"])
open_file_object.writerows(zip(passenger_id, y))
predictions_file.close()

def totalSurvival(y_hat,tip):
total=0
for index,value in enumerate(y_hat):
if value==1:
total=total+1
print tip,'存活:',total
print '人'

if __name__ == "__main__":
#加载并完善特征数据
x, y = load_data('14.Titanic.train.csv', True)
#划分训练集和测试集x表示样本特征集,y表示样本结果  test_size 样本占比,random_state 随机数的种子
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.5, random_state=1)

#print 'x_train=',x_train,'y_train=',y_train

#logistic回归
lr = LogisticRegression(penalty='l2')
lr.fit(x_train, y_train)
y_hat = lr.predict(x_test)
lr_rate = show_accuracy(y_hat, y_test, 'Logistic回归 ')
totalSurvival(y_hat,'Logistic回归')
#随机森林 n_estimators:决策树的个数,越多越好,不过值越大,性能就会越差,至少100
rfc = RandomForestClassifier(n_estimators=100)
rfc.fit(x_train, y_train)
y_hat = rfc.predict(x_test)
rfc_rate = show_accuracy(y_hat, y_test, '随机森林 ')
totalSurvival(y_hat,'随机森林')
# write_result(rfc, 2)

# XGBoost
data_train = xgb.DMatrix(x_train, label=y_train)
data_test = xgb.DMatrix(x_test, label=y_test)
watch_list = [(data_test, 'eval'), (data_train, 'train')]
param = {'max_depth': 6, 'eta': 0.8, 'silent': 1, 'objective': 'binary:logistic'}
bst = xgb.train(param, data_train, num_boost_round=100, evals=watch_list)
y_hat = bst.predict(data_test)
y_hat[y_hat > 0.5] = 1
y_hat[~(y_hat > 0.5)] = 0
xgb_rate = show_accuracy(y_hat, y_test, 'XGBoost ')
totalSurvival(y_hat,'xgboost')

print 'Logistic回归:%.3f%%' % lr_rate
print '随机森林:%.3f%%' % rfc_rate
print 'XGBoost:%.3f%%' % xgb_rate


结果:
Logistic回归 存活: 813人
随机森林 存活: 859人
xgboost 存活: 872人
准确率:
Logistic回归:78.770%
随机森林:98.160%
XGBoost:97.935%
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: