1. 程式人生 > >pytorch視訊記憶體越來越多的一個潛在原因-- 這個函式還沒有在torch.cuda.Tensor中定義

pytorch視訊記憶體越來越多的一個潛在原因-- 這個函式還沒有在torch.cuda.Tensor中定義

最近在用pytorch跑實驗,有如下操作需要用到: (pytorch版本為0.3.1)

class SpatialFilter(nn.Module):
    def __init__(self,mode=True,sf_rate=0.8):
        '''
        給定特徵圖求hot_map
        '''
        super(SpatialFilter,self).__init__()
        self.sf_rate=sf_rate
        self.mode=mode
    def forward(self,x):
        b,c,h,w=x.size()
        
        if self.mode:
            #print("====",self.sf_rate)
            hot_map=torch.mean(x,dim=1).view(b,1,h*w).
            map_med=torch.median(hot_map,dim=2)[0].view(b,1,1,1)  #hot_map的中位數
            hot_map=hot_map.view(b,1,h,w)
            hot_map=torch.gt(hot_map,map_med*self.sf_rate).float()
            del(map_med)        
        else:
            #print("++++")
            hot_map=Variable(torch.ones(b,1,h,w),requires_grad=False)
        return hot_map

發現在訓練的時候視訊記憶體會不斷增加,情況就是每隔一個epoch視訊記憶體會增加30M左右,在一個epoch之內視訊記憶體不會增加。剛開始我以為是我的訓練部分寫的有問題,後來發現不用torch.median()而是用 F.AdaptiveAvgPool2d()就不會有這個問題,於是我就去看了pytorch的中文文件,發現pytorch中文文件中torch.median()函式下有這麼一句話:   注意: 這個函式還沒有在torch.cuda.Tensor中定義. 所以問題就很清晰了,這個運算應該是先把Tensor轉到cpu上再把它挪回到gpu上,應該是由於內部轉換機制的不健全導致了gpu上的視訊記憶體沒有及時釋放,只需要人為的把操作轉到cpu上再把tensor轉到gpu上,並自己delete就可解決這個問題。

程式碼如下:

class SpatialFilter(nn.Module):
	def __init__(self,mode=True,sf_rate=0.8):
		'''
		給定特徵圖求hot_map
		'''
		super(SpatialFilter,self).__init__()
		self.sf_rate=sf_rate
		self.mode=mode
	def forward(self,x):
		b,c,h,w=x.size()
		
		if self.mode:
			#print("====",self.sf_rate)
			hot_map=torch.mean(x,dim=1).view(b,1,h*w).cpu()
			if isinstance(x.data,torch.cuda.FloatTensor):
				map_med=torch.median(hot_map,dim=2)[0].view(b,1,1,1).cuda()  #hot_map的中位數
				hot_map=hot_map.view(b,1,h,w).cuda()
			else:
				map_med=torch.median(hot_map,dim=2)[0].view(b,1,1,1)  #hot_map的中位數
				hot_map=hot_map.view(b,1,h,w)
			hot_map=torch.gt(hot_map,map_med*self.sf_rate).float()
			del(map_med)		
		else:
			#print("++++")
			hot_map=Variable(torch.ones(b,1,h,w),requires_grad=False)
		return hot_map

這樣改動之後就可以在cpu上和gpu上同時跑了。

相似的可能導致這種問題的操作還有:torch.mode() 求眾數函式

希望能幫助大家以後跳過這個坑。