[torchtext]如何利用torchtext讀取json檔案並生成batch
阿新 • • 發佈:2018-12-16
設定Field
首先載入torchtext
from torchtext import data
設定Field
,對輸入文字資料的格式進行"預設定"
question = data.Field(sequential=True, fix_length=20, pad_token='0')
label = data.Field(sequential=False, use_vocab=False)
sequential=True | tokenizer | fix_length | pad_first=True | tensor_type | lower |
---|---|---|---|---|---|
是否為sequences | 分詞器 | 文字長度 | 是否從左補全 | Tensor type | 是否令英文字元為小寫 |
以question
為例,設定文字長度為20,超過20刪除,不足20則使用pad_token
補全。sequential
的含義為輸入文字是否是序列文字,若為True則是序列文字,需要配合tokenize
(預設使用splits,也可以用Spacy)進行分詞,若為False則輸入已經是切分好的文字或不需要進行分詞。如果處理的是中文文字,也可以自定義tokenizer
對中文進行切分:
import jieba def chinese_tokenizer(text): return [tok for tok in jieba.lcut(text)] question = data.Field(sequential=True, tokenize=chinese_tokenizer, fix_length=20)
使用torchtext.data.Tabulardataset.splits讀取檔案
同時讀取訓練集、驗證集與測試集,path
為路徑,train
、validation
和test
為檔名。
splits()
的作用為 Create train-test(-valid?) splits from the instance’s examples, and return Datasets for train, validation, and test splits in that order, if the splits are provided.
train, val, test = data.TabularDataset.splits( path = './', train = 'train.json', validation = 'val.json', test = 'test.json', format = 'json', fields = {'question': ('question',question), 'label': ('label', label)})
測試資料是否正確讀入(此處使用jieba中文分詞器)
for i in range(0, len(val)):
print(vars(val[i])
列印結果示例為
{'question': ['世界', '上', '為什麼', '有', '好人', '和', '壞人'], 'label': 'generate'}
{'question': ['為什麼', '有', '壞人', '有', '好人', '呀'], 'label': 'generate'}
構建vocab表
cache = '.vector_cache
if not os.path.exists(cache):
os.mkdir(cache)
vectors = Vectors(name=configs.embedding_path, cache = cache)
question.build_vocab(train, val, test, min_freq=5, vectors=vectors)
從預訓練的 vectors 中,將當前 corpus 詞彙表的詞向量抽取出來,構成當前 corpus 的 Vocab(詞彙表)
.build_vocab
用以構建詞彙表,將分詞結果轉化為整數(.vocab.vectors
是與此詞彙表相關聯的詞向量)此外,torchtext也提供了一些預訓練好的詞向量。max_size
設定詞彙表最大個數,min_freq
設定詞彙最低出現頻率的閾值。
使用torchtext.data.Iterator.splits生成batch
train_iter = data.Iterator(dataset=train, batch_size=256, shuffle=True, sort_within_batch=False, repeat=False, device=configs.device)
val_iter = data.Iterator(dataset=val, batch_size=256, shuffle=False, sort=False, repeat=False, device=configs.device)
test_iter = data.Iterator(dataset=test, batch_size=256, shuffle=False, sort=False, repeat=False, device=configs.device)
dataset |
batch_size |
batch_size_fn |
sort_key |
train |
repeat |
shuffle |
sort |
sort_within_batch |
---|---|---|---|---|---|---|---|---|
載入的資料集 | Batch 大小 | 產生動態的batch_size的函式 | 排序的key | 是否為訓練集 | 是否在不同epoch中重複迭代 | 是否打亂資料 | 是否對資料進行排序 | batch內部是否排序 |
附錄
全部code及輸出結果
import codecs
import jieba
import os
from config import config
from torchtext import data, datasets
from torchtext.vocab import Vectors
def chinese_tokenizer(text):
return [tok for tok in jieba.lcut(text)]
def load_data(configs):
TEXT = data.Field(sequential=True, tokenize = chinese_tokenizer, fix_length=20)
LABEL = data.Field(sequential=False, use_vocab=False)
train, val, test = data.TabularDataset.splits(
path = configs.file_path,
train = configs.train,
validation = configs.val,
test = configs.test,
format = 'json',
fields = {
'question': ('question', TEXT),
'label': ('label', LABEL)
}
)
# for i in range(0, len(val)):
# print(vars(val[i]))
print('Read {} success, {} texts in total.'.format(configs.train, len(train)))
print('Read {} success, {} texts in total.'.format(configs.val, len(val)))
print('Read {} success, {} texts in total.\n'.format(configs.test, len(test)))
cache = '.vector_cache'
if not os.path.exists(cache):
os.mkdir(cache)
vectors = Vectors(name=configs.embedding_path, cache = cache)
print('load word2vec vectors from {}.'.format(configs.embedding_path))
TEXT.build_vocab(train, val, test, min_freq=5, vectors=vectors)
train_iter = data.Iterator(dataset=train, batch_size=configs.batch_size, shuffle=True,
sort_within_batch=False, repeat=False, device=configs.device)
val_iter = data.Iterator(dataset=val, batch_size=configs.batch_size, shuffle=False,
sort=False, repeat=False, device=configs.device)
test_iter = data.Iterator(dataset=test, batch_size=configs.batch_size, shuffle=False,
sort=False, repeat=False, device=configs.device)
return train_iter, val_iter, test_iter, len(TEXT.vocab), TEXT.vocab.vectors
Field引數表