1. 程式人生 > >How to Diagnose Overfitting and Underfitting of LSTM Models

How to Diagnose Overfitting and Underfitting of LSTM Models

It can be difficult to determine whether your Long Short-Term Memory model is performing well on your sequence prediction problem.

You may be getting a good model skill score, but it is important to know whether your model is a good fit for your data or if it is underfit or overfit and could do better with a different configuration.

In this tutorial, you will discover how you can diagnose the fit of your LSTM model on your sequence prediction problem.

After completing this tutorial, you will know:

  • How to gather and plot training history of LSTM models.
  • How to diagnose an underfit, good fit, and overfit model.
  • How to develop more robust diagnostics by averaging multiple model runs.

Let’s get started.

Tutorial Overview

This tutorial is divided into 6 parts; they are:

  1. Training History in Keras
  2. Diagnostic Plots
  3. Underfit Example
  4. Good Fit Example
  5. Overfit Example
  6. Multiple Runs Example

1. Training History in Keras

You can learn a lot about the behavior of your model by reviewing its performance over time.

LSTM models are trained by calling the fit() function. This function returns a variable called history that contains a trace of the loss and any other metrics specified during the compilation of the model. These scores are recorded at the end of each epoch.

12 ...history=model.fit(...)

For example, if your model was compiled to optimize the log loss (binary_crossentropy) and measure accuracy each epoch, then the log loss and accuracy will be calculated and recorded in the history trace for each training epoch.

Each score is accessed by a key in the history object returned from calling fit(). By default, the loss optimized when fitting the model is called “loss” and accuracy is called “acc“.

12345 ...model.compile(loss='binary_crossentropy',optimizer='adam',metrics=['accuracy'])history=model.fit(X,Y,epochs=100)print(history.history['loss'])print(history.history['acc'])

Keras also allows you to specify a separate validation dataset while fitting your model that can also be evaluated using the same loss and metrics.

This can be done by setting the validation_split argument on fit() to use a portion of the training data as a validation dataset.

12 ...history=model.fit(X,Y,epochs=100,validation_split=0.33)

This can also be done by setting the validation_data argument and passing a tuple of X and y datasets.

12 ...history=model.fit(X,Y,epochs=100,validation_data=(valX,valY))

The metrics evaluated on the validation dataset are keyed using the same names, with a “val_” prefix.

1234567 ...model.compile(loss='binary_crossentropy',optimizer='adam',metrics=['accuracy'])history=model.fit(X,Y,epochs=100,validation_split=0.33)print(history.history['loss'])print(history.history['acc'])print(history.history['val_loss'])print(history.history['val_acc'])

2. Diagnostic Plots

The training history of your LSTM models can be used to diagnose the behavior of your model.

You can plot the performance of your model using the Matplotlib library. For example, you can plot training loss vs test loss as follows:

12345678910 from matplotlib import pyplot...history=model.fit(X,Y,epochs=100,validation_data=(valX,valY))pyplot.plot(history.history['loss'])pyplot.plot(history.history['val_loss'])pyplot.title('model train vs validation loss')pyplot.ylabel('loss')pyplot.xlabel('epoch')pyplot.legend(['train','validation'],loc='upper right')pyplot.show()

Creating and reviewing these plots can help to inform you about possible new configurations to try in order to get better performance from your model.

Next, we will look at some examples. We will consider model skill on the train and validation sets in terms of loss that is minimized. You can use any metric that is meaningful on your problem.

3. Underfit Example

An underfit model is one that is demonstrated to perform well on the training dataset and poor on the test dataset.

This can be diagnosed from a plot where the training loss is lower than the validation loss, and the validation loss has a trend that suggests further improvements are possible.

A small contrived example of an underfit LSTM model is provided below.

12345678910111213141516171819202122232425262728293031323334353637383940 from keras.models import Sequentialfrom keras.layers import Densefrom keras.layers import LSTMfrom matplotlib import pyplotfrom numpy import array# return training datadef get_train():seq=[[0.0,0.1],[0.1,0.2],[0.2,0.3],[0.3,0.4],[0.4,0.5]]seq=array(seq)X,y=seq[:,0],seq[:,1]X=X.reshape((len(X),1,1))returnX,y# return validation datadef get_val():seq=[[0.5,0.6],[0.6,0.7],[0.7,0.8],[0.8,0.9],[0.9,1.0]]seq=array(seq)X,y=seq[:,0],seq[:,1]X=X.reshape((len(X),1,1))returnX,y# define modelmodel=Sequential()model.add(LSTM(10,input_shape=(1,1)))model.add(Dense(1,activation='linear'))# compile modelmodel.compile(loss='mse',optimizer='adam')# fit modelX,y=get_train()valX,valY=get_val()history=model.fit(X,y,epochs=100,validation_data=(valX,valY),shuffle=False)# plot train and validation losspyplot.plot(history.history['loss'])pyplot.plot(history.history['val_loss'])pyplot.title('model train vs validation loss')pyplot.ylabel('loss')pyplot.xlabel('epoch')pyplot.legend(['train','validation'],loc='upper right')pyplot.show()

Running this example produces a plot of train and validation loss showing the characteristic of an underfit model. In this case, performance may be improved by increasing the number of training epochs.

In this case, performance may be improved by increasing the number of training epochs.

