1. 程式人生 > >問題解決:Pytorch :Trying to backward through the graph a second time, but the buffers。。

問題解決:Pytorch :Trying to backward through the graph a second time, but the buffers。。

最近在學習Pytorch,剛用Pytorch重寫了之前用Tensorlfow寫的論文程式碼。
首次執行就碰到了一個bug:
Pytorch - RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.
剛開始按照這個錯誤提示,設定loss.backward(retain_graph=True),雖然解決了這個問題,但是隨著訓練的繼續,報錯OOM。很尷尬。。。
查了stackoverflow上的方法,最終解決了問題。

我原來的程式碼是:

 for side in outputs:
     loss += Loss(side, label)

 loss.backward(retain_graph=True)

很顯然,一旦呼叫loss.backward(), 就相當於呼叫了多次的Loss(side, label)的.backward()方法,而Pytorch的機制是每次呼叫.backward()都會free掉所有buffers,所以它提示,讓retain_graph。然而當retain後,buffers就不會被free了,所以會OOM。
最後的解決辦法就是, 分開寫:

side0 = Loss(output[0], label)
side1 = Loss(output[1], label)
side2 = Loss(output[2], label)
side3 = Loss(output[3], label)
side4 = Loss(output[4], label)
side5 = Loss(output[5], label)
loss = side0 + side1 + side2 + side3 + side4 + side5

作者:Mundane_World
來源:CSDN
原文:https://blog.csdn.net/Mundane_World/article/details/81038274
版權宣告:本文為博主原創文章,轉載請附上博文連結!