1. 程式人生 > >pytorch(二):分類

pytorch(二):分類

import torch
import torch.nn.functional as f
from torch.autograd import Variable
import matplotlib.pyplot as plt


# 建造資料集
data = torch.ones((100, 2))
x0 = torch.normal(2*data, 1)
y0 = torch.zeros(100)  # y0是標籤  shape(100,),是一維
x1 = torch.normal(-2*data, 1)
y1 = torch.ones(100)  # y1也是標籤 shape(100,),是一維
x = torch.cat((x0, x1), 0).type(torch.FloatTensor)  # 引數0表示維度,在縱向方向將x0,x1合併,合併後shape(200, 2))
y = torch.cat((y0, y1), 0).type(torch.LongTensor)  # 標籤是0或1,型別為整數,LongTensor = 64-bit integer,
x, y = Variable(x), Variable(y)  # 訓練神經網路只能接受變數輸入,故要把x, y轉化為變數
plt.scatter(x.data.numpy()[:, 0], x.data.numpy()[:, 1],  # 這兩個引數分別代表x,y軸座標
            c=y.data.numpy(), s=100, cmap='RdYlGn')  # c為color,y有兩種標籤,代表兩種顏色的點,'RdYlGn'紅色和綠色
plt.show()


# 建造神經網路模型
class Net(torch.nn.Module):
    def __init__(self, n_feature, n_hidden, n_output):
        super(Net, self).__init__()
        self.hidden = torch.nn.Linear(n_feature, n_hidden)
        self.out = torch.nn.Linear(n_hidden, n_output)
        
    def forward(self, x):
        x = f.relu(self.hidden(x))
        y = self.out(x)
        return y


# 定義神經網路
net = Net(n_feature=2, n_hidden=10, n_output=2) 
# n_output=2,因為它返回一個元素為2的列表。[0, 1]表示學習到的內容為標籤1,[1, 0]表示學習到的內容為標籤0。
print(net)


# 訓練神經網路模型並將訓練過程視覺化
optimizer = torch.optim.SGD(net.parameters(), lr=0.02)
loss_func = torch.nn.CrossEntropyLoss()
plt.ion()
for i in range(100):
    out = net(x)
    loss = loss_func(out, y)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    # 繪圖
    if i % 2 == 0:
        plt.cla()
        # torch.max(a,1) 返回每一行中最大值的那個元素,且返回其索引(返回最大元素在這一行的列索引
        # f.softmax(out)是將out的內容以概率表示。
        # torch.max()返回的是兩個Variable,第一個Variable存的是最大值,第二個存的是其對應的位置索引index。這裡我們想要得到的是索引,所以後面用[1]。
        prediction = torch.max(f.softmax(out), 1)[1]
        pred_y = prediction.data.numpy().squeeze()
        target_y = y.data.numpy()
        plt.scatter(x.data.numpy()[:, 0], x.data.numpy()[:, 1], c=pred_y, s=100, cmap='RdYlGn')
        accuracy = sum(pred_y == target_y)/200
        plt.text(1.5, -4, 'accuracy=%.2f'%accuracy, fontdict={'size':10, 'color':'red'})
        plt.pause(0.1)
plt.ioff()
plt.show()

相關推薦

pytorch分類

import torch import torch.nn.functional as f from torch.autograd import Variable import matplotlib.pyplot as plt # 建造資料集 data = torch.on

【更新】Infragistics Ultimate UI for WPF v18.2分類

下載Infragistics Ultimate UI for WPF最新版本 Infragistics Ultimate UI for WPF是一款提供高速的網格和圖表,輕鬆建立仿Office應用程式的WPF介面框架,從廣度和深度兩方面使得開發者在縮短開發時間的同時能夠為市場構建出現代化的,引領

faster rcnn pytorch 復現系列generate_anchors原始碼解析

