您的位置:首页 > 理论基础 > 计算机网络

bp神经网络的java实现

2013-04-22 10:59 232 查看
这两天开始研究BPNN。先阅读的这篇文章:http://www.codeproject.com/Articles/16508/AI-Neural-Network-for-beginners-Part-2-of-3;然后我把这篇文章里的代码按照面向对象的方式重写了一遍。在测试过程中发现一个比较奇怪的问题,两种实现的过程数据是一样的,但最后的计算结果却不一样,这个困扰我好几天了,我把代码贴出来,大家一起研究研究。

第一个类,神经元

public class Neuron {

static int counter = 0;

final public int id; // auto increment, starts at 0

NeuronConnection biasConnection;

final double bias = -1;

double output;

double deltaOutput = 0;

List<NeuronConnection> Inconnections = new ArrayList<NeuronConnection>();

Map<Integer, NeuronConnection> connMap = new HashMap<Integer, NeuronConnection>();

public Neuron(){

id = counter;

counter++;

}

/**

* Compute Sj = Wij*Aij + w0j*bias

*/

public void calculateOutput(){

double s = 0;

for(NeuronConnection con : Inconnections){

Neuron leftNeuron = con.getFromNeuron();

double weight = con.getWeight();

double a = leftNeuron.getOutput(); //output from previous layer

s = s + (weight*a);

}

s = s + (biasConnection.getWeight()*bias);

output = g(s);

}

double g(double x) {

return sigmoid(x);

}

double sigmoid(double x) {

return 1.0 / (1.0 + (Math.exp(-x)));

}

public void addInConnectionsS(List<Neuron> inNeurons){

for(Neuron n: inNeurons){

NeuronConnection con = new NeuronConnection(n,this);

Inconnections.add(con);

connMap.put(n.id, con);

}

}

public NeuronConnection getNeuronConn(int id){

return connMap.get(id);

}

public void addInConnection(NeuronConnection con){

Inconnections.add(con);

}

public void addBiasConnection(Neuron n){

NeuronConnection con = new NeuronConnection(n,this);

biasConnection = con;

Inconnections.add(con);

connMap.put(n.id, con);

}

public List<NeuronConnection> getAllInConnections(){

return Inconnections;

}

public double getBias() {

return bias;

}

public double getOutput() {

return output;

}

public void setOutput(double o){

output = o;

}

}

第二个类,神经元之间的连接类

public class NeuronConnection {

double weight = 0;

double deltaWeight = 0;

double prevDeltaWeight = 0; // for momentum

final Neuron leftNeuron;

final Neuron rightNeuron;

static int counter = 0;

final public int id; // auto increment, starts at 0

public NeuronConnection(Neuron fromN, Neuron toN) {

leftNeuron = fromN;

rightNeuron = toN;

id = counter;

counter++;

}

public double getWeight() {

return weight;

}

public void setWeight(double w) {

weight = w;

}

public void setDeltaWeight(double w) {

prevDeltaWeight = deltaWeight;

deltaWeight = w;

}

public double getPrevDeltaWeight() {

return prevDeltaWeight;

}

public Neuron getFromNeuron() {

return leftNeuron;

}

public Neuron getToNeuron() {

return rightNeuron;

}

}

第三个类,神经网络

