1. 程式人生 > >利用keras中image.ImageDataGenerator.flow_from_directory()實現從資料夾中提取圖片和進行簡單歸一化處理

利用keras中image.ImageDataGenerator.flow_from_directory()實現從資料夾中提取圖片和進行簡單歸一化處理

keras中有很多封裝好的API可以幫助我們實現對圖片資料的讀取和處理。

比如 : keras.preprocessing.image.ImageDataGenerator.flow_from_directory( ) 這個函式

這個函式的引數包括:

flow_from_directory(self, directory,
                            target_size=(256, 256), color_mode='rgb',
                            classes=None, class_mode='categorical',
                            batch_size=32, shuffle=True, seed=None,
                            save_to_dir=None,
                            save_prefix='',
                            save_format='jpeg',
                            follow_links=False)

使用flow_from_directory最值得注意的是directory這個引數:

directory: path to the target directory. It should contain one subdirectory per class. Any PNG, JPG, BMP, PPM or TIF images inside each of the subdirectories directory tree will be included in the generator. 

這是官方文件的定義,它的目錄格式一定要注意是包含一個子目錄下的所有圖片這種格式,driectoty路徑只要寫到標籤路徑上面的那個路徑即可。

target_size:可是實現對圖片的尺寸轉換,是預處理中比較常用的方法

save_to_dir: 可以設定儲存處理後圖片的路徑。

save_prefix: 可以對處理後圖片設定字首。

舉個例子:

這個路徑結構,注意train路徑


程式碼:

import matplotlib.pyplot as  plt
import glob
from PIL import Image
from keras.preprocessing import image

path = 'train/'
gen_path = 'result/'

def print_result(path):
    name_list = glob.glob(path)
    fig = plt.figure()
    for i in range(9):
        img = Image.open(name_list[i])
        # add_subplot(331) 引數一:子圖總行數,引數二:子圖總列數,引數三:子圖位置
        sub_img = fig.add_subplot(331 + i)
        sub_img.imshow(img)
    plt.show()
    return fig

# 列印圖片列表
name_list = glob.glob(path + '*/*')
print(name_list)
# ['train\\00a366d4b4a9bbb6c8a63126697b7656.jpg', 'train\\00f34ac0a16ef43e6fd1de49a26081ce.jpg', 'train\\0a5f744c5077ad8f8d580081ba599ff5.jpg', 'train\\0a70f64352edfef4c82c22015f0e3a20.jpg', 'train\\0a783538d5f3aaf017b435ddf14cc5c2.jpg', 'train\\0a896d2b3af617df543787b571e439d8.jpg', 'train\\0abdda879bb143b19e3c480279541915.jpg', 'train\\0ac12f840df2b15d46622e244501a88c.jpg', 'train\\0b6c5bc46b7a0e29cddfa45b0b786d09.jpg']

# 列印圖片
fig = print_result(path + '*/*')

# 儲存圖片
fig.savefig(gen_path + '/original_0.png', dpi=200, papertype='a5')

# 原圖
datagen = image.ImageDataGenerator()
gen_data = datagen.flow_from_directory(path, batch_size=1, shuffle=False, save_to_dir=gen_path ,
                                       save_prefix='dog_gen', target_size=(224, 224))
for i in range(9):
    gen_data.next()

fig = print_result(gen_path + '/*')
fig.savefig(gen_path + '/original_1.png', dpi=200, papertype='a5')
這個例子還是比較好理解的。