1. 程式人生 > >Tensorflow + ResNet101 + fasterRcnn 訓練自己的模型 資料(一)

Tensorflow + ResNet101 + fasterRcnn 訓練自己的模型 資料(一)

一、資料準備:

1、PASCAL VOC資料集格式

2、資料擴充:做了旋轉【0, 90,180,270】(備註:這裡可以不做那麼多許旋轉,fasterrcnn在訓練的時候要做圖片的映象變換)、降取樣

降取樣:

import os
import cv2
import numpy as np
try:
	import xml.etree.cElementTree as ET
except ImportEroor:
	import xml.etree.ElementTree as ET
import copy,random

def _image_downsampling(folder,Savefolder,scale=0.5):   
	JPEGImages = os.path.join(folder, 'JPEGImages')
	saveJPEGImages = os.path.join(Savefolder, 'JPEGImages')
	if not os.path.exists(saveJPEGImages):
		os.makedirs(saveJPEGImages)

	for imgfile in os.listdir(JPEGImages):
		print(imgfile)
		img = cv2.imread(os.path.join(JPEGImages,imgfile))
		img_downsampled = cv2.resize(img,(int(scale*img.shape[1]),int(scale*img.shape[0])))
		new_name = imgfile[:-4] + '_downsampled' + str(scale) + '.jpg'
		cv2.imwrite(os.path.join(saveJPEGImages,new_name),img_downsampled)
		# cv2.imshow('name',downsampled_img)
		# key = cv2.waitKey(5000)
		# if key == 27:
		#     quit()
		# cv2.destroyWindow('name')

def _rewrite_txt(folder,Savefolder,scale):
	Annotations = os.path.join(folder, 'Annotations/labels')
	saveAnnotations = os.path.join(Savefolder, 'Annotations/labels')
	if not os.path.exists(saveAnnotations):
		os.makedirs(saveAnnotations)

	for txtfile in os.listdir(Annotations):
		print(txtfile)
		ff = open(os.path.join(Annotations,txtfile),'r')
		f = open(os.path.join(saveAnnotations,txtfile[:-4] + '_downsampled'  + str(scale) + '.txt'),'at')
		for line in ff.readlines():
			[filename, cls, x1, y1, x2, y2] = line.strip().split(' ')
			x1 = int(int(x1) * scale)
			y1 = int(int(y1) * scale)
			x2 = int(int(x2) * scale)
			y2 = int(int(y2) * scale)
			new_name = filename[:-4] + '_downsampled' + str(scale) + '.jpg'
			nline = new_name + ' ' + cls + ' ' + str(x1) + ' ' + str(y1) + ' ' + str(x2) + ' ' + str(y2) + '\n'
			f.write(nline)
		f.close()
		ff.close()

def _rewrite_xml(folder,Savefolder,scale):
	Annotations = os.path.join(folder, 'Annotations')
	saveAnnotations = os.path.join(Savefolder, 'Annotations')
	if not os.path.exists(saveAnnotations):
		os.makedirs(saveAnnotations)
	for xmlfile in os.listdir(Annotations):
		print(xmlfile)
		# if xmlfile[:-4] != '.xml':
		#     continue
		new_name = xmlfile[:-4] + '_downsampled' + str(scale) + '.jpg'
		tree = ET.parse(os.path.join(Annotations,xmlfile))
		root = tree.getroot()

		for obj in root.findall('object'):
			bndbox = obj.find('bndbox')
			bndbox[0].text = str(int(float(bndbox[0].text)*scale))
			bndbox[1].text = str(int(float(bndbox[1].text)*scale))
			bndbox[2].text = str(int(float(bndbox[2].text)*scale))
			bndbox[3].text = str(int(float(bndbox[3].text)*scale))
		size = tree.find('size')
		height = size.find('height')
		width = size.find('width')
		height.text = str(int(float(height.text)*scale))
		width.text = str(int(float(width.text)*scale))

		filename = tree.find('filename')
		filename.text = new_name
		print(new_name)

		tree.write(os.path.join(saveAnnotations,new_name[:-4]+'.xml'))


