梳理一下Pytorch專案的基本結構(其實TF的也差不多是這樣,這種思路可以遷移到別的深度學習框架中)

結構樹

-------checkpoints #存放訓練完成的模型檔案

​ ----xxx.pkl #模型檔案

--------data #存放資料檔案(如txt)或者資料預處理檔案

​ ---__ init __.py

​ ---xxx.txt #資料

​ ---dataset.py #資料集相關

​ ---get_data.sh #一般用於下載某些資料

--------models #存放模型,一般一個模型對應一個.py檔案

​ ---__ init __.py

​ ---xxxNet.py

​ ---xxxModel.py

--------utils #存放一些工具函式,如視覺化等

​ ---__ init __.py

​ ---visualize.py

--------config.py #配置檔案

--------train.py #用於訓練模型,可視為主檔案

--------test.py #用於測試模型

流程

1、獲取資料

使用.sh檔案下載或者其他方法獲得資料

2、資料載入

一般會有一個檔案把資料處理成適合的格式,然後通過載入器(Dataloader)載入模型中使用,這個Dataloader可能是獨立的,也可能整合在train.py裡面

3、訓練

顧名思義,使用載入的資料對定義的模型進行訓練。這個過程基本上是使用train.py進行,結果是你會得到一個.pkl結尾的模型檔案

4、測試

用一部分資料對訓練好的模型進行測試(這些資料可以來自之前匯入的資料,也可以是新的資料),使用test.py進行,呼叫損失函式,列印日誌(就是你看到的那些在console裡重新整理的log)

5、使用模型

就是呼叫即可,先給出我們存放模型的位置,然後載入即可(沒有實操,後續再更新)

注:

  • 模型.py檔案中,一般是用一個函式或者一個類來承載一個具體模型,其中定義著模型的不同層
  • train.py是工程的核心,裡面定義了訓練時需要的各項引數、訓練次數等重要資訊