Hands-on Project: Digit classification with K-Nearest Neighbors and Data Augmentation
April 17, 2018
In this hands-on project, we’ll apply K-Nearest Neighbors algorithm to handwritten digit classification. Our main objectives are: a) to learn how to experiment with various hyper-parameters, b) introduce metrics classification accuracy and confusion matrix, c) develop intuition about how KNN works and d) use this intuition and data-augmentation to improve classification accuracy further.
Overview
This assignment guides you through using KNN for handwritten digit classification. Take the quiz at the bottom once you complete the assignment, as the quiz asks you for the results from the assignment. In addition, the last few questions in the quiz also guide you through some data augmentation techniques which you can use to improve the accuracy of your KNN model further.
Project Template on Google Colaboratory
Work on this project directly in-browser via Google Colaboratory. The link above is a starter template that you can save to your own Google Drive and work on. Google Colab is a free tool that lets you run small Machine Learning projects through your web browser. You should read this 1 min tutorial if you’re unfamiliar with Google Colaboratory. Note that, for this project, you’ll have to upload the dataset linked below to Google Colab after saving the notebook to your own system.
Dataset
The MNIST dataset is a popularly used dataset in machine learning for the handwritten digit recognition task. Here are some sample images from the dataset.
Samples from MNIST hand-written digit dataset (16 samples are shown for each label)
We’ll work with a smaller subset of the dataset. You can access it at the following links: mnist_10000.pkl.gz and mnist_1000.pkl.gz. The first one consists of 10,000 training samples (plus 2,000 validation and 2,000 test samples), and the second one consists of 1,000 training samples (plus 200 validation and 200 test samples).
Loading the dataset
You can load the dataset with the following code. We are going to use Python 3.6 for this project, which should come with pickle and gzip packages. We hope you’ve installed numpy from earlier exercises. If not, run the following in the terminal: pip3 install numpy
import pickle, gzip
import numpy as np
f = gzip.open('mnist_10000.pkl.gz', 'rb')
trainData, trainLabels, valData, valLabels, testData, testLabels = pickle.load(f, encoding='latin1')
f.close()
print("training data points: {}".format(len(trainLabels)))
print("validation data points: {}".format(len(valLabels)))
print("testing data points: {}".format(len(testLabels)))
trainData
is a NumPy array with shape (10000, 784). Each row is a data point (array of size 784), which are the values for the pixels of the 28 x 28 image (arranged row-by-row). A pixel value of 0.0 denotes white (background), and a pixel value of 1.0 denotes black (foreground). Values in between denote the pixel intensities.
Looking at the images
You can use the following snippet to look at some specific images. You can install the OpenCV package by running the following:
pip3 install opencv-python
Here’s the code to see the training images:
import cv2
image = trainData[0]
image = image.reshape((28, 28))
cv2.imshow("Image", image)
Note that OpenCV launches in a window separate from the terminal, and may take a few seconds to load up before you can see the first image in your dataset.
Choosing the best hyperparameters
Next, we will use sklearn
package’s KNeighborsClassifier
implementation (which is quite optimized) on the mnist_10000
dataset. Results from the following tasks will be asked in the quiz. (Use Euclidean distance for all tasks). Before we get started, make sure sklearn
package is installed:
pip3 install sklearn
Now, let’s implement K-nearest neighbors for a number of values of k and measure the accuracy for those values:
Task 1: Try the following values of K, and note the classification accuracy on the validation data for each. K = 1, 3, 5, 9, 15, 25
from sklearn.neighbors import KNeighborsClassifier
for k in [1, 3, 5, 9, 15, 25, ]:
model = KNeighborsClassifier(n_neighbors=k)
model.fit(trainData, trainLabels)
score = model.score(valData, valLabels)
print(k, score)
Task 2: For the best performing value of K, calculate and note the classification accuracy on the test data.
best_k = ...
model = KNeighborsClassifier(n_neighbors=best_k)
model.fit(trainData, trainLabels)
score = model.score(testData, testLabels)
predictions = model.predict(testData)
print(score)
Task 3: Inspect the performance per class, i.e. precision, recall and f-score for each digit. (hint: see sklearn.metrics.classification_report)
from sklearn.metrics import classification_report
print(classification_report(testLabels, predictions))
Task 4: Inspect the confusion matrix, i.e. when the correct label was digit I, how times did the model predict J. (hint: see sklearn.metrics.confusion_matrix)
from sklearn.metrics import confusion_matrix
print(confusion_matrix(testLabels, predictions))
Solution
The full code for the solution is available here: Solution to Hands-on Project: Digit classification with K-Nearest Neighbors and Data Augmentation. We highly encourage you to look at it only if you’re stuck and cannot proceed further.
(Bonus) Implement KNN yourself!
If you’d like to practice implementing KNN yourself (not the main focus of this assignment), you should use the mnist_1000
file so that you don’t have to wait a long while for your code to run. You can use sklearn
package’s KNeighborsClassifier
to check your implementation (compare the predictions outputted by scikit-learn, and the predictions outputted by your code).
The pseudocode for KNN is as follows
- Compute the distance between current sample and every sample in the training data (use Euclidean distance).
- Determine the closest K training samples (use K = 5).
- Check which label is the most common among the K training samples. This label is the prediction.
(Bonus) Improving accuracy further with Data Augmentation
One simple way to improve accuracy further it to try difference distance metrics based on your intuition (see the KNeighborsClassifier
documentation).
The quiz has some questions which will guide you to how to improve the accuracy using other data augmentation techniques. (finish the above tasks before starting the quiz).
Solution on Google Colaboratory
The complete notebook with all the cells executed is available via Google Colaboratory using the link above. Google Colab is a free tool that lets you run small Machine Learning experiments through your browser. You should read this 1 min tutorial if you’re unfamiliar with Google Colaboratory. Note that, for this project, you’ll have to upload the dataset to Google Colab after saving the notebook to your own system.