1. 程式人生 > >長短期記憶網路LSTM(matlab)

長短期記憶網路LSTM(matlab)

if(n~=1)
    %% 更新weight_input_x
    temp=train(1:input_num,n)'*W_input_x+h_state(:,n-1)'*W_input_h;
    for num=1:output_num
        for m=1:data_length
            delta_weight_input_x(m,:)=2*(W_preh_h(:,num)*error(num,1))'.*output_gate.*(ones(size(cell_state(:,n)))-tanh(cell_state(:,n)).^2)'.*input_gate.*(ones(size(temp))-tanh(temp.^2))*train(m,n);
        end
        W_input_x=W_input_x-lr*delta_weight_input_x;
    end
    %% 更新weight_forgetgate_x
    for num=1:output_num
        for m=1:data_length
            delta_weight_forgetgate_x(m,:)=2*(W_preh_h(:,num)*error(num,1))'.*output_gate.*(ones(size(cell_state(:,n)))-tanh(cell_state(:,n)).^2)'.*cell_state(:,n-1)'.*exp(-forget_gate_input).*(forget_gate.^2)*train(m,n);
        end
        W_forgetgate_x=W_forgetgate_x-lr*delta_weight_forgetgate_x;
    end
    %% 更新weight_inputgate_c
    for num=1:output_num
        for m=1:cell_num
            delta_weight_inputgate_c(m,:)=2*(W_preh_h(:,num)*error(num,1))'.*output_gate.*(ones(size(cell_state(:,n)))-tanh(cell_state(:,n)).^2)'.*gate.*exp(-input_gate_input).*(input_gate.^2)*cell_state(m,n-1);
        end
        W_inputgate_c=W_inputgate_c-lr*delta_weight_inputgate_c;
    end
    %% 更新weight_forgetgate_c
    for num=1:output_num
        for m=1:cell_num
            delta_weight_forgetgate_c(m,:)=2*(W_preh_h(:,num)*error(num,1))'.*output_gate.*(ones(size(cell_state(:,n)))-tanh(cell_state(:,n)).^2)'.*cell_state(:,n-1)'.*exp(-forget_gate_input).*(forget_gate.^2)*cell_state(m,n-1);
        end
        W_forgetgate_c=W_forgetgate_c-lr*delta_weight_forgetgate_c;
    end
    %% 更新weight_outputgate_c
    for num=1:output_num
        for m=1:cell_num
            delta_weight_outputgate_c(m,:)=2*(W_preh_h(:,num)*error(num,1))'.*tanh(cell_state(:,n))'.*exp(-output_gate_input).*(output_gate.^2)*cell_state(m,n-1);
        end
        W_outputgate_c=W_outputgate_c-lr*delta_weight_outputgate_c;
    end
    %% 更新weight_input_h
    temp=train(1:input_num,n)'*W_input_x+h_state(:,n-1)'*W_input_h;
    for num=1:output_num
        for m=1:output_num
            delta_weight_input_h(m,:)=2*(W_preh_h(:,num)*error(num,1))'.*output_gate.*(ones(size(cell_state(:,n)))-tanh(cell_state(:,n)).^2)'.*input_gate.*(ones(size(temp))-tanh(temp.^2))*h_state(m,n-1);
        end
        W_input_h=W_input_h-lr*delta_weight_input_h;
    end
else
    %% 更新weight_input_x
    temp=train(1:input_num,n)'*W_input_x;
    for num=1:output_num
        for m=1:data_length
            delta_weight_input_x(m,:)=2*(W_preh_h(:,num)*error(num,1))'.*output_gate.*(ones(size(cell_state(:,n)))-tanh(cell_state(:,n)).^2)'.*input_gate.*(ones(size(temp))-tanh(temp.^2))*train(m,n);
        end
        W_input_x=W_input_x-lr*delta_weight_input_x;
    end
end
W_preh_h=weight_preh_h_temp;
end