資料集.npy格式與png格式互換
阿新 • • 發佈:2018-12-10
深度學習中,有時我們需要對資料集進行預處理,這樣能夠更好的讀取資料。
一、png格式生成.npy格式
import numpy as np import os from PIL import Image dir="C:/Users/Administrator/Desktop/trainA" def getFileArr(dir): result_arr=[] label_list=[] map={} map_file_result={} map_file_label={} map_new={} count_label=0 count=0 file_list=os.listdir(dir) for file in file_list: file_path=os.path.join(dir,file) label=file.split(".")[0].split("_")[0] map[file]=label if label not in label_list: label_list.append(label) map_new[label]=count_label count_label=count_label+1 img=Image.open(file_path) result=np.array([]) r,g,b=img.split() r_arr=np.array(r).reshape(4096) g_arr=np.array(g).reshape(4096) b_arr=np.array(b).reshape(4096) img_arr=np.concatenate((r_arr,g_arr,b_arr)) result=np.concatenate((result,img_arr)) result=result.reshape((64,64,3)) result=result/255.0 map_file_result[file]=result result_arr.append(result) count=count+1 for file in file_list: map_file_label[file]=map_new[map[file]] #map[file]=map_new[map[file]] ret_arr=[] for file in file_list: each_list=[] label_one_zero=np.zeros(count_label) result=map_file_result[file] label=map_file_label[file] label_one_zero[label]=1.0 #print(label_one_zero) each_list.append(result) each_list.append(label_one_zero) ret_arr.append(each_list) os.makedirs("C:/Users/Administrator/Desktop/npy") np.save('C:/Users/Administrator/Desktop/npy/test_data.npy', ret_arr) return ret_arr if __name__=="__main__": ret_arr=getFileArr(dir)
二、.npy格式生成png格式
import numpy as np from PIL import Image import os dir="C:/Users/Administrator/Desktop/npy/"#npy檔案路徑 dest_dir="C:/Users/Administrator/Desktop/train/" def npy2jpg(dir,dest_dir): if os.path.exists(dir)==False: os.makedirs(dir) if os.path.exists(dest_dir)==False: os.makedirs(dest_dir) file=dir+'test_data.npy' con_arr=np.load(file) count=0 for con in con_arr: arr=con[0] label=con[1] print(np.argmax(label)) arr=arr*255 #arr=np.transpose(arr,(2,1,0)) arr=np.reshape(arr,(3,64,64)) r=Image.fromarray(arr[0]).convert("L") g=Image.fromarray(arr[1]).convert("L") b=Image.fromarray(arr[2]).convert("L") img=Image.merge("RGB",(r,g,b)) label_index=np.argmax(label) img.save(dest_dir+str(label_index)+"_"+str(count)+".png") count=count+1 if __name__=="__main__": npy2jpg(dir,dest_dir)
三、注意
根據自己的資料集需要改尺寸和維度以及改路徑。