您的位置:首页 > 编程语言 > Python开发

逻辑回归python实现(随机增量梯度下降,变步长)

2017-04-25 10:26 816 查看
关于逻辑回归的学习,建议大家看看这篇blog,讲的很清楚:点击打开链接

代码实现是根据机器学习实战,照着代码自己来了一遍

逻辑回归,实际上就是对线性回归多增加了一个函数映射,使其值域由无穷区间映射到[0,1]区间

在线性回归中,估计函数为

  其中delta是参数向量,x是输入样本的特征向量

而在逻辑回归中,估计函数实际上就是在线性回归的基础上,嵌套了一个sigmoid函数。

逻辑回归的估计函数为

   其中,e的指数部分就是线性回归的输出,而可以看出,逻辑回归函数的值域是(0,1),并且图像过(0,1/2)这个点,图像在x=0处很陡峭。

也就是说,在逻辑回归中,我们能将h=0.5作为一个阀值,当估计值大于0.5时把样本分为1类,估计值小于0.5时把样本分为0类。那么只要我们得到了这个估计函数,就能够实现0-1分类了。

估计函数的求解,实际上就是对delta参数向量进行求解。在这里我使用的时梯度下降法,也就是先求出极大似然估计,然后求出极大似然估计的梯度,然后进行多次迭代,每次迭代将参数向量沿着梯度下降最陡峭的方向增加一个步长alpha。alpha如果过小,迭代的速度会很慢,但是如果alpha过大,容易使步子迈得太大,使得我们总是在结果附件徘徊。这里我使用的是变步长法,也就是一开始让步长尽量大,在迭代的过程中,慢慢缩小步长。

另外,我使用的是随机增量梯度下降法。所谓增量,就是指每一次迭代我只考虑一组样本来进行参数更新,而不是遍历所有样本。这样能够降低算法的时间复杂度,而且精确度也还不错。所谓随机,就是指每次迭代选取样本不是按照固定的顺序,而是随机抽取样本进行参数更新,这样做是为了防止迭代到后期,让参数陷入一个循环节,具体的情况可参照本文开头链接的那篇blog:

接下来附上我的python实现,使用了numpy库和matplotlib库

logRegression.py

# coding=UTF-8
import numpy as ny
import matplotlib.pyplot as plt
import time
import math

def sigmoid(x): #sigmoid函数
return 1.0/(1+math.exp(-x))

def loadData(): #读取数据
train_x = []
train_y = []
fileIn = open('testSet.txt')
for line in fileIn.readlines():
lineArr = line.strip().split()
train_x.append([1.0, float(lineArr[0]), float(lineArr[1])]) #1.0代表x的0次项
train_y.append(float(lineArr[2]))
return ny.mat(train_x), ny.mat(train_y).transpose()

def logRegression(train_x,train_y):
iteration_time = 600 #最大迭代次数
delta=ny.zeros((3,1)) #初始化参数为0向量
numSamples,numFeatures=ny.shape(train_x) #获取训练样本的规模
alpha=0.01 #迭代步长
for k in range(iteration_time):
alpha = 4.0 / (1.0 + k) + 0.01
i=ny.random.randint(0, numSamples)
h=sigmoid(train_x[i,:]*delta)
error=train_y[i,0]-h
delta+=alpha*train_x[i,:].transpose()*error
return delta

def calAccuracyRate(train_x,train_y,delta):
count=0 #记录划分正确的样本数
numSamples,numFeatures=ny.shape(train_x) #获取训练样本的规模
for i in range(numSamples):
h=sigmoid(train_x[i,:]*delta)
if h>=0.5 and int(train_y[i,0])==1 :
count=count+1
elif h<0.5 and train_y[i,0]==0 :
count=count+1
return count

def showGraph(train_x,train_y,delta):
numSamples,numFeatures=ny.shape(train_x) #获取训练样本的规模
# 画出样本点
plt.figure(figsize=(12,8)) #设置绘制尺寸
for i in range(numSamples):
if int(train_y[i, 0]) == 0:
plt.plot(train_x[i, 1], train_x[i, 2], 'or')
elif int(train_y[i, 0]) == 1:
plt.plot(train_x[i, 1], train_x[i, 2], 'ob')

