Pytorch系列之常用基礎操作
阿新 • • 發佈:2020-12-12
## 各種張量初始化
### 建立特殊型別的tensor
```Python
a = torch.FloatTensor(2,3)
a = torch.DoubleTensor(2,3)
...
```
### 設定pytorch中tensor的預設型別
```Python
torch.set_default_tensor_type(torch.DoubleTensor)
```
### 更改tensor型別
```Python
a.float()
```
### 各種常用初始化
```Python
torch.randn_like()
torch.rand(3,3) #建立 0-1 (3,3)矩陣
torch.randn(3,3) #建立 -1-1 (3,3)矩陣
torch.randint(1,10,[2,2]) #建立 1-10 (2,2) int型矩陣
```
### 按照不同的均值和方差進行初始化
```Python
torch.normal(mean=torch.full([20],0),std=torch.arange(0,1,0.1))
```
### 按照間隔初始化
```Python
torch.linspace(0,10,step=3)
torch.arange(1,10,5)
```
### 建立單位矩陣
```Python
torch.eye(4,4)
```
### 建立打亂的數列
```Python
torch.randperm(10)
```
### 返回tensor元素個數
```Python
torch.numel(torch.rand(2,2))
```
## 維度操作
### 矩陣拼接
```Python
torch.cat((x,x),0)
torch.stack((x,x),0) #與cat不同的是,stack在拼接的時候,要增加一個維度
```
### 矩陣拆分
chuck直接按照數量來拆分,輸入N就拆分成N個
```Python
torch.chunk(a,N,dim)
```
split的兩種用法,第一種是輸入一個數字,這樣就會拆分成這個總維度/數字個維度,第二個是如輸入一個列表,會按照列表指定的維度進行拆分
```Python
torch.split(a,[1,2],dim)
```
### 矩陣選取
在某個維度上選擇連續的N 列或者行
```Python
torch.narrow(dim,index,size)
```
選擇一個維度dim,從index開始取size個列或者行
```Python
a.index_select(dim, list)
```
### 各種選取
```Python
a[ : , 1:10, ::2 , 1:10:2]
```
### 矩陣打平後選取
```Python
torch.take( tensor , list)
```
### 維度變化
```Python
a.view(1,5)
a.reshape(1,5)
```
### 維度減少和增加
只有一個維度的時候,就是0在前面插入,-1或1在後面插入,可以把list當成是0.5維度
```Python
a.unsqueeze(1)
a.squeeze(1)
```
### 維度擴張
```Python
a.expand()
```
維度擴充套件expand,注意這裡的維度只能由1擴張成N,其他情況下是不能擴張的,另外維度不變的時候也可以用-1代替
```Python
a.repead()
```
另外一種方式是使用repeat函式,repeat表示將之前的維度複製多少次,通過複製來進行擴張
### 維度交換
```Python
transpose(2,3) # 交換兩個維度
permute(4,2,1,3) # 交換多個維度
```
## 數學運算
### 基礎運算
其中加減除法都可以使用運算子直接計算,乘法需要額外注意兩種不同的乘法,其中:
mul或者*是矩陣對應元素相乘
mm是針對於二維的矩陣正常乘法
matmul是針對任意維度矩陣的正常乘法,@是其符號過載
### 數字近似
floor() 向下取整
ceil() 向上取整
trunc() 保留整數
frac() 保留小數
### 數值裁剪
clamp(min)
clamp(min,max) #在這個閾值之外的都變成閾值
### 累乘
prod()
### 線性代數相關
```Python
trace #矩陣的跡
diag #獲取主對角線元素
triu/tril #獲取上下三角矩陣
t #轉置
dot/cross #內積與外積
```
## 其他
### Numpy Tensor 互相轉換
```Python
np_data = np.arange(6).reshape((2, 3))
torch_data = torch.from_numpy(np_data)
tensor2array = torch_data.numpy()
```
### 型別判斷
```Python
isinstance(a,torch.FloatTensor)
```
### 廣播
什麼時候可以使用廣播,廣播將從最後一個維度開始,從後往前開始匹配,當一個物件的維度是1或者與另一個物件的維度大小一樣的時候,可以匹配上,另外,如果一個物件的維度少於另外一個維度的物件,只要從後往前開始的維度匹配,那麼就可以使用廣播。
例如
(1,2,3,4) 和 (2,3,4) or (1,2,3,4) 可以廣播
(1,2,3,4) 和 (1,1,1) or (1,1,1,1) 可以廣播
### topk
topk可以幫助返回在某一維度上最大的k個值以及下標,只需要將largest=False,就可以返回最小的k個值
### where條件選擇
根據條件是否成立,選擇矩陣X或者矩陣Y中的元素
```Python
where(condition > 0.5 , X , Y )
```
### gather
本質就是在查表,第一個引數是表格,第二個是維度,第三個是要查詢的索引
操作就是,在inpu中選擇維度dim,然後根據index編號,讀取input中的元素
```Python
torch.gather(input,dim,index,out=None)