您的位置:首页 > 其它

机器学习基石---第二周PLA

2017-12-19 15:27 330 查看
knitr::opts_chunk$set(echo = TRUE)


  台大《机器学习基石》第二周课的笔记,只整理部分重要内容。希望能把课上学的,做一个精简的记录。

变量说明

  存在两类数据,标记为y,取值为−1,1。特征向量记为x,x=(x0,x1,x2,...,xd)。其中x0为常量1,其余为具体特征值。存在超平面wTx=0,其中w=(w0,w1,...,wd),可以正确分开两类数据。共有N个样本数据。

迭代过程

  PLA采取知错就改的策略。遍历所有样本,如果发现分类错误,采用如下方式如下方式更新w

  Fort=0,1,...N1.findamistakeofwtcalled(xn(t),yn(t))sign(wTtxn(t))≠yn(t)2.(tryto)correctthemistakebywt+1←wt+yn(t)xn(t)...untilnomoremistakesreturnlastw(calledwPLA)asg

更新理由



  判断类别的公式:

sign(wTtxn(t))=sign(∥∥wTt∥∥∥∥xn(t)∥∥cos(θ))

  如果正类被误判,则cos(θ)<0,即θ∈(π2,π),所以要缩小法向量和特征向量之间的夹角。故采用上图方法迭代w的值。

证明

  证明线性可分数据集,PLA算法一定能够经过有限次的迭代,得到一个完美的分割超平面。

每一次迭代wt更接近wf

  1. wf为完美分类器

  2. (xn,yn)为错分的样本

  3. (xn(t),yn(t))为第t次迭代时,wt错分的样本

  因为wf是完美分类器,则一定有:

yn(t)wTfxn(t)≥minnynwTfxn>0

  利用任意一个错判样本(xn(t),yn(t))进行第t+1次迭代之后,计算:

wTfwt+1∥∥wTf∥∥∥wt+1∥=wTf(wt+yn(t)xn(t))∥∥wTf∥∥∥wt+1∥=wTfwt+yn(t)wTfxn(t)∥∥wTf∥∥∥wt+1∥≥wTfwt+minnyn(t)wTfxn(t)∥∥wTf∥∥∥wt+1∥>wTfwt+0∥∥wTf∥∥∥wt+1∥=wTfwt∥∥wTf∥∥∥wt+1∥

  从余弦相似度的角度看,通过错判样本对wt的修正,使得迭代后的w更接近于完美的分割超平面。

每一次迭代wt的模增长较小

∥wt+1∥2=∥∥wt+yn(t)xn(t)∥∥2=∥wt∥2+2yn(t)wTtxn(t)+∥∥yn(t)xn(t)∥∥2≤∥wt∥2+0+∥∥yn(t)xn(t)∥∥2≤∥wt∥2+maxn∥ynxn∥2

迭代次数有限

  假设w0=0,经过T次迭代之后:

wTfwT∥∥wf∥∥∥wT∥=wTf(wT−1+yn(T−1)xn(T−1))∥∥wf∥∥∥wT∥=wTf(wT−1+yn(T−1)xn(T−1))∥∥wf∥∥∥wT∥=wTfwT−1+yn(T−1)wTfxn(T−1)∥∥wf∥∥∥wT∥≥wTfwT−1+minnynwTfxn∥∥wf∥∥∥wT∥≥wTfwT−2+yn(T−2)wTfxn(T−2)+minnynwTfxn∥∥wf∥∥∥wT∥≥wTfwT−2+2minnynwTfxn∥∥wf∥∥∥wT∥⋯≥TminnynwTfxn∥∥wf∥∥∥wT∥Further:wTfwT≥TminnynwTfxnT≤wTfwTminnynwTfxnT2≤(wTfwT)2(minnynwTfxn)2=∥∥wf∥∥2∥wT∥2sin2(θ)(minnynwTfxn)2≤∥∥wf∥∥2∥wT∥2(minnynwTfxn)2≤∥∥wf∥∥2∗max∥ynxn∥n2(minnynwTfxn)2=∥∥wf∥∥2∗max∥xn∥n2(minnynwTfxn)2

  所以迭代次数T有上界。

