Darian Nwankwo

I'm a Software Engineer interested in practical applications of Machine Learning and Data Science.

MNIST's Handwritten Digits: I can model that

31 Jul 2019 » python, machine-learning, sklearn

Useful Notes

  • Consider this blog post a tutorial and follow along or read at your leisure.
  • Whenever you see '...' in a code snippet, it means everything remained the same from before.


I recently completed my first year of graduate education and although I had sat in one machine learning course the first semester and participated in another machine learning course the next semester, I did not learn much about deploying machine learning models for others to interact with. After much thought and heated debate with myself, I realized that I was approaching the discipline completely wrong. My first mistake was expecting some course to make me an expert in this area of study. My second was not utilizing my curiousity to build projects that exercise what I’ve learned, despite their level of difficulty.

I Can Machine Learning That, Watch!

Thus, he we are, doing the classic ‘Hello, World!’ of machine learning and training a model to recognize the handwritten digits provided in the MNIST database. So, if you follow the steps below, we’ll go from Machine Learning Zero to … Machine Learning 0.1 😅. This stuff takes time, but eventually, you can become a Machine Learning Hero! That’s what I want to be when I grow up.

Optional Step - For Unix-like Users

When using your personal system for development, it’s good practice to have your documents organized and a dedicated location for your work. The convention that I follow is to put all work related to software development, data science and machine learning (other coding as well), in a directory named Development in the Home directory. I navigate throughout my system using the terminal application because it’s quite powerful once understood. Let’s start off by creating our directory.

mkdir ~/Development && cd Development
mkdir mnist && cd mnist

Now we’ve created the directory Development in our Home directory. The ~ character is a shortcut for your Home directory as the name of it varies per user.


  • scikit-learn
  • numpy
  • matplotlib

1. Collect the data

Before we begin doing anything, the first thing we need to do is collect our data. The data can be found here and has the four files we’ll need for our endeavor. The four files are train-images-idx3-ubtye.gz, train-labels-idx1-ubtye.gz t10k-images-idx3-ubtye.gz, and t10k-labels-idx1-ubtye.gz. After you’ve downloaded the files, store them in a location that can be easily accessed. You can either drag and drop them to your desired location, but throught this tutorial, I will be using the terminal to move documents around.

cd ~/Development/mnist
mkdir data
mv ~/Downloads/*.gz data

Now that we’ve moved all of our files to the data directory inside of mnist, let’s uncompress them.

cd data
gunzip *.gz

Now let’s move back to our root directory for this project

cd ..

2. Parse the data

Now we need to understand what the mumbo jumbo inside of these files actually are. The site that contains the data actually has a specification for the data. The file is one fat chunk of information and we need to write some code to figure out what our information looks like. Let’s create a file and get started.

touch main.py

Now, using your favorite text editor, VS Code, open up your working directory mnist and let’s go to work. The first thing we are going to do is write a function to parse our data.

import numpy as np

def parse_data(path_to_file: str, offset: int) -> np.array:
    """Returns a numpy array of ints of data read from an unsigned byte array."""
    payload: np.array = np.array([])
    with open(path_to_file, "rb") as data:
        payload = np.array(list(data.read()))
    return payload

if __name__ == "__main__":
    X_train = parse_data("./data/train_images_ubyte", 16)
    y_train = parse_data("./data/train_label_ubyte", 8)
    X_test = parse_data("./data/test_images_ubyte", 16)
    y_test = parse_data("./data/test_labels_ubyte", 8)

With the function above, we are able to parse the files downloaded from earlier. We just specify the file’s path and an offset and we now have all of the information we need from the given file. Each file has a different offset for where the relevant data begins. For the “training set label file,” that offset is 8 since the first 8 bytes contains the magic number (4 bytes) and the number of items (4 bytes).

3. Reformat data for training

Up next is reformatting our data. Right now, we have a one dimensional numpy array for each of our variables (X_train, _y_train, X_test, y_test). This is problematic for when we want to train some model to make reasonable predictions about future observances because our data is just one big chunk of information. What we’ll do now is reshape our one dimensional arrays to their appropriate counterparts.


if __name__ == "__main__":
    # Reshape tranining/test s.t. each row corresponds to a input vector
    X_train = np.reshape(X_train, (60000, 784))
    y_train = np.reshape(Y_train, (60000, 1))
    X_test = np.reshape(X_test, (10000, 784))
    y_test = np.reshape(Y_test, (10000, 1))

So what the above is saying is that if we read our information from our data source correctly, we should be able to reshape our one dimensional array for X_train to a two dimensional array where 60,000 is the number of input vectors and 784 is the number of features for each input vector. The same logic applies to y_train, X_test, and y_test.

3.1. Visualizing a digit

When dealing with machine learning, visualization is an importance thing to utilize as reasoning in multiple dimensions quickly becomes challenging after we pass the third dimension. Heck, I sometimes struggle visualizing things in three dimensions, but that is probably a discussion to be had with my optometrist. Let’s use our friendly neighborhood matplotlib to see what’s going on here. Now, each input vector has 784 feautures, but we know that each image is contained in a rectangle of size = 28 x 28 = 784. Let’s grab an input vector, reshape it, and plot it using our handy-dandy plotting library.

import matplotlib.pyplot as plt
import random


if __name__ == "__main__":

    # Visualize a single instance
    rand_ndx = random.randint(0, X_train.shape[0])
    singleton = np.reshape(X_train[rand_ndx], (28, 28))
    plt.imshow(singleton, cmap="gray")

After you run your code now, you should end up seeing a random instance from your training set. Here is what I got.

<img src="/static/img/number.png" height=280>

4. Choosing a model

There are several choices for what type of machine learning algorithm we could choose to solve the task at hand. We’ll keep things fairly simple. We are going to train a multi-layer perceptron, also known as a neural network. Our neural network will have a single hidden layer. This single hidden layer will have a little more than half the number of neurons in the input layer. Before we train our model, we are going to standardize our training and test set. The standardization of our data will ensure that the mean of the dataset is 0 and the variance is 1.

from sklearn.preprocessing import StandardScaler


if __name__ == "__main__":

    # Normalize our data since neural networks are sensitive to feature scale
    scaler = StandardScaler()
    X_train = scaler.transform(X_train)
    X_test = scaler.transform(X_test)

Now that we’ve scaled our data, let’s train our model. Since we are using scikit-learn, most of the heavy work is taken care of for us, but I encourage you to explore the mathematics behind this model and try to gain some form of an understanding of what is being done. Now for the training.


if __name__ == "__main__":

    # Create neural network with 1 hidden layer
    feature_count = X_train[0].shape[0]
    label_count = 10
    hlayer_size = (feature_count + label_count) // 2
    model = MLPClassifier(solver="lbfgs", alpha=1e-5, hidden_layer_sizes=(hlayer_size,))
    model.fit(X_train, y_train)
    print(f"Model Accuracy: {model.score(X_test, y_test)}")

And voilà! You’ve successfully trained a neural network to recognize handwritten digits. Keep in mind, scikit-learn provided us with a lot of tools to mitigate the learning process, but do explore on your own. Scikit-learn is a great library for rapidly prototyping and testing out different machine learning methods you learn along your journey to becoming a Machine Learning Hero! I’m on the journey myself and everyday is a day where I learn just a little bit more. Take it step-by-step and do not cheat yourself. If you know you can walk, then why crawl. Move at a pace that is both reasonable and challenging for you! If you have any questions or want to stay in contact, follow me on social media or send me an email and I’d be happy to chat. ✊🏾✌🏾