1. 程式人生 > >符號*,torch.max 和 torch.sum, item()方法

符號*,torch.max 和 torch.sum, item()方法

*的作用可以參考https://www.cnblogs.com/jony7/p/8035376.html

torch.max可以參考https://blog.csdn.net/Z_lbj/article/details/79766690

a.size()
# Out[134]: torch.Size([6, 4, 3])
torch.max(a, 0)[1].size()
# Out[135]: torch.Size([4, 3])
torch.max(a, 1)[1].size()
# Out[136]: torch.Size([6, 3])
torch.max(a, 2)[1].size()
# Out[137]: torch.Size([6, 4])

 具體怎麼比較的可以看下面

b

tensor([[[  0.,   1.,   2.,   3.],
         [  4.,   5.,   6.,   7.],
         [  8.,   9.,  10.,  11.]],
        [[ 12.,  13.,  14.,  15.],
         [ 16.,  17.,  18.,  19.],
         [ 20.,  21.,  22.,  23.]]])

torch.max(b,0)[0]
 
tensor([[ 12.,  13.,  14.,  15.],
        [ 16.,  17.,  18.,  19.],
        [ 20.,  21.,  22.,  23.]])

torch.max(b,1)[0]

tensor([[  8.,   9.,  10.,  11.],
        [ 20.,  21.,  22.,  23.]])

torch.max(b,2)[0]

tensor([[  3.,   7.,  11.],
        [ 15.,  19.,  23.]])

相應的下標可以得到

b

tensor([[[  0.,   1.,   2.,   3.],
         [  4.,   5.,   6.,   7.],
         [  8.,   9.,  10.,  11.]],
        [[ 12.,  13.,  14.,  15.],
         [ 16.,  17.,  18.,  19.],
         [ 20.,  21.,  22.,  23.]]])

torch.max(b,0)[1]

tensor([[ 1,  1,  1,  1],
        [ 1,  1,  1,  1],
        [ 1,  1,  1,  1]])

torch.max(b,1)[1]

tensor([[ 2,  2,  2,  2],
        [ 2,  2,  2,  2]])

torch.max(b,2)[1]

tensor([[ 3,  3,  3],
        [ 3,  3,  3]])

 torch.sum:

torch.sum(input) → Tensor
torch.sum(input, dim, out=None) → Tensor
引數:

    input (Tensor) – 輸入張量
    dim (int) – 縮減的維度
    out (Tensor, optional) – 結果張量

函式的輸出是一個tensor

match
out:
tensor([[[ 0,  0,  2,  0],
         [ 0,  0,  0,  0],
         [ 0,  0,  0,  0]],
        [[ 0,  0,  0,  0],
         [ 0,  0,  0,  0],
         [ 0,  0,  0,  0]]], dtype=torch.uint8)

torch.sum(match)
Out: 
tensor(2)

torch.sum(match,0)
Out: 
tensor([[ 0,  0,  2,  0],
        [ 0,  0,  0,  0],
        [ 0,  0,  0,  0]])
torch.sum(match,1)
Out: 
tensor([[ 0,  0,  2,  0],
        [ 0,  0,  0,  0]])
torch.sum(match,2)
Out: 
tensor([[ 2,  0,  0],
        [ 0,  0,  0]])

還要補充一點的就是item方法的使用:如果tensor只有一個元素那麼呼叫item方法的時候就是將tensor轉換成python的scalars;如果tensor不是單個元素的話那就會引發ValueError,如下面

b.item()
Traceback (most recent call last):
    b.item()
ValueError: only one element tensors can be converted to Python scalars

torch.sum(b)
Out: tensor(276.)
torch.sum(b).item()
Out: 276.0

那麼在python中的item方法一般是怎麼樣的呢?可參見https://blog.csdn.net/qq_34941023/article/details/78431376