1. 程式人生 > >keras讀取訓練好的模型引數並把引數賦值給其它模型

keras讀取訓練好的模型引數並把引數賦值給其它模型

介紹

本博文中的程式碼,實現的是載入訓練好的模型model_halcon_resenet.h5,並把該模型的引數賦值給兩個不同的新的model。

函式式模型

官網上給出的呼叫一個訓練好模型,並輸出任意層的feature。
model = Model(inputs=base_model.input, outputs=base_model.get_layer(‘block4_pool’).output)

但是這有一個問題,就是新的model,如果輸入inputs和訓練好的model的inputs大小不同呢?比如我想建立一個輸入是600x600x3的新model,但是訓練好的model輸入是200x200x3,而這時我又想呼叫訓練好模型的卷積核引數,這時該怎麼辦呢?

其實想一下,用訓練好的模型引數,即使輸入的尺寸不同,但是這些模型引數仍然可以處理計算,只是輸出的feature map大小不同。那到底怎麼賦值呢?其實很簡單

在定義新的model時,新的model層在定義時,需要加上名字,而這個名字就是訓練好的模型的每層名字。如下程式碼所示:

inputs=Input(shape=(400,500,3))
X=Conv2D(32, (3, 3),name=“conv2d_1”)(inputs)
X=BatchNormalization(name=“batch_normalization_1”)(X)
X=Activation(‘relu’,name=“activation_1”)(X)

最後通過以下程式碼即可建立一個新的模型並擁有訓練好模型的引數:

model=Model(inputs=inputs, outputs=X)
model.load_weights(‘model_halcon_resenet.h5’, by_name=True)

原始碼

from keras.models import load_model
from keras.preprocessing import image
from keras.applications.vgg19 import preprocess_input
from keras.models import Model
import numpy as np
from keras.layers import Conv2D, MaxPooling2D,merge
from keras.layers import BatchNormalization,Activation
from keras.layers import Input, Dense
from PIL import Image
import numpy as np
import keras
from keras.models import Sequential
from keras.layers import Dense, Dropout, Flatten,Input
from keras.layers import Conv2D, MaxPooling2D,merge,AveragePooling2D,GlobalAveragePooling2D
from keras.layers import BatchNormalization,Activation
from sklearn.model_selection import train_test_split
from keras.applications.densenet import DenseNet169, DenseNet121
from keras.applications.inception_resnet_v2 import InceptionResNetV2
from keras.applications.inception_v3 import InceptionV3
from keras.optimizers import SGD
from keras import regularizers
from keras.models import Model
import tensorflow as tf
from PIL import Image
from keras.callbacks import TensorBoard
import os
import cv2
from keras import backend as K
from model import focal_loss
import keras.losses


#ReadMe 該程式碼是參考fast rcnn系列,先對整幅影象提取特徵feature map,然後從原圖對應位置上對映到feature map,並對feature map進行
# 切片,從而提取對應某個位置上的特徵,並把該特徵送進後面的識別網路進行分類識別。
keras.losses.focal_loss = focal_loss#這句程式碼是為了引入定義的loss
base_model=load_model('model_halcon_resenet.h5')
base_model.summary()

