權重初始化方式(based on FSRCNN)
阿新 • • 發佈:2018-11-30
G網路其實就是SR網路,D網路是對抗用的,作為GAN的。在原始碼中的network.py可以改變權重初始化的方式(關於程式碼,請參考博文基於pytorch的FSRCNN)
def weights_init_normal(m, std=0.02): classname = m.__class__.__name__ if classname.find('Conv') != -1: init.normal_(m.weight.data, 0.0, std) if m.bias is not None: m.bias.data.zero_() elif classname.find('Linear') != -1: init.normal_(m.weight.data, 0.0, std) if m.bias is not None: m.bias.data.zero_() elif classname.find('BatchNorm2d') != -1: init.normal_(m.weight.data, 1.0, std) # BN also uses norm init.constant_(m.bias.data, 0.0) def weights_init_kaiming(m, scale=1): classname = m.__class__.__name__ if classname.find('Conv') != -1: init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') m.weight.data *= scale if m.bias is not None: m.bias.data.zero_() elif classname.find('Linear') != -1: init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') m.weight.data *= scale if m.bias is not None: m.bias.data.zero_() elif classname.find('BatchNorm2d') != -1: init.constant_(m.weight.data, 1.0) init.constant_(m.bias.data, 0.0) def weights_init_orthogonal(m): classname = m.__class__.__name__ if classname.find('Conv') != -1: init.orthogonal_(m.weight.data, gain=1) if m.bias is not None: m.bias.data.zero_() elif classname.find('Linear') != -1: init.orthogonal_(m.weight.data, gain=1) if m.bias is not None: m.bias.data.zero_() elif classname.find('BatchNorm2d') != -1: init.constant_(m.weight.data, 1.0) init.constant_(m.bias.data, 0.0) def init_weights(net, init_type='kaiming', scale=1, std=0.02): # scale for 'kaiming', std for 'normal'. print('initialization method [{:s}]'.format(init_type)) if init_type == 'normal': weights_init_normal_ = functools.partial(weights_init_normal, std=std) net.apply(weights_init_normal_) elif init_type == 'kaiming': weights_init_kaiming_ = functools.partial(weights_init_kaiming, scale=scale) net.apply(weights_init_kaiming_) elif init_type == 'orthogonal': net.apply(weights_init_orthogonal) else: raise NotImplementedError('initialization method [{:s}] not implemented'.format(init_type)) #################### # define network #################### # Generator def define_G(opt): gpu_ids = opt['gpu_ids'] opt_net = opt['network_G'] which_model = opt_net['which_model_G']#hear decide which model, and thia para is in .json. if you add a new model, this part must be modified if which_model == 'sr_resnet': # SRResNet netG = arch.SRResNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'], \ nb=opt_net['nb'], upscale=opt_net['scale'], norm_type=opt_net['norm_type'], \ act_type='relu', mode=opt_net['mode'], upsample_mode='pixelshuffle') ############################################################################################################# elif which_model=='fsrcnn':#FSRCNN netG=arch.FSRCNN(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'], \ nb=opt_net['nb'], upscale=opt_net['scale'], norm_type=opt_net['norm_type'], \ act_type='relu', mode=opt_net['mode'], upsample_mode='pixelshuffle') ############################################################################################################# elif which_model == 'sft_arch': # SFT-GAN netG = sft_arch.SFT_Net() elif which_model == 'RRDB_net': # RRDB netG = arch.RRDBNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'], nb=opt_net['nb'], gc=opt_net['gc'], upscale=opt_net['scale'], norm_type=opt_net['norm_type'], act_type='leakyrelu', mode=opt_net['mode'], upsample_mode='upconv') else: raise NotImplementedError('Generator model [{:s}] not recognized'.format(which_model)) if opt['is_train']: init_weights(netG, init_type='kaiming', scale=0.1)###the weight initing. you can change this to change the method of init_weight if gpu_ids: assert torch.cuda.is_available() netG = nn.DataParallel(netG) return netG
之前做的都是kaiming權重初始化,現在試試其他兩種:
結果如下圖所示
關於pytorch中的init
https://www.pytorchtutorial.com/docs/package_references/nn_init/(官網)
https://blog.csdn.net/qq_19598705/article/details/80396047