递归神经网络LSTM原理——结合实例MATLAB实现
2016-10-12 15:12
183 查看
最近正在看递归神经网络,看了网上很多博文,算是鱼龙混杂,并且基本都是使用Python实现,要不就是使用Matlab中的函数库等。对于使用Matlab的同学,甚为不方便。所以我将结合实例,使用matlab语言,完成递归神经网络程序的编写(LSTM)。本人菜鸡一枚,如有错误还望各路大神,指正教导。文章的问题和数据和我之前写的递归神经网络BPTT文章中一致,方便大家比较两种方法的差异,文章链接递归神经网络BPTT的MATLAB实现。另外,关于理论推导算法步骤,等我过几天有时间更新。
一、问题描述
问题描述
二、相关数据
相关数据
三、程序代码
LSTM_mian.m
LSTM_data_process.m
LSTM_updata_weight.m
四、程序结果图
Error_Cost图
第七天预测值与理论值,第一组为预测值,第二组为实际值
一、问题描述
问题描述
二、相关数据
相关数据
三、程序代码
LSTM_mian.m
%%% LSTM网络结合实例仿真 %%% 作者:xd.wp %%% 时间:2016.10.08 12:06 %% 程序说明 % 1、数据为7天,四个时间点的空调功耗,用前三个推测第四个训练,依次类推。第七天作为检验 % 2、LSTM网络输入结点为12,输出结点为4个,隐藏结点18个 clear all; clc; %% 数据加载,并归一化处理 [train_data,test_data]=LSTM_data_process(); data_length=size(train_data,1); data_num=size(train_data,2); %% 网络参数初始化 % 结点数设置 input_num=12; cell_num=18; output_num=4; % 网络中门的偏置 bias_input_gate=rand(1,cell_num); bias_forget_gate=rand(1,cell_num); bias_output_gate=rand(1,cell_num); % ab=1.2; % bias_input_gate=ones(1,cell_num)/ab; % bias_forget_gate=ones(1,cell_num)/ab; % bias_output_gate=ones(1,cell_num)/ab; %网络权重初始化 ab=20; weight_input_x=rand(input_num,cell_num)/ab; weight_input_h=rand(output_num,cell_num)/ab; weight_inputgate_x=rand(input_num,cell_num)/ab; weight_inputgate_c=rand(cell_num,cell_num)/ab; weight_forgetgate_x=rand(input_num,cell_num)/ab; weight_forgetgate_c=rand(cell_num,cell_num)/ab; weight_outputgate_x=rand(input_num,cell_num)/ab; weight_outputgate_c=rand(cell_num,cell_num)/ab; %hidden_output权重 weight_preh_h=rand(cell_num,output_num); %网络状态初始化 cost_gate=1e-6; h_state=rand(output_num,data_num); cell_state=rand(cell_num,data_num); %% 网络训练学习 for iter=1:3000 yita=0.01; %每次迭代权重调整比例 for m=1:data_num %前馈部分 if(m==1) gate=tanh(train_data(:,m)'*weight_input_x); input_gate_input=train_data(:,m)'*weight_inputgate_x+bias_input_gate; output_gate_input=train_data(:,m)'*weight_outputgate_x+bias_output_gate; for n=1:cell_num input_gate(1,n)=1/(1+exp(-input_gate_input(1,n))); output_gate(1,n)=1/(1+exp(-output_gate_input(1,n))); end forget_gate=zeros(1,cell_num); forget_gate_input=zeros(1,cell_num); cell_state(:,m)=(input_gate.*gate)'; else gate=tanh(train_data(:,m)'*weight_input_x+h_state(:,m-1)'*weight_input_h); input_gate_input=train_data(:,m)'*weight_inputgate_x+cell_state(:,m-1)'*weight_inputgate_c+bias_input_gate; forget_gate_input=train_data(:,m)'*weight_forgetgate_x+cell_state(:,m-1)'*weight_forgetgate_c+bias_forget_gate; output_gate_input=train_data(:,m)'*weight_outputgate_x+cell_state(:,m-1)'*weight_outputgate_c+bias_output_gate; for n=1:cell_num input_gate(1,n)=1/(1+exp(-input_gate_input(1,n))); forget_gate(1,n)=1/(1+exp(-forget_gate_input(1,n))); output_gate(1,n)=1/(1+exp(-output_gate_input(1,n))); end cell_state(:,m)=(input_gate.*gate+cell_state(:,m-1)'.*forget_gate)'; end pre_h_state=tanh(cell_state(:,m)').*output_gate; h_state(:,m)=(pre_h_state*weight_preh_h)'; %误差计算 Error=h_state(:,m)-test_data(:,m); Error_Cost(1,iter)=sum(Error.^2); if(Error_Cost(1,iter)<cost_gate) flag=1; break; else [ weight_input_x,... weight_input_h,... weight_inputgate_x,... weight_inputgate_c,... weight_forgetgate_x,... weight_forgetgate_c,... weight_outputgate_x,... weight_outputgate_c,... weight_preh_h ]=LSTM_updata_weight(m,yita,Error,... weight_input_x,... weight_input_h,... weight_inputgate_x,... weight_inputgate_c,... weight_forgetgate_x,... weight_forgetgate_c,... weight_outputgate_x,... weight_outputgate_c,... weight_preh_h,... cell_state,h_state,... input_gate,forget_gate,... output_gate,gate,... train_data,pre_h_state,... input_gate_input,... output_gate_input,... forget_gate_input); end end if(Error_Cost(1,iter)<cost_gate) break; end end %% 绘制Error-Cost曲线图 % for n=1:1:iter % text(n,Error_Cost(1,n),'*'); % axis([0,iter,0,1]); % title('Error-Cost曲线图'); % end for n=1:1:iter semilogy(n,Error_Cost(1,n),'*'); hold on; title('Error-Cost曲线图'); end %% 使用第七天数据检验 %数据加载 test_final=[0.4557 0.4790 0.7019 0.8211 0.4601 0.4811 0.7101 0.8298 0.4612 0.4845 0.7188 0.8312]'; test_final=test_final/sqrt(sum(test_final.^2)); test_output=test_data(:,4); %前馈 m=4; gate=tanh(test_final'*weight_input_x+h_state(:,m-1)'*weight_input_h); input_gate_input=test_final'*weight_inputgate_x+cell_state(:,m-1)'*weight_inputgate_c+bias_input_gate; forget_gate_input=test_final'*weight_forgetgate_x+cell_state(:,m-1)'*weight_forgetgate_c+bias_forget_gate; output_gate_input=test_final'*weight_outputgate_x+cell_state(:,m-1)'*weight_outputgate_c+bias_output_gate; for n=1:cell_num input_gate(1,n)=1/(1+exp(-input_gate_input(1,n))); forget_gate(1,n)=1/(1+exp(-forget_gate_input(1,n))); output_gate(1,n)=1/(1+exp(-output_gate_input(1,n))); end cell_state_test=(input_gate.*gate+cell_state(:,m-1)'.*forget_gate)'; pre_h_state=tanh(cell_state_test').*output_gate; h_state_test=(pre_h_state*weight_preh_h)' test_output
LSTM_data_process.m
function [train_data,test_data]=LSTM_data_process() %% 数据加载并完成初始归一化 train_data_initial= [0.4413 0.4707 0.6953 0.8133 0.4379 0.4677 0.6981 0.8002 0.4517 0.4725 0.7006 0.8201; 0.4379 0.4677 0.6981 0.8002 0.4517 0.4725 0.7006 0.8201 0.4557 0.4790 0.7019 0.8211; 0.4517 0.4725 0.7006 0.8201 0.4557 0.4790 0.7019 0.8211 0.4601 0.4911 0.7101 0.8298]'; % train_data_initial=[ 0.4413 0.4707 0.6953 0.8133; % 0.4379 0.4677 0.6981 0.8002; % 0.4517 0.4725 0.7006 0.8201; % 0.4557 0.4790 0.7019 0.8211; % 0.4601 0.4811 0.7101 0.8298; % 0.4612 0.4845 0.7188 0.8312]'; test_data_initial=[0.4557 0.4790 0.7019 0.8211; 0.4612 0.4845 0.7188 0.8312; 0.4601 0.4811 0.7101 0.8298; 0.4615 0.4891 0.7201 0.8330]'; data_length=size(train_data_initial,1); %每个样本的长度 data_num=size(train_data_initial,2); %样本数目 %%归一化过程 for n=1:data_num train_data(:,n)=train_data_initial(:,n)/sqrt(sum(train_data_initial(:,n).^2)); end for m=1:size(test_data_initial,2) test_data(:,m)=test_data_initial(:,m)/sqrt(sum(test_data_initial(:,m).^2)); end
LSTM_updata_weight.m
function [ weight_input_x,weight_input_h,weight_inputgate_x,weight_inputgate_c,weight_forgetgate_x,weight_forgetgate_c,weight_outputgate_x,weight_outputgate_c,weight_preh_h ]=LSTM_updata_weight(n,yita,Error,... weight_input_x, weight_input_h, weight_inputgate_x,weight_inputgate_c,weight_forgetgate_x,weight_forgetgate_c,weight_outputgate_x,weight_outputgate_c,weight_preh_h,... cell_state,h_state,input_gate,forget_gate,output_gate,gate,train_data,pre_h_state,input_gate_input, output_gate_input,forget_gate_input) %%% 权重更新函数 input_num=12; cell_num=18; output_num=4; data_length=size(train_data,1); data_num=size(train_data,2); weight_preh_h_temp=weight_preh_h; %% 更新weight_preh_h权重 for m=1:output_num delta_weight_preh_h_temp(:,m)=2*Error(m,1)*pre_h_state; end weight_preh_h_temp=weight_preh_h_temp-yita*delta_weight_preh_h_temp; %% 更新weight_outputgate_x for num=1:output_num for m=1:data_length delta_weight_outputgate_x(m,:)=(2*weight_preh_h(:,num)*Error(num,1).*tanh(cell_state(:,n)))'.*exp(-output_gate_input).*(output_gate.^2)*train_data(m,n); end weight_outputgate_x=weight_outputgate_x-yita*delta_weight_outputgate_x; end %% 更新weight_inputgate_x for num=1:output_num for m=1:data_length delta_weight_inputgate_x(m,:)=2*(weight_preh_h(:,num)*Error(num,1))'.*output_gate.*(ones(size(cell_state(:,n)))-tanh(cell_state(:,n)).^2)'.*gate.*exp(-input_gate_input).*(input_gate.^2)*train_data(m,n); end weight_inputgate_x=weight_inputgate_x-yita*delta_weight_inputgate_x; end if(n~=1) %% 更新weight_input_x temp=train_data(:,n)'*weight_input_x+h_state(:,n-1)'*weight_input_h; for num=1:output_num for m=1:data_length delta_weight_input_x(m,:)=2*(weight_preh_h(:,num)*Error(num,1))'.*output_gate.*(ones(size(cell_state(:,n)))-tanh(cell_state(:,n)).^2)'.*input_gate.*(ones(size(temp))-tanh(temp.^2))*train_data(m,n); end weight_input_x=weight_input_x-yita*delta_weight_input_x; end %% 更新weight_forgetgate_x for num=1:output_num for m=1:data_length delta_weight_forgetgate_x(m,:)=2*(weight_preh_h(:,num)*Error(num,1))'.*output_gate.*(ones(size(cell_state(:,n)))-tanh(cell_state(:,n)).^2)'.*cell_state(:,n-1)'.*exp(-forget_gate_input).*(forget_gate.^2)*train_data(m,n); end weight_forgetgate_x=weight_forgetgate_x-yita*delta_weight_forgetgate_x; end %% 更新weight_inputgate_c for num=1:output_num for m=1:cell_num delta_weight_inputgate_c(m,:)=2*(weight_preh_h(:,num)*Error(num,1))'.*output_gate.*(ones(size(cell_state(:,n)))-tanh(cell_state(:,n)).^2)'.*gate.*exp(-input_gate_input).*(input_gate.^2)*cell_state(m,n-1); end weight_inputgate_c=weight_inputgate_c-yita*delta_weight_inputgate_c; end %% 更新weight_forgetgate_c for num=1:output_num for m=1:cell_num delta_weight_forgetgate_c(m,:)=2*(weight_preh_h(:,num)*Error(num,1))'.*output_gate.*(ones(size(cell_state(:,n)))-tanh(cell_state(:,n)).^2)'.*cell_state(:,n-1)'.*exp(-forget_gate_input).*(forget_gate.^2)*cell_state(m,n-1); end weight_forgetgate_c=weight_forgetgate_c-yita*delta_weight_forgetgate_c; end %% 更新weight_outputgate_c for num=1:output_num for m=1:cell_num delta_weight_outputgate_c(m,:)=2*(weight_preh_h(:,num)*Error(num,1))'.*tanh(cell_state(:,n))'.*exp(-output_gate_input).*(output_gate.^2)*cell_state(m,n-1); end weight_outputgate_c=weight_outputgate_c-yita*delta_weight_outputgate_c; end %% 更新weight_input_h temp=train_data(:,n)'*weight_input_x+h_state(:,n-1)'*weight_input_h; for num=1:output_num for m=1:output_num delta_weight_input_h(m,:)=2*(weight_preh_h(:,num)*Error(num,1))'.*output_gate.*(ones(size(cell_state(:,n)))-tanh(cell_state(:,n)).^2)'.*input_gate.*(ones(size(temp))-tanh(temp.^2))*h_state(m,n-1); end weight_input_h=weight_input_h-yita*delta_weight_input_h; end else %% 更新weight_input_x temp=train_data(:,n)'*weight_input_x; for num=1:output_num for m=1:data_length delta_weight_input_x(m,:)=2*(weight_preh_h(:,num)*Error(num,1))'.*output_gate.*(ones(size(cell_state(:,n)))-tanh(cell_state(:,n)).^2)'.*input_gate.*(ones(size(temp))-tanh(temp.^2))*train_data(m,n); end weight_input_x=weight_input_x-yita*delta_weight_input_x; end end weight_preh_h=weight_preh_h_temp; end
四、程序结果图
Error_Cost图
第七天预测值与理论值,第一组为预测值,第二组为实际值
相关文章推荐
- 递归神经网络RNN原理——Elman网络原理——结合实例MATLAB(BPTT算法)实现
- 卷积神经网络CNN原理——结合实例matlab实现
- 自组织神经网络SOM原理——结合例子MATLAB实现
- 自组织神经网络SOM原理——结合例子MATLAB实现
- 主成分分析(PCA)原理与故障诊断(SPE、T^2以及结合二者的综合指标)-MATLAB实现
- (一)Java EE 5实现Web服务(Web Services)及多种客户端实例-原理
- (一)Java EE 5实现Web服务(Web Services)及多种客户端实例-原理
- (一)Java EE 5实现Web服务(Web Services)及多种客户端实例-原理
- (一)Java EE 5实现Web服务(Web Services)及多种客户端实例-原理
- (一)Java EE 5实现Web服务(Web Services)及多种客户端实例-原理
- HTTP PUSH技术原理,结合ASP.NET实现以及评述
- ASP文件上传原理分析及实现实例
- (一)Java EE 5实现Web服务(Web Services)及多种客户端实例-原理
- Java EE 5实现Web服务(Web Services)及多种客户端实例-原理
- (一)Java EE 5实现Web服务(Web Services)及多种客户端实例-原理
- (一)Java EE 5实现Web服务(Web Services)及多种客户端实例-原理
- (一)Java EE 5实现Web服务(Web Services)及多种客户端实例-原理
- (一)Java EE 5实现Web服务(Web Services)及多种客户端实例-原理
- 进度条JProgressBar结合线程实现copy文件进度实例
- (一)Java EE 5实现Web服务(Web Services)及多种客户端实例-原理