Python使用tensorflow實現影象識別(貓狗大戰)-01
阿新 • • 發佈:2018-11-02
Python使用tensorflow實現影象識別(貓狗大戰)-01
import_data.py
import tensorflow as tf
import numpy as np
import os
#引入tensorflow、numpy、os 三個第三方模組
img_width = 208
img_height = 208
#此處設定一個影象的寬度高度,後面會用的到
train_dir = 'C:/Python/data/train/'
#設定訓練樣本所在路徑,可根據自己的實際修改
#讀取檔案函式
def get_files(file_dir):
cats = [ ]
label_cats = []
dogs = []
label_dogs = []
#定義一個函式 get_files,獲取目錄下的圖片及標籤,cats、label_cats等都是列表型別
for file in os.listdir(file_dir):
name = file.split('.')
if name[0] == 'cat':
cats.append(file_dir + file)
label_cats.append(0)
else:
dogs.append(file_dir + file)
label_dogs.append(1)
print('There are %d cats \nThere are %d dogs' %(len(cats), len(dogs)))
#split() 通過指定分隔符對字串進行切片,切片之後為列表型別,os.listdir(file_dir)返回列表型別;
#圖片的命名格式為:cat.0.jpg 按照 ‘ . ’ 進行分割,結果為[cat, 0, jpg],對訓練路徑下的所有圖片進行操作
#然後根據判斷條件在dogs列表或cats列表中加入file_dir+file
image_list = np.hstack((cats, dogs))
label_list = np.hstack((label_cats, label_dogs))
#numpy的hstack() 函式的用法,返回numpy的陣列
temp = np.array([image_list, label_list])
'''
陣列: =[[cats, dogs],
[label_cats, dogs_label]]
'''
temp = temp.transpose() #矩陣轉置
np.random.shuffle(temp) # 打亂存放的順序,random庫中shuffle()函式的用法,但image 和label 一一對應,不會混亂
image_list = list(temp[:, 0]) # 獲取圖片
label_list = list(temp[:, 1]) # 獲取標籤
label_list = [float(i) for i in label_list]
return image_list, label_list
#經過操作後 label_list = [0, 1]
#影象裁剪函式
def get_batch(image, label, image_W, image_H, batch_size, capacity):
# 型別轉換函式
image = tf.cast(image, tf.string) # 資料型別轉換 image->string 此處image為影象的儲存路徑,轉換為字串格式
label = tf.cast(label, tf.int32) # 資料型別轉換 label->int32---------------------------------------(1)
#make an input queue 生成輸入對列--------------------------------------------------------------------(2)
input_queue = tf.train.slice_input_producer([image, label])
label = input_queue[1] # 讀取標籤
image_contents = tf.read_file(input_queue[0]) # 讀取影象 string型別
image = tf.image.decode_jpeg(image_contents, channels = 3) #解碼
# 對圖片進行裁剪或擴充【在影象中心處裁剪】,統一大小
image = tf.image.resize_image_with_crop_or_pad(image, image_W, image_H)
# 資料標準化 訓練前需要對資料進行標準化
image = tf.image.per_image_standardization(image) # 生成批次 在輸入的tensor中建立一些tensor資料batch
image_batch, label_batch = tf.train.batch([image, label],
batch_size = batch_size,
num_threads = 64,
capacity = capacity)
# 重新生成大小,即將label_batch變換成[batch_size]行的形式
label_batch = tf.reshape(label_batch, [batch_size])
return image_batch, label_batch
- tensorflow下cast()函式用法:
cast(x,dtype,name=None)
'''
將x的資料格式轉化成dtype資料型別,name是名字。例如:原來的資料格式是bool,
那麼將其轉化成float以後,就能夠將其轉化為0和1的序列。反之也可以。
'''
- tf.train.slice_input_producer()函式
https://blog.csdn.net/dcrmg/article/details/79776876 這個部落格詳細一點,就不寫了