机器学习简易入门(四)- logistic回归
2016-03-23 09:33
501 查看
摘要:使用logistic回归来预测某个人的入学申请是否会被接受
声明:(本文的内容非原创,但经过本人翻译和总结而来,转载请注明出处)
本文内容来源:https://www.dataquest.io/mission/59/logistic-regression
原始数据展示
这是一份美国入学申请的录取记录表,admit – 是否录取,1代表录取,0代表否定;gpa – gpa成绩,gre – 绩点
在之前已经介绍过了线性回归,现在同样使用线性回归来进行预测
在上图中可见,有些预测结果小于0,而这明显是不对的,因为预测结果应该只能为0或者1,我们现在需要获取一个介于0和1之间的概率,然后通过之前的文章中介绍过的分类算法(机器学习简易入门(二)- 分类)来确定录取一个人的概率的阀值来决定录取结果,最终生成只有0和1的结果
logistic回归函数
logistic回归产生的输出都位于0和1之间,通常用来产生预测某个事件的发生概率,该函数的格式为
,其中的e是一个无理数常量,该函数有一个很漂亮的形状
logistic回归
在线性回归方程中
,可以将该方程产生的结果y放入到logistic回归方程,从而将线性方程产生的结果转换为一个概率
,对于本文来说,这个logistic回归方程为
,现在根据这个logistic回归方程就能产生一个录取概率。
类似于之前使用scikit-learn库中的线性回归,现在也可以直接使用该库中的logistic回归
评估模型
准确率
现在假设只要录取概率大于0.5的就能录取,计算一下这个模型的准确性
ROC曲线
分别计算训练集和测试集的ROC曲线和AUC
声明:(本文的内容非原创,但经过本人翻译和总结而来,转载请注明出处)
本文内容来源:https://www.dataquest.io/mission/59/logistic-regression
原始数据展示
这是一份美国入学申请的录取记录表,admit – 是否录取,1代表录取,0代表否定;gpa – gpa成绩,gre – 绩点
import pandas admissions = pandas.read_csv('admissions.csv')
在之前已经介绍过了线性回归,现在同样使用线性回归来进行预测
from sklearn.linear_model import LinearRegression model = LinearRegression() #训练模型 model.fit(admissions[['gre', 'gpa']], admissions["admit"]) admit_prediction = model.predict(admissions[['gre', 'gpa']]) plt.xlabel('gpa') plt.ylabel('admit_prediction') plt.scatter(admissions["gpa"], admit_prediction) plt.show()
在上图中可见,有些预测结果小于0,而这明显是不对的,因为预测结果应该只能为0或者1,我们现在需要获取一个介于0和1之间的概率,然后通过之前的文章中介绍过的分类算法(机器学习简易入门(二)- 分类)来确定录取一个人的概率的阀值来决定录取结果,最终生成只有0和1的结果
logistic回归函数
logistic回归产生的输出都位于0和1之间,通常用来产生预测某个事件的发生概率,该函数的格式为
,其中的e是一个无理数常量,该函数有一个很漂亮的形状
# logistic回归函数 def logit(x): return np.exp(x) / (1 + np.exp(x)) # 在-6到6之间等差产生50个数 t = np.linspace(-6,6,50, dtype=float) ylogit = logit(t) #作图 plt.plot(t, ylogit, label="logistic") plt.ylabel("Probability") plt.xlabel("t") plt.title("Logistic Function") plt.show()
logistic回归
在线性回归方程中
,可以将该方程产生的结果y放入到logistic回归方程,从而将线性方程产生的结果转换为一个概率
,对于本文来说,这个logistic回归方程为
,现在根据这个logistic回归方程就能产生一个录取概率。
类似于之前使用scikit-learn库中的线性回归,现在也可以直接使用该库中的logistic回归
from sklearn.linear_model import LogisticRegression #对数据集进行随机重排序 admissions = admissions.loc[np.random.permutation(admissions.index)] # 将随机排序后的前700条数据作为训练集,后面的作为测试集 num_train = 700 data_train = admissions[:num_train] data_test = admissions[num_train:] logistic_model = LogisticRegression() logistic_model.fit(data_train[['gpa', 'gre']], data_train['admit']) # 进行测试 fitted_test = logistic_model.predict_proba(data_test[['gpa', 'gre']])[:, 1] #因为predict_proba返回的是一个两列的矩阵,矩阵的每一行代表的是对一个事件的预测结果,第一列代表该事件不会发生的概率,第二列代表的是该事件会发生的概率。而这里需要的是第二列的数据 plt.scatter(data_test['gre'], fitted_test) plt.xlabel('gre') plt.ylabel('probability ') plt.show()
评估模型
准确率
现在假设只要录取概率大于0.5的就能录取,计算一下这个模型的准确性
# predict()函数会自动把阀值设置为0.5 predicted = logistic_model.predict(data_train[['gpa','gre']]) # 计算在训练集中正确预测的准确率 accuracy_train = (predicted == data_train['admit']).mean() #计算在测试集中正确预测的准确率 predicted = logistic_model.predict(data_test[['gpa','gre']]) accuracy_test = (predicted == data_test['admit']).mean()
ROC曲线
分别计算训练集和测试集的ROC曲线和AUC
from sklearn.metrics import roc_curve, roc_auc_score train_probs = logistic_model.predict_proba(data_train[['gpa', 'gre']])[:,1] test_probs = logistic_model.predict_proba(data_test[['gpa', 'gre']])[:,1] #计算AUC auc_train = roc_auc_score(data_train["admit"], train_probs) auc_test = roc_auc_score(data_test["admit"], test_probs) print('Auc_train: {}'.format(auc_train)) print('Auc_test: {}'.format(auc_test)) # 计算ROC曲线 roc_train = roc_curve(data_train["admit"], train_probs) roc_test = roc_curve(data_test["admit"], test_probs) # 作图 plt.plot(roc_train[0], roc_train[1]) plt.plot(roc_test[0], roc_test[1])
相关文章推荐
- 为windows添加右键菜单
- 生产环境中centOS7最简版安装
- 在win7下安装IIS 7后打开自己的项目之后会出现500.19错误
- SharedPreferences方便存取工具类
- 当程序员面对小学数学题
- 哪些工具能有效管理Azure Active Directory?
- 使用cordova+Ionic+AngularJs进行Hybird App开发的环境搭建手册
- 使用CSS3 Media Queries实现网页自适应
- poj 1015 Jury Compromise 动态规划
- Redis开源代码读书笔记零(Ubuntu14.04 64位安装)
- unity3d中用2D背景当作图片
- balabala半年了
- 做一个合格的程序猿之浅析Spring AOP源码(十二) AOP概念理解
- Java 回顾笔记_基本数据类型对象包装类
- 与LSGO一起学“第1章 初识C++(1.1 C++简介)”
- ant命令总结
- 文件操作之fopen创建多级目录下的文件
- 如何在Windows 7平台搭建Android Cocos2d-x3.2alpha0开发环境
- 2016.03.23我的第一份代码
- java Integer.valueOf()方法