1. 程式人生 > >tensorflow遷移學習:VGG16花朵分類

tensorflow遷移學習:VGG16花朵分類

其實上面的文章已經寫的很詳細了,但是還有一點小小的問題,通過參考其他的程式碼,將其進行補充,這樣完整的程式就可以運行了。

下面我就主要說說進行補充的地方:

補充1:

       如果按照原文章的步驟一步步進行,在進行到這步的時候,會有報錯,會提示找不到labels和codes檔案,原因是我們在執行上一步的程式碼沒有對其進行儲存。

# read codes and labels from file
import csv 

with open('labels') as f:
    reader = csv.reader(f, delimiter='\n')
    labels = np.array([each for each in reader if len(each) > 0]).squeeze()
with open('codes') as f:
    codes = np.fromfile(f, dtype=np.float32)
    codes = codes.reshape((len(labels), -1))

我們只需要在這段上面的程式碼處,對輸出的labels和codes檔案進行儲存處理即可。如下所示:

#將影象批量batches通過VGG模型,將輸出作為新的輸入:
# Set the batch size higher if you can fit in in your GPU memory
batch_size = 10
codes_list = []
labels = []
batch = []
 
codes = None
 
with tf.Session() as sess: 
    vgg = vgg16.Vgg16()
    input_ = tf.placeholder(tf.float32, [None, 224, 224, 3])
    with tf.name_scope("content_vgg"):
        vgg.build(input_)
    
 
    for each in classes:
        print("Starting {} images".format(each))
        class_path = data_dir + each
        files = os.listdir(class_path)
        for ii, file in enumerate(files, 1):
            # Add images to the current batch
            # utils.load_image crops the input images for us, from the center
            img = utils.load_image(os.path.join(class_path, file))
            batch.append(img.reshape((1, 224, 224, 3)))
            labels.append(each)
            
            # Running the batch through the network to get the codes
            if ii % batch_size == 0 or ii == len(files):
                images = np.concatenate(batch)
 
                feed_dict = {input_: images}
                codes_batch = sess.run(vgg.relu6, feed_dict=feed_dict)
                
                # Here I'm building an array of the codes
                if codes is None:
                    codes = codes_batch
                else:
                    codes = np.concatenate((codes, codes_batch))
                
                # Reset to start building the next batch
                batch = []
                print('{} images processed'.format(ii))
                
#這裡就是新增的儲存的程式碼
                #這樣我們就可以得到一個 codes 陣列,和一個 labels 陣列,分別儲存了所有花朵的特徵值和類別。
                with open('codes', 'w') as f:
                    codes.tofile(f)
    
                import csv
                with open('labels', 'w') as f:
                    writer = csv.writer(f, delimiter='\n')
                    writer.writerow(labels)

接下來就會在檔案所在的目錄內自動生成labels和codes檔案,然後繼續原文章的步驟進行就可以實現最後的結果。

補充2:

文章結尾處,作者是以柱狀圖的形式來展示預測結果的,現在我們以只顯示概率和品種的形式來展示結果。jin

#將陣列轉換為list
predic_list = prediction.tolist()
print(type(predic_list))
index = predic_list.index(max(predic_list))
print(lb.classes_[index]+":"+str(max(predic_list)))

這樣就可以滿足部分同學對概率顯示預測結果的需求啦。

今天第一次寫,以後會繼續堅持,加油。