您的位置:首页 > 编程语言 > Java开发

基于Java实现机器学习的感知机(可视化界面)

2018-09-06 18:26 316 查看

前言:本人也是刚刚入门机器学习,就像入门很多语言一样,第一个程序总是Hello World 。机器学习也不然,入门机器学习的第一个程序就是感知机啦。感知机是二类线性分类模型,输出的值为{+1, -1}两种类型,感知机是利用超平面将两类分离,多个不同的感知机就可以组成一张神经元网络,再往上就是人工智能系统然后就是终结者阿诺......越说越离题了。

好啦步入正题吧。下面我从三个方面简单阐述一下这个感知机到底是个什么妖怪:感知机模型,感知机学习策略,感知机学习算法。(部分代码以及资料引用了网上的)

一、感知机模型:

假设输入空间(特征空间)是X∈Rn,输出空间是Y={+1, -1},仅有两种结果,就好比一条线,位于线上方的点带入该线的方程得到的y值总是大于0,所以感知机是一种线性分类模型,属于判别模型。输入x∈X表示实例特征向量。对应于输出空间(特征空间)的点:输出y∈Y表示实例类别,由输入空间到输出空间的如下函数:

f(x)=sign(w*x+b)={-1,+1}。

线性方程w*x+b 其中w称为权值,b称为偏置。咱们的感知机呢正是通过很多训练集来训练自己,从而不断更新w和b,直到找到一个最优的分类位置。

对应于特征空间Rn 中的一个超平面S,其中w是超平面的法向量,b是超平面的截距。根据这个原理我们可以推导出计算距离 4000 的公式:

能够将数据集的正实例和负实例完全正确的划分到超平面的两侧,则称数据集T是线性可分数据集,否则称线性不可分数据集。

二、感知机学习策略:

就像咱们有自己的学习方法一样,感知机也有自己的学习方法。而感知机的学习方法我们常称为损失函数。同时我们要将这个损失函数极小化,这就要求它是连续可导的。损失函数有两种选择:一、误分类点的总数;二、误分类点到超平面S的距离;第一种不易于优化,因此我们通常选择第二种。什么叫误分类点呢?如图:

我画了这样子的一条线,意在将两种颜色的圆分类,但蓝色类里面多了一个红色的,这个红色的就称之为误分类点啦。

对于误分类点来说,它到超平面的距离计算就相当于蓝色圈到超平面的距离取反,因为它代入超平面方程得到的y值应该是负值。也就是这个:

那么总距离就是:

因此,感知机sign(wx+b)的损失函数可以简写为:

三、感知机学习算法:

感知机学习问题转化为求解损失函数最优化问题,最优化的方法是随机梯度下降法。感知机学习算法有两种形式:原始形式和对偶形式。在训练数据线性可分的条件下,感知机学习算法是收敛的。

原始形式

我的感知机采用的正是原始形式,原始形式是通过给定的训练数据集T={(x1,y1), {x2,y2},…..,{xN,yN}}去求解参数w和b使得: 

这个损失函数极小。其中M是误分类点的集合。

感知机学习算法是误分类数据驱动的,采用随机梯度下降法,即随机选取一个超平面w0和b0,使用梯度下降法对损失函数进行极小化。极小化不是一次使得所有M集合误分类点梯度下降,而是一次随机选取一个点使其梯度下降。 
假设M集合是固定,那么损失函数的梯度为: 

然后呢再随机选取一个误分类点(xi,yi)对wi和b进行更新。更新方程如下:(w和b的初始值可以随便给,但尽量不要太大,否则会影响计算的时间。)

η(0<=η<=1)是步长,统计学习中称为学习率。通过不断迭代,损失函数不断减小,直到为0。 

所以对原始类的总结如下:

1、 随机选取w0和b0 ;
2、 在训练数据中选取(xi,yi) ;
3、 如果yi(w*xi+b) <= 0;

4、 转2,直到训练数据中,没有误分类点。 

对偶形式我就简单提一下吧,毕竟我还没代码实践。

对偶形似的基本想法是:将w和b表示实例xi和标记yi的线性组合形式,通过求解器系数的到w和b。在原始形式中,通过 w和b的更新方程不断修改w和b,假设修改了n次,则w和b关于(xi, yi)的增量分别是aixiyi和aiyi,这里的ai=niη。最后学习到的w和b是: 

其中ai > =0, i =1,2,….,N。当时,表示第i个实例有误分类而进行更新的次数。实例点更新次数越多,则它里超平面的距离就越近,也就越难正确分类。 

最后附上完整的Demo代码:

//PerceptronClassifier类:

[code]package 感知机;

import java.awt.Color;
import java.awt.Dimension;
import java.awt.Font;
import java.awt.TextField;
import java.util.ArrayList;
import java.util.Arrays;

import javax.swing.JButton;
import javax.swing.JFrame;
import javax.swing.JLabel;

import java.awt.Graphics;

public class PerceptronClassifier extends JFrame{

//分类器参数
private double[]w;//权值数组
private double b = 0 ; //阈值
private double eta = 1;//学习率
ArrayList<Point>arrayList;

public double getW(int i)
{
return w[i];
}
public double getB()
{
return b;
}
//初始化分类器,读入我们要分类的数据
public  PerceptronClassifier(ArrayList<Point>arrayList,double eta)
{
this.arrayList = arrayList;
w = new double[arrayList.get(0).x.length];
this.eta = eta;
}
// 分类器初始化

public PerceptronClassifier(ArrayList<Point> arrayList)
{
this.arrayList = arrayList;
w = new double[arrayList.get(0).x.length];
this.eta = 1;
}
/********************************************************/
/*开始分类计算*/
public boolean Classify()
{
boolean flag = false;
while(!flag)//遍历所有的样本
{

for(int i=0;i<arrayList.size();i++)//所有的训练集
{
if(LearnAnswer(arrayList.get(i))<=0)//分类错误的点
{
UpdateWAndB(arrayList.get(i));//更新需要学习的点
this.paint(this.getGraphics());  //动态更新,一旦出错马上重新遍历
try {
Thread.sleep(300);
} catch (InterruptedException e) {
e.printStackTrace();
}
break;
}
if(i==arrayList.size()-1)//已经遍历到最后一个训练集
{
flag = true;
}
}

}
System.out.println("学习后:");
System.out.println(Arrays.toString(w));//输出一轮学习后找到的权值和阈值
System.out.println(b);
return true;

}

private double LearnAnswer(Point point) //计算结果,用于判断分类是否正确
{
System.out.println(Arrays.toString(w));
System.out.println(b);
return point.y * (DotProduct(w, point.x) + b);
}
private void UpdateWAndB(Point point) //更新w 和 b 的值(随机梯度下降)
{
System.out.println("分类出错!更新w、b!");
for (int i = 0; i < w.length; i++) {
w[i] += eta * point.y * point.x[i];
}
b += eta * point.y;
return;

}

private double DotProduct(double[] x1, double[] x2) //点乘函数
{
int len = x1.length;
double sum = 0;
for (int i = 0; i < len; i++) {
sum += x1[i] * x2[i];
}
return sum;

}

public void InitUI()
{
this.setTitle("机器学习感知机");
this.setSize(800, 600);
this.setDefaultCloseOperation(3);
this.setLocationRelativeTo(null);//窗口居中
this.setResizable(false);//禁止最小化
this.setLayout(null);//关闭流式布局

JButton butstart = new JButton("开始训练");
//	butstart.setPreferredSize(new Dimension(100,60));//设置按钮样式
butstart.setBounds(150, 480, 100, 60);
butstart.setContentAreaFilled(false);  //消除按钮背景颜色
butstart.setOpaque(false); //除去边框
butstart.setFocusPainted(false);//出去突起
this.add(butstart);

JButton butpro = new JButton("预测颜色");
//	butstart.setPreferredSize(new Dimension(100,60));//设置按钮样式
butpro.setBounds(300, 480, 100, 60);
butpro.setContentAreaFilled(false);  //消除按钮背景颜色
butpro.setOpaque(false); //除去边框
butpro.setFocusPainted(false);//出去突起
this.add(butpro);

JLabel label1 = new JLabel("X:");
label1.setFont(new Font("宋体",Font.BOLD,30));
label1.setBounds(400, 500, 50, 50);
this.add(label1);

JLabel label2 = new JLabel("Y:");
label2.setFont(new Font("宋体",Font.BOLD,30));
label2.setBounds(540, 500, 50, 50);
this.add(label2);

TextField text1 = new TextField("1");
text1.setFont(new Font("宋体",Font.BOLD,30));
text1.setBounds(450, 500, 70, 40);
this.add(text1);

TextField text2 = new TextField("1");
text2.setFont(new Font("宋体",Font.BOLD,30));
text2.setBounds(590, 500, 70, 40);
this.add(text2);
this.setVisible(true);//设置窗体可见

//添加监听
ButtonListener BL = new ButtonListener(this,text1,text2);
butstart.addActionListener(BL);
butpro.addActionListener(BL);
}

//为了更形象,先画个坐标轴吧。重写paint函数就可以了。
public void paint( Graphics g)
{
super.paint(g);
//绘制坐标轴
g.setColor(Color.black);
g.drawLine(100, 100, 100, 480);
g.drawLine(100, 480, 700, 480);

//接下来从arrayList里面取点画出来,此时要注意颜色的设置,比如y为1设置蓝色,y为-1设置红色
for(int i=0;i<arrayList.size();i++)
{
if(arrayList.get(i).y==1)
{
g.setColor(Color.BLUE);
}
else
{
g.setColor(Color.RED);
}
//位置可能需要进行适当的放大处理
//				g.drawLine((int)arrayList.get(i).x[0]*200+10,(int) arrayList.get(i).x[1]*200,
//						(int)arrayList.get(i).x[0]*200+10,(int)arrayList.get(i).x[1]*200+10);
g.drawOval((int)arrayList.get(i).x[0]*100+200, (int)arrayList.get(i).x[1]*100+200, 15, 15);
}
//接下来是区分线
//说白了就是计算点到直线的距离
int x1=0,y2=0;
System.out.println(this.getB()+" "+this.getW(1));
int y1 = (int)((-1)*this.getB()/this.getW(1));
int x2 = (int)((-1)*this.getB()/this.getW(0));
System.out.println("开始画标准线!");
g.setColor(Color.BLACK);
System.out.println(x1*100+200+","+y1*100+200+","+x2*100+200+","+y2*100+200);
g.drawLine(x1*100+200, y1*100+200, x2*100+200, y2*100+200);//跟上面保持一样的放大比例
}

public static void main(String[] args)
{

Point p1 = new Point(new double[] { 0,1.1 }, -1);//训练集
Point p2 = new Point(new double[] { 1.2,0 }, -1);
Point p3 = new Point(new double[] { 2.16,1 }, -1);
Point p4 = new Point(new double[] { 1,2.64 }, -1);
Point p5 = new Point(new double[] { 3.14,1.2 }, 1);
Point p6 = new Point(new double[] { 1.32,3.4 }, 1);
Point p7 = new Point(new double[] { 3.32,2.23 }, 1);
Point p8 = new Point(new double[] { 2.71,2.4 }, 1);

ArrayList<Point> list = new ArrayList<Point>();
list.add(p1);
list.add(p2);
list.add(p3);
list.add(p4);
list.add(p5);
list.add(p6);
list.add(p7);
list.add(p8);

PerceptronClassifier classifier = new PerceptronClassifier(list);
//	classifier.Classify();
classifier.InitUI();

}

}