案例

构造数据集

  构造数据集,验证算法。

x11 <- 1:10
x21 <- x11 + runif(10, 0, 1) + 3
x22 <- x11 - runif(10, 0, 1)
example_data <- data.frame(x1 = rep(x11, 2),
x2 = c(x21, x22),
label = rep(c(1, -1), each = 10))
example_data$label <- as.factor(example_data$label)
library(ggplot2)
ggplot(data = example_data, aes(
x = x1,
y = x2,
color = label,
shape = label
)) +
geom_point()




PLA算法

## 参数:数据集、标签名称

PLA_f <- function(dataset, label) {
## 样本数
row_num <-  nrow(dataset)
w <- rep(1, ncol(dataset))
w0 <- matrix(w, 1, 3, byrow = T)
real_label <- as.numeric(as.vector(dataset[, label]))
feature_matrix <-
as.matrix(data.frame(x0 = rep(1, row_num), cbind(dataset[, setdiff(colnames(dataset), label)])))
i <- 1
j <- 0
while (i < row_num & j == 0) {
i <- 1
j <- 0
for (i in 1:row_num) {
## 判断是否有误判
if (as.vector(feature_matrix[i,] %*% t(w0)) * real_label[i] <= 0) {
## 存在误判,修正w0
w0 <- w0 + real_label[i] * feature_matrix[i,]
w <- c(w, w0)
j <- 1
}
if(j == 1){
j <- 0
i <- row_num-1
break()}
}
}
w_data <- data.frame(matrix(w,ncol=ncol(dataset),byrow = TRUE))
colnames(w_data) <- paste0("x",0:(ncol(feature_matrix)-1))
w_data <- dplyr::mutate(w_data,
slope = -x1 / x2,
intercept = -x0 / x2)
return(w_data)
}


求解

w_data <- PLA_f(dataset = example_data, label = "label")
w_data


x0 x1           x2        slope    intercept
1   1  1  1.000000000   -1.0000000   -1.0000000
2   0  0  0.495471116    0.0000000    0.0000000
3  -1 -1 -0.009057768 -110.4024725 -110.4024725
4   0  0  4.912654036    0.0000000    0.0000000
5  -1 -1  4.408125152    0.2268538    0.2268538
6  -2 -2  3.903596268    0.5123481    0.5123481
7  -3 -4  1.915120282    2.0886417    1.5664812
8  -2 -1  8.363856425    0.1195621    0.2391241
9  -3 -2  7.859327541    0.2544747    0.3817120
10 -4 -4  5.870851555    0.6813322    0.6813322
11 -5 -9  1.747566727    5.1500179    2.8611211
12 -4 -8  6.669278532    1.1995300    0.5997650


动图

library(animation)
## 指定ImageMagic目录位置,注意是magick.exe,之前版本貌似一致是convert.exe
ani.options(convert = "D:/ImageMagic/ImageMagick-7.0.7-Q16/magick.exe")
saveGIF(
expr = {
library(ggplot2)
for (i in 1:nrow(w_data)) {plot(
x = example_data$x1[1:10],
y = example_data$x2[1:10],
pch = 15,
col = "red",
xlim = c(0, 20),
ylim = c(0, 15),
xlab = "x1",
ylab = "x2",main = paste0("Picture",i)
)
lines(x = example_data$x1[11:20],
y = example_data$x2[11:20],
type = "p",
pch = 17,
col = "blue")
abline(coef=c(w_data$intercept[i],w_data$slope[i]),lwd=2)
}
},
## GIF文件名,注意文件后缀名要加上
movie.name = "PLA.gif",
## 时间间隔
interval = 1,
## 图形设置
ani.width = 600,
ani.height = 600,
## 文件输出在当前目录
outdir = getwd()
)




Ref

[1]课程PPT

2017-12-19于杭州
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: