大體過程
對層數進行剪枝
1、載入預訓練的模型;
2、提取所需要層的權重,並對其進行重新命名。比如我們想要第0層和第11層的權重,那麼需要將第11層的權重保留下來並且重新命名為第1層的名字;
3、更改模型配置檔案(保留幾層就是幾),並且將第11層的權重賦值給第1層;
4、儲存模型為pytorch_model.bin;
首先我們來看一下bert具體有哪些權重:
import torch
from transformers import BertTokenizer, BertModel
bertModel = BertModel.from_pretrained('bert-base-chinese', output_hidden_states=True, output_attentions=True)
tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')
for name,param in bertModel.named_parameters():
print(name, param.shape)
embeddings.word_embeddings.weight torch.Size([21128, 768])
embeddings.position_embeddings.weight torch.Size([512, 768])
embeddings.token_type_embeddings.weight torch.Size([2, 768])
embeddings.LayerNorm.weight torch.Size([768])
embeddings.LayerNorm.bias torch.Size([768])
encoder.layer.0.attention.self.query.weight torch.Size([768, 768])
encoder.layer.0.attention.self.query.bias torch.Size([768])
encoder.layer.0.attention.self.key.weight torch.Size([768, 768])
encoder.layer.0.attention.self.key.bias torch.Size([768])
encoder.layer.0.attention.self.value.weight torch.Size([768, 768])
encoder.layer.0.attention.self.value.bias torch.Size([768])
encoder.layer.0.attention.output.dense.weight torch.Size([768, 768])
encoder.layer.0.attention.output.dense.bias torch.Size([768])
encoder.layer.0.attention.output.LayerNorm.weight torch.Size([768])
encoder.layer.0.attention.output.LayerNorm.bias torch.Size([768])
encoder.layer.0.intermediate.dense.weight torch.Size([3072, 768])
encoder.layer.0.intermediate.dense.bias torch.Size([3072])
encoder.layer.0.output.dense.weight torch.Size([768, 3072])
encoder.layer.0.output.dense.bias torch.Size([768])
encoder.layer.0.output.LayerNorm.weight torch.Size([768])
encoder.layer.0.output.LayerNorm.bias torch.Size([768])
encoder.layer.1.attention.self.query.weight torch.Size([768, 768])
encoder.layer.1.attention.self.query.bias torch.Size([768])
encoder.layer.1.attention.self.key.weight torch.Size([768, 768])
encoder.layer.1.attention.self.key.bias torch.Size([768])
encoder.layer.1.attention.self.value.weight torch.Size([768, 768])
encoder.layer.1.attention.self.value.bias torch.Size([768])
encoder.layer.1.attention.output.dense.weight torch.Size([768, 768])
encoder.layer.1.attention.output.dense.bias torch.Size([768])
encoder.layer.1.attention.output.LayerNorm.weight torch.Size([768])
encoder.layer.1.attention.output.LayerNorm.bias torch.Size([768])
encoder.layer.1.intermediate.dense.weight torch.Size([3072, 768])
encoder.layer.1.intermediate.dense.bias torch.Size([3072])
encoder.layer.1.output.dense.weight torch.Size([768, 3072])
encoder.layer.1.output.dense.bias torch.Size([768])
encoder.layer.1.output.LayerNorm.weight torch.Size([768])
encoder.layer.1.output.LayerNorm.bias torch.Size([768])
encoder.layer.2.attention.self.query.weight torch.Size([768, 768])
encoder.layer.2.attention.self.query.bias torch.Size([768])
encoder.layer.2.attention.self.key.weight torch.Size([768, 768])
encoder.layer.2.attention.self.key.bias torch.Size([768])
encoder.layer.2.attention.self.value.weight torch.Size([768, 768])
encoder.layer.2.attention.self.value.bias torch.Size([768])
encoder.layer.2.attention.output.dense.weight torch.Size([768, 768])
encoder.layer.2.attention.output.dense.bias torch.Size([768])
encoder.layer.2.attention.output.LayerNorm.weight torch.Size([768])
encoder.layer.2.attention.output.LayerNorm.bias torch.Size([768])
encoder.layer.2.intermediate.dense.weight torch.Size([3072, 768])
encoder.layer.2.intermediate.dense.bias torch.Size([3072])
encoder.layer.2.output.dense.weight torch.Size([768, 3072])
encoder.layer.2.output.dense.bias torch.Size([768])
encoder.layer.2.output.LayerNorm.weight torch.Size([768])
encoder.layer.2.output.LayerNorm.bias torch.Size([768])
encoder.layer.3.attention.self.query.weight torch.Size([768, 768])
encoder.layer.3.attention.self.query.bias torch.Size([768])
encoder.layer.3.attention.self.key.weight torch.Size([768, 768])
encoder.layer.3.attention.self.key.bias torch.Size([768])
encoder.layer.3.attention.self.value.weight torch.Size([768, 768])
encoder.layer.3.attention.self.value.bias torch.Size([768])
encoder.layer.3.attention.output.dense.weight torch.Size([768, 768])
encoder.layer.3.attention.output.dense.bias torch.Size([768])
encoder.layer.3.attention.output.LayerNorm.weight torch.Size([768])
encoder.layer.3.attention.output.LayerNorm.bias torch.Size([768])
encoder.layer.3.intermediate.dense.weight torch.Size([3072, 768])
encoder.layer.3.intermediate.dense.bias torch.Size([3072])
encoder.layer.3.output.dense.weight torch.Size([768, 3072])
encoder.layer.3.output.dense.bias torch.Size([768])
encoder.layer.3.output.LayerNorm.weight torch.Size([768])
encoder.layer.3.output.LayerNorm.bias torch.Size([768])
encoder.layer.4.attention.self.query.weight torch.Size([768, 768])
encoder.layer.4.attention.self.query.bias torch.Size([768])
encoder.layer.4.attention.self.key.weight torch.Size([768, 768])
encoder.layer.4.attention.self.key.bias torch.Size([768])
encoder.layer.4.attention.self.value.weight torch.Size([768, 768])
encoder.layer.4.attention.self.value.bias torch.Size([768])
encoder.layer.4.attention.output.dense.weight torch.Size([768, 768])
encoder.layer.4.attention.output.dense.bias torch.Size([768])
encoder.layer.4.attention.output.LayerNorm.weight torch.Size([768])
encoder.layer.4.attention.output.LayerNorm.bias torch.Size([768])
encoder.layer.4.intermediate.dense.weight torch.Size([3072, 768])
encoder.layer.4.intermediate.dense.bias torch.Size([3072])
encoder.layer.4.output.dense.weight torch.Size([768, 3072])
encoder.layer.4.output.dense.bias torch.Size([768])
encoder.layer.4.output.LayerNorm.weight torch.Size([768])
encoder.layer.4.output.LayerNorm.bias torch.Size([768])
encoder.layer.5.attention.self.query.weight torch.Size([768, 768])
encoder.layer.5.attention.self.query.bias torch.Size([768])
encoder.layer.5.attention.self.key.weight torch.Size([768, 768])
encoder.layer.5.attention.self.key.bias torch.Size([768])
encoder.layer.5.attention.self.value.weight torch.Size([768, 768])
encoder.layer.5.attention.self.value.bias torch.Size([768])
encoder.layer.5.attention.output.dense.weight torch.Size([768, 768])
encoder.layer.5.attention.output.dense.bias torch.Size([768])
encoder.layer.5.attention.output.LayerNorm.weight torch.Size([768])
encoder.layer.5.attention.output.LayerNorm.bias torch.Size([768])
encoder.layer.5.intermediate.dense.weight torch.Size([3072, 768])
encoder.layer.5.intermediate.dense.bias torch.Size([3072])
encoder.layer.5.output.dense.weight torch.Size([768, 3072])
encoder.layer.5.output.dense.bias torch.Size([768])
encoder.layer.5.output.LayerNorm.weight torch.Size([768])
encoder.layer.5.output.LayerNorm.bias torch.Size([768])
encoder.layer.6.attention.self.query.weight torch.Size([768, 768])
encoder.layer.6.attention.self.query.bias torch.Size([768])
encoder.layer.6.attention.self.key.weight torch.Size([768, 768])
encoder.layer.6.attention.self.key.bias torch.Size([768])
encoder.layer.6.attention.self.value.weight torch.Size([768, 768])
encoder.layer.6.attention.self.value.bias torch.Size([768])
encoder.layer.6.attention.output.dense.weight torch.Size([768, 768])
encoder.layer.6.attention.output.dense.bias torch.Size([768])
encoder.layer.6.attention.output.LayerNorm.weight torch.Size([768])
encoder.layer.6.attention.output.LayerNorm.bias torch.Size([768])
encoder.layer.6.intermediate.dense.weight torch.Size([3072, 768])
encoder.layer.6.intermediate.dense.bias torch.Size([3072])
encoder.layer.6.output.dense.weight torch.Size([768, 3072])
encoder.layer.6.output.dense.bias torch.Size([768])
encoder.layer.6.output.LayerNorm.weight torch.Size([768])
encoder.layer.6.output.LayerNorm.bias torch.Size([768])
encoder.layer.7.attention.self.query.weight torch.Size([768, 768])
encoder.layer.7.attention.self.query.bias torch.Size([768])
encoder.layer.7.attention.self.key.weight torch.Size([768, 768])
encoder.layer.7.attention.self.key.bias torch.Size([768])
encoder.layer.7.attention.self.value.weight torch.Size([768, 768])
encoder.layer.7.attention.self.value.bias torch.Size([768])
encoder.layer.7.attention.output.dense.weight torch.Size([768, 768])
encoder.layer.7.attention.output.dense.bias torch.Size([768])
encoder.layer.7.attention.output.LayerNorm.weight torch.Size([768])
encoder.layer.7.attention.output.LayerNorm.bias torch.Size([768])
encoder.layer.7.intermediate.dense.weight torch.Size([3072, 768])
encoder.layer.7.intermediate.dense.bias torch.Size([3072])
encoder.layer.7.output.dense.weight torch.Size([768, 3072])
encoder.layer.7.output.dense.bias torch.Size([768])
encoder.layer.7.output.LayerNorm.weight torch.Size([768])
encoder.layer.7.output.LayerNorm.bias torch.Size([768])
encoder.layer.8.attention.self.query.weight torch.Size([768, 768])
encoder.layer.8.attention.self.query.bias torch.Size([768])
encoder.layer.8.attention.self.key.weight torch.Size([768, 768])
encoder.layer.8.attention.self.key.bias torch.Size([768])
encoder.layer.8.attention.self.value.weight torch.Size([768, 768])
encoder.layer.8.attention.self.value.bias torch.Size([768])
encoder.layer.8.attention.output.dense.weight torch.Size([768, 768])
encoder.layer.8.attention.output.dense.bias torch.Size([768])
encoder.layer.8.attention.output.LayerNorm.weight torch.Size([768])
encoder.layer.8.attention.output.LayerNorm.bias torch.Size([768])
encoder.layer.8.intermediate.dense.weight torch.Size([3072, 768])
encoder.layer.8.intermediate.dense.bias torch.Size([3072])
encoder.layer.8.output.dense.weight torch.Size([768, 3072])
encoder.layer.8.output.dense.bias torch.Size([768])
encoder.layer.8.output.LayerNorm.weight torch.Size([768])
encoder.layer.8.output.LayerNorm.bias torch.Size([768])
encoder.layer.9.attention.self.query.weight torch.Size([768, 768])
encoder.layer.9.attention.self.query.bias torch.Size([768])
encoder.layer.9.attention.self.key.weight torch.Size([768, 768])
encoder.layer.9.attention.self.key.bias torch.Size([768])
encoder.layer.9.attention.self.value.weight torch.Size([768, 768])
encoder.layer.9.attention.self.value.bias torch.Size([768])
encoder.layer.9.attention.output.dense.weight torch.Size([768, 768])
encoder.layer.9.attention.output.dense.bias torch.Size([768])
encoder.layer.9.attention.output.LayerNorm.weight torch.Size([768])
encoder.layer.9.attention.output.LayerNorm.bias torch.Size([768])
encoder.layer.9.intermediate.dense.weight torch.Size([3072, 768])
encoder.layer.9.intermediate.dense.bias torch.Size([3072])
encoder.layer.9.output.dense.weight torch.Size([768, 3072])
encoder.layer.9.output.dense.bias torch.Size([768])
encoder.layer.9.output.LayerNorm.weight torch.Size([768])
encoder.layer.9.output.LayerNorm.bias torch.Size([768])
encoder.layer.10.attention.self.query.weight torch.Size([768, 768])
encoder.layer.10.attention.self.query.bias torch.Size([768])
encoder.layer.10.attention.self.key.weight torch.Size([768, 768])
encoder.layer.10.attention.self.key.bias torch.Size([768])
encoder.layer.10.attention.self.value.weight torch.Size([768, 768])
encoder.layer.10.attention.self.value.bias torch.Size([768])
encoder.layer.10.attention.output.dense.weight torch.Size([768, 768])
encoder.layer.10.attention.output.dense.bias torch.Size([768])
encoder.layer.10.attention.output.LayerNorm.weight torch.Size([768])
encoder.layer.10.attention.output.LayerNorm.bias torch.Size([768])
encoder.layer.10.intermediate.dense.weight torch.Size([3072, 768])
encoder.layer.10.intermediate.dense.bias torch.Size([3072])
encoder.layer.10.output.dense.weight torch.Size([768, 3072])
encoder.layer.10.output.dense.bias torch.Size([768])
encoder.layer.10.output.LayerNorm.weight torch.Size([768])
encoder.layer.10.output.LayerNorm.bias torch.Size([768])
encoder.layer.11.attention.self.query.weight torch.Size([768, 768])
encoder.layer.11.attention.self.query.bias torch.Size([768])
encoder.layer.11.attention.self.key.weight torch.Size([768, 768])
encoder.layer.11.attention.self.key.bias torch.Size([768])
encoder.layer.11.attention.self.value.weight torch.Size([768, 768])
encoder.layer.11.attention.self.value.bias torch.Size([768])
encoder.layer.11.attention.output.dense.weight torch.Size([768, 768])
encoder.layer.11.attention.output.dense.bias torch.Size([768])
encoder.layer.11.attention.output.LayerNorm.weight torch.Size([768])
encoder.layer.11.attention.output.LayerNorm.bias torch.Size([768])
encoder.layer.11.intermediate.dense.weight torch.Size([3072, 768])
encoder.layer.11.intermediate.dense.bias torch.Size([3072])
encoder.layer.11.output.dense.weight torch.Size([768, 3072])
encoder.layer.11.output.dense.bias torch.Size([768])
encoder.layer.11.output.LayerNorm.weight torch.Size([768])
encoder.layer.11.output.LayerNorm.bias torch.Size([768])
pooler.dense.weight torch.Size([768, 768])
pooler.dense.bias torch.Size([768])
完整程式碼:
import os
import json
import torch
import time
from transformers import BertModel,BertTokenizer
# 提取我們想要的層的權重並重命名
def get_prune_paramerts(model):
prune_paramerts = {}
for name, param in model.named_parameters():
if 'embeddings' in name:
prune_paramerts[name] = param
elif name.startswith('encoder.layer.0.'):
prune_paramerts[name] = param
elif name.startswith('encoder.layer.11.'):
pro_name = name.split('encoder.layer.11.')
prune_paramerts['encoder.layer.1.' + pro_name[1]] = param
elif 'pooler' in name:
prune_paramerts[name] = param
return prune_paramerts
# 修改配置檔案
def get_prune_config(config):
prune_config = config
prune_config['num_hidden_layers'] = 2
return prune_config
# 縮減模型的層數,併為相對應的層重新進行權重賦值
def get_prune_model(model, prune_parameters):
prune_model = model.state_dict()
for name in list(prune_model.keys()):
if 'embeddings.position_ids' == name:
continue
if 'embeddings' in name:
prune_model[name] = prune_parameters[name]
elif name.startswith('encoder.layer.0.'):
prune_model[name] = prune_parameters[name]
elif name.startswith('encoder.layer.1.'):
prune_model[name] = prune_parameters[name]
elif 'pooler' in name:
prune_model[name] = prune_parameters[name]
else:
del prune_model[name]
return prune_model
def prune_main():
model_path = '/data02/gob/project/simpleNLP/model_hub/chinese-bert-wwm-ext/'
tokenizer = BertTokenizer.from_pretrained(model_path + 'vocab.txt')
config = json.loads(open(model_path + 'config.json', 'r').read())
model = BertModel.from_pretrained(model_path)
text = '我喜歡吃魚'
inputs = tokenizer(text, return_tensors='pt')
# print(model(**inputs))
out_path = '/data02/gob/project/simpleNLP/model_hub/prune-chinese-bert-wwm-ext/'
if not os.path.exists(out_path):
os.makedirs(out_path)
prune_parameters = get_prune_paramerts(model)
prune_config = get_prune_config(config)
prune_model = get_prune_model(model, prune_parameters)
"""
for name,param in model.named_parameters():
print(name)
print("===================================")
for k,v in model.state_dict().items():
print(k)
"""
torch.save(prune_model, out_path + 'pytorch_model.bin')
with open(out_path + 'config.json', 'w') as fp:
fp.write(json.dumps(prune_config))
with open(out_path + 'vocab.txt', 'w') as fp:
fp.write(open(model_path + 'vocab.txt').read())
if __name__ == '__main__':
# prune_main()
start_time = time.time()
# 之後我們就可以像載入bert模型一樣載入剪枝層後的模型
model_path = '/data02/gob/project/simpleNLP/model_hub/prune-chinese-bert-wwm-ext/'
tokenizer = BertTokenizer.from_pretrained(model_path + 'vocab.txt')
config = json.loads(open(model_path + 'config.json', 'r').read())
model = BertModel.from_pretrained(model_path)
text = '我喜歡吃魚'
inputs = tokenizer(text, return_tensors='pt')
for name, param in model.named_parameters():
print(name, param.shape)
end_time = time.time()
print('預測耗時:{}s'.format(end_time-start_time))
對ffn裡面的維度進行剪枝
1、載入預訓練的模型;
2、提取所需要層的權重,並選擇topk的值進行裁剪,並重新賦值給該層的引數;
3、更改模型配置檔案(主要是修改維度);
4、儲存模型為pytorch_model.bin;
具體程式碼:
import os
import json
import torch
import time
from pprint import pprint
from transformers import BertModel,BertTokenizer
def get_prune_ffn_paramerts(model):
prune_paramerts = {}
for name, param in model.named_parameters():
if 'intermediate.dense.weight' in name:
param = torch.tensor(param.T.topk(384).values, requires_grad=True).T
prune_paramerts[name] = param
elif 'intermediate.dense.bias' in name:
param = torch.tensor(param.topk(384).values, requires_grad=True)
prune_paramerts[name] = param
elif 'output.dense.weight' in name and 'attention.output.dense.weight' not in name:
param = torch.tensor(param.topk(384).values, requires_grad=True)
prune_paramerts[name] = param
return prune_paramerts
def get_prune_ffn_config(config):
prune_config = config
prune_config['intermediate_size'] = 384
return prune_config
def get_prune_model(model, prune_parameters):
prune_model = model.state_dict()
for name in list(prune_model.keys()):
if name in prune_parameters:
prune_model[name] = prune_parameters[name]
return prune_model
def prune_main():
model_path = '/data02/gob/project/simpleNLP/model_hub/prune-chinese-bert-wwm-ext/'
tokenizer = BertTokenizer.from_pretrained(model_path + 'vocab.txt')
config = json.loads(open(model_path + 'config.json', 'r').read())
model = BertModel.from_pretrained(model_path)
out_path = '/data02/gob/project/simpleNLP/model_hub/prune-ffn-chinese-bert-wwm-ext/'
if not os.path.exists(out_path):
os.makedirs(out_path)
prune_parameters = get_prune_ffn_paramerts(model)
prune_config = get_prune_ffn_config(config)
prune_model = get_prune_model(model, prune_parameters)
torch.save(prune_model, out_path + 'pytorch_model.bin')
with open(out_path + 'config.json', 'w') as fp:
fp.write(json.dumps(prune_config))
with open(out_path + 'vocab.txt', 'w') as fp:
fp.write(open(model_path + 'vocab.txt').read())
if __name__ == '__main__':
# prune_main()
if torch.cuda.is_available():
device = torch.device('cuda')
else:
device = torch.device('cpu')
model_path = '/data02/gob/project/simpleNLP/model_hub/prune-chinese-bert-wwm-ext/'
# model_path = '/data02/gob/project/simpleNLP/model_hub/bert-base-chinese/'
tokenizer = BertTokenizer.from_pretrained(model_path + 'vocab.txt')
config = json.loads(open(model_path + 'config.json', 'r').read())
model = BertModel.from_pretrained(model_path)
model.to(device)
start_time = time.time()
texts = ['我喜歡吃魚,我也喜歡打籃球,你知不知道呀。在這個陽光明媚的日子裡,我們一起去放風箏'] * 5000
for text in texts:
inputs = tokenizer(text, return_tensors='pt')
for k in inputs.keys():
inputs[k] = inputs[k].to(device)
# pprint(inputs)
# for name, param in model.named_parameters():
# print(name, param.shape)
end_time = time.time()
print('預測耗時:{}s'.format(end_time-start_time))
對多頭進行剪枝和對隱藏層維度進行剪枝
相對複雜,暫時就不考慮了,一般情況下對層數進行剪枝,簡單又方便。