Digits Case Study

This demonstrates a poorly performing decision tree, identifies the problem, and fixes it with the use of random forests.

View Code on Github

Let’s look at an example of a decision tree whose performance is not optimal because it’s flaws are being exacerbated by misuse. We start with a machine learning model which determines the value of a handwritten digit. This model is a basic decision tree classifier.

It is built to consider 8x8 black and white images of digits, in the form of 64 dimensioned data points (a light intensity value for each pixel). Here’s a sample of a few of the data points:

Figure 1: a sample of handwritten digits in the dataset

When we train this model, we do what we always do by splitting the data in half, one for training and one for testing. The following code counts how many of the test data points the model correctly classified:

guessAnswerPair = zip(results, answers)
correct = 0
total = 0

for (guess, answer) in guessAnswerPair:
   if guess==answer:
       correct+=1
   total+=1
View Code on Github

When we run the test, we correctly classify 734 out of 898 images, that’s an 82% success rate. This isn’t bad at all, but we can likely improve on this. To do so we should consider what we know about the weaknesses of decision trees, and what we know about the data that the model is learning from, this way we can make modifications to our model to improve it.

We know that decision trees suffer from high variance, which means they’ll likely overfit to noisy data. So here comes the part that requires a bit of analysis, considering the weaknesses of our model, is the data likely to exacerbate these flaws?

Let’s start by looking at how much data we have; in total our ‘digits’ dataset contains 1797 entries. Unfortunately in this example we don’t know how the data was collected (for instance were random people asked online to draw them with their mouse or were they highly skilled calligraphers). The origins of the dataset might have been useful because it would have indicated how ‘noisy’ the data was. By this we mean how often do we have outliers in our data that stray from our typical expectation.

We can however, examine random samples of the data and get a good idea this way. From looking through the data with the data visualiser (uses matplotlib, code available on github), it is quite clear that the data that we’re using is a mishmash of handwriting, ranging from well defined digits, to frankly, something that looks closer to chinese. What this tells us, is that considering the fact that decision trees suffer from overfitting, this model could be fitting to the wrong pieces of data.

Fortunately, we’ve just discussed reducing overfitting with decision trees, and the solution was random forests. With random forests, for every odd ‘bad 2’ that a decision tree will use as a classifier, like this one:

Figure 2: a sample of a poorly written "2"

For which clearly, many ‘3’s may look similar to (and thus does not well generalise what a 2 will look like), there will be a bunch more decision trees that never saw this 2 to generalise from in the first place, and so it will be outweighed.

So if we adapt our model to learn with random forests instead, and run the same tests with the same training data, this time we correctly identify 823 of 898, a success rate of 92%. This case study shows how important analysis of the performance of your current model, as well as the nature of the dataset is, for improving effectiveness.

References

1. Dataset available at SciKit: http://scikit-learn.org/stable/modules/generated/sklearn.datasets.load_digits.html