1. 程式人生 > >PyTorch踩過的坑(長期更新)

PyTorch踩過的坑(長期更新)

1. nn.Module.cuda() 和 Tensor.cuda() 的作用效果差異

無論是對於模型還是資料,cuda()函式都能實現從CPU到GPU的記憶體遷移,但是他們的作用效果有所不同。

對於nn.Module:

model = model.cuda() 
model.cuda() 

上面兩句能夠達到一樣的效果,即對model自身進行的記憶體遷移。

對於Tensor:

和nn.Module不同,呼叫tensor.cuda()只是返回這個tensor物件在GPU記憶體上的拷貝,而不會對自身進行改變。因此必須對tensor進行重新賦值,即tensor=tensor.cuda().

例子:

model = create_a_model()
tensor = torch.zeros([2,3,10,10])
model.cuda()
tensor.cuda()
model(tensor)    # 會報錯
tensor = tensor.cuda()
model(tensor)    # 正常執行

2. PyTorch 0.4 計算累積損失的不同

以廣泛使用的模式total_loss += loss.data[0]為例。Python0.4.0之前,loss是一個封裝了(1,)張量的Variable,但Python0.4.0的loss現在是一個零維的標量。對標量進行索引是沒有意義的(似乎會報 invalid index to scalar variable 的錯誤)。使用loss.item()可以從標量中獲取Python數字。所以改為:

total_loss += loss.item()

如果在累加損失時未將其轉換為Python數字,則可能出現程式記憶體使用量增加的情況。這是因為上面表示式的右側原本是一個Python浮點數,而它現在是一個零維張量。因此,總損失累加了張量和它們的梯度歷史,這可能會產生很大的autograd 圖,耗費記憶體和計算資源。

3. PyTorch 0.4 編寫不限制裝置的程式碼

# torch.device object used throughout this script
device = torch.device("cuda" if use_cuda else "cpu")
model = MyRNN().to(device)

# train
total_loss= 0
for input, target in train_loader:
    input, target = input.to(device), target.to(device)
    hidden = input.new_zeros(*h_shape)       # has the same device & dtype as `input`
    ...                                                               # get loss and optimize
    total_loss += loss.item()

# test
with torch.no_grad():                                    # operations inside don't track history
    for input, targetin test_loader:
        ...

4. torch.Tensor.detach()的使用

detach()的官方說明如下:

Returns a new Tensor, detached from the current graph.
    The result will never require gradient.

假設有模型A和模型B,我們需要將A的輸出作為B的輸入,但訓練時我們只訓練模型B. 那麼可以這樣做:

input_B = output_A.detach()

它可以使兩個計算圖的梯度傳遞斷開,從而實現我們所需的功能。

5. ERROR: Unexpected bus error encountered in worker. This might be caused by insufficient shared memory (shm)

出現這個錯誤的情況是,在伺服器上的docker中執行訓練程式碼時,batch size設定得過大,shared memory不夠(因為docker限制了shm).解決方法是,將Dataloader的num_workers設定為0.

6. pytorch中loss函式的引數設定

以CrossEntropyLoss為例:

CrossEntropyLoss(self, weight=None, size_average=None, ignore_index=-100, reduce=None, reduction='elementwise_mean')
  • reduce = False,那麼 size_average 引數失效,直接返回向量形式的 loss,即batch中每個元素對應的loss.
  • reduce = True,那麼 loss 返回的是標量:
    • 如果 size_average = True,返回 loss.mean().
    • 如果 size_average = False,返回 loss.sum().
  • weight : 輸入一個1D的權值向量,為各個類別的loss加權,如下公式所示:\text{loss}(x, class) = weight[class] \left(-x[class] + \log\left(\sum_j \exp(x[j])\right)\right) 
  • ignore_index : 選擇要忽視的目標值,使其對輸入梯度不作貢獻。如果 size_average = True,那麼只計算不被忽視的目標的loss的均值。
  • reduction : 可選的引數有:‘none’ | ‘elementwise_mean’ | ‘sum’, 正如引數的字面意思,不解釋。

7. pytorch的可重複性問題

相關推薦

PyTorch長期更新

1. nn.Module.cuda() 和 Tensor.cuda() 的作用效果差異 無論是對於模型還是資料,cuda()函式都能實現從CPU到GPU的記憶體遷移,但是他們的作用效果有所不同。 對於nn.Module: model = model.cuda()  mo

網路調參時長期更新

1.學習率大小的設定 一般情況下,當網路收斂到一定程度時,loss曲線的變化不明顯,並出現上下的小幅度波動,這時候可以考慮調小學習率,幫助網路進一步收斂到最優值。如下圖所示: 但有些情況,網路看似收斂了,但實際上是到了某些平坦的曲面,離最優值還有一段距離。典型的los

那些年的CSS永久更新

1、img 標籤中的alt 與title的區別: alt  alt屬性的實質作用是在圖片無法正確顯示時起到文字替代的作用,不過在IE6下還起到了title的作用。 title 滑鼠滑過時顯示的文字提示。 對SEO優化的影響: 搜尋引擎對圖片理解是通過alt屬性,所以在圖片a

那些年我們的php持續更新

原因:在第一次迴圈時,陣列的指標指向下一個元素,得到的陣列值為2,這個時候,php陣列內部會複製一份臨時的陣列$tmp, $tmp的指標指向第二個元素,後續呼叫current($a),實際上是取的臨時陣列$tmp的當前值,而$tmp的指標始終指向第二個元素,所以輸出結果永遠是2

