1. 程式人生 > >torch.cuda.LongTensor but found type torch.cuda.FloatTensor for argument #2 'target'的一種可能原因

torch.cuda.LongTensor but found type torch.cuda.FloatTensor for argument #2 'target'的一種可能原因

可能是在使用交叉熵損失函式的時候,target需要是整數,才能轉化成索引值,進而進行one-hot編碼。

輸出一下target的張量,可以看到每個值都後面有一個點.比如5.這樣,應該表示的就是浮點型別的值。

這個時候需要target=target.long()執行一下型別轉換。