1. 程式人生 > >Improving Skin Cancer Detection with Deep Learning

Improving Skin Cancer Detection with Deep Learning

Step 3: Understanding Evaluation Metric

Balanced multi-class accuracy

Balanced multi-class accuracy is an important metric when working with imbalanced datasets, because it penalizes algorithms for favoring the most common classes. For example, an algorithm working with an unbalanced dataset of 99 cat images and 1 dog image would achieve a 99% accuracy if it classified everything as a cat, but only a 50% balanced accuracy.

The balanced multi-class accuracy can be calculated by looking at the Confusion Matrix and averaging the accuracy over each row.

Accuracy vs Balanced Accuracy
Balanced multiclass accuracy is a dismal 54.1%.
from sklearn.metrics import confusion_matrix
def get_balanced_accuracy(probs,y,classes): # sz: image size, bs: batch size    preds = np.argmax(probs, axis=1)    probs = probs[:,1]    cm = confusion_matrix(y, preds)    plot_confusion_matrix(cm, classes)    return ((    cm[0][0]/(cm[0][0]+cm[0][1]+cm[0][2]+cm[0][3]+cm[0][4]+cm[0][5]+cm[0][6]) +     cm[1][1]/(cm[1][0]+cm[1][1]+cm[1][2]+cm[1][3]+cm[1][4]+cm[1][5]+cm[1][6]) +    cm[2][2]/(cm[2][0]+cm[2][1]+cm[2][2]+cm[2][3]+cm[2][4]+cm[2][5]+cm[2][6]) +    cm[3][3]/(cm[3][0]+cm[3][1]+cm[3][2]+cm[3][3]+cm[3][4]+cm[3][5]+cm[3][6]) +    cm[4][4]/(cm[4][0]+cm[4][1]+cm[4][2]+cm[4][3]+cm[4][4]+cm[4][5]+cm[4][6]) +    cm[5][5]/(cm[5][0]+cm[5][1]+cm[5][2]+cm[5][3]+cm[5][4]+cm[5][5]+cm[5][6]) +    cm[6][6]/(cm[6][0]+cm[6][1]+cm[6][2]+cm[6][3]+cm[6][4]+cm[6][5]+cm[6][6])     )/7)

Three lines to evaluate our balanced multi-class accuracy.

## TTA stands for Test Time Augmentationlog_preds,y = learn.TTA()probs = np.mean(np.exp(log_preds),0)get_balanced_accuracy(probs,y,data.classes)

Step 4: Finding Optimal Learning Rate

The learning rate is one of the most important parameters for deep learning models, and often requires time and experimentation to set correctly. The fast.ai library is one of the first libraries that provides a quick and systematic way of finding the optimal learning rate for your model. The library implements a technique developed in the 2015 paper “Cyclical Learning Rates for Training Neural Networks.” This technique slowly increasing the learning rate from a small value until the loss blows up.

The learning rate finder plots the learning rate vs loss, allowing us to see which learning rate creates the greatest decrease in loss. The optimal learning rate here is 2e-2, not the point where the loss is lowest (1e-1).

## Create a new Learn Object to find optimal learning ratelearn = ConvLearner.pretrained(arch, data, precompute=True)lrf=learn.lr_find()learn.sched.plot_lr()learn.sched.plot()

Step 5: Data augmentation

The best way to improve our model is to get more data. Data augmentation is the practice of applying random transforms on our existing data to get new data. Back in our initial three lines of code, we created an object called tfms, or transforms, and then a data object that applied transforms on the data. However, when we created our model, we told it not to use data augmentation by setting precompute = True.

learn = ConvLearner.pretrained(arch, data, precompute=True) 

Now remember, our previous objective was to train a model quickly. Setting precompute = True effectively told the model to ignore data augmentation so that it can cleverly speed up training time by precomputing the activations in the penultimate layer before the custom head.

To use data augmentation, set precompute = false and retrain. The cycle_len parameter sets the number of epochs per cycle, determining the number of epochs until SGDR restarts.

learn.precompute = Falselearn.fit(1e-2, 3, cycle_len=1)
Increased Accuracy to 79.6%

Step 6: Fine Tuning and Differential learning rate annealing

Up until this point we took an existing model, added a custom head, and solely trained the custom head. The weights of the existing model have not been changed. For datasets that are similar to ImageNet, weights of the existing model are near perfect, and barely need any fine tuning. However, for datasets that differ from ImageNet, like this skin lesion dataset, we need to update the weights in the existing model to get competitive results.

Learn.unfreeze() effectively tells the model to update every layer’s weights instead of just the weights in the custom head during back backpropagation. Since we are updating every layer’s weights, training time increases significantly.

When I unfreeze every layer in the model, I also use a technique called differential learning rate annealing. Differential learning rates allow each group of layers of the model to be updated by different amounts. This is important because the first group of layers are often simple “edge detectors” and dont need significant change. On the other hand, the later convolutional layers may be searching for more complex features like eyeballs, which are not present in our new dataset. Setting a higher learning rate for the later groups allows these layers to specialize in complex features that are present in our new dataset.

Using differential learning rates and unfreezing the model led to my highest balanced accuracy score.

In the image below, cycle_mult is a parameter that simply multiplies the next number of cycles for training, giving us more epochs to train our model. For example, lrn.fit(lr,4,cycle_len=1,cycle_mult=2) results in 15epochs because 1+1*2+2*2 + 4*2= 15.

learn.unfreeze()lr=np.array([1e-2/9,1e-2/3,1e-2])learn.fit(lr, 3, cycle_len=1, cycle_mult=2)
log_preds,y = learn.TTA()probs = np.mean(np.exp(log_preds),0)get_balanced_accuracy(probs,y,data.classes)

Step 7: Increase Image Size and Retrain

This last step unfortunately yielded worse results. However, this technique is often effective and reducing overfitting. This technique involves training models for a few epochs on small images then gradually resizing the images to bigger images and retraining.

learn.set_data(get_data(299, bs)) learn.freeze()learn.fit(1e-3, 3, cycle_len=1)
log_preds,y = learn.TTA()probs = np.mean(np.exp(log_preds),0)get_balanced_accuracy(probs,y,data.classes)