1. 程式人生 > >caffe Python API 之Model訓練

caffe Python API 之Model訓練

# 訓練設定
# 使用GPU
caffe.set_device(gpu_id) # 若不設定,預設為0
caffe.set_mode_gpu()
# 使用CPU
caffe.set_mode_cpu()

# 載入Solver,有兩種常用方法
# 1. 無論模型中Slover型別是什麼統一設定為SGD
solver = caffe.SGDSolver('/home/xxx/data/solver.prototxt') 
# 2. 根據solver的prototxt中solver_type讀取,預設為SGD
solver = caffe.get_solver('/home/xxx/data/solver.prototxt
') # 訓練模型 # 1.1 前向傳播 solver.net.forward() # train net solver.test_nets[0].forward() # test net (there can be more than one) # 1.2 反向傳播,計算梯度 solver.net.backward() # 2. 進行一次前向傳播一次反向傳播並根據梯度更新引數 solver.step(1) # 3. 根據solver檔案中設定進行完整model訓練 solver.solve()

如果想在訓練過程中儲存模型引數,呼叫

solver.net.save('mymodel.caffemodel
')