1. 程式人生 > >Pytorch實現PointNet中的點雲分類網路。

Pytorch實現PointNet中的點雲分類網路。

下面是PointNet論文中分類模型的結構:

但是對於模型的細節,PointNet論文中並沒有詳細的解釋,尤其是T-Net,可以參考PointNet的supplemental部分。如果找不到,可以留言找我要。

話不多說,下面是程式碼,基本上完全還原了論文中的PointNet分類模型。

第一部分:資料處理模組

import h5py
import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
main_path="E:/DataSets/shapenet_part_seg_hdf5_data/hdf5_data/"
train_txt_path=main_path+"train_hdf5_file_list.txt"
valid_txt_path=main_path+"val_hdf5_file_list.txt"

def get_data(train=True):
    data_txt_path =train_txt_path if train else valid_txt_path

    with open(data_txt_path, "r") as f:
        txt = f.read()
    clouds_li = []
    labels_li = []
    for file_name in txt.split():
        h5 = h5py.File(main_path + file_name)
        pts = h5["data"].value
        lbl = h5["label"].value
        clouds_li.append(torch.Tensor(pts))
        labels_li.append(torch.Tensor(lbl))
    clouds = torch.cat(clouds_li)
    labels = torch.cat(labels_li)
    return clouds,labels.long().squeeze()

class PointDataSet(Dataset):
    def __init__(self,train=True):

        clouds, labels = get_data(train=train)

        self.x_data=clouds
        self.y_data=labels

        self.lenth=clouds.size(0)
    def __getitem__(self, index):
        return self.x_data[index],self.y_data[index]
    def __len__(self):
        return self.lenth

def get_dataLoader(train=True):
    point_data_set=PointDataSet(train=train)
    data_loader=DataLoader(dataset=point_data_set,batch_size=16,shuffle=train)
    return data_loader

第二部分:模型及其訓練

import torch
import torch.nn as nn
import getData
import datetime
class PointNet(nn.Module):
    def __init__(self,point_num):

        super(PointNet, self).__init__()

        self.inputTransform=nn.Sequential(
            nn.Conv2d(1,64,(1,3)),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),

            nn.Conv2d(64, 128,1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),

            nn.Conv2d(128, 1024,1),
            nn.BatchNorm2d(1024),
            nn.ReLU(inplace=True),

            nn.MaxPool2d((point_num,1)),
        )
        self.inputFC = nn.Sequential(
            nn.Linear(1024,512),
            nn.ReLU(inplace=True),
            nn.Linear(512, 256),
            nn.ReLU(inplace=True),
            nn.Linear(256,9),
        )
        self.mlp1=nn.Sequential(
            nn.Conv2d(1,64,(1,3)),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),

            nn.Conv2d(64,64,1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
        )
        self.featureTransform = nn.Sequential(
            nn.Conv2d(64, 64,1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),

            nn.Conv2d(64, 128,1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),

            nn.Conv2d(128, 1024,1),
            nn.BatchNorm2d(1024),
            nn.ReLU(inplace=True),

            nn.MaxPool2d((point_num, 1)),
        )
        self.featureFC=nn.Sequential(
            nn.Linear(1024, 512),
            nn.ReLU(inplace=True),
            nn.Linear(512, 256),
            nn.ReLU(inplace=True),
            nn.Linear(256, 64*64),
        )
        self.mlp2=nn.Sequential(
            nn.Conv2d(64,64,1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),

            nn.Conv2d(64,128,1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),

            nn.Conv2d(128, 1024, 1),
            nn.BatchNorm2d(1024),
            nn.ReLU(inplace=True),
        )
        self.fc=nn.Sequential(
            nn.Linear(1024,512),
            nn.ReLU(inplace=True),
            nn.Linear(512,256),
            nn.ReLU(inplace=True),
            #nn.Dropout(p=0.7,inplace=True),對於ShapeNet資料集來說,用dropout反而準確率會降低
            nn.Linear(256,16),
            nn.Softmax(dim=1),
        )
        self.inputFC[4].weight.data=torch.zeros(3*3,256)
        self.inputFC[4].bias.data=torch.eye(3).view(-1)
    def forward(self, x):               #[B, N, XYZ]
        '''
            B:batch_size
            N:point_num
            K:k_classes
            XYZ:input_features
        '''
        batch_size=x.size(0)#batchsize大小
        x=x.unsqueeze(1)                #[B, 1, N, XYZ]

        t_net=self.inputTransform(x)    #[B, 1024, 1,1]
        t_net=t_net.squeeze()           #[B, 1024]
        t_net=self.inputFC(t_net)       #[B, 3*3]
        t_net=t_net.view(batch_size,3,3)#[B, 3, 3]

        x=x.squeeze()                   #[B, N, XYZ]

        x=torch.stack([x_item.mm(t_item) for x_item,t_item in zip(x,t_net)])#[B, N, XYZ]# 因為mm只能二維矩陣之間,故逐個乘再拼起來

        x=x.unsqueeze(1)                #[B, 1, N, XYZ]

        x=self.mlp1(x)                  #[B, 64, N, 1]

        t_net=self.featureTransform(x)  #[B, 1024, 1, 1]
        t_net=t_net.squeeze()           #[B, 1024]
        t_net=self.featureFC(t_net)     #[B, 64*64]
        t_net=t_net.view(batch_size,64,64)#[B, 64, 64]

        x=x.squeeze().permute(0,2,1)    #[B, N, 64]

        x=torch.stack([x_item.mm(t_item)for x_item,t_item in zip(x,t_net)])#[B, N, 64]

        x=x.permute(0,2,1).unsqueeze(-1)#[B, 64, N, 1]

        x=self.mlp2(x)                  #[B, N, 64]

        x,_=torch.max(x,2)              #[B, 1024, 1]

        x=self.fc(x.squeeze())          #[B, K]
        return x

EPOCHES=100
POINT_NUM=2048

train_loader=getData.get_dataLoader(train=True)
test_loader=getData.get_dataLoader(train=False)

net=PointNet(POINT_NUM).cuda()

optimizer=torch.optim.Adam(net.parameters(),weight_decay=0.001)
loss_function=nn.CrossEntropyLoss()

for epoch in range(EPOCHES):
    time_start=datetime.datetime.now()
    net.train()
    for cloud,label in train_loader:
        cloud,label=cloud.cuda(),label.cuda()
        out = net(cloud)
        loss=loss_function(out,label)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    total=0
    net.eval()
    for cloud,label in test_loader:
        cloud,label=cloud.cuda(),label.cuda()
        out=net(cloud)
        _,pre=torch.max(out,1)
        correct=(pre==label).sum()
        total+=correct.item()
    time_end=datetime.datetime.now()
    time_span_str=str((time_end-time_start).seconds)
    print(str(epoch+1)+"迭代期準確率:"+ str(total/len(test_loader.dataset))+"耗時"+time_span_str+"S")

#python的強大之處
#acc=sum([(torch.max(net(cloud.cuda()),1)[1]==label.cuda()).sum() for cloud,label in test_loader]).item()/len(test_loader.dataset)

就是上面的配置,對於所使用的ShapeNet資料集,準確度可以達到百分之93以上。如發現什麼問題bug,請留言。