PyTorch(三)——使用訓練好的模型測試自己圖片
阿新 • • 發佈:2019-01-01
PyTorch的學習和使用(三)
在上一篇文章中實現瞭如何增加一個自定義的Loss,以Siamese network為例。現在實現使用訓練好的該網路對自己手寫的數字圖片進行測試。
首先需要對訓練時的權重進行儲存,然後在測試時直接載入即可。
torch.save(net, path)
torch.load(path)
即可。
然後自己手寫的圖片進行處理。
把需要測試的圖片放入一個資料夾中,然後使用然後對圖片資料進行載入,對圖片資料進行歸一化處理,並且調整大小為(B,C,H,W)。
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((1.1618,), (1.1180,))]) def get_files(directory): return [os.path.join(directory, f) for f in sorted(list(os.listdir(directory))) if os.path.isfile(os.path.join(directory, f))] images = np.array([]) file = get_files('./data/figure') for i, item in enumerate(file): print('Processing %i of %i (%s)' % (i+1, len(file), item)) image = transform(Image.open(item).convert('L')) images = np.append(images, image.numpy()) img = images.reshape(-1, 1, 28, 28) img = torch.from_numpy(img).float() label = torch.ones(5,1).long()
其載入後的資料視覺化為:
最後載入模型並測試。
torch.load('./saveT.pt') def test(data, label): net.eval() data, label = Variable(data, volatile=True), Variable(label) output = net(data) out = output.view(-1, 4) test_loss = criterion(out[:, 0:2], out[:, 2:4], label).data[0] pred = classify(out.data[:, 0:2], out.data[:, 2:4]) correct = pred.eq(label.data).sum()