def _rewrite_trainval_and_test(folder,Savefolder,scales):
	if not os.path.exists(os.path.join(Savefolder,'ImageSets', 'Main')):
		os.makedirs(os.path.join(Savefolder,'ImageSets', 'Main'))

	# folder = '../enhancement/' + cate + '/'
	# Savefolder = '../downsampled/'+ cate + '_downsampled'
	trainval = open(folder + 'ImageSets/Main/train.txt','r')
	print 'trainval:', trainval
	test = open(folder + 'ImageSets/Main/val.txt','r')
	print 'test:',test

	save_trainval = open(Savefolder + '/ImageSets/Main/train.txt','a')
	print 'save_trainval:', save_trainval
	save_test = open(Savefolder + '/ImageSets/Main/val.txt','a')
	print 'save_test:', save_test


	for line in trainval.readlines():
		for i in range(len(scales)):
			nline = line.strip() + '_downsampled' + str(i)
			save_trainval.write(nline + '\n')
	save_trainval.close()

	for line in test.readlines():
		for i in range(len(scales)):
			nline = line.strip() + '_downsampled' + str(i)
			save_test.write(nline + '\n')
	save_test.close()


def main(folder,scale = 0.5):
	#images
	_image_downsampling(folder,Savefolder,scale)

	#annotations
	if annotations_file is 'txt':
		_rewrite_txt(folder,Savefolder,scale)
		
	elif annotations_file is 'xml':
		_rewrite_xml(folder,Savefolder,scale) 

	#trainval,test
	_rewrite_trainval_and_test(folder,Savefolder,[scale])


def main2():
	
	JPEGImages = os.path.join(folder, 'JPEGImages')
	saveJPEGImages = os.path.join(Savefolder, 'JPEGImages')
	
	if not os.path.exists(saveJPEGImages):
		os.makedirs(saveJPEGImages)
	
	Annotations = os.path.join(folder, 'Annotations')
	saveAnnotations = os.path.join(Savefolder, 'Annotations')
	if not os.path.exists(saveAnnotations):
		os.makedirs(saveAnnotations)
	
	scales = None
	for imgfile in os.listdir(JPEGImages):
		#images
		img = cv2.imread(os.path.join(JPEGImages,imgfile))
		scales = [random.uniform(2,15)*0.1 for _ in range(2)]
		# scales = [random.uniform(5,10)*0.1,random.uniform(10,15)]

		for i,scale in enumerate(scales):
			img_downsampled = cv2.resize(img,(int(scale*img.shape[1]),int(scale*img.shape[0])))
			new_name = imgfile[:-4] + '_downsampled' + str(i)+'.jpg'
			print(new_name)
			cv2.imwrite(os.path.join(saveJPEGImages,new_name),img_downsampled)

			#annotations
			xmlfile = imgfile[:-4] + '.xml'
			tree = ET.parse(os.path.join(Annotations,xmlfile))
			root = tree.getroot()
			for obj in root.findall('object'):
				bndbox = obj.find('bndbox')

				xmin = bndbox.find('xmin').text
				ymin = bndbox.find('ymin').text
				xmax = bndbox.find('xmax').text
				ymax = bndbox.find('ymax').text

				bndbox.find('xmin').text = str(int(int(xmin)*scale))
				bndbox.find('ymin').text = str(int(int(ymin)*scale))
				bndbox.find('xmax').text = str(int(int(xmax)*scale))
				bndbox.find('ymax').text = str(int(int(ymax)*scale))

				assert(((int(xmax) - int(xmin))*scale)>0)
				assert(((int(ymax) - int(ymin))*scale)>0)
		

			filename = tree.find('filename')
			filename.text = new_name
			size = tree.find('size')
			height = size.find('height')
			width = size.find('width')
			height.text = str(int(float(height.text)*scale))
			width.text = str(int(float(width.text)*scale))
			tree.write(os.path.join(saveAnnotations,new_name[:-4]+'.xml'))
	
	#trainval,test
	_rewrite_trainval_and_test(folder,Savefolder,scales)


if __name__=='__main__': 
	global annotations_file,folder, Savefolder
	annotations_file = 'xml'#xml
	catelist = ['hydropower', 'thermalpower', 'tower', 'windpower']

	for cate in catelist:
		folder = '../enhancement/' + cate + '/'
		Savefolder = '../downsampled/'+ cate + '_downsampled'
		main2()
		# _rewrite_trainval_and_test(folder,Savefolder,[0,0])





旋轉:
import cv2
import os
import numpy as np
import math
import copy
try:
	import xml.etree.cElementTree as ET
except ImportError:
	import xml.etree.ElementTree as ET


