1. 程式人生 > >【用Python學習Caffe】5. 生成solver檔案

【用Python學習Caffe】5. 生成solver檔案

5. 生成solver檔案

網路訓練一般是通過solver來進行的。對於caffe來說,其是通過solver檔案來生成solver訓練器進行網路訓練及測試的,該solver檔案中包含了訓練及測試網路的配置檔案的地址,及相關訓練方法及一些訓練的超引數,該檔案一般不是很大,可以直接在一些solver.prototxt檔案上更改。也可以通過Python結合caffe_pb2.SolverParameter()結構自動生成solver.prototxt檔案

    def solver_file(model_root, model_name):
        s = caffe_pb2.SolverParameter() # 宣告solver結構
        s.train_net = model_root+'train.prototxt' # 訓練網路結構配置檔案
        s.test_net.append(model_root+'test.prototxt') # 測試時網路結構配置檔案,測試網路可有多個
        # 每訓練迭代test_interval次進行一次測試。
        s.test_interval = 500
        # 每次測試時的批量數,測試裡網路可有多個
        s.test_iter.append(100)
        # 最大訓練迭代次數
        s.max_iter = 10000
        # 基礎學習率
        s.base_lr = 0.01
        # 動量,記憶因子
        s.momentum = 0.9
        # 權重衰減值,遺忘因子
        s.weight_decay = 5e-4
        # 學習率變化策略。可選引數:fixed、step、exp、inv、multistep
        # fixed: 保持base_lr不變;
        # step: 學習率變化規律base_lr * gamma ^ (floor(iter / stepsize)),其中iter表示當前的迭代次數;
        # exp: 學習率變化規律base_lr * gamma ^ iter;
        # inv: 還需要設定一個power,學習率變化規律base_lr * (1 + gamma * iter) ^ (- power);
        # multistep: 還需要設定一個stepvalue,這個引數和step相似,step是均勻等間隔變化,而multistep則是根據stepvalue值變化;
        #   stepvalue引數說明:
        #       poly: 學習率進行多項式誤差,返回base_lr (1 - iter/max_iter) ^ (power);
        #       sigmoid: 學習率進行sigmod衰減,返回base_lr ( 1/(1 + exp(-gamma * (iter - stepsize))))。
        s.lr_policy = 'inv'
        s.gamma = 0.0001
        s.power = 0.75

        s.display = 100 # 每迭代display次顯示結果
        s.snapshot = 5000 # 儲存臨時模型的迭代數
        s.snapshot_prefix = model_root+model_name+'shapshot' # 模型字首,就是訓練好生成model的名字
        s.type = 'SGD' # 訓練方法(各類梯度下降法),可選引數:SGD,AdaDelta,AdaGrad,Adam,Nesterov,RMSProp
        s.solver_mode = caffe_pb2.SolverParameter.GPU # 訓練及測試模型,GPU或CPU

        solver_file=model_root+'solver.prototxt' # 要儲存的solver檔名

        with open(solver_file, 'w') as f:
            f.write(str(s))

5.1 具體程式碼下載