1. 程式人生 > >pytorch學習筆記(二十一): 使用 pack_padded_sequence

pytorch學習筆記(二十一): 使用 pack_padded_sequence

在使用 pytorch 的 RNN 模組的時候, 有時會不可避免的使用到 pack_padded_sequencepad_packed_sequence, 當使用雙向RNN的時候, 必須要使用 pack_padded_sequence !! .否則的話, pytorch 是無法獲得 序列的長度, 這樣也無法正確的計算雙向 RNN/GRU/LSTM 的結果.

但是在使用 pack_padded_sequence 時有個問題, 即輸入 mini-batch 序列的長度必須是從長到短排序好的, 當mini-batch 中的樣本的順序非常的重要的話, 這就有點棘手了. 比如說, 每個 sample 是個 單詞的 字母級表示, 一個 mini-batch 儲存了一句話的 words.

在這種情況下, 我們依然要使用 pack_padded_sequence, 所以需要先將 mini-batch 中樣本排序, 然後 RNN/LSTM/GRU 計算完之後再恢復成以前的順序.

下面的程式碼將用來實現這種方法:

import torch
from torch import nn
from torch.autograd import Variable

def rnn_forwarder(rnn, inputs, seq_lengths):
    """
    :param rnn: RNN instance
    :param inputs: FloatTensor, shape [batch, time, dim] if rnn.batch_first else [time, batch, dim]
    :param seq_lengths: LongTensor shape [batch]
    :return: the result of rnn layer,
    """
batch_first = rnn.batch_first # assume seq_lengths = [3, 5, 2] # 對序列長度進行排序(降序), sorted_seq_lengths = [5, 3, 2] # indices 為 [1, 0, 2], indices 的值可以這麼用語言表述 # 原來 batch 中在 0 位置的值, 現在在位置 1 上. # 原來 batch 中在 1 位置的值, 現在在位置 0 上. # 原來 batch 中在 2 位置的值, 現在在位置 2 上. sorted_seq_lengths, indices = torch.sort(seq_lengths, descending=True
) # 如果我們想要將計算的結果恢復排序前的順序的話, # 只需要對 indices 再次排序(升序),會得到 [0, 1, 2], # desorted_indices 的結果就是 [1, 0, 2] # 使用 desorted_indices 對計算結果進行索引就可以了. _, desorted_indices = torch.sort(indices, descending=False) # 對原始序列進行排序 if batch_first: inputs = inputs[indices] else: inputs = inputs[:, indices] packed_inputs = nn.utils.rnn.pack_padded_sequence(inputs, sorted_seq_lengths.cpu().numpy(), batch_first=batch_first) res, state = rnn(packed_inputs) padded_res, _ = nn.utils.rnn.pad_packed_sequence(res, batch_first=batch_first) # 恢復排序前的樣本順序 if batch_first: desorted_res = padded_res[desorted_indices] else: desorted_res = padded_res[:, desorted_indices] return desorted_res if __name__ == "__main__": bs = 3 max_time_step = 5 feat_size = 15 hidden_size = 7 seq_lengths = [3, 5, 2] rnn = nn.GRU(input_size=feat_size, hidden_size=hidden_size, batch_first=True, bidirectional=True) x = Variable(torch.FloatTensor(bs, max_time_step, feat_size).normal_()) using_packed_res = rnn_forwarder(rnn, x, seq_lengths) print(using_packed_res) # 不使用 pack_paded, 用來和上面的結果對比一下. not_packed_res, _ = rnn(x) print(not_packed_res)