1. 程式人生 > >Pytorch系列之常用基礎操作

Pytorch系列之常用基礎操作

## 各種張量初始化 ### 建立特殊型別的tensor ```Python a = torch.FloatTensor(2,3) a = torch.DoubleTensor(2,3) ... ``` ### 設定pytorch中tensor的預設型別 ```Python torch.set_default_tensor_type(torch.DoubleTensor) ``` ### 更改tensor型別 ```Python a.float() ``` ### 各種常用初始化 ```Python torch.randn_like() torch.rand(3,3) #建立 0-1 (3,3)矩陣 torch.randn(3,3) #建立 -1-1 (3,3)矩陣 torch.randint(1,10,[2,2]) #建立 1-10 (2,2) int型矩陣 ``` ### 按照不同的均值和方差進行初始化 ```Python torch.normal(mean=torch.full([20],0),std=torch.arange(0,1,0.1)) ``` ### 按照間隔初始化 ```Python torch.linspace(0,10,step=3) torch.arange(1,10,5) ``` ### 建立單位矩陣 ```Python torch.eye(4,4) ``` ### 建立打亂的數列 ```Python torch.randperm(10) ``` ### 返回tensor元素個數 ```Python torch.numel(torch.rand(2,2)) ``` ## 維度操作 ### 矩陣拼接 ```Python torch.cat((x,x),0) torch.stack((x,x),0) #與cat不同的是,stack在拼接的時候,要增加一個維度 ``` ### 矩陣拆分 chuck直接按照數量來拆分,輸入N就拆分成N個 ```Python torch.chunk(a,N,dim) ``` split的兩種用法,第一種是輸入一個數字,這樣就會拆分成這個總維度/數字個維度,第二個是如輸入一個列表,會按照列表指定的維度進行拆分 ```Python torch.split(a,[1,2],dim) ``` ### 矩陣選取 在某個維度上選擇連續的N 列或者行 ```Python torch.narrow(dim,index,size) ``` 選擇一個維度dim,從index開始取size個列或者行 ```Python a.index_select(dim, list) ``` ### 各種選取 ```Python a[ : , 1:10, ::2 , 1:10:2] ``` ### 矩陣打平後選取 ```Python torch.take( tensor , list) ``` ### 維度變化 ```Python a.view(1,5) a.reshape(1,5) ``` ### 維度減少和增加 只有一個維度的時候,就是0在前面插入,-1或1在後面插入,可以把list當成是0.5維度 ```Python a.unsqueeze(1) a.squeeze(1) ``` ### 維度擴張 ```Python a.expand() ``` 維度擴充套件expand,注意這裡的維度只能由1擴張成N,其他情況下是不能擴張的,另外維度不變的時候也可以用-1代替 ```Python a.repead() ``` 另外一種方式是使用repeat函式,repeat表示將之前的維度複製多少次,通過複製來進行擴張 ### 維度交換 ```Python transpose(2,3) # 交換兩個維度 permute(4,2,1,3) # 交換多個維度 ``` ## 數學運算 ### 基礎運算 其中加減除法都可以使用運算子直接計算,乘法需要額外注意兩種不同的乘法,其中: mul或者*是矩陣對應元素相乘 mm是針對於二維的矩陣正常乘法 matmul是針對任意維度矩陣的正常乘法,@是其符號過載 ### 數字近似 floor() 向下取整 ceil() 向上取整 trunc() 保留整數 frac() 保留小數 ### 數值裁剪 clamp(min) clamp(min,max) #在這個閾值之外的都變成閾值 ### 累乘 prod() ### 線性代數相關 ```Python trace #矩陣的跡 diag #獲取主對角線元素 triu/tril #獲取上下三角矩陣 t #轉置 dot/cross #內積與外積 ``` ## 其他 ### Numpy Tensor 互相轉換 ```Python np_data = np.arange(6).reshape((2, 3)) torch_data = torch.from_numpy(np_data) tensor2array = torch_data.numpy() ``` ### 型別判斷 ```Python isinstance(a,torch.FloatTensor) ``` ### 廣播 什麼時候可以使用廣播,廣播將從最後一個維度開始,從後往前開始匹配,當一個物件的維度是1或者與另一個物件的維度大小一樣的時候,可以匹配上,另外,如果一個物件的維度少於另外一個維度的物件,只要從後往前開始的維度匹配,那麼就可以使用廣播。 例如 (1,2,3,4) 和 (2,3,4) or (1,2,3,4) 可以廣播 (1,2,3,4) 和 (1,1,1) or (1,1,1,1) 可以廣播 ### topk topk可以幫助返回在某一維度上最大的k個值以及下標,只需要將largest=False,就可以返回最小的k個值 ### where條件選擇 根據條件是否成立,選擇矩陣X或者矩陣Y中的元素 ```Python where(condition > 0.5 , X , Y ) ``` ### gather 本質就是在查表,第一個引數是表格,第二個是維度,第三個是要查詢的索引 操作就是,在inpu中選擇維度dim,然後根據index編號,讀取input中的元素 ```Python torch.gather(input,dim,index,out=None)