1. 程式人生 > >pytorch中一些常用方法的總結

pytorch中一些常用方法的總結

主要介紹一些pytorch框架常用的方法:

2、我個人也是 pytorch 的初學者,我以一個初學者的身份來簡單介紹torch的使用,pytorch是使用GPU和CPU優化的深度張量庫,torch中最重要的一個數據型別就是Tensor(張量),我們計算的時候用Tensor來計算,速度要快一點,如果在訓練神經網路進行梯度運算的時候,一般都會用Varaible型別來計算。

常用的Tensor方法:

input:輸入可以是Tensor向量,也可以輸入單個值。output表示返回的結果,返回可以是向量也可以單個值。

  1. torch.lerp(star, end, weight) : 返回結果是out= star t+ (end-start) * weight
  2. torch.rsqrt(input) : 返回平方根的倒數
  3. torch.mean(input) : 返回平均值
  4. torch.std(input) : 返回標準偏差
  5. torch.prod(input) : 返回所有元素的乘積
  6. torch.sum(input) : 返回所有元素的之和
  7. torch.var(input) : 返回所有元素的方差
  8. torch.tanh(input) :返回元素雙正切的結果
  9. torch.equal(torch.Tensor(a), torch.Tensor(b)) :兩個張量進行比較,如果相等返回true,否則返回false
  10. torch.ge(input,other,out=none) 、 torch.ge(torch.Tensor(a),torch.Tensor(b))    比較內容:
  • ge: input>=other  也就是a>=b, 返回true,否則返回false
  • gt: input> other    也就是a>b, 返回true,否則返回false
  • lt: input<other 也就是a<b, 返回true,否則返回false
  1. torch.max(input): 返回輸入元素的最大值
  2. torch.min(input) : 返回輸入元素的最
  3. element_size() :返回單個元素的位元組
>>>torch.FloatTensor.element_size()
>>>4
>>>torch.ByteTensor.element_size()
>>>1

pytorch 在神經網路常用的方法

1、expand(*size)

返回tensor的一個新檢視,單個維度擴大為更大的尺寸。tensor也可以擴大為更高維,新增加的維度將附在前面。 擴大tensor不需要分配新記憶體,只是僅僅新建一個tensor的檢視,其中通過將stride設為0,一維將會擴充套件位更高維。任何一個一維的在不分配新記憶體情況下可擴充套件為任意的數值。

2、index_add_(dim,  index,  tensor) → Tensor

按引數index中的索引數確定的順序,將引數tensor中的元素加到原來的tensor中。引數tensor的尺寸必須嚴格地與原tensor匹配,否則會發生錯誤。引數: - dim(int)-索引index所指向的維度 - index(LongTensor)-需要從tensor中選取的指數 - tensor(Tensor)-含有相加元素的tensor

>>> x = torch.Tensor([[1, 1, 1], [1, 1, 1], [1, 1, 1]])
>>> t = torch.Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
>>> index = torch.LongTensor([0, 2, 1]) #這裡是把第2行和第3行進行對調>>> x.index_add_(0, index, t)
>>> x
  2348910
[torch.FloatTensor of size 3x3]

3、index_copy_(dim, index, tensor) → Tensor

按引數index中的索引數確定的順序,將引數tensor中的元素複製到原來的tensor中。引數tensor的尺寸必須嚴格地與原tensor匹配,否則會發生錯誤。解釋一下:index = torch.LongTensor([0,2,1]),這裡是把第2行和第3行進行對調,所以得到上述結果。

4、norrow (demension , start, length)---> te

返回一個本tensor經過縮小後的tensor。維度dim縮小範圍是startstart+length。原tensor與返回的tensor共享相同的底層記憶體。

>>>x = torch.Tensor([[1,2,3], [4,5,6], [7,8,9]])
>>>x.narrow(0,0,2)
1 2 3
4 5 6
[torch.FloatTensor of size 2x3]
>>>x.narrow(1,1,2)
2 3 5
6 8 9
[torch.FloatTensor of size 3x2]

4、permute(dims),常用的維度轉換方法

將tensor的維度換位      引數:dim(int)---換位順序

>>>x = torch.randn(2,3,4)
>>>x.size()
torch.size([2,3,5])
>>>x.permute(2,0,1).size()
torch.size([5,2,3])

5、repeat(*sizes)

沿著指定的維度重複tensor。不同與expand(),本函式複製的是tensor中的資料。

引數:*size(torch.size ot int...)-沿著每一維重複的次數

>>>x = torch.Tensor([1,2,3])
>>>x.repeat(4,2)
1 2 3 1 2 3
1 2 3 1 2 3
1 2 3 1 2 3
1 2 3 1 2 3
[torch.FloatTensor of size 4x6]
6、resize_(*size)
>>>x.repeat(4,2)
1 2 3 1 2 3
1 2 3 1 2 3
1 2 3 1 2 3
1 2 3 1 2 3
[torch.FloatTensor of size 4x6]
>>>x.repeat(4,2,1).size()
torch.Size([4,2,3])

