1. 程式人生 > >pytorch: 如何優雅的將 int list 轉成 one-hot形式

pytorch: 如何優雅的將 int list 轉成 one-hot形式

雖然 pytorch 已經升級到 0.2.0 了,但是,貌似依舊沒有簡單的 api 來幫助我們快速將 int list 轉成 one-hot。那麼,如何優雅的實現 one-hot 程式碼呢?

def one_hot(ids, out_tensor):
    """
    ids: (list, ndarray) shape:[batch_size]
    out_tensor:FloatTensor shape:[batch_size, depth]
    """
    if not isinstance(ids, (list, np.ndarray)):
        raise
ValueError("ids must be 1-D list or array") ids = torch.LongTensor(ids).view(-1,1) out_tensor.zero_() out_tensor.scatter_(dim=1, index=ids, src=1.) # out_tensor.scatter_(1, ids, 1.0)

scatter_ 是什麼鬼?

從 value 中拿值,然後根據 dim 和 index 給自己的相應位置填上值

Tensor.scatter_(dim, index, src)
# index: LongTensor
# out[index[i, j], j] = value[i, j] dim=0 # out[i,index[i, j]] = value[i, j]] dim=1 # index 的 shape 可以不和 out 的 shape 一致 # value 也可以是一個 float 值, 也可以是一個 FloatTensor # 如果 value 是 FloatTensor 的話,那麼shape 需要和 index 保持一致

參考資料