1. 程式人生 > >PyTorch 中 weight decay 的設定

PyTorch 中 weight decay 的設定

先介紹一下 Caffe 和 TensorFlow 中 weight decay 的設定: - 在 **Caffe** 中, `SolverParameter.weight_decay` 可以作用於所有的可訓練引數, 不妨稱為 global weight decay, 另外還可以為各層中的每個可訓練引數設定獨立的 `decay_mult`, global weight decay 和當前可訓練引數的 `decay_mult` 共同決定了當前可訓練引數的 weight decay. - 在 **TensorFlow** 中, 某些介面可以為其下建立的可訓練引數設定獨立的 weight decay (如 `slim.conv2d` 可通過 `weights_regularizer`, `bias_regularizer` 分別為其下定義的 weight 和 bias 設定不同的 regularizer). 在 PyTorch 中, 模組 (`nn.Module`) 和引數 (`nn.Parameter`) 的定義沒有暴露與 weight decay 設定相關的 argument, 它把 weight decay 的設定放到了 `torch.optim.Optimizer` (嚴格地說, 是 `torch.optim.Optimizer` 的子類, 下同) 中. 在 `torch.optim.Optimizer` 中直接設定 `weight_decay`, 其將作用於該 optimizer 負責優化的所有可訓練引數 (和 Caffe 中 `SolverParameter.weight_decay` 的作用類似), 這往往不是所期望的: BatchNorm 層的 $\gamma$ 和 $\beta$ 就不應該新增正則化項, 卷積層和全連線層的 bias 也往往不用加正則化項. 幸運地是, `torch.optim.Optimizer` 支援為不同的可訓練引數設定不同的 weight_decay (`params` 支援 dict 型別), 於是問題轉化為如何將不期望新增正則化項的可訓練引數 (如 BN 層的可訓練引數及卷積層和全連線層的 bias) 從可訓練引數列表中分離出來. 筆者借鑑網上的一些方法, 寫了一個滿足該功能的函式, 沒有經過嚴格測試, 僅供參考. ```python """ 作者: 採石工 部落格: http://www.cnblogs.com/quarryman/ 釋出時間: 2020年10月21日 版權宣告: 自由分享, 保持署名-非商業用途-非衍生, 知識共享3.0協議. """ import torch from torchvision import models def split_parameters(module): params_decay = [] params_no_decay = [] for m in module.modules(): if isinstance(m, torch.nn.Linear): params_decay.append(m.weight) if m.bias is not None: params_no_decay.append(m.bias) elif isinstance(m, torch.nn.modules.conv._ConvNd): params_decay.append(m.weight) if m.bias is not None: params_no_decay.append(m.bias) elif isinstance(m, torch.nn.modules.batchnorm._BatchNorm): params_no_decay.extend([*m.parameters()]) elif len(list(m.children())) == 0: params_decay.extend([*m.parameters()]) assert len(list(module.parameters())) == len(params_decay) + len(params_no_decay) return params_decay, params_no_decay def print_parameters_info(parameters): for k, param in enumerate(parameters): print('[{}/{}] {}'.format(k+1, len(parameters), param.shape)) if __name__ == '__main__': model = models.resnet18(pretrained=False) params_decay, params_no_decay = split_parameters(model) print_parameters_info(params_decay) print_parameters_info(params_no_decay) ``` ## 參考 - https://discuss.pytorch.org/t/weight-decay-in-the-optimizers-is-a-bad-idea-especially-with-batchnorm/16994 - https://discuss.pytorch.org/t/changing-the-weight-decay-on-bias-using-named-parameters/19132/4 - https://discuss.pytorch.org/t/how-to-set-different-learning-rate-for-weight-and-bias-in-one-layer/13450 - [Allow to set 0 weight decay for biases and params in batch norm #1402](https://github.com/pytorch/pytorch/issues/1402) ## 版權宣告 版權宣告:自由分享,保持署名-非商業用途-非衍生,知識共享3.0協議。 如果你對本文有疑問或建議,歡迎留言!轉載請保留版權宣告! 如果你覺得本文不錯, 也可以用微信讚賞一下哈. ![](https://files-cdn.cnblogs.com/files/quarryman/wechat_p