用於說明卷積神經網路(ConvNet)的Python指令碼
阿新 • • 發佈:2018-11-20
借鑑:https://github.com/gwding/draw_convnet
直接上程式碼:
import os import numpy as np import matplotlib.pyplot as plt plt.rcdefaults() from matplotlib.lines import Line2D from matplotlib.patches import Rectangle from matplotlib.patches import Circle NumDots = 4 NumConvMax = 8 NumFcMax = 20 White = 1. Light = 0.7 Medium = 0.5 Dark = 0.3 Darker = 0.15 Black = 0. def add_layer(patches, colors, size=(24, 24), num=5, top_left=[0, 0], loc_diff=[3, -3], ): # add a rectangle top_left = np.array(top_left) loc_diff = np.array(loc_diff) loc_start = top_left - np.array([0, size[0]]) for ind in range(num): patches.append(Rectangle(loc_start + ind * loc_diff, size[1], size[0])) if ind % 2: colors.append(Medium) else: colors.append(Light) def add_layer_with_omission(patches, colors, size=(24, 24), num=5, num_max=8, num_dots=4, top_left=[0, 0], loc_diff=[3, -3], ): # add a rectangle top_left = np.array(top_left) loc_diff = np.array(loc_diff) loc_start = top_left - np.array([0, size[0]]) this_num = min(num, num_max) start_omit = (this_num - num_dots) // 2 end_omit = this_num - start_omit start_omit -= 1 for ind in range(this_num): if (num > num_max) and (start_omit < ind < end_omit): omit = True else: omit = False if omit: patches.append( Circle(loc_start + ind * loc_diff + np.array(size) / 2, 0.5)) else: patches.append(Rectangle(loc_start + ind * loc_diff, size[1], size[0])) if omit: colors.append(Black) elif ind % 2: colors.append(Medium) else: colors.append(Light) def add_mapping(patches, colors, start_ratio, end_ratio, patch_size, ind_bgn, top_left_list, loc_diff_list, num_show_list, size_list): start_loc = top_left_list[ind_bgn] \ + (num_show_list[ind_bgn] - 1) * np.array(loc_diff_list[ind_bgn]) \ + np.array([start_ratio[0] * (size_list[ind_bgn][1] - patch_size[1]), - start_ratio[1] * (size_list[ind_bgn][0] - patch_size[0])] ) end_loc = top_left_list[ind_bgn + 1] \ + (num_show_list[ind_bgn + 1] - 1) * np.array( loc_diff_list[ind_bgn + 1]) \ + np.array([end_ratio[0] * size_list[ind_bgn + 1][1], - end_ratio[1] * size_list[ind_bgn + 1][0]]) patches.append(Rectangle(start_loc, patch_size[1], -patch_size[0])) colors.append(Dark) patches.append(Line2D([start_loc[0], end_loc[0]], [start_loc[1], end_loc[1]])) colors.append(Darker) patches.append(Line2D([start_loc[0] + patch_size[1], end_loc[0]], [start_loc[1], end_loc[1]])) colors.append(Darker) patches.append(Line2D([start_loc[0], end_loc[0]], [start_loc[1] - patch_size[0], end_loc[1]])) colors.append(Darker) patches.append(Line2D([start_loc[0] + patch_size[1], end_loc[0]], [start_loc[1] - patch_size[0], end_loc[1]])) colors.append(Darker) def label(xy, text, xy_off=[0, 4]): plt.text(xy[0] + xy_off[0], xy[1] + xy_off[1], text, family='sans-serif', size=8) if __name__ == '__main__': fc_unit_size = 2 layer_width = 40 flag_omit = True patches = [] colors = [] fig, ax = plt.subplots() ############################ # conv layers size_list = [(28, 28),(28, 28), (28, 28), (14, 14), (14, 14),(14, 14), (7, 7)]#從輸入到卷積最後的輸出的影象尺寸 num_list = [1, 32, 32, 32,64, 64,64]#每一層的特徵圖的數量 x_diff_list = [0, layer_width, layer_width, layer_width,layer_width, layer_width, layer_width]#對應上面的list的個數 text_list = ['Inputs'] + ['Feature\nmaps'] * (len(size_list) - 1) loc_diff_list = [[3, -3]] * len(size_list) num_show_list = list(map(min, num_list, [NumConvMax] * len(num_list))) top_left_list = np.c_[np.cumsum(x_diff_list), np.zeros(len(x_diff_list))] for ind in range(len(size_list)-1,-1,-1): if flag_omit: add_layer_with_omission(patches, colors, size=size_list[ind], num=num_list[ind], num_max=NumConvMax, num_dots=NumDots, top_left=top_left_list[ind], loc_diff=loc_diff_list[ind] ) else: add_layer(patches, colors, size=size_list[ind], num=num_show_list[ind], top_left=top_left_list[ind], loc_diff=loc_diff_list[ind]) label(top_left_list[ind], text_list[ind] + '\n{}@\n{}x{}'.format( num_list[ind], size_list[ind][0], size_list[ind][1])) ############################ # in between layers start_ratio_list = [[0.4, 0.5], [0.4, 0.8], [0.4, 0.5], [0.4, 0.8], [0.4, 0.5], [0.4, 0.8]]#對應list的個數,這裡是6 end_ratio_list = [[0.4, 0.5], [0.4, 0.8], [0.4, 0.5], [0.4, 0.8], [0.4, 0.5], [0.4, 0.8]]#對應list的個數,這裡是6 patch_size_list = [(3, 3), (3, 3), (2, 2), (3, 3), (3, 3),(2, 2)]#卷積或池化核的尺寸,對應list的個數,這裡是6 ind_bgn_list = range(len(patch_size_list)) text_list = ['Conv','Conv', 'pool', 'Conv','Conv', 'pool']#結構圖的說明,這裡是6個 for ind in range(len(patch_size_list)): add_mapping( patches, colors, start_ratio_list[ind], end_ratio_list[ind], patch_size_list[ind], ind, top_left_list, loc_diff_list, num_show_list, size_list) label(top_left_list[ind], text_list[ind] + '\n{}x{}'.format( patch_size_list[ind][0], patch_size_list[ind][1]), xy_off=[65, -65]##通過圖上比較相對位置來修改座標 ) ############################ # fully connected layers size_list = [(fc_unit_size, fc_unit_size)] * 3 num_list = [3136,256,10 ] num_show_list = list(map(min, num_list, [NumFcMax] * len(num_list))) x_diff_list = [sum(x_diff_list) + layer_width, layer_width, layer_width] top_left_list = np.c_[np.cumsum(x_diff_list), np.zeros(len(x_diff_list))] loc_diff_list = [[fc_unit_size, -fc_unit_size]] * len(top_left_list) text_list = ['Hidden\nunits'] * (len(size_list) - 1) + ['Outputs'] for ind in range(len(size_list)): if flag_omit: add_layer_with_omission(patches, colors, size=size_list[ind], num=num_list[ind], num_max=NumFcMax, num_dots=NumDots, top_left=top_left_list[ind], loc_diff=loc_diff_list[ind]) else: add_layer(patches, colors, size=size_list[ind], num=num_show_list[ind], top_left=top_left_list[ind], loc_diff=loc_diff_list[ind]) label(top_left_list[ind], text_list[ind] + '\n{}'.format( num_list[ind])) text_list = ['Flatten\n', 'Fully\nconnected', 'Fully\nconnected'] for ind in range(len(size_list)): label(top_left_list[ind], text_list[ind], xy_off=[30, -65])#通過圖上比較相對位置來修改座標 ############################ for patch, color in zip(patches, colors): patch.set_color(color * np.ones(3)) if isinstance(patch, Line2D): ax.add_line(patch) else: patch.set_edgecolor(Black * np.ones(3)) ax.add_patch(patch) plt.tight_layout() plt.axis('equal') plt.axis('off') plt.show() fig.set_size_inches(8, 2.5) fig_dir = './' fig_ext = '.png' fig.savefig(os.path.join(fig_dir, 'convnet_fig' + fig_ext), bbox_inches='tight', pad_inches=0)
這裡實際上就是對該程式碼的說明書吧,知道怎麼去修改用來繪畫自己的CNN。
size_list = [(28, 28),(28, 28), (28, 28), (14, 14), (14, 14),(14, 14), (7, 7)]#從輸入到卷積最後的輸出的影象尺寸
num_list = [1, 32, 32, 32,64, 64,64]#每一層的特徵圖的數量
x_diff_list = [0, layer_width, layer_width, layer_width,layer_width, layer_width, layer_width]#對應上面的list的個數
start_ratio_list = [[0.4, 0.5], [0.4, 0.8], [0.4, 0.5], [0.4, 0.8], [0.4, 0.5], [0.4, 0.8]]#對應list的個數,這裡是6
end_ratio_list = [[0.4, 0.5], [0.4, 0.8], [0.4, 0.5], [0.4, 0.8], [0.4, 0.5], [0.4, 0.8]]#對應list的個數,這裡是6
patch_size_list = [(3, 3), (3, 3), (2, 2), (3, 3), (3, 3),(2, 2)]#卷積或池化核的尺寸,對應list的個數,這裡是6
text_list = ['Conv','Conv', 'pool', 'Conv','Conv', 'pool']#結構圖的說明,這裡是6個
label(top_left_list[ind], text_list[ind], xy_off=[30, -65])#通過圖上比較相對位置來修改座標
執行如下: