1. 程式人生 > >[torchtext]如何利用torchtext讀取json檔案並生成batch

[torchtext]如何利用torchtext讀取json檔案並生成batch

在這裡插入圖片描述

設定Field

首先載入torchtext

from torchtext import data

設定Field,對輸入文字資料的格式進行"預設定"

question = data.Field(sequential=True, fix_length=20, pad_token='0')
label    = data.Field(sequential=False, use_vocab=False)
sequential=True tokenizer fix_length pad_first=True tensor_type lower
是否為sequences 分詞器 文字長度 是否從左補全 Tensor type 是否令英文字元為小寫

question為例,設定文字長度為20,超過20刪除,不足20則使用pad_token補全。sequential的含義為輸入文字是否是序列文字,若為True則是序列文字,需要配合tokenize(預設使用splits,也可以用Spacy)進行分詞,若為False則輸入已經是切分好的文字或不需要進行分詞。如果處理的是中文文字,也可以自定義tokenizer對中文進行切分:

import jieba

def chinese_tokenizer(text):
    return [tok for tok in jieba.lcut(text)]
    
question = data.Field(sequential=True, tokenize=chinese_tokenizer, fix_length=20)

使用torchtext.data.Tabulardataset.splits讀取檔案

同時讀取訓練集、驗證集與測試集,path為路徑,trainvalidationtest為檔名。 splits()的作用為 Create train-test(-valid?) splits from the instance’s examples, and return Datasets for train, validation, and test splits in that order, if the splits are provided.

train, val, test = data.TabularDataset.splits(
                   path = './',
                   train = 'train.json',
                   validation = 'val.json',
                   test = 'test.json',
                   format = 'json',
                   fields = {'question': ('question',question),
                             'label': ('label', label)})

測試資料是否正確讀入(此處使用jieba中文分詞器)

for i in range(0, len(val)):
    print(vars(val[i])

列印結果示例為

{'question': ['世界', '上', '為什麼', '有', '好人', '和', '壞人'], 'label': 'generate'}
{'question': ['為什麼', '有', '壞人', '有', '好人', '呀'], 'label': 'generate'}

構建vocab表

cache = '.vector_cache
if not os.path.exists(cache):
    os.mkdir(cache) 
vectors = Vectors(name=configs.embedding_path, cache = cache)

question.build_vocab(train, val, test, min_freq=5, vectors=vectors)

從預訓練的 vectors 中,將當前 corpus 詞彙表的詞向量抽取出來,構成當前 corpus 的 Vocab(詞彙表) .build_vocab用以構建詞彙表,將分詞結果轉化為整數(.vocab.vectors是與此詞彙表相關聯的詞向量)此外,torchtext也提供了一些預訓練好的詞向量max_size設定詞彙表最大個數,min_freq設定詞彙最低出現頻率的閾值。

使用torchtext.data.Iterator.splits生成batch

train_iter = data.Iterator(dataset=train, batch_size=256, shuffle=True,  sort_within_batch=False, repeat=False, device=configs.device)
val_iter = data.Iterator(dataset=val, batch_size=256, shuffle=False,  sort=False, repeat=False, device=configs.device)
test_iter = data.Iterator(dataset=test, batch_size=256, shuffle=False, sort=False, repeat=False, device=configs.device)
dataset batch_size batch_size_fn sort_key train repeat shuffle sort sort_within_batch
載入的資料集 Batch 大小 產生動態的batch_size的函式 排序的key 是否為訓練集 是否在不同epoch中重複迭代 是否打亂資料 是否對資料進行排序 batch內部是否排序

附錄

全部code及輸出結果

import codecs
import jieba
import os
from config import config
from torchtext import data, datasets
from torchtext.vocab import Vectors

def chinese_tokenizer(text):
    return [tok for tok in jieba.lcut(text)]

def load_data(configs):
    TEXT = data.Field(sequential=True, tokenize = chinese_tokenizer, fix_length=20)
    LABEL = data.Field(sequential=False, use_vocab=False)

    train, val, test = data.TabularDataset.splits(
            path = configs.file_path,
            train = configs.train,
            validation = configs.val,
            test = configs.test,
            format = 'json',
            fields = {
                    'question': ('question', TEXT),
                    'label': ('label', LABEL)
                     }
            )

#    for i in range(0, len(val)):
#        print(vars(val[i])) 
    print('Read {} success, {} texts in total.'.format(configs.train, len(train)))
    print('Read {} success, {} texts in total.'.format(configs.val, len(val)))
    print('Read {} success, {} texts in total.\n'.format(configs.test, len(test)))
   
    cache = '.vector_cache'
    if not os.path.exists(cache):
        os.mkdir(cache) 
    vectors = Vectors(name=configs.embedding_path, cache = cache)
    print('load word2vec vectors from {}.'.format(configs.embedding_path))

    TEXT.build_vocab(train, val, test, min_freq=5, vectors=vectors)
    
    train_iter = data.Iterator(dataset=train, batch_size=configs.batch_size, shuffle=True,
            sort_within_batch=False, repeat=False, device=configs.device)
    val_iter = data.Iterator(dataset=val, batch_size=configs.batch_size, shuffle=False, 
            sort=False, repeat=False, device=configs.device)
    test_iter = data.Iterator(dataset=test, batch_size=configs.batch_size, shuffle=False,
            sort=False, repeat=False, device=configs.device)
         
    return train_iter, val_iter, test_iter, len(TEXT.vocab), TEXT.vocab.vectors

Field引數表

torchtext.data.Field

Iterator引數表

在這裡插入圖片描述