MNIST Lab
Do you ever struggle to read your own handwriting? Computers do too.
Model specifications
Your goal is to classify 28x28
black and white images of digits (0-9
) in TensorFlow JS
. The dataset for this project is the MNIST Database
. Hands-down, MNIST is the most popular data set for machine learning. So prevalent that it is included by default by most machine-learning libraries.
Goal for this lab
If using MNIST is almost a cliche, why do it?
A key difference is our implementation of MNIST in a modern web browser using JAVASCRIPT?. Yes, we will be using JavaScript. For those concerned about performance implications when compared to lower-level languages like C++, our code will still take care of parallelism through WebGL1
Getting Started
Getting started requires a decent amount of tedious software installs. However, getting setup is essential for labs later on.
Node install
Node allows JavaScript to be run through your terminal. While we will not be actually using Node to run our TensorFlow JS
code, having node allows us to use JavaScript tooling which is a big help.
MacOS / Linux
On MacOS, installing through a package manager like Homebrew is recommended.
brew install node
Alternatively, navigate to to Node’s website
Windows
Like MacOS / Linux, windows users can also install via winget
, however most people will not have a package manger installed on their windows machine.
For this reason, most people should install through Node’s Website
NPM Package Installs
npm install vite
npm creates vite@latest mnist_in_js --template vanilla # creates a project using vanilla js
What did we just do?
Vite creates a template project using vanilla JavaScript. You get a really basic HTML web page as well as a src
directory to put your code.
Creating a project with vite allows you to easily manage third-party JavaScript packages. So we are now going to install TensorFlow JS
.
First, we have to enter the new project we just created:
cd mnist_in_js
Next, install TensorFlow JS.
npm install tjfs@latest # Installs the latest version of TensorFlow in your project
Getting TensorFlow setup
In the src
directory of the main course, there is a main.js
file. By default, this contains some demo code. Delete all default code.
const model = tf.sequential([
tf.layers.dense({
units: 10,
inputShape:[1] // a scalar number
})
]);
model.compile({
loss: 'meanSquaredError',
optimizer: 'sgd'
});
// Tensors take data first, then the dimensions that data is in
const xs = tf.tensor2d([-1, 0, 1, 2, 3, 4], [6, 1])
const ys = tf.tensor2d([-3, -1, 1 3, 5, 7], [6, 1])
await model.fit(xs, ys, {epochs: 250})
Great! You have just created your first model.
Task on your own
Use TensorFlow JS to create a model which classifies MNIST images.
As a helper, here is the data loading code:
import * as tf from "@tensorflow/tfjs";
export const IMAGE_H = 28;
export const IMAGE_W = 28;
const IMAGE_SIZE = IMAGE_H * IMAGE_W;
const NUM_CLASSES = 10;
const NUM_DATASET_ELEMENTS = 65000;
const NUM_TRAIN_ELEMENTS = 55000;
const NUM_TEST_ELEMENTS = NUM_DATASET_ELEMENTS - NUM_TRAIN_ELEMENTS;
const MNIST_IMAGES_SPRITE_PATH =
"https://storage.googleapis.com/learnjs-data/model-builder/mnist_images.png";
const MNIST_LABELS_PATH =
"https://storage.googleapis.com/learnjs-data/model-builder/mnist_labels_uint8";
/**
* A class that fetches the sprited MNIST dataset and provide data as
* tf.Tensors.
*/
export class MnistData {
constructor() {}
async load() {
// Make a request for the MNIST sprited image.
const img = new Image();
const canvas = document.createElement("canvas");
const ctx = canvas.getContext("2d");
const imgRequest = new Promise((resolve, reject) => {
img.crossOrigin = "";
img.onload = () => {
img.width = img.naturalWidth;
img.height = img.naturalHeight;
const datasetBytesBuffer = new ArrayBuffer(
NUM_DATASET_ELEMENTS * IMAGE_SIZE * 4
);
const chunkSize = 5000;
canvas.width = img.width;
canvas.height = chunkSize;
for (let i = 0; i < NUM_DATASET_ELEMENTS / chunkSize; i++) {
const datasetBytesView = new Float32Array(
datasetBytesBuffer,
i * IMAGE_SIZE * chunkSize * 4,
IMAGE_SIZE * chunkSize
);
ctx.drawImage(
img,
0,
i * chunkSize,
img.width,
chunkSize,
0,
0,
img.width,
chunkSize
);
const imageData = ctx.getImageData(0, 0, canvas.width, canvas.height);
for (let j = 0; j < imageData.data.length / 4; j++) {
// All channels hold an equal value since the image is grayscale, so
// just read the red channel.
datasetBytesView[j] = imageData.data[j * 4] / 255;
}
}
this.datasetImages = new Float32Array(datasetBytesBuffer);
resolve();
};
img.src = MNIST_IMAGES_SPRITE_PATH;
});
const labelsRequest = fetch(MNIST_LABELS_PATH);
const [imgResponse, labelsResponse] = await Promise.all([
imgRequest,
labelsRequest,
]);
this.datasetLabels = new Uint8Array(await labelsResponse.arrayBuffer());
// Slice the the images and labels into train and test sets.
this.trainImages = this.datasetImages.slice(
0,
IMAGE_SIZE * NUM_TRAIN_ELEMENTS
);
this.testImages = this.datasetImages.slice(IMAGE_SIZE * NUM_TRAIN_ELEMENTS);
this.trainLabels = this.datasetLabels.slice(
0,
NUM_CLASSES * NUM_TRAIN_ELEMENTS
);
this.testLabels = this.datasetLabels.slice(
NUM_CLASSES * NUM_TRAIN_ELEMENTS
);
}
/**
* Get all training data as a data tensor and a labels tensor.
*
* @returns
* xs: The data tensor, of shape `[numTrainExamples, 28, 28, 1]`.
* labels: The one-hot encoded labels tensor, of shape
* `[numTrainExamples, 10]`.
*/
getTrainData() {
const xs = tf.tensor4d(this.trainImages, [
this.trainImages.length / IMAGE_SIZE,
IMAGE_H,
IMAGE_W,
1,
]);
const labels = tf.tensor2d(this.trainLabels, [
this.trainLabels.length / NUM_CLASSES,
NUM_CLASSES,
]);
return { xs, labels };
}
/**
* Get all test data as a data tensor and a labels tensor.
*
* @param {number} numExamples Optional number of examples to get. If not
* provided,
* all test examples will be returned.
* @returns
* xs: The data tensor, of shape `[numTestExamples, 28, 28, 1]`.
* labels: The one-hot encoded labels tensor, of shape
* `[numTestExamples, 10]`.
*/
getTestData(numExamples) {
let xs = tf.tensor4d(this.testImages, [
this.testImages.length / IMAGE_SIZE,
IMAGE_H,
IMAGE_W,
1,
]);
let labels = tf.tensor2d(this.testLabels, [
this.testLabels.length / NUM_CLASSES,
NUM_CLASSES,
]);
if (numExamples != null) {
xs = xs.slice([0, 0, 0, 0], [numExamples, IMAGE_H, IMAGE_W, 1]);
labels = labels.slice([0, 0], [numExamples, NUM_CLASSES]);
}
return { xs, labels };
}
}
WebGL is a library which parallel computation on the Web. This is most commonly used for graphics and is how engines such as Unity can publish to the web.