tensorflow實現遷徙學習
阿新 • • 發佈:2018-12-13
import glob import os.path import random import tensorflow as tf from tensorflow.python.platform import gfile #Inception-V3模型瓶頸層的節點個數 BOTTLENECK_TENSOR_SIZE=2048; #Inception模型中代表瓶頸層結果的張量的名稱,在谷歌提供的模型中,這個名稱就是‘pool_3/_reshape:0’,可以通過tensor.name來獲取張量的名稱 BOTTLENECK_TENSOR_NAME='pool_3/_reshape:0'; #影象輸入張量所對應的名稱 JPEG_DATA_TENSOR_NAME='DecodeJpeg/contents:0' MODEL_DIR='inception/'; MODEL_FILE='tensorflow_inception_graph.pb'; #因為一個訓練資料會被使用多次,所以可以將原始資料通過inception計算得到的特徵向量儲存在檔案中,免去重複的計算 CACHE_DIR='tmp/bottlenect'; INPUT_DATA='flower_photos'; VALIDATION_PERCENTAGE=10; TEST_PERCETAGE=10; LEARNING_RATE=10; STEPS=4000; BATCH=100; def create_image_lists(testing_percetage,validation_percetage): result={}; sub_dirs=[x[0] for x in os.walk(INPUT_DATA)]; is_root_dir=True; for sub_dir in sub_dirs: if is_root_dir=False: continue; #獲取當前目錄下所有的有效圖片 extensions=['jpg','jpeg','JPG','JPEG']; flie_list=[]; dir_name=os.path.basename(sub_dir); for extension in extensions: file_glob=os.path.join(INPUT_DATA,dir_name,,'*.'+extension); file_list.extend(glob.glob(file_glob)); if not file_list: continue; label_name=dir_name.lower(); training_images=[]; testing_images=[]; validation_images=[]; for file_name in flie_list: base_name=os.path.basename(file_name); #隨機將資料分到訓練資料集,測試資料集和驗證資料集 chance=np.random.randint(100); if chance<validation_percetage validation_images.append(base_name); elif chance<(testing_percetage+validation_percetage): testing_images.append(base_name); else: training_images.append(base_name); #將當前類別的資料放入result字典中 result[label_name]={'dir':dir_name,'training':training_images,'testing':testing_images,'validation':validation_images}; return result; def get_image_path(image_lists,image_dir,label_name,index,category): label_lists=image_lists[label_name]; category_list=label_lists[category]; mod_index=index%len(category_list); #獲取圖片的檔名 base_name=categroy_list[mod_index]; sub_dir=label_lists['dir']; full_path=os.path.join(image_dir,sub_dir,base_name); return full_path; #通過類別名稱,所屬資料集和圖片編號獲取模型處理之後的特徵向量檔案地址 def get_bottlenek_path(image_lists,label_name,index,category): return get_image_path(image_lists,CACHE_DIR,label_name,index,category)+'.txt'; #通過載入訓練好的模型處理一張圖片,得到這個圖片的特徵向量 def run_bottlenect_on_image(sess,image_data,image_data_tensor,bottleneck_tensor): botttleneck_values=sess.run(bottleneck_tensor,(image_data_tensor:image_data)); bottleneck_values=np.squeeze(bottleneck_values); return botttleneck_values; #獲取一張圖片經過模型處理之後的特徵向量,這個函式會先試圖尋找已經計算下來的特徵向量,如果找不到則先計算這個向量,然後儲存檔案 def get_or_create_bottleneck(sess,image_lists,label_name,index,category,jpeg_data_tensor,bottleneck_tensor): label_lists=image_lists[label_name]; sub_dir=label_lists['dir']; sub_dir_path=os.path.join(CACHE_DIR,sub_dir); if not os.path.exists(sub_dir_path): os.makedirs(sub_dir_path); bottleneck_path=get_bottlenek_path(image_lists,label_lists,index,category); if not os.path.exists(bottleneck_path): image_path=get_image_path(image_lists,INPUT_DATA,label_name,index,category); image_data=gfile.FastGFile(image_path,'rb').read(); bottleneck_values=run_bottlenect_on_image(sess,image_data,jpeg_data_tensor,bottleneck_tensor); #將計算得到的特徵向量存入檔案 bottleneck_string=','.join(str(x) for x in bottleneck_values); with open(bottleneck_path,'w') as bottleneck_file: bottleneck_file.write(bottleneck_string); else: with open(bottleneck_path,'r') as bottleneck_file: bottleneck_string=bottleneck_file.read(); bottleneck_values=[float(x) for x in bottleneck_string.split(',')]; return bottleneck_values; #隨機獲取一個batch的圖片作為訓練集 def get_random_cached_bottlenecks(sess,n_classes,image_lists,how_many,category,jpeg_data_tensor,bottleneck_tensor): bottlenecks=[]; ground_truths=[]; for _ in range(how_many): label_index=random,randrange(n_classes); label_name=list(image_lists.keys())[label_index]; image_index=random.randrange(65536); bottleneck=get_or_create_bottleneck(sess,image_lists,label_name,image_index,category,jpeg_data_tensor,bottleneck_tensor); ground_truth=np.zeros(n_classes,dtype=np.float32); ground_truth[label_index]=1.0; bottlenecks.append(bottleneck); ground_truths.append(ground_truth); return bottlenecks,ground_truths; #獲取全部的測試資料 def get_test_bottlenecks(sess,image_lists,n_classes,jpeg_data_tensor,bottleneck_tensor): bottlenecks=[]; ground_truths=[]; label_name_list=list(image_lists.keys()); for label_index,label_name in enumerate(label_name_list): category='testing'; for index,unused_base_name in enumerate(image_lists[label_name][categroy]): bottleneck=get_or_create_bottleneck(sess,image_lists,label_name,index,category,jpeg_data_tensor,bottleneck_tensor); ground_truth=np.zeros(n_classes,dtype=np.float32); ground_truth[label_index]=1.0; bottlenecks.append(bottleneck); ground_truths.append(ground_truth); return bottlenecks,ground_truths; def main(_): #讀取所有圖片 image_lists=create_image_lists(TEST_PERCETAGE,VALIDATION_PERCENTAGE)