您的位置:首页 > 编程语言 > C语言/C++

C++中反向传播算法的简单实现

2007-02-10 13:48 579 查看

#include <cmath>


#include <ctime>




#define frand() ((double)rand()/(double)RAND_MAX)


#define sigmoid(x) (1/(1+exp(-(x))))


#define sqr(x) ((x)*(x))




bool ANN_Predict(int i_d, int h_d, int o_d, double *w_i_h, double *w_h_o, double *x, double *h, double *y)




...{


    double *fw_i_h = w_i_h, sum = 0;


    double *fx = x, *fh = h;


    for (int i = 0; i < h_d; i++)




    ...{


        sum = 0; fx = x;


        for (int j = 0; j < i_d; j++)


            sum+=*(fw_i_h++) * *(fx++);


        *(fh++) = sigmoid(sum);


    }


    double *fw_h_o = w_h_o, *fy = y;


    for (int i = 0; i < o_d; i++)




    ...{


        sum = 0; fh = h;


        for (int j = 0; j < h_d; j++)


            sum+=*(fw_h_o++) * *(fh++);


        *(fy++) = sigmoid(sum);


    }


    return 1;


}




bool back_propagate(double *o_delta, double *h_delta, int i_d, int h_d, int o_d, double *x, double *h, double *y, double *ty, double *w_i_h, double *w_h_o, double *c_i_h, double *c_h_o, double n, double m)




...{


    double *fo_delta = o_delta;


    double *fy = y, *fty = ty;


    for (int i = 0; i < o_d; i++, fy++, fty++, fo_delta++)


        *fo_delta = *fty*(1-*fty)*(*fy-*fty);




    double sum = 0, change = 0;


    double *fw_h_o = w_h_o, *ffw_h_o = NULL;


    double *fh_delta = h_delta, *fh = h;


    for (int i = 0; i < h_d; i++, fw_h_o++, fh_delta++, fh++)




    ...{


        sum = 0; ffw_h_o = fw_h_o; fo_delta = o_delta;


        for (int j = 0; j < o_d; j++, fo_delta++, ffw_h_o+=h_d)


            sum+=*ffw_h_o * *fo_delta;


        *fh_delta = *fh*(1-*fh)*sum;


    }




    fw_h_o = w_h_o; double *fc_h_o = c_h_o; fh = h;


    for (int i = 0; i < h_d; i++, fh++, fw_h_o++)




    ...{


        ffw_h_o = fw_h_o; fo_delta = o_delta;


        for (int j = 0; j < o_d; j++, fo_delta++, ffw_h_o+=h_d)




        ...{


            change = *fo_delta * *fh;


            *ffw_h_o+=change*n+*fc_h_o*m;


            *(fc_h_o++) = change;


        }


    }


   


    double *fw_i_h = w_i_h, *ffw_i_h = NULL;


    double *fc_i_h = c_i_h, *fx = x;


    for (int i = 0; i < i_d; i++, fx++, fw_i_h++)




    ...{


        ffw_i_h = fw_i_h; fh_delta = h_delta;


        for (int j = 0; j < h_d; j++, fh_delta++, ffw_i_h+=i_d)




        ...{


            change = *fh_delta * *fx;


            *ffw_i_h+=change*n+*fc_i_h*m;


            *(fc_i_h++) = change;


        }


    }


}




bool ANN_Training(int i_d, int h_d, int o_d, int C, int T, double *x, double *w, double *y, double *w_i_h, double *w_h_o, double n, double m)




...{


    double *c_i_h = (double*)malloc(i_d*h_d*sizeof(double));


    double *c_h_o = (double*)malloc(h_d*o_d*sizeof(double));


    double *h = (double*)malloc(h_d*sizeof(double));


    double *ty = (double*)malloc(o_d*sizeof(double));


    double *h_delta = (double*)malloc(h_d*sizeof(double));


    double *o_delta = (double*)malloc(o_d*sizeof(double));


    double *fw_i_h = w_i_h, *fw_h_o = w_h_o, *fw = w;


    srand(time(NULL));


    for (int i = 0; i < i_d*h_d; i++, fw_i_h++) *fw_i_h = frand()*4-2;


    for (int i = 0; i < h_d*o_d; i++, fw_h_o++) *fw_h_o = frand()*4-2;


    double *fx = x, *fy = y;




    for (int i = 0; i < C; i++)




    ...{


        fx = x; fy = y; fw = w;


        for (int j = 0; j < T; j++, fx+=i_d, fy+=o_d, fw++)




        ...{


            ANN_Predict(i_d, h_d, o_d, w_i_h, w_h_o, fx, h, ty);


            back_propagate(o_delta, h_delta, i_d, h_d, o_d, fx, h, fy, ty, w_i_h, w_h_o, c_i_h, c_h_o, *fw*n, ((j > 0) ? *(fw-1):*(w+T))*m);


        }


    }


    free(c_i_h); free(c_h_o); free(h); free(ty); free(h_delta); free(o_delta);




    return 1;


}
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息