1. 程式人生 > >深度學習:從MLP到樸素的神經網路

深度學習:從MLP到樸素的神經網路

文章目錄

神經網路都幹了些什麼?

分類的理解:

首先我們需要了解一下,分類的過程究竟幹了些什麼?
一個數據,m維features,最後給這個陣列一個類別class,其中class的數目是遠遠小於features的數目的。
從這個角度來說,分類過程就是提煉或者成為簡化features的過程。

從這個角度再看神經網路:

input_layer --> NN -->output_layer,輸入N個數據,輸出K個類別,那麼NN都幹了些什麼?
我們知道NN就是在訓練W,W為(input_dim, output_dim),也可以寫成(feature_dim,class_dim),其中feature_dim>>class_dim,也就是啥NN乾的就是提純features的活,說提純也是不準確,最後是為了提純,但是過程中可能會稀釋,多層NN就是不斷伸縮變換features的過程,每一層神經元的個數就是當前層features的數目。
需要提出的是我們輸入的都是低階的靠近物理的features,NN會不斷的高階化抽象化features。在這裡手工特徵工程不再重要,這是非常令人振奮人心的。
再次提醒一下:W是features和features的map。我們把W看成features的權重,我們每次餵給NN的是features,千萬不要混淆。

神經網路實現:

求∂L/∂w1,∂L/∂w2:

∂L/∂w2使用公式1,∂L/∂w1使用公式1和反向傳遞的誤差
在這裡插入圖片描述

實現程式碼:

class NaiveNN:
    def __init__(self, ws=None):
        self._ws = ws
        
    @staticmethod
    def relu(x):
        return np.maximum(0,x)
    
    # hidden_dim is the hidden units m
    def fit(self, x, y, hidden_dim=4, lr=
1e-3, epoch=1000): input_dim, output_dim = x.shape[1], y.shape[1] if self._ws is None: self._ws = [ np.random.random([input_dim, hidden_dim]), np.random.random([hidden_dim, output_dim])] losses = [] for _ in
range(epoch): # forward pass h = x.dot(self._ws[0]) h_relu = NaiveNN.relu(h) y_pred = h_relu.dot(self._ws[1]) # np.linalg.norm(求範數) losses.append(np.linalg.norm(y_pred-y,ord="fro")) # backford pass # ∂L/∂y_pred d1 = 2*(y_pred-y) # ∂L/∂w2 = ∂y_pred/∂w2* ∂L/∂y_pred # ∂y_pred/∂w2= h_relu.T dw2 = h_relu.T.dot(d1) # ∂L/∂w2 = ∂H/∂w2* ∂L/∂H # ∂L/∂H = ∂L/∂y_pred * w2^T * relu' dw1 = x.T.dot(d1.dot(self._ws[1].T)*(h_relu != 0)) # uodate w self._ws[0] -= lr*dw1 self._ws[1] -= lr*dw2 return losses def predict(self,x): h = x.dot(self._ws[0]) h_relu = NaiveNN.relu(h) y_pred = h_relu.dot(self._ws[1]) return np.argmax(y_pred, axis=1)

測試程式碼:

x, y = gen_five_clusters()
label = np.argmax(y, axis=1)
nn = NaiveNN()
losses = nn.fit(x, y, 32, 1e-5)
visualize2d(nn, x, label, draw_background=True)
print("準確率:{:8.6} %".format((nn.predict(x) == label).mean() * 100))

plt.figure()
plt.plot(np.arange(1, len(losses)+1), losses)
plt.show()

在這裡插入圖片描述在這裡插入圖片描述