# 绘制分割线
min_x = min(train_x[:, 1])[0,0]-1
max_x = max(train_x[:, 1])[0,0]+1
y_min_x = float(-delta[0,0] - delta[1,0] * min_x) / delta[2,0]
y_max_x = float(-delta[0,0] - delta[1,0] * max_x) / delta[2,0]
plt.plot([min_x, max_x], [y_min_x, y_max_x], 'y')
plt.xlabel('X1'); plt.ylabel('X2')
plt.show()

def testingLogR():
train_x,train_y=loadData()
maxx=0.0
numBegin=20 #起点数量
for i in range(numBegin):
delta=logRegression(train_x,train_y)
cur=calAccuracyRate(train_x,train_y,delta)
if cur>maxx:
maxx=cur
ans=delta
numSamples,numFeatures=ny.shape(train_x)
print("样本准确率为:",maxx*100/numSamples,"%")
showGraph(train_x,train_y,ans)

testingLogR()


testSet.txt                 //训练样本 

-0.017612   14.053064   0
-1.395634   4.662541    1
-0.752157   6.538620    0
-1.322371   7.152853    0
0.423363    11.054677   0
0.406704    7.067335    1
0.667394    12.741452   0
-2.460150   6.866805    1
0.569411    9.548755    0
-0.026632   10.427743   0
0.850433    6.920334    1
1.347183    13.175500   0
1.176813    3.167020    1
-1.781871   9.097953    0
-0.566606   5.749003    1
0.931635    1.589505    1
-0.024205   6.151823    1
-0.036453   2.690988    1
-0.196949   0.444165    1
1.014459    5.754399    1
1.985298    3.230619    1
-1.693453   -0.557540   1
-0.576525   11.778922   0
-0.346811   -1.678730   1
-2.124484   2.672471    1
1.217916    9.597015    0
-0.733928   9.098687    0
-3.642001   -1.618087   1
0.315985    3.523953    1
1.416614    9.619232    0
-0.386323   3.989286    1
0.556921    8.294984    1
1.224863    11.587360   0
-1.347803   -2.406051   1
1.196604    4.951851    1
0.275221    9.543647    0
0.470575    9.332488    0
-1.889567   9.542662    0
-1.527893   12.150579   0
-1.185247   11.309318   0
-0.445678   3.297303    1
1.042222    6.105155    1
-0.618787   10.320986   0
1.152083    0.548467    1
0.828534    2.676045    1
-1.237728   10.549033   0
-0.683565   -2.166125   1
0.229456    5.921938    1
-0.959885   11.555336   0
0.492911    10.993324   0
0.184992    8.721488    0
-0.355715   10.325976   0
-0.397822   8.058397    0
0.824839    13.730343   0
1.507278    5.027866    1
0.099671    6.835839    1
-0.344008   10.717485   0
1.785928    7.718645    1
-0.918801   11.560217   0
-0.364009   4.747300    1
-0.841722   4.119083    1
0.490426    1.960539    1
-0.007194   9.075792    0
0.356107    12.447863   0
0.342578    12.281162   0
-0.810823   -1.466018   1
2.530777    6.476801    1
1.296683    11.607559   0
0.475487    12.040035   0
-0.783277   11.009725   0
0.074798    11.023650   0
-1.337472   0.468339    1
-0.102781   13.763651   0
-0.147324   2.874846    1
0.518389    9.887035    0
1.015399    7.571882    0
-1.658086   -0.027255   1
1.319944    2.171228    1
2.056216    5.019981    1
-0.851633   4.375691    1
-1.510047   6.061992    0
-1.076637   -3.181888   1
1.821096    10.283990   0
3.010150    8.401766    1
-1.099458   1.688274    1
-0.834872   -1.733869   1
-0.846637   3.849075    1
1.400102    12.628781   0
1.752842    5.468166    1
0.078557    0.059736    1
0.089392    -0.715300   1
1.825662    12.693808   0
0.197445    9.744638    0
0.126117    0.922311    1
-0.679797   1.220530    1
0.677983    2.556666    1
0.761349    10.693862   0
-2.168791   0.143632    1
1.388610    9.341997    0
0.317029    14.739025   0


运行后的结果如下图:

内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: