1. 程式人生 > >[Keras] 使用Keras呼叫多GPU時出現無法儲存模型的解決方法

[Keras] 使用Keras呼叫多GPU時出現無法儲存模型的解決方法

在使用keras 的並行多路GPU時出現了模型無法

儲存,在使用單個GPU時執行完全沒有問題。執行出現can't pickle的問題

隨後在網上找了很多解決方法。下面列舉一些我實驗成功的方法。

方法一

class ParallelModelCheckpoint(ModelCheckpoint):
    def __init__(self,model,filepath, monitor='val_loss', verbose=0,
                 save_best_only=False, save_weights_only=False,
                 mode='auto', period=1):
		self.single_model = model
		super(ParallelModelCheckpoint,self).__init__(filepath, monitor, verbose,save_best_only, save_weights_only,mode, period)

    def set_model(self, model):
        super(ParallelModelCheckpoint,self).set_model(self.single_model)
check_point = ParallelModelCheckpoint(single_model ,'best.hd5')

在這裡需要解釋一下 這個single_model 時沒有進行並行的model 

 model = Model(inputs = [inputs], outputs = [conv10])
 parallel_model = multi_gpu_model(model, gpus=2)

在這裡第一就是single_model.直接用這個model checkpoint 就可以了 親測有效

方法二

original_model = ...
parallel_model = multi_gpu_model(original_model, gpus=n)

class MyCbk(keras.callbacks.Callback):

    def __init__(self, model):
         self.model_to_save = model

    def on_epoch_end(self, epoch, logs=None):
        self.model_to_save.save('model_at_epoch_%d.h5' % epoch)

cbk = MyCbk(original_model)
parallel_model.fit(..., callbacks=[cbk])

同理這裡也是一樣的。其實在上面兩種方法中可以發現,基本都是在checkpoint 問題上都是使用了單個model進行執行的。

意思就是直接使用傳入方法keras.utils.multi_gpu_model(model, gpus)中的model即可,而不要使用返回的parallel_model

參考網站