神经网络-向后传播java实现
2015-06-29 18:51
381 查看
1、Point.java
package com.network;
/**
* <p>本类描述: </p>
* <p>其他说明: </p>
* @author Wang Haiyang
* @date 2015-6-29 上午09:15:55
*/
public class Point {
public Point() {}
public Point(Integer name, Double in) {
this.name = name;
this.in = in;
}
public Point(Integer name, Double o, Double err, Double in, Double out) {
this.name = name;
this.o = o;
this.err = err;
this.in = in;
this.out = out;
}
public Point(Integer name, Double o, Double err, Double in, Double out, Integer classify) {
this.name = name;
this.o = o;
this.err = err;
this.in = in;
this.out = out;
this.classify = classify;
}
/** 点的名字 */
private Integer name;
/** 点的偏倚值 */
private Double o;
/** 点的初始值 */
private Double err;
/** 点的净输入 */
private Double in;
/** 点的输出 */
private Double out;
/** 点的类别 */
private Integer classify = 0;
public Integer getClassify() {
return classify;
}
public void setClassify(Integer classify) {
this.classify = classify;
}
public Double getIn() {
return in;
}
public void setIn(Double in) {
this.in = in;
}
public Double getOut() {
return out;
}
public void setOut(Double out) {
this.out = out;
}
public Integer getName() {
return name;
}
public void setName(Integer name) {
this.name = name;
}
public Double getO() {
return o;
}
public void setO(Double o) {
this.o = o;
}
public Double getErr() {
return err;
}
public void setErr(Double err) {
this.err = err;
}
}
2、Edge.java
package com.network;
/**
* <p>本类描述: </p>
* <p>其他说明: </p>
* @author Wang Haiyang
* @date 2015-6-29 上午09:11:42
*/
public class Edge {
/** 边的起点 */
private Point start;
/** 边的终点 */
private Point end;
/** 边的权重 */
private Double weight;
public Edge() {}
public Edge(Point start, Point end, Double weight) {
this.start = start;
this.end = end;
this.weight = weight;
}
public Point getStart() {
return start;
}
public void setStart(Point start) {
this.start = start;
}
public Point getEnd() {
return end;
}
public void setEnd(Point end) {
this.end = end;
}
public Double getWeight() {
return weight;
}
public void setWeight(Double weight) {
this.weight = weight;
}
}
3、NeuralNetwork.java
package com.network;
import java.util.ArrayList;
import java.util.List;
/**
* <p>
* 本类描述:
* 利用向后传播的神经网络方法学习,产生可预测类别的模型,本类假定隐藏层数为1(两层神经网络)
* 隐藏层包含的单元可以指定,输出层的单元也可以指定
* </p>
* <p>
* 主要步骤:
* 步骤1: 初始化网络中的权重和偏倚
* 步骤2: 针对每个元组,计算输入层、隐藏层和输出层的每个单元的净输入和输出
* 步骤3: 逐层向后计算输出层和隐藏层的每个单元的误差
* 步骤4: 更新所有权重和偏倚
* </p>
* <p>
* 其他说明:对未知元组X分类
* 利用训练好的模型,计算每个单元的净输入和输出,如果每个类有一个输出节点,则具有最高输出值的
* 节点决定X的预测类标号,如果只有一个输出节点,则输出值大于或等于0.5可以视为正类,而值小于0.5
* 可以视为负类。
* </p>
* @author Wang Haiyang
* @date 2015-6-26 下午04:10:10
*/
public class NeuralNetwork {
/** 学习率 */
public static final Double study = 0.9D;
/** 样本集 */
public static List<ArrayList<Point>> samples = new ArrayList<ArrayList<Point>>();
/** 隐藏层点集 */
public static List<Point> hideLayers = new ArrayList<Point>();
/** 输出层点集 */
public static List<Point> outLayers = new ArrayList<Point>();
/** 边集 */
public static List<Edge> edges = new ArrayList<Edge>();
public static void main(String[] args) {
// 准备初始化参数
init();
// 针对每个元组,计算输入层、隐藏层和输出层的每个单元的净输入和输出
compute();
// 打印
display();
}
/**
* 方法描述:打印
*/
private static void display() {
System.out.println("权重:");
for (int i = 0; i < edges.size(); i++) {
Edge edge = edges.get(i);
System.out.println("w" + edge.getStart().getName() + edge.getEnd().getName() + ": " + edge.getWeight());
}
System.out.println("隐藏层偏倚:");
for (int i = 0; i < hideLayers.size(); i++) {
Point point = hideLayers.get(i);
System.out.println("O" + point.getName() + ": " + point.getO());
}
System.out.println("输出层偏倚:");
for (int i = 0; i < outLayers.size(); i++) {
Point point = outLayers.get(i);
System.out.println("O" + point.getName() + ": " + point.getO());
}
}
/**
* 方法描述:训练模型
*/
private static void compute() {
for (ArrayList<Point> points : samples) {
// 计算输入层每个单元的输出
for (Point point1 : points) {
point1.setOut(point1.getIn());
}
// 计算隐藏层的每个单元的净输入和输出
getInOut(hideLayers, points);
// 计算输出层的每个单元的净输入和输出
getInOut(outLayers, points);
// 计算输出层的误差
for (Point point2 : outLayers) {
Double out = point2.getOut();
Double err = out * (1 - out) * (point2.getClassify() - out);
point2.setErr(err);
}
// 计算隐藏层的误差
for (Point hide : hideLayers) {
Double sum = 0D;
for (Point out : outLayers) {
sum += out.getErr() * (getWeight(hide, out));
}
Double out = hide.getOut();
Double err = out * (1 - out) * sum;
hide.setErr(err);
}
// 更新所有权重
for (Edge edge : edges) {
Double weight = edge.getWeight() + study * edge.getEnd().getErr() * edge.getStart().getOut();
edge.setWeight(weight);
}
// 更新隐藏层偏倚
updateO(hideLayers);
// 更新输出层偏倚
updateO(outLayers);
}
}
/**
* 方法描述:准备初始化参数
*/
private static void init() {
ArrayList<Point> inLayers = new ArrayList<Point>();
Point p1 = new Point(1, 1D);
inLayers.add(p1);
Point p2 = new Point(2, 0D);
inLayers.add(p2);
Point p3 = new Point(3, 1D);
inLayers.add(p3);
samples.add(inLayers);
Point p4 = new Point(4, -0.4D, 0D, 0D, 0D);
hideLayers.add(p4);
Point p5 = new Point(5, 0.2D, 0D, 0D, 0D);
hideLayers.add(p5);
Point p6 = new Point(6, 0.1D, 0D, 0D, 0D, 1);
outLayers.add(p6);
Edge edge1 = new Edge(p1, p4, 0.2D);
Edge edge2 = new Edge(p1, p5, -0.3D);
Edge edge3 = new Edge(p2, p4, 0.4D);
Edge edge4 = new Edge(p2, p5, 0.1D);
Edge edge5 = new Edge(p3, p4, -0.5D);
Edge edge6 = new Edge(p3, p5, 0.2D);
Edge edge7 = new Edge(p4, p6, -0.3D);
Edge edge8 = new Edge(p5, p6, -0.2D);
edges.add(edge1);
edges.add(edge2);
edges.add(edge3);
edges.add(edge4);
edges.add(edge5);
edges.add(edge6);
edges.add(edge7);
edges.add(edge8);
}
/**
* 方法描述:计算给定list的净输入和输出
* @param layers
* @param edges
* @param points
*/
private static void updateO(List<Point> layers) {
for (Point hide : layers) {
Double o = hide.getO() + study * hide.getErr();
hide.setO(o);
}
}
/**
* 方法描述:计算给定list的净输入和输出
* @param layers
* @param edges
* @param points
*/
private static void getInOut(List<Point> layers, ArrayList<Point> points) {
for (int i = 0; i< layers.size(); i++) {
Point hide = layers.get(i);
Double in = 0D;
Double out = 0D;
Double sum = 0D;
for (Point point3 : points) {
sum += getWeight(point3, hide) * point3.getOut();
}
in = sum + hide.getO();
hide.setIn(in);
out = 1.0 / (1 + Math.pow(Math.E, (-in)));
hide.setOut(out);
}
}
/**
* 方法描述:根据给定的两个点得到这条边的权重
* @param point3
* @param hide
* @return
*/
private static Double getWeight(Point point3, Point hide) {
Double weight = 0D;
for (Edge edge : edges) {
if (point3.getName() == edge.getStart().getName() && hide.getName() == edge.getEnd().getName()) {
weight = edge.getWeight();
break;
}
}
return weight;
}
}
package com.network;
/**
* <p>本类描述: </p>
* <p>其他说明: </p>
* @author Wang Haiyang
* @date 2015-6-29 上午09:15:55
*/
public class Point {
public Point() {}
public Point(Integer name, Double in) {
this.name = name;
this.in = in;
}
public Point(Integer name, Double o, Double err, Double in, Double out) {
this.name = name;
this.o = o;
this.err = err;
this.in = in;
this.out = out;
}
public Point(Integer name, Double o, Double err, Double in, Double out, Integer classify) {
this.name = name;
this.o = o;
this.err = err;
this.in = in;
this.out = out;
this.classify = classify;
}
/** 点的名字 */
private Integer name;
/** 点的偏倚值 */
private Double o;
/** 点的初始值 */
private Double err;
/** 点的净输入 */
private Double in;
/** 点的输出 */
private Double out;
/** 点的类别 */
private Integer classify = 0;
public Integer getClassify() {
return classify;
}
public void setClassify(Integer classify) {
this.classify = classify;
}
public Double getIn() {
return in;
}
public void setIn(Double in) {
this.in = in;
}
public Double getOut() {
return out;
}
public void setOut(Double out) {
this.out = out;
}
public Integer getName() {
return name;
}
public void setName(Integer name) {
this.name = name;
}
public Double getO() {
return o;
}
public void setO(Double o) {
this.o = o;
}
public Double getErr() {
return err;
}
public void setErr(Double err) {
this.err = err;
}
}
2、Edge.java
package com.network;
/**
* <p>本类描述: </p>
* <p>其他说明: </p>
* @author Wang Haiyang
* @date 2015-6-29 上午09:11:42
*/
public class Edge {
/** 边的起点 */
private Point start;
/** 边的终点 */
private Point end;
/** 边的权重 */
private Double weight;
public Edge() {}
public Edge(Point start, Point end, Double weight) {
this.start = start;
this.end = end;
this.weight = weight;
}
public Point getStart() {
return start;
}
public void setStart(Point start) {
this.start = start;
}
public Point getEnd() {
return end;
}
public void setEnd(Point end) {
this.end = end;
}
public Double getWeight() {
return weight;
}
public void setWeight(Double weight) {
this.weight = weight;
}
}
3、NeuralNetwork.java
package com.network;
import java.util.ArrayList;
import java.util.List;
/**
* <p>
* 本类描述:
* 利用向后传播的神经网络方法学习,产生可预测类别的模型,本类假定隐藏层数为1(两层神经网络)
* 隐藏层包含的单元可以指定,输出层的单元也可以指定
* </p>
* <p>
* 主要步骤:
* 步骤1: 初始化网络中的权重和偏倚
* 步骤2: 针对每个元组,计算输入层、隐藏层和输出层的每个单元的净输入和输出
* 步骤3: 逐层向后计算输出层和隐藏层的每个单元的误差
* 步骤4: 更新所有权重和偏倚
* </p>
* <p>
* 其他说明:对未知元组X分类
* 利用训练好的模型,计算每个单元的净输入和输出,如果每个类有一个输出节点,则具有最高输出值的
* 节点决定X的预测类标号,如果只有一个输出节点,则输出值大于或等于0.5可以视为正类,而值小于0.5
* 可以视为负类。
* </p>
* @author Wang Haiyang
* @date 2015-6-26 下午04:10:10
*/
public class NeuralNetwork {
/** 学习率 */
public static final Double study = 0.9D;
/** 样本集 */
public static List<ArrayList<Point>> samples = new ArrayList<ArrayList<Point>>();
/** 隐藏层点集 */
public static List<Point> hideLayers = new ArrayList<Point>();
/** 输出层点集 */
public static List<Point> outLayers = new ArrayList<Point>();
/** 边集 */
public static List<Edge> edges = new ArrayList<Edge>();
public static void main(String[] args) {
// 准备初始化参数
init();
// 针对每个元组,计算输入层、隐藏层和输出层的每个单元的净输入和输出
compute();
// 打印
display();
}
/**
* 方法描述:打印
*/
private static void display() {
System.out.println("权重:");
for (int i = 0; i < edges.size(); i++) {
Edge edge = edges.get(i);
System.out.println("w" + edge.getStart().getName() + edge.getEnd().getName() + ": " + edge.getWeight());
}
System.out.println("隐藏层偏倚:");
for (int i = 0; i < hideLayers.size(); i++) {
Point point = hideLayers.get(i);
System.out.println("O" + point.getName() + ": " + point.getO());
}
System.out.println("输出层偏倚:");
for (int i = 0; i < outLayers.size(); i++) {
Point point = outLayers.get(i);
System.out.println("O" + point.getName() + ": " + point.getO());
}
}
/**
* 方法描述:训练模型
*/
private static void compute() {
for (ArrayList<Point> points : samples) {
// 计算输入层每个单元的输出
for (Point point1 : points) {
point1.setOut(point1.getIn());
}
// 计算隐藏层的每个单元的净输入和输出
getInOut(hideLayers, points);
// 计算输出层的每个单元的净输入和输出
getInOut(outLayers, points);
// 计算输出层的误差
for (Point point2 : outLayers) {
Double out = point2.getOut();
Double err = out * (1 - out) * (point2.getClassify() - out);
point2.setErr(err);
}
// 计算隐藏层的误差
for (Point hide : hideLayers) {
Double sum = 0D;
for (Point out : outLayers) {
sum += out.getErr() * (getWeight(hide, out));
}
Double out = hide.getOut();
Double err = out * (1 - out) * sum;
hide.setErr(err);
}
// 更新所有权重
for (Edge edge : edges) {
Double weight = edge.getWeight() + study * edge.getEnd().getErr() * edge.getStart().getOut();
edge.setWeight(weight);
}
// 更新隐藏层偏倚
updateO(hideLayers);
// 更新输出层偏倚
updateO(outLayers);
}
}
/**
* 方法描述:准备初始化参数
*/
private static void init() {
ArrayList<Point> inLayers = new ArrayList<Point>();
Point p1 = new Point(1, 1D);
inLayers.add(p1);
Point p2 = new Point(2, 0D);
inLayers.add(p2);
Point p3 = new Point(3, 1D);
inLayers.add(p3);
samples.add(inLayers);
Point p4 = new Point(4, -0.4D, 0D, 0D, 0D);
hideLayers.add(p4);
Point p5 = new Point(5, 0.2D, 0D, 0D, 0D);
hideLayers.add(p5);
Point p6 = new Point(6, 0.1D, 0D, 0D, 0D, 1);
outLayers.add(p6);
Edge edge1 = new Edge(p1, p4, 0.2D);
Edge edge2 = new Edge(p1, p5, -0.3D);
Edge edge3 = new Edge(p2, p4, 0.4D);
Edge edge4 = new Edge(p2, p5, 0.1D);
Edge edge5 = new Edge(p3, p4, -0.5D);
Edge edge6 = new Edge(p3, p5, 0.2D);
Edge edge7 = new Edge(p4, p6, -0.3D);
Edge edge8 = new Edge(p5, p6, -0.2D);
edges.add(edge1);
edges.add(edge2);
edges.add(edge3);
edges.add(edge4);
edges.add(edge5);
edges.add(edge6);
edges.add(edge7);
edges.add(edge8);
}
/**
* 方法描述:计算给定list的净输入和输出
* @param layers
* @param edges
* @param points
*/
private static void updateO(List<Point> layers) {
for (Point hide : layers) {
Double o = hide.getO() + study * hide.getErr();
hide.setO(o);
}
}
/**
* 方法描述:计算给定list的净输入和输出
* @param layers
* @param edges
* @param points
*/
private static void getInOut(List<Point> layers, ArrayList<Point> points) {
for (int i = 0; i< layers.size(); i++) {
Point hide = layers.get(i);
Double in = 0D;
Double out = 0D;
Double sum = 0D;
for (Point point3 : points) {
sum += getWeight(point3, hide) * point3.getOut();
}
in = sum + hide.getO();
hide.setIn(in);
out = 1.0 / (1 + Math.pow(Math.E, (-in)));
hide.setOut(out);
}
}
/**
* 方法描述:根据给定的两个点得到这条边的权重
* @param point3
* @param hide
* @return
*/
private static Double getWeight(Point point3, Point hide) {
Double weight = 0D;
for (Edge edge : edges) {
if (point3.getName() == edge.getStart().getName() && hide.getName() == edge.getEnd().getName()) {
weight = edge.getWeight();
break;
}
}
return weight;
}
}
相关文章推荐
- 谈谈一个重要的http协议头标:X-Forwarded-For
- 华硕X550C 安装Ubuntu 14.10 无线网络显示硬件被禁用的解决方法
- java 访问https忽略证书
- HTTP服务器nginx在android平台的使用(用于在线播放本地视频)
- Netty4 之 简单搭建HTTP服务
- OKHttp源码解析-ConnectionPool对Connection重用机制&Http/Https/SPDY协议选择
- [转载]tcp粘包分析
- HttpClient (HTTP 请求工具类)
- TCP/UDP常见端口参考
- Xcode7 使用NSURLSession发送HTTP请求报错
- https协议支持get/post方法
- 网络远程唤醒 WOL Magic Packet
- Python 下载网络mp4视频资源
- 一个简单的http请求
- iOS开发工具-网络封包分析工具Charles
- DBN深信度网络
- 虚拟机桥接网络设置(转)
- 基于IHttpAsyncHandler的实时大文件传送器
- 初识贝叶斯网络
- HttpWebRequest