//Point类:

[code]package 感知机;

public class Point {

double[] x = new double[2];
double y =0;
Point(double[]x ,double y)
{
this.x = x;
this.y = y;
}

Point()
{

}

}

//ButtonListener:

[code]package 感知机;

import java.awt.TextField;
import java.awt.event.ActionEvent;
import java.awt.event.ActionListener;

import javax.swing.JOptionPane;

public class ButtonListener implements ActionListener{

public PerceptronClassifier classifier;
public TextField text1,text2;

public ButtonListener(PerceptronClassifier classifier,TextField text1,TextField text2) {
this.classifier =classifier;
this.text1 = text1;
this.text2 = text2;
}
public void actionPerformed(ActionEvent e) {
if(e.getActionCommand().equals("开始训练"))
{
classifier.Classify();//启动训练方法
}
else if(e.getActionCommand().equals("预测颜色"))
{
//首先拿到文本框输入的坐标
String x1 = text1.getText();
String x2 = text2.getText();
float xx1,xx2;
//由于是string,我们需要强制转换为数字
if(x1==""|| x2=="")
{
xx1=(float) 1.0;
xx2=(float) 1.0;
}
xx1 = new Float(x1);
xx2 = new Float(x2);
System.out.println("拿到的XY为:"+xx1+" "+xx2);
//将坐标点带入我们得到的方程,不同的结果代表不同的颜色,结果只有1和-1.
//xx1*w1+xx2*w2+b
if(xx1*classifier.getW(0)+xx2*classifier.getW(1)+classifier.getB()>=0)
{
JOptionPane.showMessageDialog(null,"该图形为蓝色");//消息框弹出
}
else
{
JOptionPane.showMessageDialog(null,"该图形为红色");
}
}

}

}

有不妥之处欢迎指出!!Demo使用说明:点击开始训练即可进行分类。在文本框X和Y处输入相关点的坐标就会预测它在这个平面内的类别,也就是属于哪种颜色。

阅读更多
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: