基于Java实现机器学习的感知机(可视化界面)
前言:本人也是刚刚入门机器学习,就像入门很多语言一样,第一个程序总是Hello World 。机器学习也不然,入门机器学习的第一个程序就是感知机啦。感知机是二类线性分类模型,输出的值为{+1, -1}两种类型,感知机是利用超平面将两类分离,多个不同的感知机就可以组成一张神经元网络,再往上就是人工智能系统然后就是终结者阿诺......越说越离题了。
好啦步入正题吧。下面我从三个方面简单阐述一下这个感知机到底是个什么妖怪:感知机模型,感知机学习策略,感知机学习算法。(部分代码以及资料引用了网上的)
一、感知机模型:
假设输入空间(特征空间)是X∈Rn,输出空间是Y={+1, -1},仅有两种结果,就好比一条线,位于线上方的点带入该线的方程得到的y值总是大于0,所以感知机是一种线性分类模型,属于判别模型。输入x∈X表示实例特征向量。对应于输出空间(特征空间)的点:输出y∈Y表示实例类别,由输入空间到输出空间的如下函数:
f(x)=sign(w*x+b)={-1,+1}。
线性方程w*x+b 其中w称为权值,b称为偏置。咱们的感知机呢正是通过很多训练集来训练自己,从而不断更新w和b,直到找到一个最优的分类位置。
对应于特征空间Rn 中的一个超平面S,其中w是超平面的法向量,b是超平面的截距。根据这个原理我们可以推导出计算距离 4000 的公式:
能够将数据集的正实例和负实例完全正确的划分到超平面的两侧,则称数据集T是线性可分数据集,否则称线性不可分数据集。
二、感知机学习策略:
就像咱们有自己的学习方法一样,感知机也有自己的学习方法。而感知机的学习方法我们常称为损失函数。同时我们要将这个损失函数极小化,这就要求它是连续可导的。损失函数有两种选择:一、误分类点的总数;二、误分类点到超平面S的距离;第一种不易于优化,因此我们通常选择第二种。什么叫误分类点呢?如图:
我画了这样子的一条线,意在将两种颜色的圆分类,但蓝色类里面多了一个红色的,这个红色的就称之为误分类点啦。
对于误分类点来说,它到超平面的距离计算就相当于蓝色圈到超平面的距离取反,因为它代入超平面方程得到的y值应该是负值。也就是这个:
那么总距离就是:
因此,感知机sign(wx+b)的损失函数可以简写为:
三、感知机学习算法:
感知机学习问题转化为求解损失函数最优化问题,最优化的方法是随机梯度下降法。感知机学习算法有两种形式:原始形式和对偶形式。在训练数据线性可分的条件下,感知机学习算法是收敛的。
原始形式
我的感知机采用的正是原始形式,原始形式是通过给定的训练数据集T={(x1,y1), {x2,y2},…..,{xN,yN}}去求解参数w和b使得:
这个损失函数极小。其中M是误分类点的集合。
感知机学习算法是误分类数据驱动的,采用随机梯度下降法,即随机选取一个超平面w0和b0,使用梯度下降法对损失函数进行极小化。极小化不是一次使得所有M集合误分类点梯度下降,而是一次随机选取一个点使其梯度下降。
假设M集合是固定,那么损失函数的梯度为:
然后呢再随机选取一个误分类点(xi,yi)对wi和b进行更新。更新方程如下:(w和b的初始值可以随便给,但尽量不要太大,否则会影响计算的时间。)
η(0<=η<=1)是步长,统计学习中称为学习率。通过不断迭代,损失函数不断减小,直到为0。
所以对原始类的总结如下:
1、 随机选取w0和b0 ;
2、 在训练数据中选取(xi,yi) ;
3、 如果yi(w*xi+b) <= 0;
4、 转2,直到训练数据中,没有误分类点。
对偶形式我就简单提一下吧,毕竟我还没代码实践。
对偶形似的基本想法是:将w和b表示实例xi和标记yi的线性组合形式,通过求解器系数的到w和b。在原始形式中,通过 w和b的更新方程不断修改w和b,假设修改了n次,则w和b关于(xi, yi)的增量分别是aixiyi和aiyi,这里的ai=niη。最后学习到的w和b是:
其中ai > =0, i =1,2,….,N。当时,表示第i个实例有误分类而进行更新的次数。实例点更新次数越多,则它里超平面的距离就越近,也就越难正确分类。
最后附上完整的Demo代码:
//PerceptronClassifier类:
[code]package 感知机; import java.awt.Color; import java.awt.Dimension; import java.awt.Font; import java.awt.TextField; import java.util.ArrayList; import java.util.Arrays; import javax.swing.JButton; import javax.swing.JFrame; import javax.swing.JLabel; import java.awt.Graphics; public class PerceptronClassifier extends JFrame{ //分类器参数 private double[]w;//权值数组 private double b = 0 ; //阈值 private double eta = 1;//学习率 ArrayList<Point>arrayList; public double getW(int i) { return w[i]; } public double getB() { return b; } //初始化分类器,读入我们要分类的数据 public PerceptronClassifier(ArrayList<Point>arrayList,double eta) { this.arrayList = arrayList; w = new double[arrayList.get(0).x.length]; this.eta = eta; } // 分类器初始化 public PerceptronClassifier(ArrayList<Point> arrayList) { this.arrayList = arrayList; w = new double[arrayList.get(0).x.length]; this.eta = 1; } /********************************************************/ /*开始分类计算*/ public boolean Classify() { boolean flag = false; while(!flag)//遍历所有的样本 { for(int i=0;i<arrayList.size();i++)//所有的训练集 { if(LearnAnswer(arrayList.get(i))<=0)//分类错误的点 { UpdateWAndB(arrayList.get(i));//更新需要学习的点 this.paint(this.getGraphics()); //动态更新,一旦出错马上重新遍历 try { Thread.sleep(300); } catch (InterruptedException e) { e.printStackTrace(); } break; } if(i==arrayList.size()-1)//已经遍历到最后一个训练集 { flag = true; } } } System.out.println("学习后:"); System.out.println(Arrays.toString(w));//输出一轮学习后找到的权值和阈值 System.out.println(b); return true; } private double LearnAnswer(Point point) //计算结果,用于判断分类是否正确 { System.out.println(Arrays.toString(w)); System.out.println(b); return point.y * (DotProduct(w, point.x) + b); } private void UpdateWAndB(Point point) //更新w 和 b 的值(随机梯度下降) { System.out.println("分类出错!更新w、b!"); for (int i = 0; i < w.length; i++) { w[i] += eta * point.y * point.x[i]; } b += eta * point.y; return; } private double DotProduct(double[] x1, double[] x2) //点乘函数 { int len = x1.length; double sum = 0; for (int i = 0; i < len; i++) { sum += x1[i] * x2[i]; } return sum; } public void InitUI() { this.setTitle("机器学习感知机"); this.setSize(800, 600); this.setDefaultCloseOperation(3); this.setLocationRelativeTo(null);//窗口居中 this.setResizable(false);//禁止最小化 this.setLayout(null);//关闭流式布局 JButton butstart = new JButton("开始训练"); // butstart.setPreferredSize(new Dimension(100,60));//设置按钮样式 butstart.setBounds(150, 480, 100, 60); butstart.setContentAreaFilled(false); //消除按钮背景颜色 butstart.setOpaque(false); //除去边框 butstart.setFocusPainted(false);//出去突起 this.add(butstart); JButton butpro = new JButton("预测颜色"); // butstart.setPreferredSize(new Dimension(100,60));//设置按钮样式 butpro.setBounds(300, 480, 100, 60); butpro.setContentAreaFilled(false); //消除按钮背景颜色 butpro.setOpaque(false); //除去边框 butpro.setFocusPainted(false);//出去突起 this.add(butpro); JLabel label1 = new JLabel("X:"); label1.setFont(new Font("宋体",Font.BOLD,30)); label1.setBounds(400, 500, 50, 50); this.add(label1); JLabel label2 = new JLabel("Y:"); label2.setFont(new Font("宋体",Font.BOLD,30)); label2.setBounds(540, 500, 50, 50); this.add(label2); TextField text1 = new TextField("1"); text1.setFont(new Font("宋体",Font.BOLD,30)); text1.setBounds(450, 500, 70, 40); this.add(text1); TextField text2 = new TextField("1"); text2.setFont(new Font("宋体",Font.BOLD,30)); text2.setBounds(590, 500, 70, 40); this.add(text2); this.setVisible(true);//设置窗体可见 //添加监听 ButtonListener BL = new ButtonListener(this,text1,text2); butstart.addActionListener(BL); butpro.addActionListener(BL); } //为了更形象,先画个坐标轴吧。重写paint函数就可以了。 public void paint( Graphics g) { super.paint(g); //绘制坐标轴 g.setColor(Color.black); g.drawLine(100, 100, 100, 480); g.drawLine(100, 480, 700, 480); //接下来从arrayList里面取点画出来,此时要注意颜色的设置,比如y为1设置蓝色,y为-1设置红色 for(int i=0;i<arrayList.size();i++) { if(arrayList.get(i).y==1) { g.setColor(Color.BLUE); } else { g.setColor(Color.RED); } //位置可能需要进行适当的放大处理 // g.drawLine((int)arrayList.get(i).x[0]*200+10,(int) arrayList.get(i).x[1]*200, // (int)arrayList.get(i).x[0]*200+10,(int)arrayList.get(i).x[1]*200+10); g.drawOval((int)arrayList.get(i).x[0]*100+200, (int)arrayList.get(i).x[1]*100+200, 15, 15); } //接下来是区分线 //说白了就是计算点到直线的距离 int x1=0,y2=0; System.out.println(this.getB()+" "+this.getW(1)); int y1 = (int)((-1)*this.getB()/this.getW(1)); int x2 = (int)((-1)*this.getB()/this.getW(0)); System.out.println("开始画标准线!"); g.setColor(Color.BLACK); System.out.println(x1*100+200+","+y1*100+200+","+x2*100+200+","+y2*100+200); g.drawLine(x1*100+200, y1*100+200, x2*100+200, y2*100+200);//跟上面保持一样的放大比例 } public static void main(String[] args) { Point p1 = new Point(new double[] { 0,1.1 }, -1);//训练集 Point p2 = new Point(new double[] { 1.2,0 }, -1); Point p3 = new Point(new double[] { 2.16,1 }, -1); Point p4 = new Point(new double[] { 1,2.64 }, -1); Point p5 = new Point(new double[] { 3.14,1.2 }, 1); Point p6 = new Point(new double[] { 1.32,3.4 }, 1); Point p7 = new Point(new double[] { 3.32,2.23 }, 1); Point p8 = new Point(new double[] { 2.71,2.4 }, 1); ArrayList<Point> list = new ArrayList<Point>(); list.add(p1); list.add(p2); list.add(p3); list.add(p4); list.add(p5); list.add(p6); list.add(p7); list.add(p8); PerceptronClassifier classifier = new PerceptronClassifier(list); // classifier.Classify(); classifier.InitUI(); } }
//Point类:
[code]package 感知机; public class Point { double[] x = new double[2]; double y =0; Point(double[]x ,double y) { this.x = x; this.y = y; } Point() { } }
//ButtonListener:
[code]package 感知机; import java.awt.TextField; import java.awt.event.ActionEvent; import java.awt.event.ActionListener; import javax.swing.JOptionPane; public class ButtonListener implements ActionListener{ public PerceptronClassifier classifier; public TextField text1,text2; public ButtonListener(PerceptronClassifier classifier,TextField text1,TextField text2) { this.classifier =classifier; this.text1 = text1; this.text2 = text2; } public void actionPerformed(ActionEvent e) { if(e.getActionCommand().equals("开始训练")) { classifier.Classify();//启动训练方法 } else if(e.getActionCommand().equals("预测颜色")) { //首先拿到文本框输入的坐标 String x1 = text1.getText(); String x2 = text2.getText(); float xx1,xx2; //由于是string,我们需要强制转换为数字 if(x1==""|| x2=="") { xx1=(float) 1.0; xx2=(float) 1.0; } xx1 = new Float(x1); xx2 = new Float(x2); System.out.println("拿到的XY为:"+xx1+" "+xx2); //将坐标点带入我们得到的方程,不同的结果代表不同的颜色,结果只有1和-1. //xx1*w1+xx2*w2+b if(xx1*classifier.getW(0)+xx2*classifier.getW(1)+classifier.getB()>=0) { JOptionPane.showMessageDialog(null,"该图形为蓝色");//消息框弹出 } else { JOptionPane.showMessageDialog(null,"该图形为红色"); } } } }
有不妥之处欢迎指出!!Demo使用说明:点击开始训练即可进行分类。在文本框X和Y处输入相关点的坐标就会预测它在这个平面内的类别,也就是属于哪种颜色。
阅读更多- [java] 可视化日历的实现(基于Calendar类 )
- java模拟实现一个基于文本界面的——客户信息管理系
- 一种基于flex的可视化多层流量切分界面的实现
- 基于Java的界面布局DSL的设计与实现
- 基于 Java 的界面布局 DSL 的设计与实现
- 【源码】基于SQLite实现CMS论坛(BBS)----附件SQLite可视化界面客户端
- 利用java实现基于文本的图书管理系统(有界面)
- 基于Java的界面布局DSL的设计与实现
- JAVA实现基于可视化的代码,可以实现商品总计,很方便。
- 机器学习知识点(八)感知机模型Java实现
- java实现基于TCP协议带界面的多人聊天代码
- 基于Java的界面布局DSL的设计与实现
- 基于Java的界面布局DSL的设计与实现
- dui框架开发系列:基于控件组合或继承实现 可视化界面编辑工具 的优劣
- 基于Java的界面布局DSL的设计与实现
- 机器学习入门 一、理解机器学习+简单感知机(JAVA实现)
- 用Eclipse进行可视化Java界面设计
- 用Eclipse进行可视化Java界面设计
- 基于JAVA技术的搜索引擎的研究与实现
- java实现的基于RSA的公钥加密算法