1. 程式人生 > >pytorch中圖片顯示問題

pytorch中圖片顯示問題

最近看pytorch的一個程式碼,結果中間有一句還是看不太懂,最後過了一些陣子才看懂,在此Mark一下。

程式碼如下:

def imshow(img,text,should_save=False):
    npimg = img.numpy()  # 將torch.FloatTensor 轉換為numpy
    plt.axis("off")  # 不顯示座標尺寸
    if text:
        plt.text(75, 8, text, style='italic',fontweight='bold',
            bbox={'facecolor':'white'
, 'alpha':0.8, 'pad':10}) # facecolor前景色 # pytorch 圖片的顯示問題 plt.imshow(np.transpose(npimg, (1, 2, 0))) plt.show()

解釋這句話:plt.imshow(np.transpose(npimg, (1, 2, 0)))。因為在plt.imshow在現實的時候輸入的是(imagesize,imagesize,channels),而def imshow(img,text,should_save=False)中,引數img的格式為(channels,imagesize,imagesize),這兩者的格式不一致,我們需要呼叫一次np.transpose函式,即np.transpose(npimg,(1,2,0)),將npimg的資料格式由(channels,imagesize,imagesize)轉化為(imagesize,imagesize,channels),進行格式的轉換後方可進行顯示。

總結一下,pytorch在載入資料集是元組tuple的形式,裡面包括了資料及標籤,其中的資料可以轉換為torch .Tensor的形式,方便後面計算使用。在顯示資料的時候,需要將torchtensor
在pytorch中,讀入圖片並進行顯示的方式有兩種。

方式一

將讀取出來的torch.FloatTensor轉換為numpy,然後將其(1 ,imagesize,imagesize)給reshape一下,變成(imagesize,imagesize)的形式,最後進行顯示,上程式碼:

# dataset的格式為:([torch.FloatTensor of size 1x28x28],3)  其中圖片的格式為(1x28x28)圖片的標籤為3
# 這裡我們只取這一張圖片本身,先不管它的標籤。 img=dataset[0] # First 將 torch.FloatTensor 轉換為 numpy的格式 img=img.numpy() # Second 將shape(1,28,28)轉化為(28,28) img=img.reshape(28,28) # Third 呼叫plt 將圖片顯示出來 plt.imshow(img,cmap='gray') plt.show() #然後就可以顯示圖片了
方式二:呼叫torch的介面
img=torchvision.utils.make_grid(img).numpy()
plt.imshow(np.transpose(img,(1,2,0)))
plt.show()

這裡用np.transpose(img,(1,2,0))將圖片的格式由(channels,imagesize,imagesize)轉化為(imagesize,imagesize,channels),這樣plt.show()就可以顯示圖片了。

我對此處np.transpose(1,2,0)理解參考此處