您的位置:首页 > 编程语言 > Python开发

【python|ML】k-fold/leave-one-out 方法在对率回归实现(西瓜书习题3.4,数据UCI-iris)

2016-12-19 21:22 706 查看
# 经验:梯度下降要让目标确实减小(或者增大),此程序tempsum需要减小,故要着重检查!!

import numpy as np

def dlhg(x, d):
x = np.c_[x, np.ones(d.shape[0])]
w = np.random.randn(x[0].shape[0])
miu = 0.01
temp_sum_old = np.inf
count = 0
while True:
tempsum = 0
dw = 0
for i in range(d.shape[0]):
pdt = np.sum(w * x[i])
tempsum = tempsum + np.log(1 + np.exp(pdt)) - d[i] * pdt
dw = dw + (np.exp(pdt) / (1 + np.exp(pdt)) - d[i]) * x[i]
if np.abs(tempsum - temp_sum_old) < 0.01:
break
w = w - miu * dw
temp_sum_old = tempsum
count = count + 1
#print(tempsum)
#print(count)
#for i in range(d.shape[0]):
#  # i 为第i个样本
#   if np.sum(w * x[i]) > 0:
#     print(1)
#else:
#   print(0)
#print(w)
return w
import numpy as npdef load_iris():iris_data_file = "D:/WORKSPACE/ML/DATA/UCI/_iris.data"_x = np.genfromtxt(iris_data_file, delimiter=',', usecols=(0, 1, 2, 3))_y = np.genfromtxt(iris_data_file, delimiter=',', usecols=(4))return _x, _y
# 示例:from UCIdata import *from duilvhuigui import *x, d = load_iris()x = x[:100]d = d[:100]# 10-fold cross-validationerr_count = 0print("\td\t\ty")for k in range(10):xte = np.r_[x[5 * k:5 * (k + 1)], x[50+5 * k:50+5 * (k + 1)]]dte = np.r_[d[5 * k:5 * (k + 1)], d[50+5 * k:50+5 * (k + 1)]]xtr = np.r_[x[:5 * k], x[5 * (k + 1):50+5 * k], x[50+5 * (k + 1):]]dtr = np.r_[d[:5 * k], d[5 * (k + 1):50+5 * k], d[50+5 * (k + 1):]]w = dlhg(xtr, dtr)xte = np.c_[xte, np.ones(dte.shape[0])]print("{}-fold\t".format(k))for i in range(dte.shape[0]):if np.sum(w * xte[i]) > 0:y = 1else:y = 0print("\t{}\t{}".format(dte[i], y))if dte[i] != y:err_count += 1print("Total error:{}\t".format(err_count))# leave-one-outerr_count = 0print("Leave-one-out")for k in range(100):xte = x[k]dte = d[k]xtr = np.r_[x[:k], x[k+1:]]dtr = np.r_[d[:k], d[k+1:]]w = dlhg(xtr, dtr)# 特殊处理:只有一维情况下,这里了用np.c_[]会报错xte = np.r_[xte, 1]if np.sum(w * xte) > 0:y = 1else:y = 0print("\t{}\t{}".format(dte, y))if dte != y:err_count += 1print("Total error:{}\t".format(err_count))
结语:error都是0,也没发现哪里不对。。。我的天此博客纯粹为记录,算是知识的积累。当然也欢迎交流批评等。
                                            
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签:  机器学习 数据