1. 程式人生 > >pytorch1.0 用torch script匯出儲存模型

pytorch1.0 用torch script匯出儲存模型

python的易上手和pytorch的動態圖特性,使得pytorch在學術研究中越來越受歡迎,但在生產環境,礙於python的GIL等特性,可能達不到高併發、低延遲的要求,存在需要用c++介面的情況。除了將模型匯出為ONNX外,pytorch1.0給出了新的解決方案:pytorch 訓練模型 - 通過torch script中間指令碼儲存模型 -- C++載入模型。最近工作需要嘗試做了轉換,總結一下步驟和遇到的坑。

用torch script把torch模型轉成c++介面可讀的模型有兩種方式:trace && script. trace比script簡單,但只適合結構固定的網路模型,即forward中沒有控制流的情況,因為trace只會儲存執行時實際走的路徑。如果forward函式中有控制流,需要用script方式實現。

trace顧名思義,就是沿著資料運算的路徑走一遍,官方例子:

 
import torch
def foo(x, y): return 2*x + y traced_foo = torch.jit.trace(foo, (torch.rand(3), torch.rand(3)))

 

 

 

 

script稍複雜,主要改三處:

1. Model由之前繼承 nn.Model 改為繼承 torch.jit.ScriptModule

2. forward函式前加 @torch.jit.script_method

3. 其他需要呼叫的函式前加 @torch.jit.script

 

踩過的坑&&解決方法:

A. torch script預設函式或方法的引數都是Tensor型別的,如果不是需要說明,不然呼叫非Tensor引數時會報型別不符的編譯錯誤。

python3可以直接:

def example_func(param_1: Tensor, param_2: int, param_3: List[int]):

 

 

python2需要用type註釋:

def example_func(param_1, param_2, param_3):

#type: (Tensor, int, List[int]) -> Tensor

 

 

 

B. model的方法中orward加@torch.jit.script_method, __init__函式不用

C. 前面說過,torch scrip支援的函式是pytorch的子集,意味著有一部分函式不支援,例如: not boolean,pass, List的切片賦值,CPU和GPU切換的value.to( ), 需要想辦法繞過去。看github上討論區說新版好像已經支援not操作了,沒有驗證。

 

結論:pytorch 1.0目前的預覽版還有比較多優化的空間,至少是在torch script支援的函式集合上,不建議使用,等穩定版釋出再看看吧。

  

 

原創內容,轉載請註明出處。

 

參考資料:

https://pytorch.org/docs/master/jit.html

https://pytorch.org/tutorials/beginner/deploy_seq2seq_hybrid_frontend_tutorial.html