7、resize_(*size)

將tensor的大小調整為指定的大小。如果元素個數比當前的記憶體大小大,就將底層儲存大小調整為與新元素數目一致的大小。

如果元素個數比當前記憶體小,則底層儲存不會被改變。原來tensor中被儲存下來的元素將保持不變,但新記憶體將不會被初始化。

引數:sizes(torch.Size or int....)需要調整的大小

>>>x = torch.Tensor([[1,2], [3,4], [5,6]])
>>>x.resize_(2,2) #這兩個2,分別表示兩行兩列,如果換成(1,3),則列印的結果是一個一行三列的向量
>>>x
12
34
[torch.FloatTensor of size 2x2]

8、storage_offset()---->int

以儲存元素的個數的形式返回tensor在地城記憶體中的偏移量。

>>>x = torch.Tensor([1,2,3,4,5])
>>>x.storage_offset()
0
>>>x[3:].storage_offset()
3


9、unfold(dim, size, step)---->Tensor

返回一個Tensor,其中含有在dim維tianchong度上所有大小為size的分片。兩個分片之間的步長為step。如果_sizedim_是dim維度的原始大小,則返回tensor

中的維度dim大小是_(sizesdim-size)/step+1_維度大小的附加維度將附加在返回的tensor中。

引數:_dim(int)--需要展開的維度--size(int)每一個分片需要展開的大小--step(int)-相鄰分片之間的步長


10、torch.gather(input, dim, index)

說明一下: dim=1,表示按行索引,dim = 0,按列索引。舉例說明一下

import torch
input = torch.LongTensor([[1,2],[3,4]])
print(input)
index = torch.LongTensor([[0,0],[1,0]])
res = torch.gather(input, 1, index)
print(res)

結果是:

 1  2
 3  4
[torch.LongTensor of size 2x2]
 1  1
 4  3
[torch.LongTensor of size 2x2]

按列索引:

import torch
input = torch.LongTensor([[1,2],[3,4]])
index = torch.LongTensor([[0,0],[1,0]])
res = torch.gather(input, 0, index)
print(res)

結果是:

 1  2
 3  4
[torch.LongTensor of size 2x2]
 1  2
 3  2
[torch.LongTensor of size 2x2]

另一個例子

index0 = torch.LongTensor([[0], [1]])
print(index0)
res = torch.gather(input, 1, index0)
print(res)

結果是:

 0
 1
[torch.LongTensor of size 2x1]
 1
 4
[torch.LongTensor of size 2x1]

11、torch.masked_select(input, mask)

說明一下這個方法的使用,input、mask這兩個引數維度必須一致,我試過不同維度,但是都沒有成功,姑且這麼認為吧。mask必須是ByteTensor型別。

masked =torch.ByteTensor([[0,1,0,0],[0,0,0,1],[0,0,1,0]])
input = torch.LongTensor([[1,2,3,4],[5,6,7,8],[9,10,11,12]])
res = torch.masked_select(input,masked)
print(res)

結果是:

 2
 8
 11
[torch.LongTensor of size 3]

12、target.masked_scatter_(mask, source)

說明:source是表示用來替換(改變)的值,而mask只有0,1;0表示不替換目標值(或者叫不改變);1是表示要替換目標值。target是需要被替換(改變)的值 。特別說明一下,mask、source的維度一致,而且必須和target的行的維度一致。

source = torch.LongTensor([[5],[7]])
mask = torch.ByteTensor([[1],[0]])
target = torch.LongTensor([[1,2],[3,4]])
res = target.masked_scatter_(mask,source)
print(res)

結果是:

 5  7
 3  4
[torch.LongTensor of size 2x2]
source = torch.LongTensor([[5],[7]])
mask = torch.ByteTensor([[0],[0]]) ###改變索引值
target = torch.LongTensor([[1,2],[3,4]])
res = target.masked_scatter_(mask,source)
print(res)

 1  2
 3  4
[torch.LongTensor of size 2x2]
source = torch.LongTensor([[5],[7]])
mask = torch.ByteTensor([[0],[1]])  #####改變索引值
target = torch.LongTensor([[1,2],[3,4]])
res = target.masked_scatter_(mask,source)
print(res)

1  2
5  7
[torch.LongTensor of size 2x2]

上述的每個方法都參考過文章前面的兩個連結,如果你想深入瞭解,轉向連結。


相關推薦

pytorch一些常用方法總結

主要介紹一些pytorch框架常用的方法:2、我個人也是 pytorch 的初學者,我以一個初學者的身份來簡單介紹torch的使用,pytorch是使用GPU和CPU優化的深度張量庫,torch中最重要的一個數據型別就是Tensor(張量),我們計算的時候用Tensor來計算

jquery項目一些常用方法

