pytorch處理類別不平衡問題
阿新 • • 發佈:2018-11-08
訪問本站觀看效果更佳
當訓練樣本不均勻時,我們可以採用過取樣、欠取樣、資料增強等手段來避免過擬合。今天遇到一個3d點雲資料集合,樣本分佈極不均勻,正例與負例相差4-5個數量級。資料增強效果就不會太好了,另外過取樣也不太合適,因為是空間資料,新增的點有可能會對真實分佈產生未知影響。所以採用欠取樣來緩解類別不平衡的問題。
下面的程式碼展示瞭如何使用WeightedRandomSampler
來完成抽樣。
numDataPoints = 1000 data_dim = 5 bs = 100 # Create dummy data with class imbalance 9 to 1 data = torch.FloatTensor(numDataPoints, data_dim) target = np.hstack((np.zeros(int(numDataPoints * 0.9), dtype=np.int32), np.ones(int(numDataPoints * 0.1), dtype=np.int32))) print 'target train 0/1: {}/{}'.format( len(np.where(target == 0)[0]), len(np.where(target == 1)[0])) class_sample_count = np.array( [len(np.where(target == t)[0]) for t in np.unique(target)]) weight = 1. / class_sample_count samples_weight = np.array([weight[t] for t in target]) samples_weight = torch.from_numpy(samples_weight) samples_weight = samples_weight.double() sampler = WeightedRandomSampler(samples_weight, len(samples_weight)) target = torch.from_numpy(target).long() train_dataset = torch.utils.data.TensorDataset(data, target) train_loader = DataLoader( train_dataset, batch_size=bs, num_workers=1, sampler=sampler) for i, (data, target) in enumerate(train_loader): print "batch index {}, 0/1: {}/{}".format( i, len(np.where(target.numpy() == 0)[0]), len(np.where(target.numpy() == 1)[0]))
核心部分為實際使用時替換下變數把sampler
傳遞給DataLoader
即可,注意使用了sampler
就不能使用shuffle
,另外需要指定取樣點個數:
class_sample_count = np.array( [len(np.where(target == t)[0]) for t in np.unique(target)]) weight = 1. / class_sample_count samples_weight = np.array([weight[t] for t in target]) samples_weight = torch.from_numpy(samples_weight) samples_weight = samples_weight.double() sampler = WeightedRandomSampler(samples_weight, len(samples_weight))