1. 程式人生 > >pytorch實現self-attention機制,並可視化

pytorch實現self-attention機制,並可視化

pytorch 實現 self attention 並可視化


  • python 3
  • pytorch 0.4.0

請閱讀原文

模型

class SelfAttention(nn.Module):
    def __init__(self, hidden_dim):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.projection = nn.Sequential(
            nn.Linear(hidden_dim, 64),
            nn.ReLU(True
), nn.Linear(64, 1) ) def forward(self, encoder_outputs): # (B, L, H) -> (B , L, 1) energy = self.projection(encoder_outputs) weights = F.softmax(energy.squeeze(-1), dim=1) # (B, L, H) * (B, L, 1) -> (B, H) outputs = (encoder_outputs * weights.unsqueeze(-1
)).sum(dim=1) return outputs, weights class AttnClassifier(nn.Module): def __init__(self, input_dim, embedding_dim, hidden_dim): super().__init__() self.input_dim = input_dim self.embedding_dim = embedding_dim self.hidden_dim = hidden_dim self.embedding = nn.Embedding(input_dim, embedding_dim) self.lstm = nn.LSTM(embedding_dim, hidden_dim, bidirectional=True
) self.attention = SelfAttention(hidden_dim) self.fc = nn.Linear(hidden_dim, 1) def set_embedding(self, vectors): self.embedding.weight.data.copy_(vectors) def forward(self, inputs, lengths): batch_size = inputs.size(1) # (L, B) embedded = self.embedding(inputs) # (L, B, E) packed_emb = nn.utils.rnn.pack_padded_sequence(embedded, lengths) out, hidden = self.lstm(packed_emb) out = nn.utils.rnn.pad_packed_sequence(out)[0] out = out[:, :, :self.hidden_dim] + out[:, :, self.hidden_dim:] # (L, B, H) embedding, attn_weights = self.attention(out.transpose(0, 1)) # (B, HOP, H) outputs = self.fc(embedding.view(batch_size, -1)) # (B, 1) return outputs, attn_weights

視覺化

效果如下:
這裡寫圖片描述
完整程式碼,HERE