Diagnostic Line Plot Showing an Underfit Model

Diagnostic Line Plot Showing an Underfit Model

Alternately, a model may be underfit if performance on the training set is better than the validation set and performance has leveled off. Below is an example of an

Below is an example of an an underfit model with insufficient memory cells.

12345678910111213141516171819202122232425262728293031323334353637383940 from keras.models import Sequentialfrom keras.layers import Densefrom keras.layers import LSTMfrom matplotlib import pyplotfrom numpy import array# return training datadef get_train():seq=[[0.0,0.1],[0.1,0.2],[0.2,0.3],[0.3,0.4],[0.4,0.5]]seq=array(seq)X,y=seq[:,0],seq[:,1]X=X.reshape((5,1,1))returnX,y# return validation datadef get_val():seq=[[0.5,0.6],[0.6,0.7],[0.7,0.8],[0.8,0.9],[0.9,1.0]]seq=array(seq)X,y=seq[:,0],seq[:,1]X=X.reshape((len(X),1,1))returnX,y# define modelmodel=Sequential()model.add(LSTM(1,input_shape=(1,1)))model.add(Dense(1,activation='linear'))# compile modelmodel.compile(loss='mae',optimizer='sgd')# fit modelX,y=get_train()valX,valY=get_val()history=model.fit(X,y,epochs=300,validation_data=(valX,valY),

相關推薦

How to Diagnose Overfitting and Underfitting of LSTM Models

Tweet Share Share Google Plus It can be difficult to determine whether your Long Short-Term Memo

How to Engineer Your Way Out of Slow Models

So you just finished designing that great neural network architecture of yours. It has a blazing number of 300 fully connected layers interleaved

How to start/stop DB instance of Oracle under Linux

sid dbca tracking onf status account note notes all All below actions should be executed with "oracle" user account 1. Check the stat

How to solve multi-version conflict of OpenCV or PCL on ROS kinetic?

Solve multi-version conflict prepare: make sure you know which version is in your machine: dpk-config --modversion opencv Note: If it don't work, try

SuiteScript Tutorial - How to use it and why use it?

What you will learn: What SuiteScript is? How to create a Script record in NetSuite? How to write and upload a JavaScript file? How to

Best (and Free!!) Resources to Understand Nuts and Bolts of Deep Learning

The internet is filled with tutorials to get started with Deep Learning. You can choose to get started with the superb Stanford courses CS221&nbs

How to read version (and other) information from Android and iOS apps using Java

How to read version (and other) information from Android and iOS apps using Java https://medium.com/@mart.schneider/how-to-read-version-and-oth

go : How to get the reflect.Type of an interface?

4 In order to determine whether a given type implements an interface using the reflect package, you need to pass a reflect.Type

How to get the IP address of a Linux system

之前在 Windows/Mac OS 取得 ip address 透過 import socket print socket.gethostbyname(socket.gethostname()) 都沒問題。但在  Linux 裡出問題了。 print socket.gethostbyname_ex(s

How to manually BEGIN and END transactions?

程式愈寫愈複雜,怕資料不一致,所以 connection 的 isolation_level 設到 None = auto commit mode. 雖然,沒有下 commit() 不會寫到 database 裡,但由於為了效能,我偷偷的把 connection 放在記憶體裡重覆使用,connection

How to use *args and **kwargs in Python

這篇文章寫的滿好的耶,結論: 1星= array, 2星=dictionary. 1星範例: def test_var_args(farg, *args): print "formal arg:", farg for arg in args: print "an

How to Use Homebrew Zsh Instead of Mac OS X Default

Out of the box Mac OS X version 10.8.x (Lion) comes with zsh version 4.3.11 (i386-apple-darwin12.0). However zsh is currently at versi

How to safely charge and store lithium drone batteries

This post was done in partnership with Wirecutter. When readers choose to buy Wirecutter's independently chosen editorial picks, Wirecutter and Engadget ma

How To Leverage AI As Part of Your Mobile Testing Efforts

The mobile application development has rapidly grown in recent years. The practices like Mobile DevOps and CI/CD set up the infrastructure to speed up app

Ask HN: How to break out the loop of being an employee to your own business?

I'm doing exactly that right now. Hired for a decent salary - but quiting.I don't know what ill be doing in 6 months. I don't have a financial buffer. This

Privacy By Design: How To Sell Privacy And Make Change

Joe Toscano is an award-winning designer and former consultant for Google who left in 2017 due to ethical concerns. Upgrade your inbox and get our editors'

How to speak up and impact conversations as a junior designer

How to speak up and impact conversations as a junior designerA large part of my week is spent in meeting rooms, design critiques, and spontaneous discussio

How to Be Lazy and Stay Calm

What frustrates me most in my profession of software development is the regular necessity to understand large problem scopes before fixing small bugs, espe

Command Magicks: How to Manipulate Files and Strings with the Console

Command Magicks: How to Manipulate Files and Strings with the ConsoleProgramming will make you be amazed by the Cosmos. Source: Pixabay.As developers, ther

How to deploy Kubernetes and Containerum on Digital Ocean

How to deploy Kubernetes and Containerum on Digital Oceanby Nikita MazurForewordSeveral days ago we decided that Containerum Online — a hosted PaaS for lau