MNIST Lab

Do you ever struggle to read your own handwriting? Computers do too.

Model specifications

Your goal is to classify 28x28 black and white images of digits (0-9) in TensorFlow JS. The dataset for this project is the MNIST Database. Hands-down, MNIST is the most popular data set for machine learning. So prevalent that it is included by default by most machine-learning libraries.

Goal for this lab

If using MNIST is almost a cliche, why do it?

A key difference is our implementation of MNIST in a modern web browser using JAVASCRIPT?. Yes, we will be using JavaScript. For those concerned about performance implications when compared to lower-level languages like C++, our code will still take care of parallelism through WebGL1

Getting Started

Getting started requires a decent amount of tedious software installs. However, getting setup is essential for labs later on.

Node install

Node allows JavaScript to be run through your terminal. While we will not be actually using Node to run our TensorFlow JS code, having node allows us to use JavaScript tooling which is a big help.

MacOS / Linux

On MacOS, installing through a package manager like Homebrew is recommended.

brew install node

Alternatively, navigate to to Node’s website

Windows

Like MacOS / Linux, windows users can also install via winget, however most people will not have a package manger installed on their windows machine. For this reason, most people should install through Node’s Website

NPM Package Installs

npm install vite
npm creates vite@latest mnist_in_js --template vanilla # creates a project using vanilla js

What did we just do?

Vite creates a template project using vanilla JavaScript. You get a really basic HTML web page as well as a src directory to put your code.

Creating a project with vite allows you to easily manage third-party JavaScript packages. So we are now going to install TensorFlow JS.

First, we have to enter the new project we just created:

cd mnist_in_js

Next, install TensorFlow JS.

npm install tjfs@latest # Installs the latest version of TensorFlow in your project

Getting TensorFlow setup

In the src directory of the main course, there is a main.js file. By default, this contains some demo code. Delete all default code.

const model = tf.sequential([
  tf.layers.dense({
      units: 10,
      inputShape:[1] // a scalar number
    })
]);

model.compile({
  loss: 'meanSquaredError',
  optimizer: 'sgd'
});

// Tensors take data first, then the dimensions that data is in
const xs = tf.tensor2d([-1, 0, 1, 2, 3, 4], [6, 1])
const ys = tf.tensor2d([-3, -1, 1 3, 5, 7], [6, 1])

await model.fit(xs, ys, {epochs: 250})

Great! You have just created your first model.

TensorFlow JS Website

Task on your own

Use TensorFlow JS to create a model which classifies MNIST images.

As a helper, here is the data loading code:

import * as tf from "@tensorflow/tfjs";

export const IMAGE_H = 28;
export const IMAGE_W = 28;
const IMAGE_SIZE = IMAGE_H * IMAGE_W;
const NUM_CLASSES = 10;
const NUM_DATASET_ELEMENTS = 65000;

const NUM_TRAIN_ELEMENTS = 55000;
const NUM_TEST_ELEMENTS = NUM_DATASET_ELEMENTS - NUM_TRAIN_ELEMENTS;

const MNIST_IMAGES_SPRITE_PATH =
  "https://storage.googleapis.com/learnjs-data/model-builder/mnist_images.png";
const MNIST_LABELS_PATH =
  "https://storage.googleapis.com/learnjs-data/model-builder/mnist_labels_uint8";

/**
 * A class that fetches the sprited MNIST dataset and provide data as
 * tf.Tensors.
 */
export class MnistData {
  constructor() {}

