1. 程式人生 > >Image Augmentation for Deep Learning With Keras

Image Augmentation for Deep Learning With Keras

Data preparation is required when working with neural network and deep learning models. Increasingly data augmentation is also required on more complex object recognition tasks.

In this post you will discover how to use data preparation and data augmentation with your image datasets when developing and evaluating deep learning models in Python with Keras.

After reading this post, you will know:

  • About the image augmentation API provide by Keras and how to use it with your models.
  • How to perform feature standardization.
  • How to perform ZCA whitening of your images.
  • How to augment data with random rotations, shifts and flips.
  • How to save augmented image data to disk.

Let’s get started.

  • Update: The examples in this post were updated for the latest Keras API. The datagen.next() function was removed.
  • Update Oct/2016: Updated examples for Keras 1.1.0, TensorFlow 0.10.0 and scikit-learn v0.18.
  • Update Jan/2017: Updated examples for Keras 1.2.0 and TensorFlow 0.12.1.
  • Update Mar/2017
    : Updated example for Keras 2.0.2, TensorFlow 1.0.1 and Theano 0.9.0.

Keras Image Augmentation API

Like the rest of Keras, the image augmentation API is simple and powerful.

Keras provides the ImageDataGenerator class that defines the configuration for image data preparation and augmentation. This includes capabilities such as:

  • Sample-wise standardization.
  • Feature-wise standardization.
  • ZCA whitening.
  • Random rotation, shifts, shear and flips.
  • Dimension reordering.
  • Save augmented images to disk.

An augmented image generator can be created as follows:

1 datagen=ImageDataGenerator()

Rather than performing the operations on your entire image dataset in memory, the API is designed to be iterated by the deep learning model fitting process, creating augmented image data for you just-in-time. This reduces your memory overhead, but adds some additional time cost during model training.

After you have created and configured your ImageDataGenerator, you must fit it on your data. This will calculate any statistics required to actually perform the transforms to your image data. You can do this by calling the fit() function on the data generator and pass it your training dataset.

1 datagen.fit(train)

The data generator itself is in fact an iterator, returning batches of image samples when requested. We can configure the batch size and prepare the data generator and get batches of images by calling the flow() function.

1 X_batch,y_batch=datagen.flow(train,train,batch_size=32)

Finally we can make use of the data generator. Instead of calling the fit() function on our model, we must call the fit_generator() function and pass in the data generator and the desired length of an epoch as well as the total number of epochs on which to train.

1 fit_generator(datagen,samples_per_epoch=len(train),epochs=100)

Need help with Deep Learning in Python?

Take my free 2-week email course and discover MLPs, CNNs and LSTMs (with code).

Click to sign-up now and also get a free PDF Ebook version of the course.

Point of Comparison for Image Augmentation

Now that you know how the image augmentation API in Keras works, let’s look at some examples.

We will use the MNIST handwritten digit recognition task in these examples. To begin with, let’s take a look at the first 9 images in the training dataset.

1234567891011 # Plot imagesfrom keras.datasets import mnistfrom matplotlib import pyplot# load data(X_train,y_train),(X_test,y_test)=mnist.load_data()# create a grid of 3x3 imagesforiinrange(0,9):pyplot.subplot(330+1+i)pyplot.imshow(X_train[i],cmap=pyplot.get_cmap('gray'))# show the plotpyplot.show()

Running this example provides the following image that we can use as a point of comparison with the image preparation and augmentation in the examples below.

Example MNIST images

Example MNIST images

Feature Standardization

It is also possible to standardize pixel values across the entire dataset. This is called feature standardization and mirrors the type of standardization often performed for each column in a tabular dataset.

You can perform feature standardization by setting the featurewise_center and featurewise_std_normalization arguments on the ImageDataGenerator class. These are in fact set to True by default and creating an instance of ImageDataGenerator with no arguments will have the same effect.

123456789101112131415161718192021222324252627 # Standardize images across the dataset, mean=0, stdev=1from keras.datasets import mnistfrom keras.preprocessing.image import ImageDataGeneratorfrom matplotlib import pyplotfrom keras import backend asKK.set_image_dim_ordering('th')# load data(X_train,y_train),(X_test,y_test)=mnist.load_data()# reshape to be [samples][pixels][width][height]X_train=X_train.reshape(X_train.shape[0],1,28,28)X_test=X_test.reshape(X_test.shape[0],1,28,28)# convert from int to floatX_train=X_train.astype('float32')X_test=X_test.astype('float32')# define data preparationdatagen=ImageDataGenerator(featurewise_center=True,featurewise_std_normalization=True)# fit parameters from datadatagen.fit(X_train)# configure batch size and retrieve one batch of imagesforX_batch,y_batch indatagen.flow(X_train,y_train,batch_size=9):# create a grid of 3x3 imagesforiinrange(0,9):pyplot.subplot(330+1+i)pyplot.imshow(X_batch[i].reshape(28,28),cmap=pyplot.get_cmap('gray'))# show the plotpyplot.show()break

Running this example you can see that the effect is different, seemingly darkening and lightening different digits.

Standardized Feature MNIST Images

Standardized Feature MNIST Images

ZCA Whitening

A whitening transform of an image is a linear algebra operation that reduces the redundancy in the matrix of pixel images.

Less redundancy in the image is intended to better highlight the structures and features in the image to the learning algorithm.

Typically, image whitening is performed using the Principal Component Analysis (PCA) technique. More recently, an alternative called ZCA (learn more in Appendix A of this tech report) shows better results and results in transformed images that keeps all of the original dimensions and unlike PCA, resulting transformed images still look like their originals.

You can perform a ZCA whitening transform by setting the zca_whitening argument to True.

123456789101112131415161718192021222324252627 # ZCA whiteningfrom keras.datasets import mnistfrom keras.preprocessing.image import ImageDataGeneratorfrom matplotlib import pyplotfrom keras import backend asKK.set_image_dim_ordering('th')# load data(X_train,y_train),(X_test,y_test)=mnist.load_data()# reshape to be [samples][pixels][width][height]X_train=X_train.reshape(X_train.shape[0],1,28,28)X_test=X_test.reshape(X_test.shape[0],1,28,28)# convert from int to floatX_train=X_train.astype('float32')X_test=X_test.astype('float32')# define data preparationdatagen=ImageDataGenerator(zca_whitening=True)# fit parameters from datadatagen.fit(X_train)# configure batch size and retrieve one batch of imagesforX_batch,y_batch indatagen.flow(X_train,y_train,batch_size=9):# create a grid of 3x3 imagesforiinrange(0,9):pyplot.subplot(330+1+i)pyplot.imshow(X_batch[i].reshape(28,28),cmap=pyplot.get_cmap('gray'))# show the plotpyplot.show()break

Running the example, you can see the same general structure in the images and how the outline of each digit has been highlighted.

ZCA Whitening MNIST Images

ZCA Whitening MNIST Images

Random Rotations

Sometimes images in your sample data may have varying and different rotations in the scene.

You can train your model to better handle rotations of images by artificially and randomly rotating images from your dataset during training.

The example below creates random rotations of the MNIST digits up to 90 degrees by setting the rotation_range argument.

123456789101112131415161718192021222324252627 # Random Rotationsfrom keras.datasets import mnistfrom keras.preprocessing.image import ImageDataGeneratorfrom matplotlib import pyplotfrom keras import backend asKK.set_image_dim_ordering('th')# load data(X_train,y_train),(X_test,y_test)=mnist.load_data()# reshape to be [samples][pixels][width][height]X_train=X_train.reshape(X_train.shape[0],1,28,28)X_test=X_test.reshape(X_test.shape[0],1,28,28)# convert from int to floatX_train=X_train.astype('float32')X_test=X_test.astype('float32')# define data preparationdatagen=ImageDataGenerator(rotation_range=90)# fit parameters from datadatagen.fit(X_train)# configure batch size and retrieve one batch of imagesforX_batch,y_batch indatagen.flow(X_train,y_train,batch_size=9):# create a grid of 3x3 imagesforiinrange(0,9):pyplot.subplot(330+1+i)pyplot.imshow(X_batch[i].reshape(28,28),cmap=pyplot.get_cmap('gray'))# show the plotpyplot.show()break

Running the example, you can see that images have been rotated left and right up to a limit of 90 degrees. This is not helpful on this problem because the MNIST digits have a normalized orientation, but this transform might be of help when learning from photographs where the objects may have different orientations.

Random Rotations of MNIST Images

Random Rotations of MNIST Images

Random Shifts

Objects in your images may not be centered in the frame. They may be off-center in a variety of different ways.

You can train your deep learning network to expect and currently handle off-center objects by artificially creating shifted versions of your training data. Keras supports separate horizontal and vertical random shifting of training data by the width_shift_range and height_shift_range arguments.