1. 程式人生 > >矩陣標準差在神經網路中的反向傳播以及數值微分梯度驗證

矩陣標準差在神經網路中的反向傳播以及數值微分梯度驗證

最近開腦洞想訓練一個關於球面擬合的模型於是用到了標準差作為輸出層的損失函式,所以就對於標準差方程進行反向傳播推導了一下。

現在分享一下推導過程和結果和用數值微分方法對於結果正確性的驗證,順便記錄一下以免忘記了。

這是標準差方程

標準差主要是用來描述資料離散程度,其實就是方差的開平方

 

首先若a為矩陣,那麼標準差計算可用numpy實現如下

np.sqrt(np.sum((a - np.mean(a)) ** 2) / a.size);

矩陣標準差數值微分求梯度如下,(這個函式主要用來驗證反向傳播推導結果)

# 數值微分求標準差梯度
def gradient ():
    d 
= 1e-5; grad = np.zeros(a.size); func = lambda : np.sqrt(np.sum((a - np.mean(a)) ** 2) / a.size); # func = lambda : np.std(a, ddof = 1); # func = lambda : np.mean(a); for index, value in enumerate(a): bak = value; a[index] -= d; leftv = func(); a[index]
= bak; a[index] += d; rightv = func(); a[index] = bak; grad[index] = (rightv - leftv) / (d * 2); return grad;

接下來是標準差方程的反向傳播推導過程,直接上草稿紙

 

這裡初步推匯出結果

所以,反向傳播求標準差方程的Python實現程式碼如下

這裡傳入索引可計算矩陣中每一個元素相對於標準差方程的導數,這裡沒用numpy陣列作為引數,可自己修改程式碼支援矩陣,我就不附上了

def func2 (index):
    
# x x = a[index]; # 平均數 avg = np.mean(a); # 平方和 sqsum = np.sum((a - avg) ** 2); # N n = a.size; print((np.power(sqsum / n, -0.5) * (x - avg)) / n);

 看一下結果

上面是數值微分的結果,下面是反向傳播的結果,基本一致,可以證明反向傳播推導正確

 

附上全部程式碼

#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import numpy as np;

a = np.array([3.0, 3.0, 2.0, 4.9, 100.2, -8.9]);

# 數值微分求標準差梯度
def gradient ():
    d = 1e-5;
    grad = np.zeros(a.size);
    func = lambda : np.sqrt(np.sum((a - np.mean(a)) ** 2) / a.size);
    # func = lambda : np.std(a, ddof = 1);
    # func = lambda : np.mean(a);
    for index, value in enumerate(a):
        bak = value;
        a[index] -= d;
        leftv = func();
        a[index] = bak;
        a[index] += d;
        rightv = func();
        a[index] = bak;
        grad[index] = (rightv - leftv) / (d * 2);
    return grad;

grad = gradient();

def func2 (index):
    # x
    x = a[index];
    # 平均數
    avg = np.mean(a);
    # 平方和
    sqsum = np.sum((a - avg) ** 2);
    # N
    n = a.size;
    return (np.power(sqsum / n, -0.5) * (x - avg)) / n;

print(grad);
n1 = func2(0);
n2 = func2(1);
n3 = func2(2);
n4 = func2(3);
n5 = func2(4);
n6 = func2(5);
b = [n1, n2, n3, n4, n5, n6];
print(b);