1. 程式人生 > >『TensorFlow』遷移學習_他山之石,可以攻玉


不同 target lis 理解 pla 之前 photo 多說 esc







1 curl -O http://download.tensorflow.org/example_images/flower_photos.tgz


1 wget https://storage.googleapis.com/download.tensorflow.org/models/inception_dec_2015.zip


  1 import glob
  2 import os.path
  3 import random
  4 import numpy as np
  5 import tensorflow as tf
  6 from tensorflow.python.platform import gfile
  8 ‘‘‘模型及樣本路徑設置‘‘‘
 10 BOTTLENECK_TENSOR_SIZE = 2048                          # 瓶頸層節點個數
 11 BOTTLENECK_TENSOR_NAME = pool_3/_reshape:0
# 瓶頸層輸出張量名稱 12 JPEG_DATA_TENSOR_NAME = DecodeJpeg/contents:0 # 輸入層張量名稱 13 14 MODEL_DIR = ./inception_dec_2015 # 模型存放文件夾 15 MODEL_FILE = tensorflow_inception_graph.pb # 模型名 16 17 CACHE_DIR = ./bottleneck # 瓶頸輸出中轉文件夾
18 INPUT_DATA = ./flower_photos # 數據文件夾 19 20 VALIDATION_PERCENTAGE = 10 # 驗證用數據百分比 21 TEST_PERCENTAGE = 10 # 測試用數據百分比 22 23 ‘‘‘新添加神經網絡部參數設置‘‘‘ 24 25 LEARNING_RATE = 0.01 26 STEP = 4000 27 BATCH = 100 28 29 def creat_image_lists(validation_percentage,testing_percentage): 30 ‘‘‘ 31 將圖片(無路徑文件名)信息保存在字典中 32 :param validation_percentage: 驗證數據百分比 33 :param testing_percentage: 測試數據百分比 34 :return: 字典{標簽:{文件夾:str,訓練:[],驗證:[],測試:[]},...} 35 ‘‘‘ 36 result = {} 37 sub_dirs = [x[0] for x in os.walk(INPUT_DATA)] 38 # 由於os.walk()列表第一個是‘./‘,所以排除 39 is_root_dir = True #<----- 40 # 遍歷各個label文件夾 41 for sub_dir in sub_dirs: 42 if is_root_dir: #<----- 43 is_root_dir = False 44 continue 45 46 extensions = [jpg, jpeg, JPG, JPEG] 47 file_list = [] 48 dir_name = os.path.basename(sub_dir) 49 # 遍歷各個可能的文件尾綴 50 for extension in extensions: 51 # file_glob = os.path.join(INPUT_DATA,dir_name,‘*.‘+extension) 52 file_glob = os.path.join(sub_dir, *. + extension) 53 file_list.extend(glob.glob(file_glob)) # 匹配並收集路徑&文件名 54 # print(file_glob,‘\n‘,glob.glob(file_glob)) 55 if not file_list: continue 56 57 label_name = dir_name.lower() # 生成label,實際就是小寫文件夾名 58 59 # 初始化各個路徑&文件收集list 60 training_images = [] 61 testing_images = [] 62 validation_images = [] 63 64 # 去路徑,只保留文件名 65 for file_name in file_list: 66 base_name = os.path.basename(file_name) 67 68 # 隨機劃分數據給驗證和測試 69 chance = np.random.randint(100) 70 if chance < validation_percentage: 71 validation_images.append(base_name) 72 elif chance < (validation_percentage + testing_percentage): 73 testing_images.append(base_name) 74 else: 75 training_images.append(base_name) 76 # 本標簽字典項生成 77 result[label_name] = { 78 dir : dir_name, 79 training : training_images, 80 testing : testing_images, 81 validation : validation_images 82 } 83 return result 84 85 def get_random_cached_bottlenecks(sess,n_class,image_lists,batch,category,jpeg_data_tensor,bottleneck_tensor): 86 ‘‘‘ 87 函數隨機獲取一個batch的圖片作為訓練數據 88 :param sess: 89 :param n_class: 90 :param image_lists: 91 :param how_many: 92 :param category: training or validation 93 :param jpeg_data_tensor: 94 :param bottleneck_tensor: 95 :return: 瓶頸張量輸出 & label 96 ‘‘‘ 97 bottlenecks = [] 98 ground_truths = [] 99 for i in range(batch): 100 label_index = random.randrange(n_class) # 標簽索引隨機生成 101 label_name = list(image_lists.keys())[label_index] # 標簽名獲取 102 image_index = random.randrange(65536) # 標簽內圖片索引隨機種子 103 # 瓶頸層張量 104 bottleneck = get_or_create_bottleneck( # 獲取對應標簽隨機圖片瓶頸張量 105 sess,image_lists,label_name,image_index,category, 106 jpeg_data_tensor,bottleneck_tensor) 107 ground_truth = np.zeros(n_class,dtype=np.float32) 108 ground_truth[label_index] = 1.0 # 標準結果[0,0,1,0...] 109 # 收集瓶頸張量和label 110 bottlenecks.append(bottleneck) 111 ground_truths.append(ground_truth) 112 return bottlenecks,ground_truths 113 114 def get_or_create_bottleneck( 115 sess,image_lists,label_name,index,category,jpeg_data_tensor,bottleneck_tensor): 116 ‘‘‘ 117 尋找已經計算且保存下來的特征向量,如果找不到則先計算這個特征向量,然後保存到文件 118 :param sess: 119 :param image_lists: 全圖像字典 120 :param label_name: 當前標簽 121 :param index: 圖片索引 122 :param category: training or validation 123 :param jpeg_data_tensor: 124 :param bottleneck_tensor: 125 :return: 126 ‘‘‘ 127 label_lists = image_lists[label_name] # 本標簽字典獲取 標簽:{文件夾:str,訓練:[],驗證:[],測試:[]} 128 sub_dir = label_lists[dir] # 獲取標簽值 129 sub_dir_path = os.path.join(CACHE_DIR,sub_dir) # 保存文件路徑 130 if not os.path.exists(sub_dir_path):os.mkdir(sub_dir_path) 131 bottleneck_path = get_bottleneck_path(image_lists,label_name,index,category) 132 if not os.path.exists(bottleneck_path): 133 image_path = get_image_path(image_lists, INPUT_DATA, label_name, index, category) 134 #image_data = gfile.FastGFile(image_path,‘rb‘).read() 135 image_data = open(image_path,rb).read() 136 # print(gfile.FastGFile(image_path,‘rb‘).read()==open(image_path,‘rb‘).read()) 137 # 生成向前傳播後的瓶頸張量 138 bottleneck_values = run_bottleneck_on_images(sess,image_data,jpeg_data_tensor,bottleneck_tensor) 139 # list2string以便於寫入文件 140 bottleneck_string = ,.join(str(x) for x in bottleneck_values) 141 # print(bottleneck_values) 142 # print(bottleneck_string) 143 with open(bottleneck_path, w) as bottleneck_file: 144 bottleneck_file.write(bottleneck_string) 145 else: 146 with open(bottleneck_path, r) as bottleneck_file: 147 bottleneck_string = bottleneck_file.read() 148 bottleneck_values = [float(x) for x in bottleneck_string.split(,)] 149 # 返回的是list註意 150 return bottleneck_values 151 152 def run_bottleneck_on_images(sess,image_data,jpeg_data_tensor,bottleneck_tensor): 153 ‘‘‘ 154 使用加載的訓練好的Inception-v3模型處理一張圖片,得到這個圖片的特征向量。 155 :param sess: 會話句柄 156 :param image_data: 圖片文件句柄 157 :param jpeg_data_tensor: 輸入張量句柄 158 :param bottleneck_tensor: 瓶頸張量句柄 159 :return: 瓶頸張量值 160 ‘‘‘ 161 # print(‘input:‘,len(image_data)) 162 bottleneck_values = sess.run(bottleneck_tensor,feed_dict={jpeg_data_tensor:image_data}) 163 bottleneck_values = np.squeeze(bottleneck_values) 164 # print(‘bottle:‘,len(bottleneck_values)) 165 return bottleneck_values 166 167 def get_bottleneck_path(image_lists, label_name, index, category): 168 ‘‘‘ 169 獲取一張圖片的中轉(featuremap)地址(添加txt) 170 :param image_lists: 全圖片字典 171 :param label_name: 標簽名 172 :param index: 隨機數索引 173 :param category: training or validation 174 :return: 中轉(featuremap)地址(添加txt) 175 ‘‘‘ 176 return get_image_path(image_lists, CACHE_DIR, label_name, index, category) + .txt 177 178 def get_image_path(image_lists, image_dir, label_name, index, category): 179 ‘‘‘ 180 通過類別名稱、所屬數據集和圖片編號獲取一張圖片的中轉(featuremap)地址(無txt) 181 :param image_lists: 全圖片字典 182 :param image_dir: 外層文件夾(內部是標簽文件夾) 183 :param label_name: 標簽名 184 :param index: 隨機數索引 185 :param category: training or validation 186 :return: 圖片中間變量地址 187 ‘‘‘ 188 label_lists = image_lists[label_name] 189 category_list = label_lists[category] # 獲取目標category圖片列表 190 mod_index = index % len(category_list) # 隨機獲取一張圖片的索引 191 base_name = category_list[mod_index] # 通過索引獲取圖片名 192 return os.path.join(image_dir,label_lists[dir],base_name) 193 194 def get_test_bottlenecks(sess,image_lists,n_class,jpeg_data_tensor,bottleneck_tensor): 195 ‘‘‘ 196 獲取全部的測試數據,計算輸出 197 :param sess: 198 :param image_lists: 199 :param n_class: 200 :param jpeg_data_tensor: 201 :param bottleneck_tensor: 202 :return: 瓶頸輸出 & label 203 ‘‘‘ 204 bottlenecks = [] 205 ground_truths = [] 206 label_name_list = list(image_lists.keys()) 207 for label_index,label_name in enumerate(image_lists[label_name_list]): 208 category = testing 209 for index, unused_base_name in enumerate(image_lists[label_name][category]): # 索引, {文件名} 210 bottleneck = get_or_create_bottleneck( 211 sess, image_lists, label_name, index, 212 category, jpeg_data_tensor, bottleneck_tensor) 213 ground_truth = np.zeros(n_class, dtype=np.float32) 214 ground_truth[label_index] = 1.0 215 bottlenecks.append(bottleneck) 216 ground_truths.append(ground_truth) 217 return bottlenecks, ground_truths 218 219 def main(): 220 # 生成文件字典 221 images_lists = creat_image_lists(VALIDATION_PERCENTAGE,TEST_PERCENTAGE) 222 # 記錄label種類(字典項數) 223 n_class = len(images_lists.keys()) 224 225 # 加載模型 226 # with gfile.FastGFile(os.path.join(MODEL_DIR,MODEL_FILE),‘rb‘) as f: # 閱讀器上下文 227 with open(os.path.join(MODEL_DIR, MODEL_FILE), rb) as f: # 閱讀器上下文 228 graph_def = tf.GraphDef() # 生成圖 229 graph_def.ParseFromString(f.read()) # 圖加載模型 230 # 加載圖上節點張量(按照句柄理解) 231 bottleneck_tensor,jpeg_data_tensor = tf.import_graph_def( # 從圖上讀取張量,同時導入默認圖 232 graph_def, 233 return_elements=[BOTTLENECK_TENSOR_NAME,JPEG_DATA_TENSOR_NAME]) 234 235 ‘‘‘新的神經網絡‘‘‘ 236 # 輸入層,由原模型輸出層feed 237 bottleneck_input = tf.placeholder(tf.float32,[None,BOTTLENECK_TENSOR_SIZE],name=BottleneckInputPlaceholder) 238 ground_truth_input = tf.placeholder(tf.float32,[None,n_class] ,name=GroundTruthInput) 239 # 全連接層 240 with tf.name_scope(final_train_ops): 241 Weights = tf.Variable(tf.truncated_normal([BOTTLENECK_TENSOR_SIZE,n_class],stddev=0.001)) 242 biases = tf.Variable(tf.zeros([n_class])) 243 logits = tf.matmul(bottleneck_input,Weights) + biases 244 final_tensor = tf.nn.softmax(logits) 245 # 交叉熵損失函數 246 cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=logits,labels=ground_truth_input)) 247 # 優化算法選擇 248 train_step = tf.train.GradientDescentOptimizer(LEARNING_RATE).minimize(cross_entropy) 249 250 # 正確率 251 with tf.name_scope(evaluation): 252 correct_prediction = tf.equal(tf.argmax(final_tensor,1),tf.argmax(ground_truth_input,1)) 253 evaluation_step = tf.reduce_mean(tf.cast(correct_prediction,tf.float32)) 254 255 with tf.Session() as sess: 256 init = tf.global_variables_initializer() 257 sess.run(init) 258 for i in range(STEP): 259 # 隨機batch獲取瓶頸輸出 & label 260 train_bottlenecks,train_ground_truth = get_random_cached_bottlenecks( 261 sess,n_class,images_lists,BATCH,training,jpeg_data_tensor,bottleneck_tensor) 262 sess.run(train_step,feed_dict={bottleneck_input:train_bottlenecks,ground_truth_input:train_ground_truth}) 263 264 # 每叠代100次運行一次驗證程序 265 if i % 100 == 0 or i + 1 == STEP: 266 validation_bottlenecks, validation_ground_truth = get_random_cached_bottlenecks( 267 sess, n_class, images_lists, BATCH, validation, jpeg_data_tensor, bottleneck_tensor) 268 validation_accuracy = sess.run(evaluation_step, feed_dict={ 269 bottleneck_input: validation_bottlenecks, ground_truth_input: validation_ground_truth}) 270 print(Step %d: Validation accuracy on random sampled %d examples = %.1f%% % 271 (i, BATCH, validation_accuracy * 100)) 272 273 test_bottlenecks,test_ground_truth = get_test_bottlenecks( 274 sess,images_lists,n_class,jpeg_data_tensor,bottleneck_tensor) 275 test_accuracy = sess.run(evaluation_step,feed_dict={ 276 bottleneck_input:test_bottlenecks,ground_truth_input:test_ground_truth}) 277 print(Final test accuracy = %.1f%% % (test_accuracy * 100)) 278 279 if __name__ == __main__: 280 main()





  • 如何使用把自己的圖片數據導入框架中訓練測試
  • 如何加載模型,import模型中的張量(在源代碼以及[置頂]『TensorFlow』常用函數實踐筆記給出了介紹)



