1. 程式人生 > >pytorch處理類別不平衡問題

pytorch處理類別不平衡問題

訪問本站觀看效果更佳
當訓練樣本不均勻時,我們可以採用過取樣、欠取樣、資料增強等手段來避免過擬合。今天遇到一個3d點雲資料集合,樣本分佈極不均勻,正例與負例相差4-5個數量級。資料增強效果就不會太好了,另外過取樣也不太合適,因為是空間資料,新增的點有可能會對真實分佈產生未知影響。所以採用欠取樣來緩解類別不平衡的問題。
下面的程式碼展示瞭如何使用WeightedRandomSampler來完成抽樣。

numDataPoints = 1000
data_dim = 5
bs = 100

# Create dummy data with class imbalance 9 to 1
data = torch.FloatTensor(numDataPoints, data_dim)
target = np.hstack((np.zeros(int(numDataPoints * 0.9), dtype=np.int32),
                    np.ones(int(numDataPoints * 0.1), dtype=np.int32)))

print 'target train 0/1: {}/{}'.format(
    len(np.where(target == 0)[0]), len(np.where(target == 1)[0]))

class_sample_count = np.array(
    [len(np.where(target == t)[0]) for t in np.unique(target)])
weight = 1. / class_sample_count
samples_weight = np.array([weight[t] for t in target])

samples_weight = torch.from_numpy(samples_weight)
samples_weight = samples_weight.double()
sampler = WeightedRandomSampler(samples_weight, len(samples_weight))

target = torch.from_numpy(target).long()
train_dataset = torch.utils.data.TensorDataset(data, target)

train_loader = DataLoader(
    train_dataset, batch_size=bs, num_workers=1, sampler=sampler)

for i, (data, target) in enumerate(train_loader):
    print "batch index {}, 0/1: {}/{}".format(
        i,
        len(np.where(target.numpy() == 0)[0]),
        len(np.where(target.numpy() == 1)[0]))

核心部分為實際使用時替換下變數把sampler傳遞給DataLoader即可,注意使用了sampler就不能使用shuffle,另外需要指定取樣點個數

class_sample_count = np.array(
    [len(np.where(target == t)[0]) for t in np.unique(target)])
weight = 1. / class_sample_count
samples_weight = np.array([weight[t] for t in target])

samples_weight = torch.from_numpy(samples_weight)
samples_weight = samples_weight.double()
sampler = WeightedRandomSampler(samples_weight, len(samples_weight))

參考