1. 程式人生 > >莫煩課程Batch Normalization 批標準化

莫煩課程Batch Normalization 批標準化

github cti mas pen get lin pytorch 生成 def

 for i in range(N_HIDDEN):               # build hidden layers and BN layers
            input_size = 1 if i == 0 else 10
            fc = nn.Linear(input_size, 10)
            setattr(self, ‘fc%i‘ % i, fc)       # IMPORTANT set layer to the Module
            self._set_init(fc)                  # parameters initialization
            self.fcs.append(fc)
            if self.do_bn:
                bn = nn.BatchNorm1d(10, momentum=0.5)
                setattr(self, ‘bn%i‘ % i, bn)   # IMPORTANT set layer to the Module
self.bns.append(bn)

上面的代碼對每個隱層進行批標準化,setattr(self, ‘fc%i‘ % i, fc)作用相當於self.fci=fc

每次生成的結果append到bns的最後面,結果的size 10×10,取出這些數據是非常方便

def forward(self, x):
        pre_activation = [x]
        if self.do_bn: x = self.bn_input(x)     # input batch normalization
        layer_input = [x]
        for i in range(N_HIDDEN):
            x = self.fcs[i](x)
            pre_activation.append(x)
            if self.do_bn: x = self.bns[i](x)   # batch normalization
            x = ACTIVATION(x)
            layer_input.append(x)
        out = self.predict(x)
return out, layer_input, pre_activation

全部的源代碼

莫煩課程Batch Normalization 批標準化