1. 程式人生 > >【PyTorch】PyTorch進階教程三

【PyTorch】PyTorch進階教程三

前面介紹了使用PyTorch構造CNN網路,這一節介紹點高階的東西LSTM。

關於LSTM的理論介紹請參考兩篇有名的部落格:

以及我之前的一篇中文翻譯部落格:

LSTM

class torch.nn.LSTM(*args, **kwargs)
  • Parameters

    1. input_size 輸入特徵維數
    2. hidden_size 隱層狀態的維數
    3. num_layers RNN層的個數
    4. bias 隱層狀態是否帶bias,預設為true
    5. batch_first 是否輸入輸出的第一維為batchsize
    6. dropout 是否在除最後一個RNN層外的RNN層後面加dropout層
    7. bidirectional 是否是雙向RNN,預設為false
  • Inputs: input, (h_0, c_0)

    1. input (seq_len, batch, input_size) 包含特徵的輸入序列,如果設定了batch_first,則batch為第一維
    2. (h_0, c_0) 隱層狀態
  • Outputs: output, (h_n, c_n)

    1. output (seq_len, batch, hidden_size * num_directions) 包含每一個時刻的輸出特徵,如果設定了batch_first,則batch為第一維
    2. (h_n, c_n) 隱層狀態

Model

class
RNN(nn.Module):
def __init__(self, input_size, hidden_size, num_layers, num_classes): super(RNN, self).__init__() self.hidden_size = hidden_size self.num_layers = num_layers self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True
) self.fc = nn.Linear(hidden_size, num_classes) # 2 for bidirection def forward(self, x): # Forward propagate RNN out, _ = self.lstm(x) # Decode hidden state of last time step out = self.fc(out[:, -1, :]) return out rnn = RNN(input_size, hidden_size, num_layers, num_classes) rnn.cuda()

PyTorch中實現LSTM是十分方便的,只需要定義輸入維度,隱層維度,RNN個數,以及分類個數就可以了。lstm的輸入狀態如果為空的話,表示預設初始化為0。在MNIST上,只需要2個epoch就可以達到97%的正確率。