1. 程式人生 > >tensorflow實現遷徙學習

tensorflow實現遷徙學習

import glob
import os.path
import random
import tensorflow as tf
from tensorflow.python.platform import gfile

#Inception-V3模型瓶頸層的節點個數
BOTTLENECK_TENSOR_SIZE=2048;
#Inception模型中代表瓶頸層結果的張量的名稱,在谷歌提供的模型中,這個名稱就是‘pool_3/_reshape:0’,可以通過tensor.name來獲取張量的名稱
BOTTLENECK_TENSOR_NAME='pool_3/_reshape:0';

#影象輸入張量所對應的名稱
JPEG_DATA_TENSOR_NAME='DecodeJpeg/contents:0'

MODEL_DIR='inception/';
MODEL_FILE='tensorflow_inception_graph.pb';
#因為一個訓練資料會被使用多次,所以可以將原始資料通過inception計算得到的特徵向量儲存在檔案中,免去重複的計算
CACHE_DIR='tmp/bottlenect';

INPUT_DATA='flower_photos';
VALIDATION_PERCENTAGE=10;
TEST_PERCETAGE=10;
LEARNING_RATE=10;
STEPS=4000;
BATCH=100;

def create_image_lists(testing_percetage,validation_percetage):
    result={};
    sub_dirs=[x[0] for x in os.walk(INPUT_DATA)];
    is_root_dir=True;
    for sub_dir in sub_dirs:
        if is_root_dir=False:
            continue;
            
        #獲取當前目錄下所有的有效圖片
        extensions=['jpg','jpeg','JPG','JPEG'];
        flie_list=[];
        dir_name=os.path.basename(sub_dir);
        for extension in extensions:
            file_glob=os.path.join(INPUT_DATA,dir_name,,'*.'+extension);
            file_list.extend(glob.glob(file_glob));
        
        if not file_list:
            continue;
        label_name=dir_name.lower();
        training_images=[];
        testing_images=[];
        validation_images=[];
        for file_name in flie_list:
            base_name=os.path.basename(file_name);
            #隨機將資料分到訓練資料集,測試資料集和驗證資料集
            chance=np.random.randint(100);
            if chance<validation_percetage
            validation_images.append(base_name);
            elif chance<(testing_percetage+validation_percetage):
                testing_images.append(base_name);
            else:
                training_images.append(base_name);
                
        #將當前類別的資料放入result字典中
        result[label_name]={'dir':dir_name,'training':training_images,'testing':testing_images,'validation':validation_images};
    return result;


def get_image_path(image_lists,image_dir,label_name,index,category):
    label_lists=image_lists[label_name];
    category_list=label_lists[category];
    mod_index=index%len(category_list);
    #獲取圖片的檔名
    base_name=categroy_list[mod_index];
    sub_dir=label_lists['dir'];
    full_path=os.path.join(image_dir,sub_dir,base_name);
    return full_path;

#通過類別名稱,所屬資料集和圖片編號獲取模型處理之後的特徵向量檔案地址
def get_bottlenek_path(image_lists,label_name,index,category):
    return get_image_path(image_lists,CACHE_DIR,label_name,index,category)+'.txt';

#通過載入訓練好的模型處理一張圖片,得到這個圖片的特徵向量
def run_bottlenect_on_image(sess,image_data,image_data_tensor,bottleneck_tensor):
    botttleneck_values=sess.run(bottleneck_tensor,(image_data_tensor:image_data));
    bottleneck_values=np.squeeze(bottleneck_values);
    return botttleneck_values;

#獲取一張圖片經過模型處理之後的特徵向量,這個函式會先試圖尋找已經計算下來的特徵向量,如果找不到則先計算這個向量,然後儲存檔案
def get_or_create_bottleneck(sess,image_lists,label_name,index,category,jpeg_data_tensor,bottleneck_tensor):
    label_lists=image_lists[label_name];
    sub_dir=label_lists['dir'];
    sub_dir_path=os.path.join(CACHE_DIR,sub_dir);
    if not os.path.exists(sub_dir_path):
        os.makedirs(sub_dir_path);
        bottleneck_path=get_bottlenek_path(image_lists,label_lists,index,category);
        
    if not os.path.exists(bottleneck_path):
        image_path=get_image_path(image_lists,INPUT_DATA,label_name,index,category);
        image_data=gfile.FastGFile(image_path,'rb').read();
        bottleneck_values=run_bottlenect_on_image(sess,image_data,jpeg_data_tensor,bottleneck_tensor);
        #將計算得到的特徵向量存入檔案
        bottleneck_string=','.join(str(x) for x in bottleneck_values);
        with open(bottleneck_path,'w') as bottleneck_file:
            bottleneck_file.write(bottleneck_string);
    else:
        with open(bottleneck_path,'r') as bottleneck_file:
            bottleneck_string=bottleneck_file.read();
        bottleneck_values=[float(x) for x in bottleneck_string.split(',')];
    return bottleneck_values;

#隨機獲取一個batch的圖片作為訓練集
def get_random_cached_bottlenecks(sess,n_classes,image_lists,how_many,category,jpeg_data_tensor,bottleneck_tensor):
    bottlenecks=[];
    ground_truths=[];
    
    for _ in range(how_many):
        label_index=random,randrange(n_classes);
        label_name=list(image_lists.keys())[label_index];
        image_index=random.randrange(65536);
        bottleneck=get_or_create_bottleneck(sess,image_lists,label_name,image_index,category,jpeg_data_tensor,bottleneck_tensor);
        ground_truth=np.zeros(n_classes,dtype=np.float32);
        ground_truth[label_index]=1.0;
        bottlenecks.append(bottleneck);
        ground_truths.append(ground_truth);
    return bottlenecks,ground_truths;

#獲取全部的測試資料
def get_test_bottlenecks(sess,image_lists,n_classes,jpeg_data_tensor,bottleneck_tensor):
    bottlenecks=[];
    ground_truths=[];
    label_name_list=list(image_lists.keys());
    for label_index,label_name in enumerate(label_name_list):
        category='testing';
        for index,unused_base_name in enumerate(image_lists[label_name][categroy]):
            bottleneck=get_or_create_bottleneck(sess,image_lists,label_name,index,category,jpeg_data_tensor,bottleneck_tensor);
            ground_truth=np.zeros(n_classes,dtype=np.float32);
            ground_truth[label_index]=1.0;
            bottlenecks.append(bottleneck);
            ground_truths.append(ground_truth);
    
    return bottlenecks,ground_truths;

def main(_):
    #讀取所有圖片
    image_lists=create_image_lists(TEST_PERCETAGE,VALIDATION_PERCENTAGE)