1. 程式人生 > >多光譜遙感分類:使用CNN1(一)

多光譜遙感分類:使用CNN1(一)

程式碼源於很久以前練手的一個Demo,時間長了許多魔改版的都不見了,目前只剩下此簡陋版本。讀者如有相關需求,可根據隻言片語斷章取義。由於程式碼混亂基礎,不再上傳GitHub。

所用資料為多光譜遙感影像(.tif,由arcgis匯出RGB彩色影象),摳圖所得點檔案(.shp)(由摳圖面檔案使用arcgis隨機生成點生成,至少有一個欄位,即標籤)。

工具篇

根據點shp檔案(樣本點集合),對柵格影象的3、2、1波段切圖,並儲存在相應標籤下的資料夾,注意shp、tif的投影座標一致

from osgeo import gdal
import numpy as np
import shapefile
import
cv2 import os size=64 bands=3 dataset = gdal.Open(r"E:\資料2\test_tif_peizhun_subset_proj_.tif") rer=shapefile.Reader(r'E:\shps\test.shp') def __createDir(path): if not os.path.exists(path): try: os.makedirs(path) except: print("建立資料夾失敗") exit(
1) def __getACell(geo,pos): try: xoffset = int((pos[0] - geo[0]) / geo[1]) yoffset = int((pos[1] - geo[3]) / geo[5]) print("pixels: x= %d,y= %d" % (xoffset, yoffset)) output = [] for i in [3,2,1]: band = dataset.GetRasterBand(i) if (int
(xoffset - size / 2) < 0 or int(yoffset - size / 2) < 0 or int(xoffset - size / 2) + size > dataset.RasterXSize or int(yoffset - size / 2) + size > dataset.RasterYSize): return None t = band.ReadAsArray(int(xoffset - size / 2), int(yoffset - size / 2), size, size) output.append(t) img = np.moveaxis(np.array(output, dtype=np.uint8), 0, 2) except: return None return img def getShpDataForNum(): labels=[i[0] for i in rer.records()] for i in set(labels): __createDir(os.path.join("data/org/"+str(i))) for i in range(rer.numRecords):#rer.numRecords print("deal %d: " % (i+1)) sr=rer.shape(i) img=__getACell(dataset.GetGeoTransform(), sr.points[0]) if(img is None): print("the area of points %d is out range." %(i)) continue label=labels[i] cv2.imwrite("data/org/%s/%s.%d.jpg" % (label, label, i), img) print("data/org/%s/%s.%d.jpg" % (label, label, i)) print("deal finish,to numpy array.") getShpDataForNum()

如下,將上述所得檔案拆分為測試集和訓練集。

import os
import shutil
import random

def createDir(path):
    if not os.path.exists(path):
        try:
            os.makedirs(path)
        except:
            print("建立資料夾失敗")
            exit(1)

createDir("data/train/")
createDir("data/test/")


dir='data/org/'
for dir_item in os.listdir(dir):

    createDir("data/train/" + dir_item)
    createDir("data/test/"+dir_item)

    org_data=os.listdir(dir+dir_item+"/")
    random.shuffle(org_data)
    num=int(len(org_data)*0.25)

    print(dir + dir_item + " start.")
    for d in org_data[:-num]:
        shutil.copyfile(dir + dir_item + "/" + d, "data/train/" + dir_item + "/" + d)
    for d in org_data[-num:]:
        shutil.copyfile(dir+dir_item+"/"+d,"data/test/"+dir_item+"/"+d)
    print(dir+dir_item+" finished")


以下顯示制定資料夾下的子資料夾中的檔案數目直方圖。

import os
import seaborn as sns
import matplotlib.pyplot as plt
def show(path,title):
    d=os.listdir(path)
    d_len=[len(os.listdir(os.path.join(path,i))) for i in d]

    # print(d,d_len)

    plt.rcParams['font.sans-serif'] = ['SimHei']  # 用來正常顯示中文標籤
    plt.rcParams['axes.unicode_minus'] = False  # 用來正常顯示負號
    sns.barplot(d,d_len,)
    plt.xlabel("樣本型別")
    plt.ylabel("數量")
    plt.title(title)

    for i in range(len(d_len)):
        plt.text(i,d_len[i]+2,"%d" % d_len[i],ha="center",va="bottom")
    plt.show()

show(r"data/1_train","訓練集源資料取樣集")


由於其他原因,資料更改。如下為使用shp樣本點對應的畫素座標所採圖集。此時分為train pos.txt和test pos.txt諸如此類。

from osgeo import gdal
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import os
import cv2,shutil


class Tiff:
    def createDir(self, path):
        if not os.path.exists(path):
            try:
                os.makedirs(path)
            except:
                print("建立資料夾失敗")
                exit(1)

    def __init__(self,  pos_src,other_feather,contact_src,size=128,bands=[3,2,1],tif_src=r"D:/lishihang/jiangxia_simple/ZY3_GS_jiangxia1.tif"):

        self.dataset = gdal.Open(tif_src)  # tif資料
        self.size = size  # 取樣視窗大小
        self.bands=bands
        self.contact_pos_feather(pos_src, other_feather,contact_src)
        self.fea =pd.read_csv(contact_src, header=None)
        # shutil.rmtree("data/temp.txt")

    def get_cell(self, pos_x, pos_y):
        try:
            output = []
            for i in self.bands:
                band = self.dataset.GetRasterBand(i)
                t = band.ReadAsArray(int(pos_x - self.size / 2), int(pos_y - self.size / 2), self.size, self.size)
                output.append(t)

            img2 = np.moveaxis(np.array(output, dtype=np.uint8), 0, 2)
            # print(img2.shape)
            # self.showImg(img2)
        except:
            return None
        return img2

    def get_cells(self,target_src):
        fea_len=len(self.fea)


        self.createDir(target_src)
        for label in set(self.fea.iloc[:,-2]):
            self.createDir("%s/%s" % (target_src,label))

        print("fea length: %d" % fea_len)

        for i in range(fea_len):
            temp=self.fea.iloc[i,:].values
            img = self.get_cell(temp[1], temp[0])
            if img is None:
                continue
            cv2.imwrite("%s/%s/%s.%d.jpg" % (target_src,temp[-2], temp[-2], i), img)
            if(i%1000==0):
                print("%d/%d hava finsh save." % (i,fea_len))

    def contact_pos_feather(self,pos_src, other_feather,target):
        if os.path.exists(target):
            print("檔案已存在")
            return
        pos = pd.read_csv(pos_src, header=None, sep=' ')
        feather = pd.read_csv(other_feather, header=None, sep='\t')
        # fea = pd.concat([pos, feather], axis=1).sample(frac=1).reset_index(drop=True)
        fea = pd.concat([pos, feather], axis=1)
        print("pos Length=%d,feather Length=%d,fea Length=%d" % (len(pos), len(feather), len(fea)))
        # print(type(fea))
        del feather
        del pos
        fea = pd.DataFrame(fea)
        fea.to_csv(target, index=None, header=None)



if __name__ == '__main__':
    tiff=Tiff(r"D:/tr_sample_1.txt",r"D:/train1.txt",r"tr_1.txt")
    # tiff=Tiff(r"D:/te_sample_1.txt",r"D:/test1.txt",r"te_1.txt")
    # tiff.get_cells("data/1_test")