def _write_(f,ff):
	for lines in ff.readlines():
		for angle in rotation_angle:
			line1 = lines.strip()+'_rotated_' + str(angle)
			# line2 = lines.strip()+'_rotated_' + str(angle) + '_mirror'
			f.write(line1+'\n')
			# f.write(line2+'\n')
	f.close()
	ff.close()


def rotate_about_center(src, angle, scale=1.):
	w = src.shape[1]
	h = src.shape[0]

	rangle = np.deg2rad(angle) #angle in radians
	nw = (abs(np.sin(rangle)*h)+abs(np.cos(rangle)*w))*scale
	nh = (abs(np.cos(rangle)*h)+abs(np.sin(rangle)*w))*scale

	rot_mat = cv2.getRotationMatrix2D((nw*0.5,nh*0.5), angle, scale)# rotate with center
	rot_move = np.dot(rot_mat,np.array([(nw-w)*0.5,(nh-h)*0.5,0]))

	rot_mat[0,2] += rot_move[0]
	rot_mat[1,2] += rot_move[1]
	return cv2.warpAffine(src,rot_mat,(int(math.ceil(nw)),int(math.ceil(nh))),flags = cv2.INTER_LANCZOS4)


def enhancement_using_rotation(ImagePath,AnnotationsPath):
	print '============'
	print ImagePath

	for imgfile in os.listdir(ImagePath):
		print imgfile[:-4]
		if not os.path.isfile(os.path.join(AnnotationsPath,imgfile[:-4]+'.xml')):
			continues
		img = cv2.imread(os.path.join(ImagePath,imgfile))

		#rotation
		for angle in rotation_angle:

			new_name = imgfile[:-4] + '_rotated_'+ str(angle) + '.jpg'
			# print '\nnew JPEGImage:', new_name
			rotate_img = rotate_about_center(img,angle)
			cv2.imwrite(os.path.join(ImageSavePath,new_name),rotate_img)

			center_x = img.shape[1]/2
			center_y = img.shape[0]/2
			new_center_x = rotate_img.shape[1]/2
			new_center_y = rotate_img.shape[0]/2


			if annotation_file is 'txt':            
				ff = open(os.path.join(AnnotationsPath,'labels',imgfile[:-4]+'.txt'),'r')
				f = open(os.path.join(AnnotationsSavePath,'labels',new_name[:-4] +'.txt'),'a')
				for line in ff.readlines():
					[filename, cls, x1,y1,x2,y2,t] = line.split(' ')
					
					final_x1, final_y1, final_x2, final_y2 = \
										_rotated_location(x1,y1,x2,y2,center_x,center_y,new_center_x,new_center_y)
					assert(final_y2-final_y1>0) & (final_x2-final_x1>0)
					# print 'rotated:',final_x1, final_y1, final_x2, final_y2
					if ifshow == 1:
						cv2.rectangle(rotate_img,(final_x1,final_y1),(final_x2,final_y2),(0,0,255),2)
						cv2.putText(rotate_img, cls, (int(final_x1),int(final_y1)),0,1.2,(0,0,255),2)
					nline = new_name + ' ' + cls + ' ' + str(final_x1) + ' ' + str(final_y1) + ' ' + str(final_x2) + ' ' + str(final_y2) + '\n'
					f.write(nline)
				f.close()
				ff.close()
			
			elif annotation_file is 'xml':
				Annotations = os.path.join(AnnotationsPath,imgfile[:-4]+'.xml')
				saveAnnotations = os.path.join(AnnotationsSavePath,new_name[:-4]+'.xml')
				tree = ET.parse(Annotations)

				filename = tree.find('filename')
				filename.text = new_name

				size = tree.find('size')
				height = size.find('height')
				height.text = str(rotate_img.shape[0])
				width = size.find('width')
				width.text = str(rotate_img.shape[1])

				root = tree.getroot()
				for obj in tree.findall('object'):
					class_node = obj.find('name')
					bndbox = obj.find('bndbox')
					x1 = bndbox.find('xmin').text
					y1 = bndbox.find('ymin').text
					x2 = bndbox.find('xmax').text
					y2 = bndbox.find('ymax').text
					if not(int(y2)-int(y1)>0) & (int(x2)-int(x1)>0):
						print(Annotations)
						root.remove(obj)
						continue
					assert(int(y2)-int(y1)>0) & (int(x2)-int(x1)>0)

					# cls = class_node.text
					# if cls == 'airplane': 
					# 	class_node.text = 'aircraft'
					# 	print(class_node.text,cls,'11')
					# if cls == 'car': 
					#     class_node.text = 'vehicle'
					#     print(class_node.text,cls,'22')
					#     print(imgfile)
					#     root.remove(obj)
					#     continue
					# if cls == 'copy of helicopter' or cls =='Copy of helicopter': 
					#     class_node.text = 'helicopter'
					#     print(class_node.text,cls,'33')
					#     print(imgfile)
			
					final_x1, final_y1, final_x2, final_y2 = \
										_rotated_location(x1,y1,x2,y2,center_x,center_y,new_center_x,new_center_y,angle)
					assert(final_y2-final_y1>0)
					assert(final_x2-final_x1>0)
					assert(final_x1>=0)
					assert(final_y1>=0)
					assert(final_x2<=width.text)
					assert(final_y2<=height.text)
					# print 'rotated:',final_x1, final_y1, final_x2, final_y2
					if ifshow == 1:
						cv2.rectangle(rotate_img,(final_x1,final_y1),(final_x2,final_y2),(0,0,255),2)
						cv2.putText(rotate_img, cls, (int(final_x1),int(final_y1)),0,1.2,(0,0,255),2)
					bndbox.find('xmin').text = str(final_x1)
					bndbox.find('ymin').text = str(final_y1)
					bndbox.find('xmax').text = str(final_x2)
					bndbox.find('ymax').text = str(final_y2)
					
				tree.write(saveAnnotations)
				# tree.clear()

			if ifshow == 1:
				cv2.imshow(new_name,rotate_img)
				key = cv2.waitKey(500)
				if key ==27:
						quit()
				cv2.destroyWindow(new_name)


