1. 程式人生 > >pytorch中獲取指定位置元素

pytorch中獲取指定位置元素

  這段程式碼的應用場景是:某個batch的sentence,有的經過了padding操作,如果獲取每句話中實際的最後一個單詞。

A = torch.Tensor([[[2, 3, 1], [1, 4, 0], [1, 0, 0]], [[2, 2, 0], [2, 0, 0], [3, 1, 4]]])
print(A.size())

B = torch.Tensor([[3, 2, 1], [2, 1, 3]]).long()
print(B.size())
B = B.view(2, 3, -1)
B = B - 1

C = torch.gather(A, 2, B)
print(C)