1. 程式人生 > >基於BERT命名實體識別程式碼的理解

基於BERT命名實體識別程式碼的理解

我一直做的是有關實體識別的任務,BERT已經火了有一段時間,也研究過一點,今天將自己對bert對識別實體的簡單認識記錄下來,希望與大家進行來討論

BERT官方Github地址:https://github.com/google-research/bert ,其中對BERT模型進行了詳細的介紹,更詳細的可以查閱原文獻:https://arxiv.org/abs/1810.04805 

bert可以簡單地理解成兩段式的nlp模型,(1)pre_training:即預訓練,相當於wordembedding,利用沒有任何標記的語料訓練一個模型;(2)fine-tuning:即微調,利用現有的訓練好的模型,根據不同的任務,輸入不同,修改輸出的部分,即可完成下游的一些任務(如命名實體識別、文字分類、相似度計算等等)
本文是在官網上給定的run_classifier.py中進行修改從而完成命名實體識別的任務

程式碼的解讀,將主要的幾個程式碼進行簡單的解讀

1、主函式

if __name__ == "__main__":
    flags.mark_flag_as_required("data_dir")
    flags.mark_flag_as_required("task_name")
    flags.mark_flag_as_required("vocab_file")
    flags.mark_flag_as_required("bert_config_file")
    flags.mark_flag_as_required("output_dir")
    tf.app.run()


主函式中指定了一些必須不能少的引數
data_dir:指的是我們的輸入資料的資料夾路徑
task_name:任務的名字
vocab_file:字典,一般從下載的模型中直接包含這個字典,名字“vocab.txt”
bert_config_file:一些預訓練好的配置引數,同樣在下載的模型資料夾中,名字為“bert_config.json”
output_dir:輸出檔案儲存的位置

2、main(_)函式

processors = {
        "ner": NerProcessor
    }
task_name = FLAGS.task_name.lower()  
processor = processors[task_name]()

上面程式碼中的task_name是用來選擇processor的
NerProcessor的程式碼如下:

class NerProcessor(DataProcessor):  ##資料的讀入
    def get_train_examples(self, data_dir):
        return self._create_example(
            self._read_data(os.path.join(data_dir, "train.txt")), "train"
        )

    def get_dev_examples(self, data_dir):
        return self._create_example(
            self._read_data(os.path.join(data_dir, "dev.txt")), "dev"
        )

    def get_test_examples(self, data_dir):
        return self._create_example(
            self._read_data(os.path.join(data_dir, "test.txt")), "test")

    def get_labels(self):

        # 9個類別
        return ["O", "B-dizhi", "I-dizhi", "B-shouduan", "I-shouduan", "B-caiwu", "I-caiwu", "B-riqi", "I-riqi", "X",
                "[CLS]", "[SEP]"]

    def _create_example(self, lines, set_type):
        examples = []
        for (i, line) in enumerate(lines):
            guid = "%s-%s" % (set_type, i)
            text = tokenization.convert_to_unicode(line[1])
            label = tokenization.convert_to_unicode(line[0])
            if i == 0:
            examples.append(InputExample(guid=guid, text=text, label=label))
        return examples

上面的程式碼主要是完成了資料的讀入,且繼承了DataProcessor這個類,_read_data()函式是在父類DataProcessor中實現的,具體的程式碼如下所示:

class DataProcessor(object):
    """Base class for data converters for sequence classification data sets."""

    def get_train_examples(self, data_dir):
        """Gets a collection of `InputExample`s for the train set."""
        raise NotImplementedError()

    def get_dev_examples(self, data_dir):
        """Gets a collection of `InputExample`s for the dev set."""
        raise NotImplementedError()

    def get_labels(self):
        """Gets the list of labels for this data set."""
        raise NotImplementedError()

    @classmethod
    def _read_data(cls, input_file):
        """Reads a BIO data."""
        with codecs.open(input_file, 'r', encoding='utf-8') as f:
            lines = []
            words = []
            labels = []
            for line in f:
                contends = line.strip()
                tokens = contends.split()  ##根據不同的語料,此處的split()劃分標誌需要進行更改
                # print(len(tokens))
                if len(tokens) == 2:
                    word = line.strip().split()[0]  ##根據不同的語料,此處的split()劃分標誌需要進行更改
                    label = line.strip().split()[-1]  ##根據不同的語料,此處的split()劃分標誌需要進行更改
                else:
                    if len(contends) == 0:
                        l = ' '.join([label for label in labels if len(label) > 0])
                        w = ' '.join([word for word in words if len(word) > 0])
                        lines.append([l, w])
                        words = []
                        labels = []
                        continue
                if contends.startswith("-DOCSTART-"):
                    words.append('')
                    continue
                words.append(word)
                labels.append(label)

            return lines  ##(label,word)

_read_data()函式:主要是針對NER的任務進行改寫的,將輸入的資料中的字儲存到words中,標籤儲存到labels中,將一句話中所有字以空格隔開組成一個字串放入到w中,同理標籤放到l中,同時將w與l放到lines中,具體的程式碼如下所示:

l = ' '.join([label for label in labels if len(label) > 0])
w = ' '.join([word for word in words if len(word) > 0])
lines.append([l, w])

def get_labels(self):是將標籤返回,會在原來標籤的基礎之上多新增"X","[CLS]", "[SEP]"這三個標籤,句子開始設定CLS 標誌,句尾新增[SEP] 標誌,"X"表示的是英文中縮寫拆分時,拆分出的幾個部分,除了第1部分,其他的都標記為"X"

程式碼中使用了InputExample類

class InputExample(object):
    """A single training/test example for simple sequence classification."""

    def __init__(self, guid, text, label=None):
        """Constructs a InputExample. ##構造BLSTM_CRF一個輸入的例子
        Args:
          guid: Unique id for the example.
          text: string. The untokenized text of the first sequence. For single
            sequence tasks, only this sequence must be specified.
          label: (Optional) string. The label of the example. This should be
            specified for train and dev examples, but not for test examples.
        """
        self.guid = guid
        self.text = text
        self.label = label

我的理解是這個是輸入資料的一個封裝,不管要處理的是什麼任務,需要經過這一步,對輸入的格式進行統一一下
guid是一種標識,標識的是test、train、dev

 

暫時更新到這個地方,後續會繼續更新