1. 程式人生 > >Pytorch學習筆記(四)

Pytorch學習筆記(四)

(8)遷移學習(Transfer Learning)
接下來將會使用ResNet進行遷移學習,完成圖片分類。目前遷移學習的方式主要有兩種,一種是fineturning,就是隻改變pretrain網路最後一層或者幾層的網路結構,對於pretrain網路的全域性引數在原來的基礎上進行微調;另外一種是將ConvNet當做一個特徵提取器(Feature Extractor),結構方面只改變pretrain網路最後一層或者幾層的網路結構,對於引數的話固定住前面沒有改變部分的引數,只對後面修改過的層進行更新。
兩種方式的程式碼如下:

# -*- coding:utf-8 -*-
# Transfer Learning tutorial
import torch import torch.nn as nn from torch.autograd import Variable import torch.optim as optim import numpy as np import torchvision from torchvision import datasets, models, transforms import matplotlib.pyplot as plt import time import copy import os data_transforms = { 'train': transforms.Compose([ transforms.RandomSizedCrop(224
), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]), 'val': transforms.Compose([ transforms.Scale(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize([0.485
, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) } data_dir = './data/hymenoptera_data' dsets = {x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms[x]) for x in ['train', 'val']} dset_loaders = {x: torch.utils.data.DataLoader(dsets[x], batch_size=4, shuffle=True, num_workers=4) for x in ['train', 'val']} dset_sizes = {x: len(dsets[x]) for x in ['train', 'val']} dset_classes = dsets['train'].classes print(dset_classes) use_gpu = torch.cuda.is_available() print(use_gpu) def imshow(inp, title=None): inp = inp.numpy().transpose(1, 2, 0) mean = np.array([0.485, 0.456, 0.406]) std = np.array([0.229, 0.224, 0.225]) inp = std * inp + mean plt.imshow(inp) if title is not None: plt.title(title) inputs, classes = next(iter(dset_loaders['train'])) out = torchvision.utils.make_grid(inputs) imshow(out, title=[dset_classes[x] for x in classes]) # plt.show() def train_model(model, criterion, optimizer, lr_scheduler, num_epoch=25): since = time.time() best_model = model best_acc = 0.0 for epoch in range(num_epoch): print('Epoch {}/{}'.format(epoch, num_epoch - 1)) print('-' * 10) for phase in ['train', 'val']: if phase == 'train': optimizer = lr_scheduler(optimizer, epoch) model.train(True) else: model.train(False) running_loss = 0.0 running_corrects = 0 for data in dset_loaders[phase]: inputs, labels = data if use_gpu: inputs, labels = Variable(inputs.cuda()), Variable(labels.cuda()) else: inputs, labels = Variable(inputs), Variable(labels) optimizer.zero_grad() outputs = model(inputs) _, preds = torch.max(outputs.data, 1) loss = criterion(outputs, labels) if phase == 'train': loss.backward() optimizer.step() running_loss += loss.data[0] running_corrects += torch.sum(preds == labels.data) epoch_loss = running_loss / dset_sizes[phase] epoch_acc = running_corrects / dset_sizes[phase] print('{} Loss: {:.4f} Acc: {:.4f}'.format( phase, epoch_loss, epoch_acc)) if phase == 'val' and epoch_acc > best_acc: best_acc = epoch_acc best_model = copy.deepcopy(model) print() time_elapsed = time.time() - since print('Training complete in {:.0f}m {:.0f}s'.format( time_elapsed // 60, time_elapsed % 60)) print("Best val Acc: {:4f}".format(best_acc)) return best_model def exp_lr_scheduler(optimizer, epoch, init_lr=0.001, lr_decay_epoch=7): lr = init_lr * (0.1 ** (epoch // lr_decay_epoch)) if epoch % lr_decay_epoch == 0: print("LR is set to {}".format(lr)) for param_group in optimizer.param_groups: param_group['lr'] = lr return optimizer def visualize_model(model, num_images=6): images_so_far = 0 fig = plt.figure() for i, data in enumerate(dset_loaders['val']): inputs, labels = data if use_gpu: inputs, labels = Variable(inputs.cuda()), Variable(labels.cuda()) else: inputs, labels = Variable(inputs), Variable(labels) outputs = model(inputs) _, preds = torch.max(outputs.data, 1) for j in range(inputs.size()[0]): images_so_far += 1 ax = plt.subplot(num_images // 2, 2, images_so_far) ax.axis('off') ax.set_title('predicted: {}'.format(dset_classes[labels.data[j]])) imshow(inputs.cpu().data[j]) if images_so_far == num_images: return # Finetuning the convnet model_ft = models.resnet18(pretrained=True) num_ftrs = model_ft.fc.in_features model_ft.fc = nn.Linear(num_ftrs, 2) if use_gpu: model_ft = model_ft.cuda() criterion = nn.CrossEntropyLoss() optimizer_ft = optim.SGD(model_ft.parameters(), lr=0.001, momentum=0.9) model_ft = train_model(model_ft, criterion, optimizer_ft, exp_lr_scheduler, num_epoch=25) visualize_model(model_ft) plt.ioff() plt.show() # ConvNet as feature extractor model_conv = models.resnet18(pretrained=True) for param in model_conv.parameters(): param.requires_grad = False num_ftrs = model_conv.fc.in_features model_conv.fc = nn.Linear(num_ftrs, 2) if use_gpu: model_conv = model_conv.cuda() criterion = nn.CrossEntropyLoss() optimizer_conv = optim.SGD(model_conv.fc.parameters(), lr=0.001, momentum=0.9) model_conv = train_model(model_conv, criterion, optimizer_conv, exp_lr_scheduler, num_epochs=25) visualize_model(model_conv) plt.ioff() plt.show()

執行結果如下:

['ants', 'bees']
True
Epoch 0/24
----------
LR is set to 0.001
train Loss: 0.1694 Acc: 0.6311
val Loss: 0.1212 Acc: 0.7974

Epoch 1/24
----------
train Loss: 0.1318 Acc: 0.7623
val Loss: 0.0505 Acc: 0.9216

Epoch 2/24
----------
train Loss: 0.1236 Acc: 0.7992
val Loss: 0.0510 Acc: 0.9085

Epoch 3/24
----------
train Loss: 0.1451 Acc: 0.7705
val Loss: 0.0487 Acc: 0.9412

Epoch 4/24
----------
train Loss: 0.1047 Acc: 0.8525
val Loss: 0.0753 Acc: 0.9020

Epoch 5/24
----------
train Loss: 0.1324 Acc: 0.8115
val Loss: 0.0756 Acc: 0.8889

這裡寫圖片描述

相關推薦

莫煩pytorch學習筆記——激勵函式Activation

1.什麼是Activation 普通神經網路出來的資料都是一個線性的資料,將輸出來的結果用激勵函式處理。 2.Torch中的激勵函式 import torch import torch.nn.functional as F # 激勵函式都在這,nn是神經網路模組

Pytorch學習筆記

(8)遷移學習(Transfer Learning) 接下來將會使用ResNet進行遷移學習,完成圖片分類。目前遷移學習的方式主要有兩種,一種是fineturning,就是隻改變pretrain網路最後一層或者幾層的網路結構,對於pretrain網路的全域性引

Cocos2d-x學習筆記 布景層的加入移除

dcl from position 顏色 顯示地圖 idt col 分享 學習 布景層類也就是CCLayer類,每一個遊戲場景中都能夠有非常多層,每一層負責各自的任務。顯示地圖、顯示人物等。同一時候層還是一個容器,能夠放入文本、圖片和菜單。構成遊戲中一個個UI。這次

機器學習筆記機器學習可行性分析

資料 表示 image 隨機 訓練樣本 -s mage 例如 lin 從大量數據中抽取出一些樣本,例如,從大量彈珠中隨機抽取出一些樣本,總的樣本中橘色彈珠的比例為,抽取出的樣本中橘色彈珠的比例為,這兩個比例的值相差很大的幾率很小,數學公式表示為: 用抽取到的樣本作為訓練

Python_sklearn機器學習學習筆記decision_tree決策樹

min n) 空間 strong output epo from 標簽 ict # 決策樹 import pandas as pd from sklearn.tree import DecisionTreeClassifier from sklearn.

Python學習筆記 列表生成式_生成器

rec triangle 小寫 ont 無限 end clas 普通 執行過程 筆記摘抄來自:https://www.liaoxuefeng.com/wiki/0014316089557264a6b348958f449949df42a6d3a2e542c000/001431

Unity3D之Mecanim動畫系統學習筆記:Animation State

大致 面板 輸入 jpg any 動畫播放 速度 nsf 顯示 動畫的設置 我們先看看Animation Clip的一些設置: Loop time:動畫是否循環播放。 下面出現了3個大致一樣的選項: Root Transform Rotation:表示為播放動畫

.net core 2.0學習筆記:遷移.net framework 工程到.net core

編譯 its evel hashtable ref 學習筆記 inline null 創建 在遷移.net core的過程中,第一步就是要把.net framework 工程的目標框架改為.net core2.0,但是官網卻沒有提供轉換工具,需要我們自己動手完成了

ES6學習筆記—— async 函數

ons fst cte code span pre getname 普通 聲明 await 是 async wait 的簡寫, 是 generator 函數的語法糖。 async 函數的特點: async 聲明一個方法是異步的,await 則等待這個異步方法執行的完

Hibernate學習筆記 --- 映射基本數據類型的List集合

varchar prim drop n) 進行 lis auth pos 方案 集合按其內元素的數據類型分為兩種:基本數據類型集合及復雜對象類型集合,Hibernate對於兩類集合提供不同的映射方式。(在類上以@Embeddable註解的復雜對象數據類型處理方式同基本數據類

java學習筆記:import語法

employee sign cnblogs java 調用 變量賦值 temp 職位 求職 Import 語法是給編譯器尋找特定類的適當位置的一種方法。 創建一個Employee 類,包括四個實體變量姓名(name),年齡(age),職位(designation)和薪水(s

Cesium學習筆記Camera

ttr can str efault 簡單的 list 事件處理 get provider http://blog.csdn.net/HobHunter/article/details/74909641 Cesium 相機控制場景中的視野。操作相機的方法有很多,如

python學習筆記-數據類型

rand 兩個 urn 浪費 line 平年 randint .com .cn 0. 在 Python 中的數據類型詳解 http://www.cnblogs.com/scios/p/8026576.html 1. 為什麽布爾類型(bool)的 True 和 False 分

Nodejs學習筆記-----Buffer

pretty 成員 保存 n) tin 設置 amp 個數 普通 Node.js Buffer(緩沖區) JavaScript 語言自身只有字符串數據類型,沒有二進制數據類型。 但在處理像TCP流或文件流時,必須使用到二進制數據。因此在 Node.js中,定義了一個 Buf

Elasticsearch學習筆記ElasticSearch分布式機制

clas cse 負載均衡 丟失 數據 不可 分布式 復雜 發生 一、Elasticsearch對復雜分布式機制透明的隱藏特性 1、分片機制: (1)index包含多個shard,每個shard都是一個最小工作單元,承載部分數據,lucen

DeepLearning.ai學習筆記卷積神經網絡 -- week1 卷積神經網絡基礎知識介紹

除了 lock 還需要 情況 好處 計算公式 max 位置 網絡基礎 一、計算機視覺 如圖示,之前課程中介紹的都是64* 64 3的圖像,而一旦圖像質量增加,例如變成1000 1000 * 3的時候那麽此時的神經網絡的計算量會巨大,顯然這不現實。所以需要引入其他的方法來

python學習筆記字符串及字符串操作

默認 小寫字母 是不是 swap git 查找字符 英文 去掉 title 字符串   字符串可以存任意類型的字符串,比如名字,一句話等等。 字符串還有很多內置方法,對字符串進行操作,常用的方法如下: 1 name1=‘hello world‘ 2 print(nam

day3-python學習筆記

end tar upper date update size upd sdi reat 字符串方法 #字符串這些方法都不會改變原來字符串的值name = ‘beSTtest‘# new_name = name.strip()#默認是去掉空格和換行符# new_name =

DeepLearning.ai學習筆記卷積神經網絡 -- week2深度卷積神經網絡 實例探究

過濾 common 經典 上一個 問題 inline 最壞情況 ali method 一、為什麽要進行實例探究? 通過他人的實例可以更好的理解如何構建卷積神經網絡,本周課程主要會介紹如下網絡 LeNet-5 AlexNet VGG ResNet (有152層) Incep

《Qt5 開發與實例第三版學習筆記

常用 斷言 max swap 正則表達 4.2 debug 實例 筆記 1 //2.4 算法及正則表達式 2 //2.4.1 Qt5常用算法 3 double c=qAbs(a);//返回絕對值 4 double max=qMax(b,c);//返回最大值 5