您的位置:首页 > 理论基础 > 计算机网络

神经网络-向后传播java实现

2015-06-29 18:51 381 查看
1、Point.java

package com.network;

/**

* <p>本类描述: </p>

* <p>其他说明: </p>

* @author Wang Haiyang

* @date 2015-6-29 上午09:15:55

*/

public class Point {

public Point() {}

public Point(Integer name, Double in) {

this.name = name;

this.in = in;

}

public Point(Integer name, Double o, Double err, Double in, Double out) {

this.name = name;

this.o = o;

this.err = err;

this.in = in;

this.out = out;

}

public Point(Integer name, Double o, Double err, Double in, Double out, Integer classify) {

this.name = name;

this.o = o;

this.err = err;

this.in = in;

this.out = out;

this.classify = classify;

}

/** 点的名字 */

private Integer name;

/** 点的偏倚值 */

private Double o;

/** 点的初始值 */

private Double err;

/** 点的净输入 */

private Double in;

/** 点的输出 */

private Double out;

/** 点的类别 */

private Integer classify = 0;

public Integer getClassify() {

return classify;

}

public void setClassify(Integer classify) {

this.classify = classify;

}

public Double getIn() {

return in;

}

public void setIn(Double in) {

this.in = in;

}

public Double getOut() {

return out;

}

public void setOut(Double out) {

this.out = out;

}

public Integer getName() {

return name;

}

public void setName(Integer name) {

this.name = name;

}

public Double getO() {

return o;

}

public void setO(Double o) {

this.o = o;

}

public Double getErr() {

return err;

}

public void setErr(Double err) {

this.err = err;

}

}

2、Edge.java

package com.network;

/**

* <p>本类描述: </p>

* <p>其他说明: </p>

* @author Wang Haiyang

* @date 2015-6-29 上午09:11:42

*/

public class Edge {

/** 边的起点 */

private Point start;

/** 边的终点 */

private Point end;

/** 边的权重 */

private Double weight;

public Edge() {}

public Edge(Point start, Point end, Double weight) {

this.start = start;

this.end = end;

this.weight = weight;

}

public Point getStart() {

return start;

}

public void setStart(Point start) {

this.start = start;

}

public Point getEnd() {

return end;

}

public void setEnd(Point end) {

this.end = end;

}

public Double getWeight() {

return weight;

}

public void setWeight(Double weight) {

this.weight = weight;

}

}

3、NeuralNetwork.java

package com.network;

import java.util.ArrayList;

import java.util.List;

/**

* <p>

* 本类描述:

* 利用向后传播的神经网络方法学习,产生可预测类别的模型,本类假定隐藏层数为1(两层神经网络)

* 隐藏层包含的单元可以指定,输出层的单元也可以指定

* </p>

* <p>

* 主要步骤:

* 步骤1: 初始化网络中的权重和偏倚

* 步骤2: 针对每个元组,计算输入层、隐藏层和输出层的每个单元的净输入和输出

* 步骤3: 逐层向后计算输出层和隐藏层的每个单元的误差

* 步骤4: 更新所有权重和偏倚

* </p>

* <p>

* 其他说明:对未知元组X分类

* 利用训练好的模型,计算每个单元的净输入和输出,如果每个类有一个输出节点,则具有最高输出值的

* 节点决定X的预测类标号,如果只有一个输出节点,则输出值大于或等于0.5可以视为正类,而值小于0.5

* 可以视为负类。

* </p>

* @author Wang Haiyang

* @date 2015-6-26 下午04:10:10

*/

