Java 神经网络模型 待续
2015-12-22 12:43
543 查看
package com.bioevent.neuralnet; import java.util.Random; /** * 人工神经网络的结构 */ public class NeuralNet { //输入层,隐层以及输出层神经元的个数 public int inputNeuronNum; public int hiddenNeuronNum; public int outputNeuronNum; //输入层,隐层以及输出层的值 double[] inputValue; double[] hiddenInputValue; double[] hiddenOutputValue; double[] hiddenBias; double[] outputInputValue; double[] outputFinalValue; //输入层到隐层的权重,隐层到输出层的权重 public double[][] input_hidden_weight; public double[][] hidden_output_weight; //更新的权重 public double[][] input_hidden_prev_weight; public double[][] hidden_output_prev_weight; //梯度 public double[][] grad_inputhidden_weight; public double[][] grad_hiddenoutput_weight; public double[] gradInput; public double[] gradHiddenBias; //差值 public double[] hiddenDelta; public double[] outputDelta; //第一个是学习率,第二个是动量 public double eta; public double momentum; public NeuralNet(int inputNeuronNum, int hiddenNeuronNum, int outputNeuronNum, double eta, double momentum) { this.inputNeuronNum = inputNeuronNum; this.hiddenNeuronNum = hiddenNeuronNum; this.outputNeuronNum = outputNeuronNum; inputValue = new double[inputNeuronNum]; hiddenInputValue = new double[hiddenNeuronNum]; hiddenOutputValue = new double[hiddenNeuronNum]; hiddenBias = new double[hiddenNeuronNum]; outputInputValue = new double[outputNeuronNum]; outputFinalValue = new double[outputNeuronNum]; this.eta = eta; this.momentum = momentum; //输入层到隐含层的连接权重 initializeWeight(input_hidden_weight); //隐含层到输出层的连接权重 initializeWeight(hidden_output_weight); //隐层偏置初始化 Random random = new Random(); for(int p = 0; p < hiddenNeuronNum; p++) { hiddenBias[p] = random.nextGaussian(); } } /** * 按正态分布初始化权重,或则按照[-0.01,0.01]之间初始化 * @param matrix */ public void initializeWeight(double[][] matrix) { Random random = new Random(); for(int i = 0; i < matrix.length; i++) { for(int j = 0; j < matrix[i].length; j++) { matrix[i][j] = random.nextGaussian(); } } } /** * 输入层到隐层的前向计算 */ public double[] input_hidden_forward(double[] inputLayer, int hiddenNeuronNum, double[][] weight, double[] bias) { double[] input_hidden = new double[hiddenNeuronNum]; for(int j = 0; j < hiddenNeuronNum; j++) { double sum = 0; for(int i = 0; i < inputLayer.length; i++ ) { sum += inputLayer[i] * weight[i][j] + bias[j]; } hiddenInputValue[j] = sum + bias[j]; //加上偏置 input_hidden[j] = Math.pow(hiddenInputValue[j], 3); //求立方和 } return input_hidden; } /** * 隐层到输出层的计算 * 识别trigger词 */ public double[] hidden_output_forward(double[] hiddenLayer, int outputNeuronNum, double[] outputLayer, double[][] weight) { double[] hidden_output = new double[outputNeuronNum]; for(int j = 0; j < outputLayer.length; j++) { for(int i = 0; i < hiddenLayer.length; i++ ) { hidden_output[j] += hiddenLayer[i] * weight[i][j]; } } double sum = 0; for(int p = 0; p < outputLayer.length; p++) { sum += hidden_output[p]; } // softmax函数 for(int q = 0; q < hidden_output.length; q++) { outputLayer[q] = hidden_output[q] / sum; } return hidden_output; } /** * 输出层的误差计算 * predictedResult[] 就是hidden_output_forward()返回值 */ public double outputLayerErro(double[] goldResult, double[] predictedResult) { double predictedR = 0; for(int sortId = 0; sortId < goldResult.length; sortId++) { double goldR = goldResult[sortId]; if(goldR == 1) { predictedR = predictedResult[sortId]; } } return predictedR; } /** * 隐层的误差计算 */ public void hiddenLayerErr() { } /** * 调整权重 * wi = wi - eta*delta(wi) */ public double[][] adjustWeight(double[] delta, double[] layer, double[][] prevWeight) { double[][] weight = new double[prevWeight.length][prevWeight[0].length]; //layer[0] = 1; for(int i = 0; i < delta.length; i++) { for(int j = 0; j < layer.length; j++) { double newVal = momentum * prevWeight[j][i] + eta * delta[i] * layer[j]; weight[j][i] += newVal; prevWeight[j][i] = newVal; } } return weight; } /** * 调整所有权重矩阵 */ public void adjustWeight() { adjustWeight(outputDelta, hiddenValue, hidden_output_weight, hidden_output_prev_weight); adjustWeight(hiddenDelta, inputValue, input_hidden_weight, input_hidden_prev_weight); } /** * 计算输出层残差 */ public double[] computeDelta() { double[] delta = new double[outputNeuronNum]; return delta; } }
相关文章推荐
- java对世界各个时区(TimeZone)的通用转换处理方法(转载)
- java-注解annotation
- java-模拟tomcat服务器
- java-用HttpURLConnection发送Http请求.
- java-WEB中的监听器Lisener
- Android IPC进程间通讯机制
- Android Native 绘图方法
- Android java 与 javascript互访(相互调用)的方法例子
- 介绍一款信息管理系统的开源框架---jeecg
- 聚类算法之kmeans算法java版本
- java实现 PageRank算法
- PropertyChangeListener简单理解
- c++11 + SDL2 + ffmpeg +OpenAL + java = Android播放器
- 插入排序
- 冒泡排序
- 堆排序
- 快速排序
- 二叉查找树