1. 程式人生 > >pytorch藉助tensorboard實現模型視覺化

pytorch藉助tensorboard實現模型視覺化

補充 : 剛發現貌似sqrt操作是不支援的

python庫:

pytorch(>=0.3) , onnx, tensorboardX

原理:

 

Open Neural Network Exchange (ONNX)是開放生態系統的第一步,它使人工智慧開發人員可以在專案的發展過程中選擇合適的工具;ONNX為AI models提供了一種開源格式。它定義了一個可以擴充套件的計算圖模型,同時也定義了內建操作符和標準資料型別。最初我們關注的是推理(評估)所需的能力。

Caffe2, PyTorch, Microsoft Cognitive Toolkit, Apache MXNet 和其他工具都在對ONNX進行支援。在不同的框架之間實現互操作性,並簡化從研究到產品化的過程,將提高人工智慧社群的創新速度。

簡單來說就是藉助onnx將pytorch的模型存為model.proto的檔案,然後藉助於tensorboardX這個工具將model.proto轉換為tensorboar的graph.

程式碼:

#對於pytorch0.3以上的版本
import tensorboardX 
import torch
from torchvision.models import resnet34
import torch.onnx

x=torch.autograd.Variable(torch.rand(1,3,224,224)) #隨便定義一個輸入
model=resnet34()

proto=torch.onnx.export(model,x,"resnet34.proto",verbose=True) #將model的結構和引數全部儲存為 resnet32.proto

writer=tensorboardX.SummaryWriter("./logs/")  #定義一個tensorboardX的寫物件 
writer.add_graph_onnx("./resnet34.proto")  #將proto格式的檔案轉換為tensorboard中的graph

對於pytorch 0.2來說可以直接來畫:

import tensorboardX 
import torch
from torchvision.models import resnet34
import torch.onnx

x=torch.autograd.Variable(torch.rand(1,3,224,224)) #隨便定義一個輸入
model=resnet34()
writer=tensorboardX.SummaryWriter("./logs/")  #定義一個tensorboardX的寫物件 
writer.add_graph(model,x,verbose=True)  #將proto格式的檔案轉換為tensorboard中的graph

 

效果如下 ,確實有點醜,不如tensorflow那樣五顏六色,也沒有更加詳細的操作:

 

拉近的圖片:

補充,剛才有人說好像max_pool2d是不支援的,我自己的測試時可以的,建議檢查一下tensorboardX和ONNX的版本,程式碼如下: 我的tensorboardX版本是1.4的,onnx版本是1.3.0

import torch
import  torch.nn.functional as F
import torch.onnx
import tensorboardX

class ResNet(nn.Module):

	def __init__(self, block, layers, num_classes=1000):
		super(ResNet, self).__init__()
	def forward(self, x):
		#這兒就是我加的操作
		x=F.max_pool2d(x,kernel_size=7)
		return x

def resnet50():
	"""Constructs a ResNet-50 model.

	Args:
		pretrained (bool): If True, returns a model pre-trained on ImageNet
	"""
	model = ResNet(Bottleneck, [3, 4, 6, 3])
	return model

if __name__=="__main__":
	x=torch.autograd.Variable(torch.rand(1,3,224,224)) #隨便定義一個輸入
	model=resnet50()
	 
	proto=torch.onnx.export(model,x,"resnet50.proto",verbose=True) #將model的結構和引數全部儲存為 resnet32.proto
	 
	writer=tensorboardX.SummaryWriter("./logs/")  #定義一個tensorboardX的寫物件 
	writer.add_graph_onnx("./resnet50.proto")