  async load() {
    // Make a request for the MNIST sprited image.
    const img = new Image();
    const canvas = document.createElement("canvas");
    const ctx = canvas.getContext("2d");
    const imgRequest = new Promise((resolve, reject) => {
      img.crossOrigin = "";
      img.onload = () => {
        img.width = img.naturalWidth;
        img.height = img.naturalHeight;

        const datasetBytesBuffer = new ArrayBuffer(
          NUM_DATASET_ELEMENTS * IMAGE_SIZE * 4
        );

        const chunkSize = 5000;
        canvas.width = img.width;
        canvas.height = chunkSize;

        for (let i = 0; i < NUM_DATASET_ELEMENTS / chunkSize; i++) {
          const datasetBytesView = new Float32Array(
            datasetBytesBuffer,
            i * IMAGE_SIZE * chunkSize * 4,
            IMAGE_SIZE * chunkSize
          );
          ctx.drawImage(
            img,
            0,
            i * chunkSize,
            img.width,
            chunkSize,
            0,
            0,
            img.width,
            chunkSize
          );

          const imageData = ctx.getImageData(0, 0, canvas.width, canvas.height);

          for (let j = 0; j < imageData.data.length / 4; j++) {
            // All channels hold an equal value since the image is grayscale, so
            // just read the red channel.
            datasetBytesView[j] = imageData.data[j * 4] / 255;
          }
        }
        this.datasetImages = new Float32Array(datasetBytesBuffer);

        resolve();
      };
      img.src = MNIST_IMAGES_SPRITE_PATH;
    });

    const labelsRequest = fetch(MNIST_LABELS_PATH);
    const [imgResponse, labelsResponse] = await Promise.all([
      imgRequest,
      labelsRequest,
    ]);

    this.datasetLabels = new Uint8Array(await labelsResponse.arrayBuffer());

    // Slice the the images and labels into train and test sets.
    this.trainImages = this.datasetImages.slice(
      0,
      IMAGE_SIZE * NUM_TRAIN_ELEMENTS
    );
    this.testImages = this.datasetImages.slice(IMAGE_SIZE * NUM_TRAIN_ELEMENTS);
    this.trainLabels = this.datasetLabels.slice(
      0,
      NUM_CLASSES * NUM_TRAIN_ELEMENTS
    );
    this.testLabels = this.datasetLabels.slice(
      NUM_CLASSES * NUM_TRAIN_ELEMENTS
    );
  }

  /**
   * Get all training data as a data tensor and a labels tensor.
   *
   * @returns
   *   xs: The data tensor, of shape `[numTrainExamples, 28, 28, 1]`.
   *   labels: The one-hot encoded labels tensor, of shape
   *     `[numTrainExamples, 10]`.
   */
  getTrainData() {
    const xs = tf.tensor4d(this.trainImages, [
      this.trainImages.length / IMAGE_SIZE,
      IMAGE_H,
      IMAGE_W,
      1,
    ]);
    const labels = tf.tensor2d(this.trainLabels, [
      this.trainLabels.length / NUM_CLASSES,
      NUM_CLASSES,
    ]);
    return { xs, labels };
  }

  /**
   * Get all test data as a data tensor and a labels tensor.
   *
   * @param {number} numExamples Optional number of examples to get. If not
   *     provided,
   *   all test examples will be returned.
   * @returns
   *   xs: The data tensor, of shape `[numTestExamples, 28, 28, 1]`.
   *   labels: The one-hot encoded labels tensor, of shape
   *     `[numTestExamples, 10]`.
   */
  getTestData(numExamples) {
    let xs = tf.tensor4d(this.testImages, [
      this.testImages.length / IMAGE_SIZE,
      IMAGE_H,
      IMAGE_W,
      1,
    ]);
    let labels = tf.tensor2d(this.testLabels, [
      this.testLabels.length / NUM_CLASSES,
      NUM_CLASSES,
    ]);

    if (numExamples != null) {
      xs = xs.slice([0, 0, 0, 0], [numExamples, IMAGE_H, IMAGE_W, 1]);
      labels = labels.slice([0, 0], [numExamples, NUM_CLASSES]);
    }
    return { xs, labels };
  }
}
1

WebGL is a library which parallel computation on the Web. This is most commonly used for graphics and is how engines such as Unity can publish to the web.