37 Reasons why your Neural Network is not working
Over the course of many debugging sessions, I’ve compiled my experience along with the best ideas around in this handy list. I hope they would be useful to you.
By Slav Ivanov, Entrepreneur & ML Practitioner.
The network had been training for the last 12 hours. It all looked good: the gradients were flowing and the loss was decreasing. But then came the predictions: all zeroes, all background, nothing detected. “What did I do wrong?” — I asked my computer, who didn’t answer.
Where do you start checking if your model is outputting garbage (for example predicting the mean of all outputs, or it has really poor accuracy)?
A network might not be training for a number of reasons. Over the course of many debugging sessions, I would often find myself doing the same checks. I’ve compiled my experience along with the best ideas around in this handy list. I hope they would be of use to you, too.
Table of Contents
- 0. How to use this guide?
- I. Dataset issues
- II. Data Normalization/Augmentation issues
- III. Implementation issues
- IV. Training issues
0. How to use this guide?
A lot of things can go wrong. But some of them are more likely to be broken than others. I usually start with this short list as an emergency first response:
- Start with a simple model that is known to work for this type of data (for example, VGG for images). Use a standard loss if possible.
- Turn off all bells and whistles, e.g. regularization and data augmentation.
- If fine-tuning a model, double check the preprocessing, for it should be the same as the original model’s training.
- Verify that the input data is correct.
- Start with a really small dataset (2–20 samples). Overfit on it and gradually add more data.
- Start gradually adding back all the pieces that were omitted: augmentation/regularization, custom loss functions, try more complex models.
If the steps above don’t do it, start going down the following big list and verify things one by one.
I. Dataset issues
1. Check your input data
Check if the input data you are feeding the network makes sense. For example, I’ve more than once mixed the width and the height of an image. Sometimes, I would feed all zeroes by mistake. Or I would use the same batch over and over. So print/display a couple of batches of input and target output and make sure they are OK.
2. Try random input
Try passing random numbers instead of actual data and see if the error behaves the same way. If it does, it’s a sure sign that your net is turning data into garbage at some point. Try debugging layer by layer /op by op/ and see where things go wrong.
3. Check the data loader
Your data might be fine but the code that passes the input to the net might be broken. Print the input of the first layer before any operations and check it.
4. Make sure input is connected to output
Check if a few input samples have the correct labels. Also make sure shuffling input samples works the same way for output labels.
5. Is the relationship between input and output too random?
Maybe the non-random part of the relationship between the input and output is too small compared to the random part (one could argue that stock prices are like this). I.e. the input are not sufficiently related to the output. There isn’t an universal way to detect this as it depends on the nature of the data.
6. Is there too much noise in the dataset?
This happened to me once when I scraped an image dataset off a food site. There were so many bad labels that the network couldn’t learn. Check a bunch of input samples manually and see if labels seem off.
The cutoff point is up for debate, as this paper got above 50% accuracy on MNIST using 50% corrupted labels.
7. Shuffle the dataset
If your dataset hasn’t been shuffled and has a particular order to it (ordered by label) this could negatively impact the learning. Shuffle your dataset to avoid this. Make sure you are shuffling input and labels together.
8. Reduce class imbalance
Are there a 1000 class A images for every class B image? Then you might need to balance your loss function or try other class imbalance approaches.
9. Do you have enough training examples?
If you are training a net from scratch (i.e. not finetuning), you probably need lots of data. For image classification, people say you need a 1000 images per class or more.
10. Make sure your batches don’t contain a single label
This can happen in a sorted dataset (i.e. the first 10k samples contain the same class). Easily fixable by shuffling the dataset.
11. Reduce batch size
This paper points out that having a very large batch can reduce the generalization ability of the model.
Addition 1. Use standard dataset (e.g. mnist, cifar10)
Thanks to @hengcherkeng for this one:
When testing new network architecture or writing a new piece of code, use the standard datasets first, instead of your own data. This is because there are many reference results for these datasets and they are proved to be ‘solvable’. There will be no issues of label noise, train/test distribution difference , too much difficulty in dataset, etc.
II. Data Normalization/Augmentation
12. Standardize the features
Did you standardize your input to have zero mean and unit variance?
13. Do you have too much data augmentation?
Augmentation has a regularizing effect. Too much of this combined with other forms of regularization (weight L2, dropout, etc.) can cause the net to underfit.
14. Check the preprocessing of your pretrained model
If you are using a pretrained model, make sure you are using the same normalization and preprocessing as the model was when training. For example, should an image pixel be in the range [0, 1], [-1, 1] or [0, 255]?
15. Check the preprocessing for train/validation/test set
CS231n points out a common pitfall:
“… any preprocessing statistics (e.g. the data mean) must only be computed on the training data, and then applied to the validation/test data. E.g. computing the mean and subtracting it from every image across the entire dataset and then splitting the data into train/val/test splits would be a mistake. “
Also, check for different preprocessing in each sample or batch.