1. 程式人生 > >caffe權值及featureMap視覺化

caffe權值及featureMap視覺化

1、權值視覺化

主函式 conv1_weights_vis.m,放在caffe根目錄,需要matcaffe

clear;
clc;
close all;
addpath('matlab')
caffe.set_mode_cpu();
fprintf(['Caffe Version = ', caffe.version(), '\n']);

net = caffe.Net('models/bvlc_reference_caffenet/deploy.prototxt', 'models/bvlc_reference_caffenet/bvlc_reference_caffenet.caffemodel', 'test');

fprintf('Load net done. Net layers : ');
net.layer_names 

fprintf('Net blobs : ');
net.blob_names

% Conv1 Weight Visualization
conv1_layer = net.layer_vec(2);
blob1 = conv1_layer.params(1);
w = blob1.get_data();
fprintf('Conv1 Weight shape: ');
size(w)
visualize_weights(w, 1);

% Conv2 Weight Visualization
conv2_layer = net.layer_vec(6);
blob2 = conv2_layer.params(1);
w2 = blob2.get_data();
fprintf('Conv2 Weight shape: ');
size(w2)
visualize_weights(w2, 1);

% Conv3 Weight Visualization
conv3_layer = net.layer_vec(10);
blob3 = conv3_layer.params(1);
w3 = blob3.get_data();
fprintf('Conv3 Weight shape: ');
size(w3)
visualize_weights(w3, 1);

% Conv4 Weight Visualization
conv4_layer = net.layer_vec(12);
blob4 = conv4_layer.params(1);
w4 = blob4.get_data();
fprintf('Conv4 Weight shape: ');
size(w4)
visualize_weights(w4, 1);

% Conv5 Weight Visualization
conv5_layer = net.layer_vec(14);
blob5 = conv5_layer.params(1);
w5 = blob5.get_data();
fprintf('Conv5 Weight shape: ');
size(w5)
visualize_weights(w5, 1);

visualize_weights.m

function [] = visualize_weights(w, s)
rr=size(w,1);
cc=size(w,2);
h = max(rr, cc);             % Kernel size
g = h + s;          % Grid size, larger than Kernel size for better visual effects.

% Normalization for gray scale
w = w - min(min(min(min(w))));
w = w / max(max(max(max(w)))) * 255;
w = uint8(w);

W = zeros(g * size(w, 3), g * size(w, 4));
for u = 1:size(w, 3)
    for v = 1:size(w, 4)
        W(g * (u - 1) + (1:cc), g * (v -1) + (1:rr)) = w(:,:,u,v)';
       % figure,imshow(uint8(w(:,:,u,v)));
    end
end
W = uint8(W);
figure;imshow(W);

2、featureMap視覺化

clear;
clc;
close all;
addpath('matlab')
caffe.set_mode_cpu();
fprintf(['Caffe Version = ', caffe.version(), '\n']);

net = caffe.Net('models/bvlc_reference_caffenet/deploy.prototxt', 'models/bvlc_reference_caffenet/bvlc_reference_caffenet.caffemodel', 'test');

fprintf('Load net done. Net layers : ');
net.layer_names 

fprintf('Net blobs : ');
net.blob_names

im=imread('examples/images/cat.jpg');
figure,imshow(im);title('Original Image');
d=load('matlab/+caffe/imagenet/ilsvrc_2012_mean.mat');
mean_data=d.mean_data;
IMAGE_DIM=256;
CROPPED_DIM=227;

im_data=im(:,:,[3,2,1]); %matlab影象通道是RGB,轉換為opencv格式BGR
im_data=permute(im_data,[2,1,3]); %matlab內部是列優先儲存,轉化為opencv格式的行優先儲存
im_data=single(im_data); %將uint8格式,轉化為single型別。
im_data=imresize(im_data,[IMAGE_DIM IMAGE_DIM],'bilinear');

im_data=im_data-mean_data;

im=imresize(im_data,[CROPPED_DIM CROPPED_DIM],'bilinear');
km=cat(4,im,im,im,im,im); % 227*227*3*5
pm=cat(4,km,km); % 227*227*3*10  因為輸入要求為 input_param { shape: { dim: 10 dim: 3 dim: 227 dim: 227 } },注意順序反了
input_data={pm};

scores=net.forward(input_data);

scores=scores{1};

scores=mean(scores,2);
[~,maxlabel]=max(scores);
maxlabel
figure;plot(scores);

fm_data=net.blob_vec(1);
d1=fm_data.get_data();
fprintf('Data size=')
size(d1)
visualize_feature_maps(d1,1);

% 卷積層1
fm_conv1=net.blob_vec(2);
f1=fm_conv1.get_data();
fprintf('Feature map conv1 size=')
size(f1)
visualize_feature_maps(f1,1);
% 卷積層2
fm_conv2=net.blob_vec(5);
f2=fm_conv2.get_data();
fprintf('Feature map conv2 size=')
size(f2)
visualize_feature_maps(f2,1);
% 卷積層3
fm_conv3=net.blob_vec(8);
f3=fm_conv3.get_data();
fprintf('Feature map conv3 size=')
size(f3)
visualize_feature_maps(f3,1);
% 卷積層4
fm_conv4=net.blob_vec(9);
f4=fm_conv4.get_data();
fprintf('Feature map conv4 size=')
size(f4)
visualize_feature_maps(f4,1);
% 卷積層5
fm_conv5=net.blob_vec(10);
f5=fm_conv5.get_data();
fprintf('Feature map conv5 size=')
size(f5)
visualize_feature_maps(f5,1);
function []=visualize_feature_maps(w,s)
h=max(size(w,1),size(w,2));
g=h+s;
c=size(w,3);
cv=ceil(sqrt(c));
W=zeros(g*cv,g*cv);

%%% 缺少最後一個通道10
for u=1:cv
    for v=1:cv
        tw=zeros(h,h);
        if(((u-1)*cv+v)<=c)
            tw=w(:,:,(u-1)*cv+v,1)';
            tw=tw-min(min(tw));
            tw=tw/max(max(tw))*255;
        end
        W(g*(u-1)+(1:h),g*(v-1)+(1:h))=tw;
    end
end
W=uint8(W);
figure,imshow(W);