public class BPNN {

private int[] layers;

private Neuron bias = new Neuron();

private double learningRate = 10f;

private double momentum = 0.7f;

private List<List<Neuron>> layerList = new ArrayList<List<Neuron>>();

private Random random = new Random();

public BPNN(int[] layers){

this.layers = layers;

init();

}

public BPNN(int[] layers, double learningRate, double momentum){

this.layers = layers;

this.learningRate = learningRate;

this.momentum = momentum;

init();

}

public String toString(){

int i;

List<Neuron> list0 = layerList.get(0);

List<Neuron> list1 = layerList.get(layerList.size() - 1);

for(i = 0; i < list0.size(); i++){

System.out.print(list0.get(i).getOutput() + " ");

}

for(i = 0; i < list1.size(); i++){

System.out.println(list1.get(i).getOutput() + " " + list1.get(i).deltaOutput);

}

return null;

}

private void init(){

int i;

/*初始化神经网络结构*/

List<Neuron> list0 = new ArrayList<Neuron>();

for(i = 0; i < layers[0]; i++){

Neuron neuron = new Neuron();

list0.add(neuron);

}

layerList.add(list0);

for(int j = 1; j < layers.length; j++){

List<Neuron> list = new ArrayList<Neuron>();

for(i = 0; i < layers[j]; i++){

Neuron neuron = new Neuron();

neuron.addInConnectionsS(layerList.get(j - 1));

neuron.addBiasConnection(bias);

list.add(neuron);

}

layerList.add(list);

}

/*初始化权重*/

for(List<Neuron> list : layerList){

for(Neuron neuron : list){

for(NeuronConnection conn : neuron.getAllInConnections()){

conn.setWeight(2 * random.nextDouble() - 1);

}

}

}

System.out.println();

}

/*接收输入参数*/

public void setInput(double[] inputs){

List<Neuron> list = layerList.get(0);

if(list.size() != inputs.length){

System.err.println("入参数量不对");

}

for(int i = 0; i < inputs.length; i++){

list.get(i).setOutput(inputs[i]);

}

}

/*前向计算输出,不计算输入层的输出*/

public void forward(){

int size = layerList.size();

for(int i = 1; i < size; i++){

List<Neuron> list = layerList.get(i);

for(int j = 0; j < list.size(); j++){

Neuron n = list.get(j);

n.calculateOutput();

}

}

}

/*获取计算结果*/

public double[] getOutput(){

double[] ret = new double[layers[layers.length - 1]];

List<Neuron> list = layerList.get(layerList.size() - 1);

for(int i = 0; i < ret.length; i++){

ret[i] = list.get(i).getOutput();

}

return ret;

}

/*根据误差反馈修正权重*/

public void updateWeights(double[] target){

double e = 1;

if(target.length != layers[layers.length - 1]){

System.err.println("期望输出数量错误");

}

int size = layerList.size();

List<Neuron> outputLayer = layerList.get(size - 1);

int i;

double delta_output = 0;

for(i = 0; i < outputLayer.size(); i++){

Neuron n = outputLayer.get(i);

double output = n.getOutput();

e += target[i] - output;

delta_output = output * (1 - output) * (target[i] - output);

for(NeuronConnection conn : n.getAllInConnections()){

double leftNoutput = conn.leftNeuron.getOutput();

double delta_weight = this.learningRate * delta_output * leftNoutput;

conn.setDeltaWeight(delta_weight);

//this.momentum * conn.getPrevDeltaWeight()添加动量影响

conn.setWeight(conn.getWeight() + delta_weight + this.momentum * conn.getPrevDeltaWeight());

}

n.deltaOutput = delta_output;

}

e = Math.abs(e - 1)/outputLayer.size();

//TODO反馈修正

for(i = size - 2; i == 1; i--){

List<Neuron> hiddenLayer = layerList.get(i);

List<Neuron> rightLayer = layerList.get(i + 1);

int j;

for (j = 0; j < hiddenLayer.size(); j++) {

Neuron n = hiddenLayer.get(j);

double output = n.getOutput();

double error = 0;

for(int h = 0; h < rightLayer.size(); h++){

Neuron rightN = rightLayer.get(h);

NeuronConnection conn = rightN.getNeuronConn(n.id);

error += rightN.deltaOutput * conn.getWeight();

}

n.deltaOutput = output * (1 - output) * error;

for(int m = 0; m < n.getAllInConnections().size(); m++){

NeuronConnection conn = n.getAllInConnections().get(m);

double leftNOutput = conn.leftNeuron.getOutput();

double delta_weight = this.learningRate * n.deltaOutput * leftNOutput;

conn.setDeltaWeight(delta_weight);

conn.setWeight(conn.getWeight() + delta_weight + this.momentum * conn.getPrevDeltaWeight());

}

}

}

}

/*对外训练接口*/

public void train(double[] inputs, double[] output){

setInput(inputs);

forward();

updateWeights(output);

}

/*获取计算结果*/

public double[] result(double[] inputs){

setInput(inputs);

forward();

return getOutput();

}

/*获取权值矩阵*/

public List<NeuronConnection> getWeights(){

List<NeuronConnection> ret = new ArrayList<NeuronConnection>();

for(List<Neuron> list : layerList){

for(Neuron neuron : list){

for(NeuronConnection conn : neuron.getAllInConnections()){

ret.add(conn);

}

}

}

return ret;

}

public static void main(String[] args){

int train = 100;

BPNN n = new BPNN(new int[]{2,2,1});

while(train-- > 0){

n.train( new double[]{0,0}, new double[]{0});

n.toString();

n.train( new double[]{1,0}, new double[]{1});

n.toString();

n.train( new double[]{0,1}, new double[]{1});

n.toString();

n.train( new double[]{1,1}, new double[]{0});

n.toString();

System.out.println("****************************************");

}

for(double d : n.result( new double[]{0,0})){

System.out.println(d);

}

for(double d : n.result( new double[]{1,0})){

System.out.println(d);

}

for(double d : n.result( new double[]{0,1})){

System.out.println(d);

}

for(double d : n.result( new double[]{1,1})){

System.out.println(d);

}

}

}

请大家告诉我到底是哪个地方出现了问题,谢谢您指导。
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: