1. 程式人生 > >顯示資料集

顯示資料集

from keras.datasets import cifar10

(x_train,y_train),(x_test,y_test)=cifar10.load_data();

import matplotlib.pyplot as plt
def plot_images_labels_prediction(images,labels,idx,num=10):
    fig=plt.gcf();
    fig.set_size_inches(12,14);
    if num>25:
        num=25;
    for i in range(0,num):
        ax=plt.subplot(5,5,i+1);
        ax.imshow(images[idx],cmap='binary');
        ax.set_title(i,fontsize=10);
        ax.set_xticks([]);
        ax.set_yticks([]);
        idx+=1;
    plt.show();
    
plot_images_labels_prediction(x_test,y_test,100);