1. 程式人生 > >Pytorch中的torch.cat()函式

Pytorch中的torch.cat()函式

cat是concatnate的意思:拼接,聯絡在一起。


 

先說cat( )的普通用法

如果我們有兩個tensor是A和B,想把他們拼接在一起,需要如下操作:

C = torch.cat( (A,B),0 )  #按維數0拼接(豎著拼)

C = torch.cat( (A,B),1 )  #按維數1拼接(橫著拼)
>>> import torch
>>> A=torch.ones(2,3)    #2x3的張量(矩陣)                                     
>>> A
tensor([[ 
1., 1., 1.], [ 1., 1., 1.]]) >>> B=2*torch.ones(4,3) #4x3的張量(矩陣) >>> B tensor([[ 2., 2., 2.], [ 2., 2., 2.], [ 2., 2., 2.], [ 2., 2., 2.]]) >>> C=torch.cat((A,B),0) #按維數0(行)拼接 >>> C tensor([[
1., 1., 1.], [ 1., 1., 1.], [ 2., 2., 2.], [ 2., 2., 2.], [ 2., 2., 2.], [ 2., 2., 2.]]) >>> C.size() torch.Size([6, 3]) >>> D=2*torch.ones(2,4) #2x4的張量(矩陣) >>> C=torch.cat((A,D),1)#按維數1(列)拼接 >>> C tensor([[
1., 1., 1., 2., 2., 2., 2.], [ 1., 1., 1., 2., 2., 2., 2.]]) >>> C.size() torch.Size([2, 7])

其次,cat還可以把list中的tensor拼接起來。

比如:

上面的程式碼可以合成一行來寫: