1. 程式人生 > >搭建簡單圖片分類的卷積神經網路(三)-- 模型的測試和運用

搭建簡單圖片分類的卷積神經網路(三)-- 模型的測試和運用

兩個功能都在同一個檔案中

一、新建Disimage.py檔案

import tensorflow as tf
from PIL import Image
import os
import numpy as np
import matplotlib.pyplot as plt
from GetCnnData import get_files
import CNN

classes = []
n_classes = 0

#獲取一張圖片
def get_one_image(train):
    n = len(train)
    ind = np.random.randint(0, n)
    img_dir = train[ind]  # 隨機選擇測試的圖片

    # img_data = Image.open(img_dir)
    imag = Image.open(img_dir)
    imag = imag.resize([64, 64])  # 由於圖片在預處理階段以及resize,因此該命令可略
    image = np.array(imag)
    return image

def evaluate_one_image(image_array,N_CLASSES):
    with tf.Graph().as_default():
        BATCH_SIZE = 1

        image = tf.cast(image_array, tf.float32)
        image = tf.image.per_image_standardization(image)
        image = tf.reshape(image, [1, 64, 64, 3])

        logit = CNN.inference(image, BATCH_SIZE, N_CLASSES)

        logit = tf.nn.softmax(logit)

        x = tf.placeholder(tf.float32, shape=[64, 64, 3])
        logs_train_dir = r'E:\PycharmPython\NewCnn\logs'

        saver = tf.train.Saver()

        with tf.Session() as sess:
            print('Reading checkpoints...')
            ckpt = tf.train.get_checkpoint_state(logs_train_dir)
            if ckpt and ckpt.model_checkpoint_path:
                global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
                saver.restore(sess, ckpt.model_checkpoint_path)
                print('Loading success, global_step is %s' % global_step)
            else:
                print('No checkpoint file found')

            prediction = sess.run(logit,feed_dict={x:image_array})
            max_index = np.argmax(prediction)
            if max_index == 0:
                print('This is a animales with possibility %.6f' % prediction[:, 0])
            elif max_index == 1:
                print('This is a banded with possibility %.6f' % prediction[:, 1])
            elif max_index == 2:
                print('This is a potholed with possibility %.6f' % prediction[:, 2])
            elif max_index == 3:
                print('This is a writeflowers with possibility %.6f' % prediction[:, 3])
            else:
                print('This is a yellowflowers with possibility %.6f' % prediction[:, 4])
    return max_index

if __name__ == '__main__':
    train_dir = r'E:\PycharmPython\NewCnn\train\train_data'  #訓練集路徑

    for str_classes in os.listdir(train_dir):
        classes.append(str_classes)
        n_classes =n_classes + 1

    train, train_label, val, val_label = get_files(train_dir, 0.3)
    img = get_one_image(val)  # 通過改變引數train or val,進而驗證訓練集或測試集
    pre = evaluate_one_image(img,n_classes)

上面是對之前已經處理好圖片劃分好測試集,進行測試的。

二、將程式碼改成

import tensorflow as tf
from PIL import Image
import os
import numpy as np
import matplotlib.pyplot as plt
from GetCnnData import get_files
import CNN


classes = []
n_classes = 0

#對預測後圖片路徑的處理
def prediction_image_path(Classes,dir):
    for index,name in enumerate(Classes):
        prediction_path = dir +'\\' + name   #判斷是否有資料夾
        folder = os.path.exists(prediction_path)
        if not folder :
            os.makedirs(prediction_path)  #建立資料夾
            print(prediction_path,'new file')
        else:
            for str_image in os.listdir(prediction_path):
                prediction_image_path = prediction_path + '\\'+str_image
                os.remove(prediction_image_path)   #清空資料夾
            print('There is this flie')

#獲取一張圖片
def get_one_image(train):
    # n = len(train)
    # ind = np.random.randint(0, n)
    # img_dir = train[ind]  # 隨機選擇測試的圖片

    img_data = Image.open(train)
    imag = Image.open(train).convert('RGB')
    imag = imag.resize([64, 64])  # 由於圖片在預處理階段以及resize,因此該命令可略
    image = np.array(imag)
    return img_data,image


def evaluate_one_image(image_array,N_CLASSES):
    with tf.Graph().as_default():
        BATCH_SIZE = 1

        image = tf.cast(image_array, tf.float32)
        image = tf.image.per_image_standardization(image)
        image = tf.reshape(image, [1, 64, 64, 3])

        logit = CNN.inference(image, BATCH_SIZE, N_CLASSES)

        logit = tf.nn.softmax(logit)

        x = tf.placeholder(tf.float32, shape=[64, 64, 3])
        logs_train_dir = r'E:\PycharmPython\NewCnn\logs'

        saver = tf.train.Saver()

        with tf.Session() as sess:
            print('Reading checkpoints...')
            ckpt = tf.train.get_checkpoint_state(logs_train_dir)
            if ckpt and ckpt.model_checkpoint_path:
                global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
                saver.restore(sess, ckpt.model_checkpoint_path)
                print('Loading success, global_step is %s' % global_step)
            else:
                print('No checkpoint file found')

            prediction = sess.run(logit,feed_dict={x:image_array})
            max_index = np.argmax(prediction)
            # if max_index == 0:
            #     print('This is a animales with possibility %.6f' % prediction[:, 0])
            # elif max_index == 1:
            #     print('This is a banded with possibility %.6f' % prediction[:, 1])
            # elif max_index == 2:
            #     print('This is a potholed with possibility %.6f' % prediction[:, 2])
            # elif max_index == 3:
            #     print('This is a writeflowers with possibility %.6f' % prediction[:, 3])
            # else:
            #     print('This is a yellowflowers with possibility %.6f' % prediction[:, 4])
    return max_index
        # print(max_index)

if __name__ == '__main__':
    train_dir = r'E:\PycharmPython\NewCnn\train\train_data'  #訓練集路徑
    image_dir = r'E:\PycharmPython\NewCnn\image'   #待分類圖片路徑
    prediction_dir = r'E:\PycharmPython\NewCnn\prediction'  #分類結果儲存路徑
    for str_classes in os.listdir(train_dir):
        classes.append(str_classes)
        n_classes =n_classes + 1

    # #建立分類後圖片的儲存路徑
    # train, train_label, val, val_label = get_files(train_dir, 0.3)
    # img = get_one_image(val)  # 通過改變引數train or val,進而驗證訓練集或測試集
    # pre = evaluate_one_image(img,n_classes)
    prediction_image_path(classes,prediction_dir)
    #掃描待分類圖片,分類之後儲存到對應的分類路徑
    for image_data in  os.listdir(image_dir):
        image_data_path = image_dir + '\\'+image_data
        orig_img,img = get_one_image(image_data_path)
        pre = evaluate_one_image(img,n_classes)
        for i in range(n_classes):
            if pre == i:
                print(classes[i])
                orig_img.save(prediction_dir +'\\'+ classes[i] +'\\' +str(i) + image_data+ '.jpg')

上面是對image檔案中圖片進行分類。

連載:https://blog.csdn.net/qq_28821995/article/details/83587032

https://blog.csdn.net/qq_28821995/article/details/83587530