Pytorch實現PointNet中的點雲分類網路。
阿新 • • 發佈:2018-11-21
下面是PointNet論文中分類模型的結構:
但是對於模型的細節,PointNet論文中並沒有詳細的解釋,尤其是T-Net,可以參考PointNet的supplemental部分。如果找不到,可以留言找我要。
話不多說,下面是程式碼,基本上完全還原了論文中的PointNet分類模型。
第一部分:資料處理模組
import h5py import torch from torch.utils.data import Dataset from torch.utils.data import DataLoader main_path="E:/DataSets/shapenet_part_seg_hdf5_data/hdf5_data/" train_txt_path=main_path+"train_hdf5_file_list.txt" valid_txt_path=main_path+"val_hdf5_file_list.txt" def get_data(train=True): data_txt_path =train_txt_path if train else valid_txt_path with open(data_txt_path, "r") as f: txt = f.read() clouds_li = [] labels_li = [] for file_name in txt.split(): h5 = h5py.File(main_path + file_name) pts = h5["data"].value lbl = h5["label"].value clouds_li.append(torch.Tensor(pts)) labels_li.append(torch.Tensor(lbl)) clouds = torch.cat(clouds_li) labels = torch.cat(labels_li) return clouds,labels.long().squeeze() class PointDataSet(Dataset): def __init__(self,train=True): clouds, labels = get_data(train=train) self.x_data=clouds self.y_data=labels self.lenth=clouds.size(0) def __getitem__(self, index): return self.x_data[index],self.y_data[index] def __len__(self): return self.lenth def get_dataLoader(train=True): point_data_set=PointDataSet(train=train) data_loader=DataLoader(dataset=point_data_set,batch_size=16,shuffle=train) return data_loader
第二部分:模型及其訓練
import torch import torch.nn as nn import getData import datetime class PointNet(nn.Module): def __init__(self,point_num): super(PointNet, self).__init__() self.inputTransform=nn.Sequential( nn.Conv2d(1,64,(1,3)), nn.BatchNorm2d(64), nn.ReLU(inplace=True), nn.Conv2d(64, 128,1), nn.BatchNorm2d(128), nn.ReLU(inplace=True), nn.Conv2d(128, 1024,1), nn.BatchNorm2d(1024), nn.ReLU(inplace=True), nn.MaxPool2d((point_num,1)), ) self.inputFC = nn.Sequential( nn.Linear(1024,512), nn.ReLU(inplace=True), nn.Linear(512, 256), nn.ReLU(inplace=True), nn.Linear(256,9), ) self.mlp1=nn.Sequential( nn.Conv2d(1,64,(1,3)), nn.BatchNorm2d(64), nn.ReLU(inplace=True), nn.Conv2d(64,64,1), nn.BatchNorm2d(64), nn.ReLU(inplace=True), ) self.featureTransform = nn.Sequential( nn.Conv2d(64, 64,1), nn.BatchNorm2d(64), nn.ReLU(inplace=True), nn.Conv2d(64, 128,1), nn.BatchNorm2d(128), nn.ReLU(inplace=True), nn.Conv2d(128, 1024,1), nn.BatchNorm2d(1024), nn.ReLU(inplace=True), nn.MaxPool2d((point_num, 1)), ) self.featureFC=nn.Sequential( nn.Linear(1024, 512), nn.ReLU(inplace=True), nn.Linear(512, 256), nn.ReLU(inplace=True), nn.Linear(256, 64*64), ) self.mlp2=nn.Sequential( nn.Conv2d(64,64,1), nn.BatchNorm2d(64), nn.ReLU(inplace=True), nn.Conv2d(64,128,1), nn.BatchNorm2d(128), nn.ReLU(inplace=True), nn.Conv2d(128, 1024, 1), nn.BatchNorm2d(1024), nn.ReLU(inplace=True), ) self.fc=nn.Sequential( nn.Linear(1024,512), nn.ReLU(inplace=True), nn.Linear(512,256), nn.ReLU(inplace=True), #nn.Dropout(p=0.7,inplace=True),對於ShapeNet資料集來說,用dropout反而準確率會降低 nn.Linear(256,16), nn.Softmax(dim=1), ) self.inputFC[4].weight.data=torch.zeros(3*3,256) self.inputFC[4].bias.data=torch.eye(3).view(-1) def forward(self, x): #[B, N, XYZ] ''' B:batch_size N:point_num K:k_classes XYZ:input_features ''' batch_size=x.size(0)#batchsize大小 x=x.unsqueeze(1) #[B, 1, N, XYZ] t_net=self.inputTransform(x) #[B, 1024, 1,1] t_net=t_net.squeeze() #[B, 1024] t_net=self.inputFC(t_net) #[B, 3*3] t_net=t_net.view(batch_size,3,3)#[B, 3, 3] x=x.squeeze() #[B, N, XYZ] x=torch.stack([x_item.mm(t_item) for x_item,t_item in zip(x,t_net)])#[B, N, XYZ]# 因為mm只能二維矩陣之間,故逐個乘再拼起來 x=x.unsqueeze(1) #[B, 1, N, XYZ] x=self.mlp1(x) #[B, 64, N, 1] t_net=self.featureTransform(x) #[B, 1024, 1, 1] t_net=t_net.squeeze() #[B, 1024] t_net=self.featureFC(t_net) #[B, 64*64] t_net=t_net.view(batch_size,64,64)#[B, 64, 64] x=x.squeeze().permute(0,2,1) #[B, N, 64] x=torch.stack([x_item.mm(t_item)for x_item,t_item in zip(x,t_net)])#[B, N, 64] x=x.permute(0,2,1).unsqueeze(-1)#[B, 64, N, 1] x=self.mlp2(x) #[B, N, 64] x,_=torch.max(x,2) #[B, 1024, 1] x=self.fc(x.squeeze()) #[B, K] return x EPOCHES=100 POINT_NUM=2048 train_loader=getData.get_dataLoader(train=True) test_loader=getData.get_dataLoader(train=False) net=PointNet(POINT_NUM).cuda() optimizer=torch.optim.Adam(net.parameters(),weight_decay=0.001) loss_function=nn.CrossEntropyLoss() for epoch in range(EPOCHES): time_start=datetime.datetime.now() net.train() for cloud,label in train_loader: cloud,label=cloud.cuda(),label.cuda() out = net(cloud) loss=loss_function(out,label) optimizer.zero_grad() loss.backward() optimizer.step() total=0 net.eval() for cloud,label in test_loader: cloud,label=cloud.cuda(),label.cuda() out=net(cloud) _,pre=torch.max(out,1) correct=(pre==label).sum() total+=correct.item() time_end=datetime.datetime.now() time_span_str=str((time_end-time_start).seconds) print(str(epoch+1)+"迭代期準確率:"+ str(total/len(test_loader.dataset))+"耗時"+time_span_str+"S") #python的強大之處 #acc=sum([(torch.max(net(cloud.cuda()),1)[1]==label.cuda()).sum() for cloud,label in test_loader]).item()/len(test_loader.dataset)
就是上面的配置,對於所使用的ShapeNet資料集,準確度可以達到百分之93以上。如發現什麼問題bug,請留言。