1 img = np.asarray(Image.open(image_path))
2 print(np.prod(img.shape))
3 print(img.shape)


 1 def run_bottleneck_on_images(sess,image_data,jpeg_data_tensor,bottleneck_tensor):
 2     ‘‘‘
 3     使用加載的訓練好的Inception-v3模型處理一張圖片,得到這個圖片的特征向量。
 4     :param sess:              會話句柄
 5     :param image_data:        圖片文件句柄
 6     :param jpeg_data_tensor:  輸入張量句柄
 7     :param bottleneck_tensor: 瓶頸張量句柄
 8     :return:                  瓶頸張量值
 9     ‘‘‘
10     bottleneck_values = sess.run(bottleneck_tensor,feed_dict={jpeg_data_tensor:image_data})
11     bottleneck_values = np.squeeze(bottleneck_values)
12     print(input:,len(image_data))
13     print(bottle:,len(bottleneck_values))
14     return bottleneck_values


(240, 320, 3)
input: 45685 <class ‘bytes‘>
172800 (240, 240, 3) input: 30673 <class bytes> 2048


  • Inception-v3網絡的tensorflow的源碼
  • 繼續研究一下其他的輸入數據的方式,交叉印證一下到底怎麽傳入圖片數據,是不是只能二進制輸入




 1 import os
 2 import tensorflow as tf
 4 inception_graph_def_file = os.path.join(./, tensorflow_inception_graph.pb)
 5 with tf.Session() as sess:
 6     with tf.gfile.FastGFile(inception_graph_def_file, rb) as f:
 7         graph_def = tf.GraphDef()
 8         graph_def.ParseFromString(f.read())
 9         tf.import_graph_def(graph_def, name=‘‘)
10     writer = tf.summary.FileWriter(./, sess.graph)
11     writer.close()


 1 import tensorflow as tf
 2 import matplotlib.pyplot as plt
 4 # 使用‘r‘會出錯,無法解碼,只能以2進制形式讀取
 5 # img_raw = tf.gfile.FastGFile(‘./123.png‘,‘rb‘).read()
 6 img_raw = open(./123.png,rb).read()
 8 # 把二進制文件解碼為uint8
 9 img_0 = tf.image.decode_png(img_raw)
10 # img_1 = tf.image.convert_image_dtype(img_0,dtype=tf.uint8)
12 sess = tf.Session()
13 print(sess.run(img_0).shape)
14 plt.imshow(sess.run(img_0))
15 plt.show()