目錄​ 1. 總函式 generate_anchors 2. 函式分功能寫,首先是ratios的實現,其次是scale的實現 3. anchor2WHXY函式+WsHsXsYs2anchors函式[s表示複數] 4.  _ratio_enum(anchor,r

線性分類模型logistic迴歸模型分析

前言 上一篇文章介紹了線性判別模型,本文介紹線性生成模型——logistic迴歸模型。本文介紹logstic迴歸模型相關的知識,為了更好理解模型的決策邊界函式,本文同時分析了多元變數的協方差對概率分佈的影響。   目錄 1、logistic迴歸模型的含義 2、l

OpenCV學習記錄自己訓練haar特徵的adaboost分類器進行人臉識別

上一篇文章中介紹瞭如何使用OpenCV自帶的haar分類器進行人臉識別(點我開啟)。 這次我試著自己去訓練一個haar分類器,前後花了兩天,最後總算是訓練完了。不過效果並不是特別理想,由於我是在自己的筆記本上進行訓練,為減少訓練時間我的樣本量不是很大,最後也只是勉強看看效果了

pytorch學習筆記gradient

gradient 在BP的時候,pytorch是將Variable的梯度放在Variable物件中的,我們隨時都可以使用Variable.grad得到對應Variable的grad。剛建立Variable的時候,它的grad屬性是初始化為0.0的。 import tor

PyTorch 學習筆記PyTorch的資料增強與資料標準化

本文擷取自《PyTorch 模型訓練實用教程》,獲取全文pdf請點選:https://github.com/tensor-yu/PyTorch_Tutorial 文章目錄 transform的使用 在實際應用過程中,我們會在資

文字分類scrapy爬取網易新聞

文字分類的第一項應該就是獲取文字了吧。 在木有弄懂scrapy的情況下寫的,純應用,或許後續會補上scrapy的原理。 首先說一下我的環境:ubuntu14.10 scrapy安裝指南(肯定官網的最權威了):[傳送門](http://scrapy-chs.rea

機器學習快速入門SVM分類

定義 SVM便是根據訓練樣本的分佈,搜尋所有可能的線性分類器中最佳的那個。仔細觀察彩圖中的藍線,會發現決定其位置的樣本並不是所有訓練資料,而是其中的兩個空間間隔最小的兩個不同類別的資料點,而我們把這種可以用來真正幫助決策最優線性分類模型的資料點稱為”支

.NET物件與Windows控制代碼控制代碼分類和.NET控制代碼洩露的例子

上一篇文章介紹了控制代碼的基本概念,也描述了C#中建立檔案控制代碼的過程。我們已經知道控制代碼代表Windows內部物件,檔案物件就是其中一種,但顯然系統中還有更多其它型別的物件。本文將簡單介紹Windows物件的分類。 控制代碼可以代表的Windows物件分為三類,核心物件(Kernel Object)、

keras學習例項mnist 手寫體分類

承接上次筆記,這次進行mnist 的手寫體目標識別例項,先說明一下出現的問題。 如上圖,源程式類似keras的mnist_example例項,資料來源是通過 url = https://s3.amazonaws.com/img-datasets/mnist.npz 進行

Javascript面向對象編程構造函數的繼承

沒有 cal type 這一 今天 nts 實現繼承 刪除 函數綁定 今天要介紹的是,對象之間的"繼承"的五種方法。 比如,現在有一個"動物"對象的構造函數。   function Animal(){     this.species = "動物";   } 還有一個

虛擬化虛擬化及vmware workstation產品使用

應該 server esxi aof 手機 text 產品 窗體 pass 虛擬化(一):虛擬化及vmware產品介紹 vmware workstation的最新版本號是10.0.2。相信大家也都使用過,當中的簡單的虛擬機的創建。刪除等,都非常easy

CSS3動畫波浪效果

col -1 loading ack css代碼 code load width ase 實現效果 如圖所示: 首先得準備三張圖,一張是淺黃色的背景圖loading_bg.png,一張是深紅色的圖loading.png,最後一張為bolang.png。 css代碼

設計模式 工廠模式

dem blank hibernate 執行 oid code 做出 void actor 工廠模式 工廠模式(Factory Pattern)是 Java 中最常用的設計模式之一。這種類型的設計模式屬於創建型模式,它提供了一種創建對象的最佳方式。 在工廠模式中,我們在創建

iptables實用教程管理鏈和策略

否則 命令顯示 accept 目的 number cep 存在 當前 末尾 概念和原理請參考上一篇文章“iptables實用教程(一)”。 本文講解如果管理iptables中的鏈和策略。 下面的代碼格式中,下劃線表示是一個占位符,需要根據實際情況輸入參數,不帶下劃線的表示是

中國mooc北京理工大學機器學習第二周分類

kmeans 方法 輸入 nump arr mod 理工大學 each orm 一、K近鄰方法(KNeighborsClassifier) 使用方法同kmeans方法,先構造分類器,再進行擬合。區別是Kmeans聚類是無監督學習,KNN是監督學習,因此需要劃分出訓練集和測試

javascript學習筆記定義函數、調用函數、參數、返回值、局部和全局變量

兩個 cnblogs bsp 結果 value ava ase com 調用 定義函數、調用函數、參數、返回值 關鍵字function定義函數,格式如下: function 函數名(){ 函數體 } 調用函數、參數、返回值的規則和c語言規則類似。 1 <!DOC

Nginx實用教程配置文件入門

affinity type 服務 源碼編譯 設置時間 shutdown ber 可用 控制指令 Nginx配置文件結構 nginx配置文件由指令(directive)組成,指令分為兩種形式,簡單指令和區塊指令。 一條簡單指令由指令名、參數和結尾的分號(;)組成,例如:

Python和C|C++的混編利用Cython進行混編

cde uil 有時 當前 class def 將在 python 混編 還能夠使用Cython來實現混編 1 下載Cython。用python setup.py install進行安裝 2 一個實例 ① 創建helloworld文件夾創建hellowor