public class NeuralNetwork {

/** 学习率 */

public static final Double study = 0.9D;

/** 样本集 */

public static List<ArrayList<Point>> samples = new ArrayList<ArrayList<Point>>();

/** 隐藏层点集 */

public static List<Point> hideLayers = new ArrayList<Point>();

/** 输出层点集 */

public static List<Point> outLayers = new ArrayList<Point>();

/** 边集 */

public static List<Edge> edges = new ArrayList<Edge>();

public static void main(String[] args) {

// 准备初始化参数

init();

// 针对每个元组,计算输入层、隐藏层和输出层的每个单元的净输入和输出

compute();

// 打印

display();

}

/**

* 方法描述:打印

*/

private static void display() {

System.out.println("权重:");

for (int i = 0; i < edges.size(); i++) {

Edge edge = edges.get(i);

System.out.println("w" + edge.getStart().getName() + edge.getEnd().getName() + ": " + edge.getWeight());

}

System.out.println("隐藏层偏倚:");

for (int i = 0; i < hideLayers.size(); i++) {

Point point = hideLayers.get(i);

System.out.println("O" + point.getName() + ": " + point.getO());

}

System.out.println("输出层偏倚:");

for (int i = 0; i < outLayers.size(); i++) {

Point point = outLayers.get(i);

System.out.println("O" + point.getName() + ": " + point.getO());

}

}

/**

* 方法描述:训练模型

*/

private static void compute() {

for (ArrayList<Point> points : samples) {

// 计算输入层每个单元的输出

for (Point point1 : points) {

point1.setOut(point1.getIn());

}

// 计算隐藏层的每个单元的净输入和输出

getInOut(hideLayers, points);

// 计算输出层的每个单元的净输入和输出

getInOut(outLayers, points);

// 计算输出层的误差

for (Point point2 : outLayers) {

Double out = point2.getOut();

Double err = out * (1 - out) * (point2.getClassify() - out);

point2.setErr(err);

}

// 计算隐藏层的误差

for (Point hide : hideLayers) {

Double sum = 0D;

for (Point out : outLayers) {

sum += out.getErr() * (getWeight(hide, out));

}

Double out = hide.getOut();

Double err = out * (1 - out) * sum;

hide.setErr(err);

}

// 更新所有权重

for (Edge edge : edges) {

Double weight = edge.getWeight() + study * edge.getEnd().getErr() * edge.getStart().getOut();

edge.setWeight(weight);

}

// 更新隐藏层偏倚

updateO(hideLayers);

// 更新输出层偏倚

updateO(outLayers);

}

}

/**

* 方法描述:准备初始化参数

*/

private static void init() {

ArrayList<Point> inLayers = new ArrayList<Point>();

Point p1 = new Point(1, 1D);

inLayers.add(p1);

Point p2 = new Point(2, 0D);

inLayers.add(p2);

Point p3 = new Point(3, 1D);

inLayers.add(p3);

samples.add(inLayers);

Point p4 = new Point(4, -0.4D, 0D, 0D, 0D);

hideLayers.add(p4);

Point p5 = new Point(5, 0.2D, 0D, 0D, 0D);

hideLayers.add(p5);

Point p6 = new Point(6, 0.1D, 0D, 0D, 0D, 1);

outLayers.add(p6);

Edge edge1 = new Edge(p1, p4, 0.2D);

Edge edge2 = new Edge(p1, p5, -0.3D);

Edge edge3 = new Edge(p2, p4, 0.4D);

Edge edge4 = new Edge(p2, p5, 0.1D);

Edge edge5 = new Edge(p3, p4, -0.5D);

Edge edge6 = new Edge(p3, p5, 0.2D);

Edge edge7 = new Edge(p4, p6, -0.3D);

Edge edge8 = new Edge(p5, p6, -0.2D);

edges.add(edge1);

edges.add(edge2);

edges.add(edge3);

edges.add(edge4);

edges.add(edge5);

edges.add(edge6);

edges.add(edge7);

edges.add(edge8);

}

/**

* 方法描述:计算给定list的净输入和输出

* @param layers

* @param edges

* @param points

*/

private static void updateO(List<Point> layers) {

for (Point hide : layers) {

Double o = hide.getO() + study * hide.getErr();

hide.setO(o);

}

}

/**

* 方法描述:计算给定list的净输入和输出

* @param layers

* @param edges

* @param points

*/

private static void getInOut(List<Point> layers, ArrayList<Point> points) {

for (int i = 0; i< layers.size(); i++) {

Point hide = layers.get(i);

Double in = 0D;

Double out = 0D;

Double sum = 0D;

for (Point point3 : points) {

sum += getWeight(point3, hide) * point3.getOut();

}

in = sum + hide.getO();

hide.setIn(in);

out = 1.0 / (1 + Math.pow(Math.E, (-in)));

hide.setOut(out);

}

}

/**

* 方法描述:根据给定的两个点得到这条边的权重

* @param point3

* @param hide

* @return

*/

private static Double getWeight(Point point3, Point hide) {

Double weight = 0D;

for (Edge edge : edges) {

if (point3.getName() == edge.getStart().getName() && hide.getName() == edge.getEnd().getName()) {

weight = edge.getWeight();

break;

}

}

return weight;

}

}
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: