梳理一下Pytorch專案的基本結構(其實TF的也差不多是這樣,這種思路可以遷移到別的深度學習框架中)
結構樹
-------checkpoints #存放訓練完成的模型檔案
----xxx.pkl #模型檔案
--------data #存放資料檔案(如txt)或者資料預處理檔案
---__ init __.py
---xxx.txt #資料
---dataset.py #資料集相關
---get_data.sh #一般用於下載某些資料
--------models #存放模型,一般一個模型對應一個.py檔案
---__ init __.py
---xxxNet.py
---xxxModel.py
--------utils #存放一些工具函式,如視覺化等
---__ init __.py
---visualize.py
--------config.py #配置檔案
--------train.py #用於訓練模型,可視為主檔案
--------test.py #用於測試模型
流程
1、獲取資料
使用.sh檔案下載或者其他方法獲得資料
2、資料載入
一般會有一個檔案把資料處理成適合的格式,然後通過載入器(Dataloader)載入模型中使用,這個Dataloader可能是獨立的,也可能整合在train.py裡面
3、訓練
顧名思義,使用載入的資料對定義的模型進行訓練。這個過程基本上是使用train.py進行,結果是你會得到一個.pkl結尾的模型檔案
4、測試
用一部分資料對訓練好的模型進行測試(這些資料可以來自之前匯入的資料,也可以是新的資料),使用test.py進行,呼叫損失函式,列印日誌(就是你看到的那些在console裡重新整理的log)
5、使用模型
就是呼叫即可,先給出我們存放模型的位置,然後載入即可(沒有實操,後續再更新)
注:
- 模型.py檔案中,一般是用一個函式或者一個類來承載一個具體模型,其中定義著模型的不同層
- train.py是工程的核心,裡面定義了訓練時需要的各項引數、訓練次數等重要資訊