1. 程式人生 > >Pytorch常用函式整理

Pytorch常用函式整理

1.torch.numel() 返回tensor變數內所有元素的個數,也可以簡單理解為矩陣內yu元素的個數

   例如,a的size為([64, 3, 7, 7]),那麼a.numel() 返回值為64*3*7*7=9408

2.torch.squeeze() 將輸入張量形狀中的1去除並返回,如果輸入是形如(Ax1xBx1xCx1xD),那麼輸出形狀就為(AxBxCxD)

3.torch.unsqueeze() 在指定維度上增加一個維度,該維度數為1

4.torch.sort() 返回元組,第一個是pa排序後的值,第二個是排序後的值在元資料中的index,如果不給dim,預設

按照最後一維

5.torch.mean() 如果不給維度,則fa返回所有值的均值

6.torch.gather(input, dim, index) 把輸入按照給定的維度和index取input中的值

7.torch.stack(sequence,dim,out)  stack()會增加一個維度出來,sequence是tensor列表,dim表示拼接的維度,而torch.cat()是在已有的維度上拼接,