def _rotated_location(x1,y1,x2,y2,center_x,center_y,new_center_x,new_center_y,angle):
	x1 = float(x1) - center_x
	y1 = -(float(y1) - center_y)
	x2 = float(x2) - center_x
	y2 = -(float(y2) - center_y)

	rangle = np.deg2rad(angle)
	rotated_x1 = np.cos(rangle)*x1 - np.sin(rangle)*y1
	rotated_y1 = np.cos(rangle)*y1 + np.sin(rangle)*x1
	rotated_x2 = np.cos(rangle)*x2 - np.sin(rangle)*y2
	rotated_y2 = np.cos(rangle)*y2 + np.sin(rangle)*x2

	rotated_x1 = int(rotated_x1 + new_center_x)
	rotated_y1 = int(-rotated_y1 + new_center_y)
	rotated_x2 = int(rotated_x2 + new_center_x)
	rotated_y2 = int(-rotated_y2 + new_center_y)

	final_x1 = int(min(rotated_x1,rotated_x2))
	final_y1 = int(min(rotated_y1,rotated_y2))
	final_x2 = int(max(rotated_x1,rotated_x2))
	final_y2 = int(max(rotated_y1,rotated_y2))
	return final_x1, final_y1, final_x2, final_y2


def enhancement_using_mirror(ImagePath,AnnotationsPath):
	#mirror
	for imgfile in os.listdir(ImagePath):
		img = cv2.imread(os.path.join(ImagePath,imgfile))
		new_name = imgfile[:-4] + '_mirror' + '.jpg'
		# print(new_name)

		mirror_img  = mirroir_hierachically(img)
		cv2.imwrite(os.path.join(ImageSavePath,new_name),mirror_img)

		w = img.shape[1]

		if annotation_file is 'txt':
			ff = open(os.path.join(txtpath,imgfile[:-4]+'.txt'),'r')
			f = open(os.path.join(txtpath,new_name[:-4]+'.txt'),'a')
			for line in ff.readlines():
				[filename,cls,x1,y1,x2,y2] = line.strip().split(' ')
				x2 = abs(int(x1)-w)
				x1 = abs(int(x2)-w)

				final_x1 = min(x1,x2)
				final_x2 = max(x1,x2)
				assert(final_x2-final_x1>0)

				nline = new_name + ' ' + cls + ' ' + str(final_x1) + ' ' + y1 + ' ' + str(final_x2) +  ' ' + y2 + '\n'
				f.write(nline)

				if ifshow == 1:
					cv2.rectangle(mirror_img,(final_x1,int(y1)),(final_x2,int(y2)),(0,0,255),2)
					cv2.putText(mirror_img, cls, (final_x1,int(y1)),0,1.2,(0,0,255),2)           
			
			f.close()
			ff.close()
		
		elif annotation_file is 'xml':
			Annotations = os.path.join(AnnotationsPath, imgfile[:-4]+'.xml')
			saveAnnotations = os.path.join(AnnotationsSavePath,new_name[:-4]+'.xml')
			tree = ET.parse(Annotations)


			filename = tree.find('filename')
			filename.text = new_name


			for obj in tree.findall('object'):
				bndbox = obj.find('bndbox')
				x1 = bndbox.find('xmin').text
				y1 = bndbox.find('ymin').text
				x2 = bndbox.find('xmax').text
				y2 = bndbox.find('ymax').text

				new_x1 = abs(int(x2)-w)
				new_x2 = abs(int(x1)-w)

				final_x1 = int(min(new_x1,new_x2))
				final_x2 = int(max(new_x1,new_x2))
				assert(final_x2-final_x1>0)

				cls = obj.find('name').text 
	 
				if ifshow == 1:
					cv2.rectangle(mirror_img,(final_x1,int(y1)),(final_x2,int(y2)),(0,0,255),2)
					cv2.putText(mirror_img, cls, (int(final_x1),int(y1)),0,1.2,(0,0,255),2)
										   
				bndbox.find('xmin').text = str(final_x1)
				bndbox.find('xmax').text = str(final_x2)


			tree.write(saveAnnotations)
			# tree.clear()


		if ifshow == 1:
			cv2.imshow(new_name,mirror_img)
			key = cv2.waitKey(500)
			if key ==27:
				quit()
			cv2.destroyWindow(new_name)

