搭建簡單圖片分類的卷積神經網路(三)-- 模型的測試和運用
阿新 • • 發佈:2018-11-19
兩個功能都在同一個檔案中
一、新建Disimage.py檔案
import tensorflow as tf from PIL import Image import os import numpy as np import matplotlib.pyplot as plt from GetCnnData import get_files import CNN classes = [] n_classes = 0 #獲取一張圖片 def get_one_image(train): n = len(train) ind = np.random.randint(0, n) img_dir = train[ind] # 隨機選擇測試的圖片 # img_data = Image.open(img_dir) imag = Image.open(img_dir) imag = imag.resize([64, 64]) # 由於圖片在預處理階段以及resize,因此該命令可略 image = np.array(imag) return image def evaluate_one_image(image_array,N_CLASSES): with tf.Graph().as_default(): BATCH_SIZE = 1 image = tf.cast(image_array, tf.float32) image = tf.image.per_image_standardization(image) image = tf.reshape(image, [1, 64, 64, 3]) logit = CNN.inference(image, BATCH_SIZE, N_CLASSES) logit = tf.nn.softmax(logit) x = tf.placeholder(tf.float32, shape=[64, 64, 3]) logs_train_dir = r'E:\PycharmPython\NewCnn\logs' saver = tf.train.Saver() with tf.Session() as sess: print('Reading checkpoints...') ckpt = tf.train.get_checkpoint_state(logs_train_dir) if ckpt and ckpt.model_checkpoint_path: global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1] saver.restore(sess, ckpt.model_checkpoint_path) print('Loading success, global_step is %s' % global_step) else: print('No checkpoint file found') prediction = sess.run(logit,feed_dict={x:image_array}) max_index = np.argmax(prediction) if max_index == 0: print('This is a animales with possibility %.6f' % prediction[:, 0]) elif max_index == 1: print('This is a banded with possibility %.6f' % prediction[:, 1]) elif max_index == 2: print('This is a potholed with possibility %.6f' % prediction[:, 2]) elif max_index == 3: print('This is a writeflowers with possibility %.6f' % prediction[:, 3]) else: print('This is a yellowflowers with possibility %.6f' % prediction[:, 4]) return max_index if __name__ == '__main__': train_dir = r'E:\PycharmPython\NewCnn\train\train_data' #訓練集路徑 for str_classes in os.listdir(train_dir): classes.append(str_classes) n_classes =n_classes + 1 train, train_label, val, val_label = get_files(train_dir, 0.3) img = get_one_image(val) # 通過改變引數train or val,進而驗證訓練集或測試集 pre = evaluate_one_image(img,n_classes)
上面是對之前已經處理好圖片劃分好測試集,進行測試的。
二、將程式碼改成
import tensorflow as tf from PIL import Image import os import numpy as np import matplotlib.pyplot as plt from GetCnnData import get_files import CNN classes = [] n_classes = 0 #對預測後圖片路徑的處理 def prediction_image_path(Classes,dir): for index,name in enumerate(Classes): prediction_path = dir +'\\' + name #判斷是否有資料夾 folder = os.path.exists(prediction_path) if not folder : os.makedirs(prediction_path) #建立資料夾 print(prediction_path,'new file') else: for str_image in os.listdir(prediction_path): prediction_image_path = prediction_path + '\\'+str_image os.remove(prediction_image_path) #清空資料夾 print('There is this flie') #獲取一張圖片 def get_one_image(train): # n = len(train) # ind = np.random.randint(0, n) # img_dir = train[ind] # 隨機選擇測試的圖片 img_data = Image.open(train) imag = Image.open(train).convert('RGB') imag = imag.resize([64, 64]) # 由於圖片在預處理階段以及resize,因此該命令可略 image = np.array(imag) return img_data,image def evaluate_one_image(image_array,N_CLASSES): with tf.Graph().as_default(): BATCH_SIZE = 1 image = tf.cast(image_array, tf.float32) image = tf.image.per_image_standardization(image) image = tf.reshape(image, [1, 64, 64, 3]) logit = CNN.inference(image, BATCH_SIZE, N_CLASSES) logit = tf.nn.softmax(logit) x = tf.placeholder(tf.float32, shape=[64, 64, 3]) logs_train_dir = r'E:\PycharmPython\NewCnn\logs' saver = tf.train.Saver() with tf.Session() as sess: print('Reading checkpoints...') ckpt = tf.train.get_checkpoint_state(logs_train_dir) if ckpt and ckpt.model_checkpoint_path: global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1] saver.restore(sess, ckpt.model_checkpoint_path) print('Loading success, global_step is %s' % global_step) else: print('No checkpoint file found') prediction = sess.run(logit,feed_dict={x:image_array}) max_index = np.argmax(prediction) # if max_index == 0: # print('This is a animales with possibility %.6f' % prediction[:, 0]) # elif max_index == 1: # print('This is a banded with possibility %.6f' % prediction[:, 1]) # elif max_index == 2: # print('This is a potholed with possibility %.6f' % prediction[:, 2]) # elif max_index == 3: # print('This is a writeflowers with possibility %.6f' % prediction[:, 3]) # else: # print('This is a yellowflowers with possibility %.6f' % prediction[:, 4]) return max_index # print(max_index) if __name__ == '__main__': train_dir = r'E:\PycharmPython\NewCnn\train\train_data' #訓練集路徑 image_dir = r'E:\PycharmPython\NewCnn\image' #待分類圖片路徑 prediction_dir = r'E:\PycharmPython\NewCnn\prediction' #分類結果儲存路徑 for str_classes in os.listdir(train_dir): classes.append(str_classes) n_classes =n_classes + 1 # #建立分類後圖片的儲存路徑 # train, train_label, val, val_label = get_files(train_dir, 0.3) # img = get_one_image(val) # 通過改變引數train or val,進而驗證訓練集或測試集 # pre = evaluate_one_image(img,n_classes) prediction_image_path(classes,prediction_dir) #掃描待分類圖片,分類之後儲存到對應的分類路徑 for image_data in os.listdir(image_dir): image_data_path = image_dir + '\\'+image_data orig_img,img = get_one_image(image_data_path) pre = evaluate_one_image(img,n_classes) for i in range(n_classes): if pre == i: print(classes[i]) orig_img.save(prediction_dir +'\\'+ classes[i] +'\\' +str(i) + image_data+ '.jpg')
上面是對image檔案中圖片進行分類。
連載:https://blog.csdn.net/qq_28821995/article/details/83587032