PointNet++: Deep Hierarchical Feature Learning on Point Sets in a Metric Space 論文閱讀與實現
文章大意:
本文主要面向空間點集的分類及語義分割問題,下面的討論與實現以語義分割為例進行介紹。下面是要解決的問題的示意圖:
常見的2D語義分割問題一般的資料格式可能是語義邊界——如Polygon Annotation(多邊形(Polygon)例項分割論文閱讀與實現),或如FCN等的pixel級別的資料,也分別對應關於“位置”的分類與關於“類別”的迴歸。
本文討論3D語義分割。一個可以想到的問題是維度增高所帶來的問題,從資料來源的角度要找到一種Annotation使得網路能夠消受得起,並能起到較好的估計效果。這兩個問題是相應網路結構所要著重解決的。該文章所面向的資料格式是3D點雲,相當於採取了類似pixel級別的處理。採取這種手段面向效果,資料的豐富性可以得到保證,但於此同時增大了對相應網路效能的擔憂。
給定3D點雲資料,這部分資料對應特徵可以分成兩部分——位置特徵、“畫素”特徵。作者在解釋自己提出的方法時與RNN的序列想法(RNN對permutation的依賴,位置permutation的複雜度)做了對比,指出了序列觀點建模的不可行性。在這種問題中必須丟掉轉而思考點集流形觀點(在傳統語義分割中有使用CRF的方法來做的——比如U-Net+CRF,可以說在這個資料場景下,可能不能用了)。
面對多例項問題,難以使用RNN結構,或許是很多人考慮Pooling的一個原因。Pooling的兩層含義在於聚合性及序的無關性。保證序的無關性的同時如何同時兼顧聚合效果(比如最基礎的,解決多例項問題時適當選區特徵視窗的窗寬以保證特徵提取對於空間結構的有效性)。
從數學的角度,當我們將點集的語義分割作者給出了進行空間語義分割的問題同時也是解決方案的數學表示:
該定理的證明基本上就是通過點集連續函式連續性的(Cauchy)定義,構造關於閉區間的一個點集分割,利用這種分割構造出上述定理中待收斂的函式。數學分析中關於區間分割、尋找子列(基本上就是利用Cauchy列進行派生)的證明方法,基本上是數學分析中比較難的內容。
從上述定理中可以看到作者將使用MaxPoolling的可交換性解決上述問題。下面是PointNet的網路結構圖:
n代表一個點集輸入的個數,segmentation中第一層網路完成local feature與global feature的fuse。網路結構簡單清晰,圖中沒有詳細展開的T-net為近似正交(通過正交損失逼迫正交)矩陣的生產網路,用於校準網路特徵。(這裡的MLP層除了最後一層都用Relu及BatchNormalization進行作用,在實測時如果不改激發、去掉BatchNormalization從量綱穩定角度不可取)
最後抽象成一個pixel級別的分類問題。(這個網路的關鍵點基本上是Pooling)
如果將上述網路結構概括為一個卷積,PointNet的卷積結構相當於固定尺度的(與點集本身的空間結構無關),這樣會導致一個問題,比如文字中極度多標籤的分類問題——見:http://nyc.lti.cs.cmu.edu/yiming/Publications/jliu-sigir17.pdf 中關於極度多標籤分類的討論,其解決方案是被動地改變卷積核的尺度(動態MaxPooling),這源於難以動態的捕獲及調節語義層級的“窗寬”,對應到這個問題,這種位置上的特徵實際上是一種“分詞”,如果能夠找到一種“分詞方法”(相當於建立“語義座標”)對於不同語義關係的token放入不同的卷積核中,或許是另一種解決這個問題的方法。
“座標”作為特徵描述的重要組成部分,在特徵相對座標分佈不均勻的情況下,往往十分重要,這時將其相對關係納入到具體結構顯得很重要。
把這個問題投影到PointNet上來,也同時存在這個問題,當特徵存在類似的關係時如何找到一種座標抽象或者“聚合”方案是下面PointNet++解決的問題。
下面先給出網路示意圖:
此網路結構的基本思想是在層次特徵層對於點集特徵進行關於質心的抽象,具體就是初始化若干質心(原文使用最遠點進行均勻刻畫),根據質心確定的cluster在cluster內使用類似PointNet中得到global feature的結構得到cluster的特徵抽象,並迭代進行此過程;此過程又可以看作PointNet中得到的global feature的區域性敏感(區域性精確)描述(空間離散化)。之後的Segmentation過程相當於先通過上述特徵的“插值”(與質心距離成反比)將這部分特徵逐步global化,並伴隨跳連的區域性特徵進行特徵融合,進行更為精確的特徵描述,並迭代進行此過程。
PointNet++在目標上是為了解決PointNet的多尺度Pooling問題,在手段上選取了global feature區域性精確“離散化”插值的方式。
對於一些細節先給出部分論文截圖及註釋:
Sampling Layer選取FPS方式(NP難問題),在自實現(下面的實現)中用Kmeans++的初始點選擇演算法進行替換。
PointNet Layer基本參照上述PointNet T-net的結構給出,具體引數見實現。
PointNet Layer多尺度特徵的融合有下面兩種方式。(多個Pooling特徵融合)
自實現是(a),可能有些人會說(b)有些像Stack
Segmentation的特徵插值使用距離倒數加權:
K elements cluster的選取使用KD-Tree,作者與KNN進行了對比,差別不算太大。
下面針對特定資料集給出實現:
資料集簡要說明:
選取的是兩種資料集中比較大的那個,資料集的點量確實比較大。這從一方面方便了增強過程(可以直接用抽樣代替增強)。同時給單個數據(一個3D點集檔案)的資料分割帶來一定挑戰,實驗證明不進行分割、直接抽樣,效果不太好;進行分割但分割細密(關於x y各分成10份,生成100個grid),效果也不太好,分別對應多例項窗寬的兩種情形,下面的實驗是將x y各分成3份進行的。
本文的資料處理使用了https://github.com/daavoo/pyntcloud 要進行安裝,不建議用matplotlib (相應原因可以在生成PyntCloud物件呼叫plot選取backend為matplotlib窺得。)
從網站下取資料集,並進行label與點雲特徵合併(去掉未標註點)如下:
def join_source_lable_file(source = r"E:\Temp\bildstein_station1_xyz_intensity_rgb\bildstein_station1_xyz_intensity_rgb.txt",
label = r"E:\Temp\sem8_labels_training\bildstein_station1_xyz_intensity_rgb.labels",
sample_rate = 1):
header_str = \
'''ply
format ascii 1.0
element vertex 0
property float x
property float y
property float z
property float intensity
property uchar diffuse_red
property uchar diffuse_green
property uchar diffuse_blue
property uchar label
end_header\n'''
req_file_path = label.split("\\")[-1].replace(".labels", "_{}.ply")
req_file_list = list(map(lambda idx: open(req_file_path.format(idx), "w"), range(sample_rate)))
for i in range(sample_rate):
req_file_list[i].write(header_str)
i = 0
with open(label, "r") as l:
with open(source, "r") as f:
while True:
l_line = l.readline()
if not l_line:
break
if l_line != "0 \n":
idx = i % sample_rate
req_file_list[idx].write("{} {}\n".format(f.readline().strip(), l_line.strip()))
if i % int(1e6) == 0:
print("i {}".format(i))
i += 1
for i in range(sample_rate):
req_file_list[i].close()
print("all write end")
return req_file_path
def join_train_files():
source_target_list = [
{"source": r"E:\Temp\bildstein_station1_xyz_intensity_rgb\bildstein_station1_xyz_intensity_rgb.txt",
"target": r"E:\Temp\sem8_labels_training\bildstein_station1_xyz_intensity_rgb.labels",
"sample_rate": 1,},
{"source": r"E:\Temp\bildstein_station3_xyz_intensity_rgb\bildstein_station3_xyz_intensity_rgb.txt",
"target": r"E:\Temp\sem8_labels_training\bildstein_station3_xyz_intensity_rgb.labels",
"sample_rate": 1,},
{"source": r"E:\Temp\bildstein_station5_xyz_intensity_rgb\bildstein_station5_xyz_intensity_rgb.txt",
"target": r"E:\Temp\sem8_labels_training\bildstein_station5_xyz_intensity_rgb.labels",
"sample_rate": 1,},
{"source": r"E:\Temp\domfountain_station1_xyz_intensity_rgb\domfountain_station1_xyz_intensity_rgb.txt",
"target": r"E:\Temp\sem8_labels_training\domfountain_station1_xyz_intensity_rgb.labels",
"sample_rate": 1,},
{"source": r"E:\Temp\domfountain_station2_xyz_intensity_rgb\domfountain_station2_xyz_intensity_rgb.txt",
"target": r"E:\Temp\sem8_labels_training\domfountain_station2_xyz_intensity_rgb.labels",
"sample_rate": 1,},
{"source": r"E:\Temp\domfountain_station3_xyz_intensity_rgb\domfountain_station3_xyz_intensity_rgb.txt",
"target": r"E:\Temp\sem8_labels_training\domfountain_station3_xyz_intensity_rgb.labels",
"sample_rate": 1,},
{
"source": r"E:\Temp\sg27_station1_intensity_rgb\sg27_station1_intensity_rgb.txt",
"target": r"E:\Temp\sem8_labels_training\sg27_station1_intensity_rgb.labels",
"sample_rate": 4 * 9,
},
{
"source": r"E:\Temp\sg27_station2_intensity_rgb\sg27_station2_intensity_rgb.txt",
"target": r"E:\Temp\sem8_labels_training\sg27_station2_intensity_rgb.labels",
"sample_rate": 6 * 9,
},
{
"source": r"E:\Temp\sg27_station4_intensity_rgb\sg27_station4_intensity_rgb.txt",
"target": r"E:\Temp\sem8_labels_training\sg27_station4_intensity_rgb.labels",
"sample_rate": 3 * 9,
},
{
"source": r"E:\Temp\sg27_station9_intensity_rgb\sg27_station9_intensity_rgb.txt",
"target": r"E:\Temp\sem8_labels_training\sg27_station9_intensity_rgb.labels",
"sample_rate": 3 * 9,
},
]
req_source_list = []
for inner_dict in source_target_list:
source, target, sample_rate = inner_dict["source"], inner_dict["target"], inner_dict["sample_rate"]
req_source_list.append(join_source_lable_file(source, target, sample_rate))
return req_source_list
def join_valid_files():
source_target_list = [
{
"source": r"E:\Temp\sg27_station5_intensity_rgb\sg27_station5_intensity_rgb.txt",
"target": r"E:\Temp\sem8_labels_training\sg27_station5_intensity_rgb.labels",
"sample_rate": 2 * 9,
},
]
req_source_list = []
for inner_dict in source_target_list:
source, target = inner_dict["source"], inner_dict["target"]
req_source_list.append(join_source_lable_file(source, target))
return req_source_list
if __name__ == "__main__":
join_train_files()
join_valid_files()
pass
採取n = 1024 * 4的設定,直接生成pkl檔案待快速抽樣呼叫:(包含資料集匯出過程)
(BoundingBoxFilter呼叫compute方法可能顯示poinss物件未初始化,只要建構函式呼叫extract_info方法即可, valid_xx_part_second進行資料平衡性篩選)
from pyntcloud import PyntCloud
from pyntcloud.filters import BoundingBoxFilter
import numpy as np
import gc
import glob
import pickle
from sklearn.model_selection import train_test_split
from random import choice
# use single stats to produce 3d split stats
def single_stats(file_path = r"C:\Coding\Python\pyntcloud_test\pyntcloud_dataloader\sg27_station1_intensity_rgb.ply"
,grid_num = 10, filter_num = 1024 * 8):
point_file = PyntCloud.from_file(file_path)
xy_array = point_file.points[["x", "y"]].values
x_max, y_max = np.max(xy_array, axis=0)
x_min, y_min = np.min(xy_array, axis=0)
del xy_array
gc.collect()
req_xy_t4_list = []
x_grid = np.linspace(x_min, x_max, grid_num)
y_grid = np.linspace(y_min, y_max, grid_num)
points_num_array = np.zeros(shape=[len(x_grid) - 1, len(y_grid) - 1])
for i in range(len(x_grid) - 1):
for j in range(len(y_grid) - 1):
xmin, xmax = x_grid[i], x_grid[i + 1]
ymin, ymax = y_grid[j], y_grid[j + 1]
compute_mask = BoundingBoxFilter(pyntcloud = point_file,
min_x=xmin, max_x=xmax, min_y=ymin,
max_y=ymax, ).compute()
points_num_array[i][j] = np.sum(compute_mask.astype(np.int32))
if points_num_array[i][j] > filter_num:
req_xy_t4_list.append((xmin, xmax, ymin, ymax))
print("points_num_array :")
print(points_num_array)
return (point_file ,req_xy_t4_list)
def single_df_gen(point_file, req_xy_t4_list, ply_file):
#points_df = PyntCloud.from_file(ply_file).points
points_df = point_file.points
total_X, total_y = points_df[["x", "y", "z", "intensity", "diffuse_red", "diffuse_green", "diffuse_blue"]].values, points_df["label"].values
pkl_path_format = ply_file.replace(".ply", "{}_.pkl")
for idx ,t4 in enumerate(req_xy_t4_list):
x_min, x_max, y_min, y_max = t4
x_part_first, xx_part_first = total_X[np.where(total_X[:, 0] >= x_min)], total_y[np.where(total_X[:, 0] >= x_min)]
x_part_second, xx_part_second = x_part_first[np.where(x_part_first[:, 0] < x_max)], xx_part_first[np.where(x_part_first[:, 0] < x_max)]
y_part_first, yy_part_first = x_part_second[np.where(x_part_second[:, 1] >= y_min)], xx_part_second[np.where(x_part_second[:, 1] >= y_min)],
y_part_second, yy_part_second = y_part_first[np.where(y_part_first[:, 1] < y_max)], yy_part_first[np.where(y_part_first[:, 1] < y_max)]
print(y_part_second.shape, yy_part_second.shape)
with open(pkl_path_format.format(idx), "wb") as f:
pickle.dump((y_part_second, yy_part_second), f)
del total_y, total_X, points_df
gc.collect()
def serlize_pd_df(ply_dir):
ply_file_list = glob.glob(ply_dir + "\\" + "*.ply")
for ply_file in ply_file_list:
point_file ,req_xy_t4_list = single_stats(ply_file)
single_df_gen(point_file, req_xy_t4_list, ply_file)
print("{} pkl end".format(ply_file))
def data_loader(pkl_dir, batch_size = 4, n = 1024 * 8, centered = False):
files = list(glob.glob(pkl_dir + "\\" + "*.pkl"))
def valid_xx_part_second(xx_part_second, min_categories = 2, ratio = 0.5):
categories, category_num = np.unique(xx_part_second, return_counts=True)
if len(categories) >= min_categories:
sort_list = sorted(category_num.tolist())
if sort_list[-2] > sort_list[-1] * ratio:
return True
return False
# not yield above max_times
def single_gen(file, max_times = 10):
with open(file, "rb") as f:
x_part_second, xx_part_second = pickle.load(f)
if valid_xx_part_second(xx_part_second):
if centered:
x_part_second[:, 0] = (x_part_second[:, 0] - x_part_second[:, 0].min()) / (x_part_second[:, 0].max() - x_part_second[:, 0].min())
x_part_second[:, 1] = (x_part_second[:, 1] - x_part_second[:, 1].min()) / (x_part_second[:, 1].max() - x_part_second[:, 1].min())
times = int(len(x_part_second) / n)
for i in range(min(times, max_times)):
input_cloud_points, _, targets, _ = train_test_split(x_part_second, xx_part_second, train_size=n, shuffle=True)
yield (input_cloud_points, targets)
batch_input_cloud_points = np.zeros(shape=[batch_size, n, 7], dtype=np.float32)
batch_targets = np.zeros(shape=[batch_size, n], dtype=np.int32)
start_idx = 0
while True:
file = choice(files)
for input_cloud_points, targets in single_gen(file):
batch_input_cloud_points[start_idx] = input_cloud_points
batch_targets[start_idx] = targets
start_idx += 1
if start_idx == batch_size:
# random sample points yield without shuffle
yield (batch_input_cloud_points, batch_targets - 1)
batch_input_cloud_points = np.zeros(shape=[batch_size, n, 7], dtype=np.float32)
batch_targets = np.zeros(shape=[batch_size, n], dtype=np.int32)
start_idx = 0
train_dir = r"C:\Coding\Python\train_split_dir"
valid_dir = r"C:\Coding\Python\valid_split_dir"
if __name__ == "__main__":
serlize_pd_df(train_dir)
serlize_pd_df(valid_dir)
pass
下面先給出模型構建部分,由於多次使用MLP故用sonnet進行簡單MLP構建與呼叫。
為了在MLP中間層嵌入Batchnormalization,把二者融合如下:
import tensorflow as tf
def batch_normal_with_relu(inputs, name = None):
with tf.variable_scope("batch_normal_with_relu_{}".format(name)):
fc_mean, fc_var = tf.nn.moments(
inputs,
axes=[0],
)
out_size = int(inputs.get_shape()[-1])
scale = tf.Variable(tf.ones([out_size]))
shift = tf.Variable(tf.zeros([out_size]))
epsilon = 0.001
outputs = tf.nn.relu(tf.nn.batch_normalization(inputs, fc_mean, fc_var, shift, scale, epsilon))
return outputs
if __name__ == "__main__":
pass
這樣就可以如下得到PointNet結構:
import tensorflow as tf
from sonnet.python.modules.nets import mlp
from sonnet.python.modules import basic
from model.model_utils import batch_normal_with_relu
from functools import partial
from sklearn.metrics import f1_score
from data_preprocess.sample_dataset_generator import train_dir, valid_dir, data_loader
import keras.backend as K
sess = tf.Session()
K.set_session(sess)
class PointNet(object):
def __init__(self, n = 100, m = 20, batch_size = 4, reg_val = 0.001):
# n indicate points num, m indicate segmentation label num
self.input_cloud_points = tf.placeholder(tf.float32, [None, n, 7])
# max target val m - 1
self.target = tf.placeholder(tf.int32, [None, n])
self.n = n
self.m = m
self.batch_size = batch_size
self.reg_val = reg_val
self.model_construct()
self.opt_construct()
def T_net_layer(self, input, name = None):
matrix_dim = int(input.get_shape()[-1])
assert matrix_dim in [7, 64]
batch_normalization_1 = partial(batch_normal_with_relu, name = "t_net_mlp_layer_first")
t_net_mlp_layer_first = mlp.MLP([64, 128, 1024], name = "t_net_mlp_{}_first".format(name),
activation=batch_normalization_1)
t_net_mlp_output_first = basic.BatchApply(t_net_mlp_layer_first)(input)
max_pool_output = tf.reshape(tf.layers.max_pooling1d(inputs=t_net_mlp_output_first, pool_size=self.n, strides=1,
name="t_net_maxpool_f"), [-1, 1024])
batch_normalization_2 = partial(batch_normal_with_relu, name = "t_net_mlp_layer_second")
t_net_mlp_layer_second = mlp.MLP([512, 256, matrix_dim * matrix_dim], name = "t_net_mlp_{}_second".format(name),
activation=batch_normalization_2)
t_net_mlp_output_second = t_net_mlp_layer_second(max_pool_output)
# matrix various with single input sample
batch_output_matrix = tf.reshape(t_net_mlp_output_second, [-1, matrix_dim, matrix_dim])
def mul_input_matrix(fuse_input):
return tf.matmul(tf.reshape(fuse_input[:-1 * matrix_dim * matrix_dim], [1, matrix_dim]), tf.reshape(fuse_input[-1 * matrix_dim * matrix_dim:], [matrix_dim, matrix_dim]))
transformed_input = tf.map_fn(mul_input_matrix, tf.concat([tf.reshape(input, [-1, matrix_dim]), tf.tile(tf.reshape(batch_output_matrix, [-1, matrix_dim * matrix_dim])[:self.batch_size, ...], [self.n, 1])], axis=-1))
transformed_input_reshape = tf.reshape(transformed_input, [-1, self.n, matrix_dim])
return transformed_input_reshape, batch_output_matrix
def compute_orthogonal_loss(self, batch_transform_matrix):
def produce_inner_prod_m(matrix):
return tf.matmul(tf.transpose(matrix, [1, 0]), matrix)
# batch_transform_matrix [batch, matrix_dim, matrix_dim]
matrix_dim = int(batch_transform_matrix.get_shape()[-1])
left_hand = batch_transform_matrix[:self.batch_size, ...]
left_hand = tf.map_fn(produce_inner_prod_m, left_hand)
right_hand = tf.tile(tf.expand_dims(tf.eye(num_rows=matrix_dim), 0), [self.batch_size, 1, 1])
return tf.nn.l2_loss(left_hand - right_hand)
def model_construct(self):
transformed_input, self.batch_transform_matrix_first = self.T_net_layer(self.input_cloud_points, name="T_net_0")
batch_normalization_3 = partial(batch_normal_with_relu, name = "mlp_layer_first")
mlp_layer_first = mlp.MLP([64, 64], name = "mlp_layer_first", activation=batch_normalization_3)
mlp_layer_first_output = basic.BatchApply(mlp_layer_first)(transformed_input)
# [batch, n, 64]
transformed_feature, self.batch_transform_matrix_second = self.T_net_layer(mlp_layer_first_output, name="T_net_1")
batch_normalization_4 = partial(batch_normal_with_relu, name = "mlp_layer_second")
mlp_layer_second = mlp.MLP([64, 128, 1024], name = "mlp_layer_second", activation=batch_normalization_4)
mlp_layer_second_output = basic.BatchApply(mlp_layer_second)(transformed_feature)
# [batch, 1, 1024]
max_pool_output = tf.reshape(tf.layers.max_pooling1d(inputs=mlp_layer_second_output, pool_size=self.n, strides=1,
name="t_net_maxpool_s"), [-1, 1, 1024])
concat_feature = tf.concat([transformed_feature, tf.tile(max_pool_output, [1, self.n, 1])], axis=-1)
batch_normalization_5 = partial(batch_normal_with_relu, name = "mlp_layer_third")
mlp_layer_third = mlp.MLP([512, 256, 128], name = "mlp_layer_third", activation=batch_normalization_5)
self.point_features = basic.BatchApply(mlp_layer_third)(concat_feature)
batch_normalization_6 = partial(batch_normal_with_relu, name = "mlp_layer_fourth")
mlp_layer_fourth = mlp.MLP([128, self.m], name = "mlp_layer_fourth", activation=batch_normalization_6)
self.output_scores = basic.BatchApply(mlp_layer_fourth)(self.point_features)
def opt_construct(self):
self.first_orthogonal_loss = self.compute_orthogonal_loss(self.batch_transform_matrix_first)
self.second_orthogonal_loss = self.compute_orthogonal_loss(self.batch_transform_matrix_second)
labels = tf.one_hot(self.target, depth=self.m)
logits = self.output_scores
self.logits = logits
self.segmentation_loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels = labels,
logits = logits))
self.total_loss = self.segmentation_loss + self.reg_val * (self.first_orthogonal_loss + self.second_orthogonal_loss)
self.prediction = tf.argmax(tf.nn.softmax(logits, axis=-1), axis=-1)
self.accuracy = tf.reduce_mean(tf.cast(tf.equal(tf.cast(self.prediction, tf.int32), self.target), tf.float32))
self.train_op = tf.train.AdamOptimizer(0.001).minimize(self.total_loss)
@staticmethod
def train(input_sess):
import os
import numpy as np
from uuid import uuid1
from copy import deepcopy
header_str = \
'''ply
format ascii 1.0
element vertex 0
property float x
property float y
property float z
property float intensity
property uchar diffuse_red
property uchar diffuse_green
property uchar diffuse_blue
property uchar pred_labels
property uchar true_labels
end_header\n'''
def centered_xy(cloud_points):
input_cloud_points = deepcopy(cloud_points)
input_cloud_points_list = [ele for ele in input_cloud_points]
def process_single(x_part_second):
x_part_second[:, 0] = (x_part_second[:, 0] - x_part_second[:, 0].min()) / (x_part_second[:, 0].max() - x_part_second[:, 0].min())
x_part_second[:, 1] = (x_part_second[:, 1] - x_part_second[:, 1].min()) / (x_part_second[:, 1].max() - x_part_second[:, 1].min())
return x_part_second
return np.asarray(list(map(process_single, input_cloud_points_list)), np.float32)
def serlize_points_and_label(input_cloud_points, pred_labels, true_labels, epoch):
if not os.path.exists(r"C:\Coding\Python\conclusion_{}".format(epoch)):
os.mkdir(r"C:\Coding\Python\conclusion_{}".format(epoch))
input_cloud_points, pred_labels, true_labels = input_cloud_points[0], pred_labels[0], true_labels[0]
concat_ndarray = np.concatenate([input_cloud_points, pred_labels.reshape([len(pred_labels), 1]), true_labels.reshape([len(true_labels), 1])],
axis=-1)
tail_str = "\n".join([" ".join(str(inner_ele) for inner_ele in line_array) for line_array in concat_ndarray])
full_str = header_str + tail_str
with open(r"C:\Coding\Python\conclusion_{}\{}.ply".format(epoch ,uuid1()), "w") as f:
f.write(full_str)
batch_size = 4
n = 1024 * 8
m = 8
train_gen = data_loader(train_dir, batch_size=batch_size)
valid_gen = data_loader(valid_dir, batch_size=batch_size)
model_ext = PointNet(batch_size=batch_size, n = n, m = m)
step = 0
save_epoch = 3
saver = tf.train.Saver()
with input_sess as sess:
if os.path.exists(r"C:\Coding\Python\PointNet\pm_{}.meta".format(save_epoch)):
saver.restore(sess ,r"C:\Coding\Python\PointNet\pm_{}".format(save_epoch))
print("load exist")
else:
sess.run(tf.global_variables_initializer())
print("init_new")
while True:
step += 1
train_data = train_gen.__next__()
if step % 100 == 0:
print("train data consume end !")
saver.save(sess, r"C:\Coding\Python\PointNet\pm_{}".format(save_epoch))
save_epoch += 1
step = 0
input_cloud_points, targets = train_data
_, train_loss, train_acc, train_pred, train_logits = sess.run([model_ext.train_op, model_ext.total_loss, model_ext.accuracy, model_ext.prediction,
model_ext.logits],
feed_dict={
model_ext.input_cloud_points: centered_xy(input_cloud_points),
model_ext.target: targets
})
if step % 1 == 0:
train_targets = targets.reshape([-1])
train_pred = train_pred.reshape([-1])
valid_data = valid_gen.__next__()
input_cloud_points, targets = valid_data
valid_loss, valid_acc, valid_pred = sess.run([model_ext.total_loss, model_ext.accuracy, model_ext.prediction],
feed_dict={
model_ext.input_cloud_points: centered_xy(input_cloud_points),
model_ext.target: targets
})
serlize_points_and_label(input_cloud_points, valid_pred * 32, targets * 32, save_epoch)
valid_targets = targets.reshape([-1])
valid_pred = valid_pred.reshape([-1])
print("epoch : {} step : {} train_loss : {:.2f} train_acc : {:.2f} valid_loss : {:.2f} valid_acc : {:.2f} train_f1 : {:.2f} valid_f1 : {:.2f}".format(save_epoch, step, train_loss, train_acc, valid_loss, valid_acc,
f1_score(train_targets, train_pred, average="macro"), f1_score(valid_targets, valid_pred, average="macro")))
if __name__ == "__main__":
PointNet.train(sess)
這裡沒有用到keras,感興趣的同學可以用AutoPooling的keras替換MaxPooling試一試看看效果。
下面給出PointNet++所需要的scipy函式
import random
import numpy as np
from scipy.spatial import KDTree
from functools import reduce
from sklearn.cluster.k_means_ import _init_centroids
def distance(p0, p1):
p0, p1 = map(np.array, [p0, p1])
return ((p0 - p1) ** 2).mean()
# input points and output solution_set are all list type
def incremental_farthest_search(points, k):
#start_time = time()
points = points.tolist()
remaining_points = points[:]
solution_set = []
solution_set.append(remaining_points.pop( \
random.randint(0, len(remaining_points) - 1)))
for _ in range(k-1):
distances = [distance(p, solution_set[0]) for p in remaining_points]
for i, p in enumerate(remaining_points):
for j, s in enumerate(solution_set):
distances[i] = min(distances[i], distance(p, s))
solution_set.append(remaining_points.pop(distances.index(max(distances))))
#print("incremental_farthest_search time_consume: {}".format(time() - start_time))
return np.array(solution_set, dtype=np.float32)
def incremental_farthest_search_with_kmeans_pp_centroids(points, k):
#start_time = time()
solution_set = _init_centroids(points, k, init="k-means++")
#print("incremental_farthest_search time_consume: {}".format(time() - start_time))
return np.array(solution_set, dtype=np.float32)
def ball_tree_query_with_corr(points, centroid_points, K = 3):
kdTree = KDTree(points)
# return shape [len(centroid_points), K] the element constructed by indices.
#mean_distance = (np.sum((points[1:, ...] - points[:-1,...]) ** 2, axis=-1) ** 0.5).mean()
#return kdTree.query(centroid_points, k = K, distance_upper_bound=mean_distance * 10)
return kdTree.query(centroid_points, k = K, eps=0.0, distance_upper_bound=np.inf)
def prepare_process_before_PointNet_with_feature(points, features, centroid_points, K):
#start_time = time()
d = 3
# common not add zeros for Degree of freedom decrease and not contain information
# features [batch, c]
# point_coordinates [batch, 2] ball_tree_query_output [len(centroid_points), K]
# centroid_points [len(centroid_points), 2]
#temp_start = time()
ball_tree_query_corr, ball_tree_query_indexes = ball_tree_query_with_corr(points, centroid_points, K)
#print(ball_tree_query_corr.shape, ball_tree_query_indexes.shape)
#print("prepare_process_before_PointNet_with_feature ball_tree time_consume: {}".format(time() - temp_start))
flatten_indices = np.reshape(ball_tree_query_indexes, [-1]).tolist()
flatten_coordinate_array = np.asarray(reduce(lambda a, b : a + b, map(lambda index: points[int(index): int(index) + 1].tolist(), flatten_indices)))
# [len(centroid_points), K, d]
inner_points_coordinates = flatten_coordinate_array.reshape([-1 ,K, d])
#inner_points_coordinates = flatten_coordinate_array.reshape([len(centroid_points) ,-1, d])
# [len(centroid_points), K - 1, d]
normalized_points = inner_points_coordinates[:,1:,:] - inner_points_coordinates[:, 0:1, :]
# [len(centroid_points), K - 1, C]
req_features = features[ball_tree_query_indexes[:, 1:]]
#print("prepare_process_before_PointNet_with_feature time_consume: {}".format(time() - start_time))
# [len(centroid_points), K - 1, d + C]
return np.asarray(np.concatenate([normalized_points, req_features], axis=-1), dtype=np.float32)
def segmentation_precedure(points, features, centroid_points, K):
#start_time = time()
d = 3
#temp_start = time()
ball_tree_query_corr, ball_tree_query_indexes = ball_tree_query_with_corr(points, centroid_points, K)
#print("segmentation_precedure ball_tree time_consume: {}".format(time() - temp_start))
flatten_indices = np.reshape(ball_tree_query_indexes, [-1]).tolist()
flatten_coordinate_array = np.asarray(reduce(lambda a, b : a + b, map(lambda index: list(points[int(index): int(index) + 1]), flatten_indices)))
# [len(centroid_points), K, d]
inner_points_coordinates = flatten_coordinate_array.reshape([-1 ,K, d])
#inner_points_coordinates = flatten_coordinate_array.reshape([len(centroid_points) ,-1, d])
# [len(centroid_points), K]
distance_array = 1 / (np.sum((inner_points_coordinates - centroid_points[:, np.newaxis]) ** 2, axis=-1) + np.finfo(np.float32).eps)
# [len(centroid_points), K, C]
req_features = features[ball_tree_query_indexes]
weighted_features = req_features * distance_array[..., np.newaxis]
# [len(centroid_points), C]
weighted_features = np.sum(weighted_features / np.sum(weighted_features, axis=1)[:, np.newaxis], axis=1)
#print("segmentation_precedure time_consume: {}".format(time() - start_time))
return weighted_features
if __name__ == "__main__":
pass
使用tf.py_func就可以如下得到PointNet++的網路結構(真正考慮效能的相應實現是要用C++重寫這部分Operation的)
import tensorflow as tf
from sonnet.python.modules.nets import mlp
from sonnet.python.modules import basic
from model.model_utils import batch_normal_with_relu
from functools import partial
from sklearn.metrics import f1_score
from model.model_utils_final_py import incremental_farthest_search, prepare_process_before_PointNet_with_feature, segmentation_precedure, \
incremental_farthest_search_with_kmeans_pp_centroids
from data_preprocess.sample_dataset_generator import train_dir, valid_dir, data_loader
import keras.backend as K
sess = tf.Session()
K.set_session(sess)
class PointNet_pp(object):
def __init__(self, n = 1024 * 8, m = 8, batch_size = 4,
N1 = 512 * 2, N2 = 256 * 2, K = 8, d = 3, C = 4, C1 = 16, C2 = 32, C3 = 64,
use_kpp_init = True):
# n indicate points num, m indicate segmentation label num
self.input_cloud_points = tf.placeholder(tf.float32, [None, n, d + C])
# max target val m - 1
self.target = tf.placeholder(tf.int32, [None, n])
self.n = n
self.m = m
self.batch_size = batch_size
self.N1 = N1
self.N2 = N2
self.K = K
self.d = d
self.C = C
self.C1 = C1
self.C2 = C2
self.C3 = C3
self.m = m
self.use_kpp_init = use_kpp_init
self.Hierarchical_point_set_feature_learning_layer()
self.Segmentation_layer()
self.opt_construct()
def sampling_grouping_layer(self, Ni, inputs):
# Ni in {N1, N2}
# inputs [batch, N, d + c]
def single_sampling_grouping_procedure(input):
d_plus_c = int(input.get_shape()[-1])
# [Ni, d] [Ni, c]
single_corrdinate_part, single_feature = input[..., :self.d], input[..., self.d:]
if self.use_kpp_init:
incremental_farthest_search_op = tf.py_func(incremental_farthest_search_with_kmeans_pp_centroids, [single_corrdinate_part, Ni], tf.float32,
name="incremental_farthest_search_op")
else:
incremental_farthest_search_op = tf.py_func(incremental_farthest_search, [single_corrdinate_part, Ni], tf.float32,
name="incremental_farthest_search_op")
# this part not processed
# [Ni, 1, d]
centroid_corrdinates = tf.reshape(incremental_farthest_search_op, [Ni, 1, self.d])
# [Ni, 1, d + c]
centroid_corrdinates_with_zeros = tf.concat([centroid_corrdinates, tf.zeros([Ni, 1, d_plus_c - self.d])],axis=-1)
# this part had centered
# [Ni, K - 1, d + c]
prepare_process_before_PointNet_op = tf.py_func(prepare_process_before_PointNet_with_feature, [single_corrdinate_part, single_feature, incremental_farthest_search_op, self.K], tf.float32,
name="prepare_process_before_PointNet_op")
########## 32 * 5 * 4
prepare_process_before_PointNet_op = tf.reshape(prepare_process_before_PointNet_op, [Ni, self.K - 1, d_plus_c])
# [Ni, K, d + c]
return tf.concat([centroid_corrdinates_with_zeros, prepare_process_before_PointNet_op], axis=1)
# [batch, Ni, K, d + c]
sampling_grouping_conclusion = tf.map_fn(single_sampling_grouping_procedure, inputs, dtype=tf.float32)
# [batch, Ni, 1, d + c] [batch, Ni, K - 1, d + c]
center_part, centerd_part = sampling_grouping_conclusion[:,:, 0:1, :], sampling_grouping_conclusion[:,:, 1:, :]
# [batch, Ni, d] drop feature dims (which filled by zeros)
center_part = tf.reshape(tf.squeeze(center_part[..., :self.d], axis=[2]), [-1, Ni, self.d])
return center_part, centerd_part
def fuse_corrdinate(self, corrdinate_inputs, pointNet_outputs):
# corrdinate_inputs [batch, Ni, d] pointNet_outputs [batch, Ni, Ci]
# [batch, Ni, d + Ci]
return tf.concat([corrdinate_inputs, pointNet_outputs], axis=-1)
def map_mini_pointNet_slice(self, inputs, slice_list = [3, 5, 7]):
assert slice_list[-1] == self.K - 1
return list(map(lambda slice_idx: inputs[:, :, :slice_idx, :], slice_list))
# implementation of mini_pointNet will reference to T_net_layer in PointNet
def mini_pointNet(self, inputs, Ni ,Ci, name = None):
# inputs [batch, N1, K - 1, d + c] N1 indicate clusters num, K single cluster elements num,
# d dim, c feature num, rigorously, K may not have full elements so will indexed by mask.
# we use lookup to retrieve feature, so use 0 to replace mask in the first time.
# the input of single K cluster must be centered, so the network can transform it without
# difference.
# the output of this layer may be [batch, N1, C1] and use centroid coordinate to expand it to
# [batch, N1, d + C1]
d_plus_c = int(inputs.get_shape()[-1])
# [batch * N1, K - 1, d + c]
K_1 = int(inputs.get_shape()[-2])
reshape_inputs = tf.reshape(inputs, [-1, K_1, d_plus_c])
batch_normalization_1 = partial(batch_normal_with_relu, name = "t_net_mlp_layer_first")
t_net_mlp_layer_first = mlp.MLP([64, 128, 1024], name = "t_net_mlp_{}_first".format(name),
activation=batch_normalization_1)
t_net_mlp_output_first = basic.BatchApply(t_net_mlp_layer_first)(reshape_inputs)
pool_size = int(t_net_mlp_output_first.get_shape()[1])
max_pool_output = tf.reshape(tf.layers.max_pooling1d(inputs=t_net_mlp_output_first, pool_size=pool_size, strides=1,
name="t_net_maxpool"), [-1, 1024])
batch_normalization_2 = partial(batch_normal_with_relu, name = "t_net_mlp_layer_second")
t_net_mlp_layer_second = mlp.MLP([512, 256, Ci], name = "t_net_mlp_{}_second".format(name),
activation=batch_normalization_2)
t_net_mlp_output_second = t_net_mlp_layer_second(max_pool_output)
output = tf.reshape(t_net_mlp_output_second, [-1, Ni, Ci])
return output
def Hierarchical_point_set_feature_learning_layer(self):
# in the first step only use one scale, in the future add mult-scale fuse procedure.
self.center_part_zero = self.input_cloud_points[...,:self.d]
self.mini_pointNet_zero_output = self.input_cloud_points[...,self.d:]
self.fuse_corrdinate_zero = self.input_cloud_points
with tf.variable_scope("sampling_grouping_layer_first"):
self.center_part_first, centerd_part_first = self.sampling_grouping_layer(self.N1, self.input_cloud_points)
centerd_part_first_list = self.map_mini_pointNet_slice(centerd_part_first)
centerd_part_first_feature = []
for i in range(len(centerd_part_first_list)):
with tf.variable_scope("centered_part_first_feature_{}".format(i)):
centerd_part_first_feature.append(self.mini_pointNet(centerd_part_first_list[i], self.N1, self.C1, name="mini_pointNet_first")[:,:self.N1,:self.C1])
self.mini_pointNet_first_output = tf.concat(centerd_part_first_feature, axis=-1)
self.fuse_corrdinate_first = self.fuse_corrdinate(self.center_part_first[:,:self.N1,:self.d], self.mini_pointNet_first_output)
with tf.variable_scope("sampling_grouping_layer_second"):
self.center_part_second, centerd_part_second = self.sampling_grouping_layer(self.N2, self.fuse_corrdinate_first)
centerd_part_second_list = self.map_mini_pointNet_slice(centerd_part_second)
centerd_part_second_feature = []
for i in range(len(centerd_part_second_list)):
with tf.variable_scope("centered_part_second_feature_{}".format(i)):
centerd_part_second_feature.append(self.mini_pointNet(centerd_part_second_list[i], self.N2, self.C2, name="mini_pointNet_first")[:,:self.N1,:self.C1])
self.mini_pointNet_second_output = tf.concat(centerd_part_second_feature, axis=-1)
self.fuse_corrdinate_second = self.fuse_corrdinate(self.center_part_second[:,:self.N2,:self.d], self.mini_pointNet_second_output)
def Segmentation_layer(self):
with tf.variable_scope("interpolate_layer_first"):
# [batch, N1, C2]
interpolate_feature_first = self.interpolate_layer(self.fuse_corrdinate_second, self.center_part_first)
# [batch, N1, C1 + C2]
before_input_pointNet_first = tf.concat([interpolate_feature_first, self.mini_pointNet_first_output], axis=-1)
# [batch, N1, 1, C1 + C2]
before_input_pointNet_first = tf.expand_dims(before_input_pointNet_first, 2)
# [batch, N1, C3]
self.mini_pointNet_third = self.mini_pointNet(before_input_pointNet_first, self.N1, self.C3, name="mini_pointNet_third")
# [batch, N1, d + C3]
self.fuse_corrdinate_third = self.fuse_corrdinate(self.center_part_first[:,:self.N1,:self.d], self.mini_pointNet_third)
with tf.variable_scope("interpolate_layer_second"):
# [batch, N, C3]
interpolate_feature_second = self.interpolate_layer(self.fuse_corrdinate_third, self.center_part_zero)
# [batch, N, C + C3]
before_input_pointNet_second = tf.concat([interpolate_feature_second, self.mini_pointNet_zero_output], axis=-1)
# [batch, N, 1, C + C3]
before_input_pointNet_second = tf.expand_dims(before_input_pointNet_second, 2)
# [batch, N, m]
self.mini_pointNet_fourth = self.mini_pointNet(before_input_pointNet_second, self.n, self.m, name="mini_pointNet_fourth")
def interpolate_layer(self, inputs, output_corrdinates):
# use self.center_part_* to indicate points for interplote
# inputs [batch, Ni, d + Ci] output_corrdinates [batch, Nj, d]
Ni = int(inputs.get_shape()[1])
Nj = int(output_corrdinates.get_shape()[1])
d_plus_ci = int(inputs.get_shape()[-1])
fuse_features = tf.concat([tf.reshape(inputs, [-1, Ni * d_plus_ci]),
tf.reshape(output_corrdinates, [-1, Nj * self.d])],
axis=-1)
def single_interpolate_procedure(input):
# [Ni, d + c] [Nj, d]
inputs_part, output_part = tf.reshape(input[ :Ni * d_plus_ci], [Ni, d_plus_ci]), tf.reshape(input[ Ni * d_plus_ci:], [Nj, self.d])
# [Ni, d] [Ni, c]
points, features = inputs_part[..., :self.d], inputs_part[..., self.d:]
centroid_points = output_part
segmentation_precedure_op = tf.py_func(segmentation_precedure, [points, features, centroid_points, self.K], tf.float32,
name="segmentation_precedure_op")
# [Nj, c]
# c2 * N1 : 32 * 32
segmentation_precedure_op = tf.reshape(segmentation_precedure_op, [Nj, d_plus_ci - self.d])
return segmentation_precedure_op
# [batch, Nj, c]
interpolate_feature = tf.map_fn(single_interpolate_procedure, fuse_features, dtype=tf.float32)
return interpolate_feature
def opt_construct(self):
labels = tf.one_hot(self.target, depth=self.m)
logits = self.mini_pointNet_fourth
self.logits = logits
self.segmentation_loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels = labels,
logits = logits))
self.total_loss = self.segmentation_loss
self.prediction = tf.argmax(tf.nn.softmax(logits, axis=-1), axis=-1)
self.accuracy = tf.reduce_mean(tf.cast(tf.equal(tf.cast(self.prediction, tf.int32), self.target), tf.float32))
self.train_op = tf.train.AdamOptimizer(0.001).minimize(self.total_loss)
@staticmethod
def train(input_sess):
import os
import numpy as np
from uuid import uuid1
from copy import deepcopy
header_str = \
'''ply
format ascii 1.0
element vertex 0
property float x
property float y
property float z
property float intensity
property uchar diffuse_red
property uchar diffuse_green
property uchar diffuse_blue
property uchar pred_labels
property uchar true_labels
end_header\n'''
def centered_xy(cloud_points):
input_cloud_points = deepcopy(cloud_points)
input_cloud_points_list = [ele for ele in input_cloud_points]
def process_single(x_part_second):
x_part_second[:, 0] = (x_part_second[:, 0] - x_part_second[:, 0].min()) / (x_part_second[:, 0].max() - x_part_second[:, 0].min())
x_part_second[:, 1] = (x_part_second[:, 1] - x_part_second[:, 1].min()) / (x_part_second[:, 1].max() - x_part_second[:, 1].min())
return x_part_second
return np.asarray(list(map(process_single, input_cloud_points_list)), np.float32)
def serlize_points_and_label(input_cloud_points, pred_labels, true_labels, epoch):
if not os.path.exists(r"C:\Coding\Python\conclusion_p_{}".format(epoch)):
os.mkdir(r"C:\Coding\Python\conclusion_p_{}".format(epoch))
input_cloud_points, pred_labels, true_labels = input_cloud_points[0], pred_labels[0], true_labels[0]
concat_ndarray = np.concatenate([input_cloud_points, pred_labels.reshape([len(pred_labels), 1]), true_labels.reshape([len(true_labels), 1])],
axis=-1)
tail_str = "\n".join([" ".join(str(inner_ele) for inner_ele in line_array) for line_array in concat_ndarray])
full_str = header_str + tail_str
with open(r"C:\Coding\Python\conclusion_p_{}\{}.ply".format(epoch ,uuid1()), "w") as f:
f.write(full_str)
batch_size = 4
n = 1024 * 8
m = 8
train_gen = data_loader(train_dir, batch_size=batch_size)
valid_gen = data_loader(valid_dir, batch_size=batch_size)
model_ext = PointNet_pp(batch_size=batch_size, n = n, m = m)
step = 0
save_epoch = 3
saver = tf.train.Saver()
with input_sess as sess:
if os.path.exists(r"C:\Coding\Python\PointNet\pmm_{}.meta".format(save_epoch)):
saver.restore(sess ,r"C:\Coding\Python\PointNet\pmm_{}".format(save_epoch))
print("load exist")
else:
sess.run(tf.global_variables_initializer())
print("init_new")
while True:
step += 1
train_data = train_gen.__next__()
if step % 100 == 0:
print("train data consume end !")
saver.save(sess, r"C:\Coding\Python\PointNet\pmm_{}".format(save_epoch))
save_epoch += 1
step = 0
input_cloud_points, targets = train_data
_, train_loss, train_acc, train_pred, train_logits = sess.run([model_ext.train_op, model_ext.total_loss, model_ext.accuracy, model_ext.prediction,
model_ext.logits],
feed_dict={
model_ext.input_cloud_points: centered_xy(input_cloud_points),
model_ext.target: targets
})
if step % 1 == 0:
train_targets = targets.reshape([-1])
train_pred = train_pred.reshape([-1])
valid_data = valid_gen.__next__()
input_cloud_points, targets = valid_data
valid_loss, valid_acc, valid_pred = sess.run([model_ext.total_loss, model_ext.accuracy, model_ext.prediction],
feed_dict={
model_ext.input_cloud_points: centered_xy(input_cloud_points),
model_ext.target: targets
})
serlize_points_and_label(input_cloud_points, valid_pred * 32, targets * 32, save_epoch)
valid_targets = targets.reshape([-1])
valid_pred = valid_pred.reshape([-1])
print("epoch : {} step : {} train_loss : {:.2f} train_acc : {:.2f} valid_loss : {:.2f} valid_acc : {:.2f} train_f1 : {:.2f} valid_f1 : {:.2f}".format(save_epoch, step, train_loss, train_acc, valid_loss, valid_acc,
f1_score(train_targets, train_pred, average="macro"), f1_score(valid_targets, valid_pred, average="macro")))
if __name__ == "__main__":
PointNet_pp.train(sess)
對valid集效果進行如下聚合:
from pyntcloud import PyntCloud
import glob
def join_ply():
def process_values(concat_ndarray):
return "\n".join([" ".join(str(inner_ele) for inner_ele in line_array) for line_array in concat_ndarray])
header_str = \
'''ply
format ascii 1.0
element vertex 0
property float x
property float y
property float z
property float intensity
property uchar diffuse_red
property uchar diffuse_green
property uchar diffuse_blue
property uchar pred_labels
property uchar true_labels
end_header\n'''
all_files = glob.glob(r"C:\Coding\Python\conclusion_pointNet_PP" + "\\" + "*")
tail_str = "\n".join(map(lambda file: process_values(PyntCloud.from_file(file).points.values), all_files))
full_str = header_str + tail_str
with open("pointNet_PP.ply", "w") as f:
f.write(full_str)
if __name__ == "__main__":
join_ply()
下面就可以在Jupyter中進行視覺化如下:(在Jupyter中執行程式碼,藉助js)
先給一張整個valid集的圖片:
from pyntcloud import PyntCloud
pointNet_cloud = PyntCloud.from_file(r"C:\Coding\Python\PointNet\pyntcloud_dataloader\pointNet.ply")
pointNet_cloud.plot(use_as_color="true_labels",cmap="cool")
pointNet_cloud.plot(use_as_color="pred_labels",cmap="cool")
pointNet_PP_cloud = PyntCloud.from_file(r"C:\Coding\Python\PointNet\pyntcloud_dataloader\pointNet_PP.ply")
pointNet_PP_cloud.plot(use_as_color="pred_labels",cmap="cool")
可見二者在一些細節上的差異。