def mirroir_hierachically(src):
	w = src.shape[1]
	h = src.shape[0]
	ll = src.shape[2]
	mirror_img = copy.deepcopy(src)
	for wi in xrange(w):
		mirror_img[:,w-wi-1] = src[:,wi]
	return mirror_img




if __name__ == '__main__':
	global annotation_file, rotation_angle, ifshow, ImageSavePath,AnnotationsSavePath,AnnotationsSavePath
	annotation_file = 'xml'#'txt',xml'
	rotation_angle=[0,90,180,270]
	ifshow = 0
	catelist = ['hydropower', 'thermalpower', 'tower', 'windpower']

# ##################################### PART 1 ######################################################## 
	for cate in catelist:
		
		savepath =  '../enhancement/' + cate
		ImageSavePath = savepath + '/JPEGImages/'
		AnnotationsSavePath = savepath + '/Annotations/'
		ImageSetsSavePath = savepath + '/ImageSets/Main/'

		if not os.path.exists(ImageSavePath):
			os.makedirs(ImageSavePath)    
		if not os.path.exists(AnnotationsSavePath):
			os.makedirs(AnnotationsSavePath)
		if not os.path.exists(ImageSetsSavePath):
			os.makedirs(ImageSetsSavePath)

	# ##################################### PART 2 ########################################################

		#augument trainning data
		print('rotating...')
		# path =   '../TrainData/' + cate  
		path =   '../TrainData/' + cate 
		
		ImagePath = path + '/JPEG/'
		AnnotationsPath = path + '/XML/'
		enhancement_using_rotation(ImagePath,AnnotationsPath)
   
	##################################### PART 3 ########################################################
		# trainval.txt/test.txt
		print('writing trainval / test txt...')
		path =  '../TrainData/' + cate + '/ImageSets/Main/'
		f1 = open(path + 'train.txt','r')
		f2 = open(path + 'val.txt','r')

		trainval_txt = open('../enhancement/'+cate +'/ImageSets/Main/train.txt','a')
		test_txt = open( '../enhancement/'+cate +'/ImageSets/Main/val.txt','a')
		_write_(trainval_txt,f1)
		_write_(test_txt,f2)


3、訓練中資料出現的問題:

(1)此次訓練了4類目標,但是資料裡面實際存在了其他類的資料,

(2)圖片邊界問題:x,y不能為負數; anchor一致

(3)一些引數問題要好好看下論文

(4)cache檔案刪除(資料初始化將train_val和test寫入,因為沒有刪除 導致在test是出現keyerror。訓練的時候是先初始化,資料放在cache檔案pkl檔案中)