inputs=Input(shape=(400,500,3))
X=Conv2D(32, (3, 3),name="conv2d_1")(inputs)
X=BatchNormalization(name="batch_normalization_1")(X)
X=Activation('relu',name="activation_1")(X)
#第一個殘差模組
X_1=Conv2D(32, (3, 3),padding='same',name="conv2d_2")(X)
X_1=BatchNormalization(name="batch_normalization_2")(X_1)
X_1= Activation('relu',name="activation_2")(X_1)
X_1 = Conv2D(32, (3, 3),padding='same',name="conv2d_3")(X_1)
X_1 = BatchNormalization(name="batch_normalization_3")(X_1)
merge_data = merge([X_1, X], mode='sum',name="merge_1")
X = Activation('relu',name="activation_3")(merge_data)
#第一個殘差模組結束
X=MaxPooling2D(pool_size=(2, 2),strides=(2,2),name="max_pooling2d_1")(X)
X=Conv2D(64, (3, 3),kernel_regularizer=regularizers.l2(0.01),name="conv2d_4")(X)
X=BatchNormalization(name="batch_normalization_4")(X)
X=Activation('relu',name="activation_4")(X)
#第二個殘差模組
X_2=Conv2D(64, (3, 3),padding='same',name="conv2d_5")(X)
X_2=BatchNormalization(name="batch_normalization_5")(X_2)
X_2= Activation('relu',name="activation_5")(X_2)
X_2 = Conv2D(64, (3, 3),padding='same',name="conv2d_6")(X_2)
X_2 = BatchNormalization(name="batch_normalization_6")(X_2)
merge_data = merge([X_2, X], mode='sum',name="merge_2")
X = Activation('relu',name="activation_6")(merge_data)
#第二個殘差模組結束
X = MaxPooling2D(pool_size=(2, 2), strides=(2, 2),name="max_pooling2d_2")(X)
X=Conv2D(64, (3, 3),name="conv2d_7")(X)
X=BatchNormalization(name="batch_normalization_7")(X)
X=Activation('relu',name="activation_7")(X)
X=MaxPooling2D(pool_size=(2, 2),strides=(2,2),name="max_pooling2d_3")(X)
#第三個殘差模組開始
X_3=Conv2D(64, (3, 3),padding='same',name="conv2d_8")(X)
X_3=BatchNormalization(name="batch_normalization_8")(X_3)
X_3= Activation('relu',name="activation_8")(X_3)
X_3 = Conv2D(64, (3, 3),padding='same',name="conv2d_9")(X_3)
X_3 = BatchNormalization(name="batch_normalization_9")(X_3)
merge_data = merge([X_3, X], mode='sum',name="merge_3")
X = Activation('relu',name="activation_9")(merge_data)
#第三個殘差模組結束
X=Conv2D(32, (3, 3),kernel_regularizer=regularizers.l2(0.01),name="conv2d_10")(X)
X=BatchNormalization(name="batch_normalization_10")(X)
X=Activation('relu',name="activation_10")(X)
#第四個殘差模組開始
X_4=Conv2D(32, (3, 3),padding='same',name="conv2d_11")(X)
X_4=BatchNormalization(name="batch_normalization_11")(X_4)
X_4= Activation('relu',name="activation_11")(X_4)
X_4 = Conv2D(32, (3, 3),padding='same',name="conv2d_12")(X_4)
X_4 = BatchNormalization(name="batch_normalization_12")(X_4)
merge_data = merge([X_4, X], mode='sum',name="merge_4")
X = Activation('relu',name="activation_12")(merge_data)
#第四個殘差模組結束
X = MaxPooling2D(pool_size=(2, 2), strides=(2, 2),name="max_pooling2d_4")(X)
X = Conv2D(64, (3, 3),name="conv2d_13")(X)
X = BatchNormalization(name="batch_normalization_13")(X)
X = Activation('relu',name="activation_13")(X)
#第五個殘差模組開始
X_5=Conv2D(64, (3, 3),padding='same',name="conv2d_14")(X)
X_5=BatchNormalization(name="batch_normalization_14")(X_5)
X_5= Activation('relu',name="activation_14")(X_5)
X_5 = Conv2D(64, (3, 3),padding='same',name="conv2d_15")(X_5)
X_5 = BatchNormalization(name="batch_normalization_15")(X_5)
merge_data = merge([X_5, X], mode='sum',name="merge_5")
X = Activation('relu',name="activation_15")(merge_data)
#第五個殘差模組結束
model=Model(inputs=inputs, outputs=X)
model.load_weights('model_halcon_resenet.h5', by_name=True)
#讀取指定影象資料
image_dir='C:/Users/18301/Desktop/blister/new/blister_mixed_11.png'
img = image.load_img(image_dir, target_size=(400, 500))
x = image.img_to_array(img)
x = np.expand_dims(x, axis=0)
x = preprocess_input(x)
#利用第一個模型預測出特徵資料,並對特徵資料進行切片
feature_map=model.predict(x)
T=np.array(feature_map)
f_1=T[:,16:21,0:10,:]
print(f_1.shape)
print(feature_map.shape)
#第一個模型沒有問題
#定義第二個模型
inputs_sec=Input(shape=(1,5,10,64))
X_= Flatten(name="flatten_1")(inputs_sec)
X_ = Dense(256, activation='relu',name="dense_1")(X_)
X_ = Dropout(0.5,name="dropout_1")(X_)
predictions = Dense(6, activation='softmax',name="dense_2")(X_)
model_sec=Model(inputs=inputs_sec, outputs=predictions)
model_sec.load_weights('model_halcon_resenet.h5', by_name=True)
#第二個模型定義結束
model_sec.summary()
#開始對整幅影象進行切片,並記錄座標位置
pic=cv2.imread(image_dir)
cor_list=[]
name_list=['blank','green_blank','red_blank','yellow','yellow_balnk','yellow_blue']
font = cv2.FONT_HERSHEY_SIMPLEX
for i in range(3):
    for j in range(5):
        if(i==2):
            cut_feature = T[:, 4 * j:4 * j + 5, 17:27, :]
            data = np.expand_dims(cut_feature, axis=0)
            result = model_sec.predict(data)
            print(result)
            result_data=result[0].tolist()
            #如果置信度過低,則捨棄
            # if(max(result_data)<=0.7):
            #     continue
            index_num = result_data.index(max(result_data))
            name=name_list[index_num]
            cor_list = [i * 160 + 6, j * 80]  # 每個切片資料,對映到原圖上,檢測框對應的左上角座標
            x=cor_list[0]
            y=cor_list[1]
            cv2.rectangle(pic, (160 * i + 6, 80 * j), ((i + 1) * 160 + 6, 80 * (j+ 1)), (0, 255, 0), 2)
            cv2.putText(pic, name, (x + 40, y + 40), font, 0.5, (0, 0, 255), 1)
        else:
            cut_feature = T[:, 4 * j:4 * j + 5, 9 * i:9 * i + 10, :]
            data = np.expand_dims(cut_feature, axis=0)
            result = model_sec.predict(data)
            print(result)
            result_data = result[0].tolist()
            #如果置信度過低,則捨棄
            # if (max(result_data) <= 0.7):
            #     continue
            index_num = result_data.index(max(result_data))
            name = name_list[index_num]
            cor_list = [i * 160 + 6, j * 80]  # 每個切片資料,對映到原圖上,檢測框對應的左上角座標
            x = cor_list[0]
            y = cor_list[1]
            cv2.rectangle(pic, (160 * i + 6, 80 * j), ((i + 1) * 160 + 6, 80 * (j + 1)), (0, 255, 0), 2)
            cv2.putText(pic, name, (x + 40, y + 40), font, 0.5, (0, 0, 255), 1)

cv2.imshow('pic',pic)
cv2.waitKey(0)
cv2.destroyAllWindows()
# data= np.expand_dims(f_1, axis=0)
# result=model_sec.predict(data)
# print(result)
#第二個模型可以完全預測,沒有問題