dev touch wid sets add subst arch param 時間 1、獲取url中的參數 function getUrlParam(name) { var reg = new RegExp("(^|&)" + name + "=([^&am

rails 日期Date、時間Time的一些常用方法總結

獲取當前時間、今天的當前日期 Time.now  國際時間就是Time.now.utc Date.today 日相關的獲取方法 一天的開始也就是0點00分 2.4.1 :001 > Time.now.beginning_of_day

js一些常用方法總結

  這兩天開始在牛客網上做一些js線上程式設計,發現很多程式設計題其實呼叫的js方法都差不多一樣,所以覺得可以彙總一下,方便記憶也可以多多熟悉。   1.slice()方法     這個方法就是可以從已有的陣列中返回選定的元素。     必須得有start,但是可以沒有end。     2

JAVAsort()常用方法總結

一、Arrays.sort()的用法 import java.util.Arrays; public class Main{ public static void main(String args[

java陣列常用方法總結

Java和C陣列的一些異同: 相同點:陣列名都是首元素的地址 不同點:C語言宣告變數可以直接定義陣列長度,java不可以                 java只有在為陣列分配變數時,可以宣告陣列長度                 java:int  [] a;    

安卓一些常用方法

根據手機的解析度從 dp 的單位 轉成為 px(畫素) :切勿在返回值後面+0.5增加精度 因為在某些低解析度跟高分辨的手機上會有大的誤差 public static int dip2px(Context context, float dpValue) { fin

Linux c一些常用函式總結(c語言中文網。。。)

fgets()函式 標頭檔案:include<stdio.h> fgets()函式 標頭檔案:include<stdio.h>fgets()函式用於從檔案流中讀取一行或指定個數的字元,其原型為:    char * fgets(char * 

js陣列常用方法總結

  運算元組 運算元組,印象中運算元組的方法很多,下面總結了一下陣列中常用的幾個方法: JavaScript中建立陣列有兩種方式 (一)使用 Array 建構函式:   var arr1 = new Array(); //建立一個空陣列var arr2 = new Array(

numpy一些常用函數的用法總結

num matrix 空白 記錄 維數 補充 結果 創建 array 先簡單記錄一下,後續補充詳細的例子 1. strip()函數 s.strip(rm):s為字符串,rm為要刪除的字符序列 只能刪除開頭或是結尾的字符或者字符串。不能刪除中間的字符或是字符串 當rm為空

關於機器學習一些常用方法的補充

機器學習 k近鄰 apriori pagerank前言 機器學習相關算法數量龐大,很難一一窮盡,網上有好事之人也評選了相關所謂十大算法(可能排名不分先後),它們分別是: 1. 決策樹2. 隨機森林算法3. 邏輯回歸4. 支持向量機5. 樸素貝葉斯6

tp5.0及其常用方法一些函數方法(自己看)和技巧(不斷添加

pro xtend yml 數據庫操作 apach txt 圖標 index run 1.目錄結構 2.路由 3..控制器 4.模型寫法 5.視圖標簽 6.數據庫操作 7.表單驗證 --------------------------- 1.目錄結構

js數組常用方法總結

dds 設置 布爾 nsh border 方式 cal AR 操作數 前言 從事前端到現在也有快兩年了,平時也會收集整理一些筆記放在印象筆記,不過收集過之後就在沒有看過,經大佬指點,真正掌握一個知識點,最好的方式就是用自己的話把內容講明白,就開始將以前零散的東西整合一下,和

bash shell 時間操作常用方法總結

hour day 當前時間 簡單的 之前 nbsp seconds 獲取 相互   在日常的工作中,bash shell 的時間操作非常頻繁。比如shell腳本定時發送數據統計的時候,會查看當前是否為預定的發送時間。或者使用文件保存一些數據時,一般會生成時間字符串當做文

JavaScriptArray型別一些常用方法

與其他語言中的陣列有著極大的區別,JavaScript中的陣列,每一項都可以儲存任何型別的資料,且陣列的大小可以動態的調整,即可以隨著資料的新增自動增長以容納新增的資料。 1.陣列的建立方式 建立陣列的基本方式有兩種 (1)使用Array建構函式 var colors = new Ar

JS一些常用的陣列方法

<!doctype html> <html lang="en"> <head> <meta charset="UTF-8"> <meta name="viewport" content="width=d

總結下git一些常用命令

一、目錄操作 1、cd    即change directory,改變目錄,如 cd d:/www,切換到d盤的www目錄。 2、cd ..   cd+空格+兩個點,回退到上一目錄。 3、pwd        即 print workin

J2EE一些常用方法和細節整理

1.setAttribute、getAttribute方法 方法 描述 注意點 void setAttribute(String name,Object o) 設定屬性的名稱及內容

淺談自定義View一些常用的回撥方法

1. 構造方法 1.public View(Context context) 2.public View(Context context, @Nullable AttributeSet attrs) 3.public View(Context context, @Nulla

執行緒一些常用方法的用法 join()、yield()、sleep()、wait()、notify()、notifyAll()

1.執行緒休眠sleep();:執行緒有優先順序,但是我們可以用此方法人為的改變它們的優先順序,讓執行緒暫停,它其他執行緒獲得分配空間。 用法:Thread.sleep(2000);//休眠兩秒 2.執行緒讓步yield();就是讓出自己的分配空間給其他執行