Android 開發時遇到持續更新

1.匯入工程後,更改應用報名報錯,clean 無反應。 在網上查詢資料沒有頭緒,後面發現,自定義的控制元件所在的路徑因為更改報名之後改變了,需要在引用該控制元件的佈局檔案中修改屬性的路徑 xmlns:example="http://schemas.android.com/a

tomcat 與 java web中url路徑的配置以及使用規則詳情長期更新

root 每一個 ava 目錄 clip ima 文件夾 logs 需要 首先我們看一下在myeclipse中建立的java web項目的結構 在這裏我們需要註意這個webroot也就是我們在tomcat裏的webapp裏面的應用 之所以每一個項目都有這個webroot

Java成神之路技術整理長期更新

重復註解 java多線程 加載機制 rom 倒計時器 dad 免費 dcm servle 以下是Java技術棧微信公眾號發布的關於 Java 的技術幹貨,從以下幾個方面匯總。 Java 基礎篇 Java 集合篇 Java 多線程篇 Java JVM篇 Java 進階篇 J

vscode常見錯誤匯總長期更新

python git vscode debug 1.錯誤提示 Q:不是每一個紅波浪線都是錯誤,都需要修改 A: 看下面這個地方: 這裏的from確實標記了紅色波浪線,鼠標放上去還有提示: 但是,這裏並不需要修改,因為pep8檢查很嚴格,我們這裏前面是針對整個工程,把工程目錄添加到了環境變量

Python花式錯誤集錦長期更新

留言 int 項目 add encoding ror 操作 pat oba Python是一門靈活的,有意思的,用途廣泛的語言。近些年來,收到越來越多的重視。也有越來越多的人來學習這門語言。 於是,問題來了,對於初學者,往往在寫代碼的過程中,出現這樣或那樣的錯誤,導致程序運

git 指令長期更新

引言:git 是一個非常棒的分散式版本管理系統,我想做開發的小夥伴們對 git 都不陌生,我平時也很喜歡用 git 與github 協同開發(想起以前沒用git 的日子,真是很難受,現在已經是離不開了)。關於git 不得不說的就是 git 指令,平時我自己用的比較多的是:git add ; git commi

長距離單曆元非差GNSS網路RTK理論與方法總結長期更新

1.狀態空間: 狀態空間是控制工程中的一個名詞。狀態是指在系統中可決定系統狀態、最小數目變數的有序集合。    而所謂狀態空間則是指該系統全部可能狀態的集合。簡單來說,狀態空間可以視為一個以狀態變數為座標軸的空間,因此係統的狀態可以表示為此空間中的一個向量。  狀

unity優化一些總結 長期更新

unity優化一些總結 (長期更新) UI: 1:儘量不要使用動態文字 2: 使用更多畫布 拆分畫布 ​ 我開始使用3幅畫布。一個用於我的背景影象,一個用於我的主要UI元素,另一個用於需要放置在其他所有元素頂部的元素。 我瞭解到,每當畫布中的某些內容發生變化時,整個畫布都會被重新評估並重新繪製。因此

vue中的細節長期更新

(一)條件渲染:v-if 與 v-show:          兩者均用於條件渲染,都可以與”v-else”搭配使用。區別在於使用” v-if “時,如果條件不滿足,被” v-if “包裹的元素不會進行初始化,即DOM結構中沒有插入該標籤包

科研心得日記長期更新

2018/9/13 頭幾天導師在群裡發了一篇推送,推送中寫到了現在研究生普遍存在的一些問題,其中最引起導師共鳴的就是“不會自己想idea”。導師說,這個問題就很嚴重的發生在我們組裡。仔細想想,其實我就是屬於非常不會想idea的。 科研的本領對於大部分人而言,

Python——關於常見模組長期更新

1、在Python中,我們最常見的估計就是時間模組,所以第一個我來說一下時間模組: import time print(time.time()) #時間戳,從1970年8點開始的 print(time.clock()) #計算cpu的執行時間 print(ti

codeforces div1 ABC彙總長期更新

Codeforces Round #310 (Div. 1) Codeforces Round #309 (Div. 1) Codeforces Round #305 (Div. 1) B - Mike and Feet

centos 更新與清理教程 長期更新

在使用sudo yum update命令更新centos的核心與庫後,需要進行一些清理工作。1.刪除CentOS更新後的舊核心首先檢視系統當前核心版本uname -a提示如下Linux localhost.localdomain 3.10.0-693.21.1.el7.x86

關於sql語句中的一些函式長期更新

前言在最近看別人的sql的時候,看到一些函式,比如left(),right()等等,好奇是什麼意思,查詢之後覺得還是挺有用的,特此記錄下來。部落格會在遇到新的函式的時候定期更新。—————————————————————————————————————————————————

JAVA IO 長期更新

1.OutputStream 類 1.1 void write(int b) 方法 本來應該是接受一個無符號的1個位元組的整數(0-255). 接受 Int型資料,但是java內部其實會將b 轉換成0-255之間的數字,原因是Stream是以位元組

單目,雙目以及RGB-D相機的資料長期更新

一、相機標定以及畸變矯正1、https://blog.csdn.net/humanking7/article/details/45037239——[影象]畸變校正詳解2、https://blog.csdn.net/liulina603/article/details/5295