The code that accompanies this article can be downloaded here.

A couple of days ago news about AI that could detect shoplifters even before they commit the crime surfaced on the web. Not long after that, we could read about the GAN network that can create photorealistic images from simple sketches. Even though, this news left me amazed I was hardly surprised. Machine learning and deep learning are dominating image classification and segmentation field, and engineers are coming up with more and more interesting solutions. From Facebook tag suggestions to self-driving cars neural networks really took over this world.

In fact, behind all these successes lay concept of Convolutional Neural Networks, that we explained in this article. This type of neural networks was created back in the 1990s by Yann LeCun, today’s director of AI research at Facebook. Similar to other ideas in the field, this one also has roots in biology. Researchers detected that individual neurons from visual cortex respond to stimuli only in a restricted region of the visual field known as the receptive field.

Because these fields of different neurons overlap, together they make the entire visual field. This effectively means that certain neurons are activated only if there is a certain attribute in the visual field, for example, horizontal edge. So, different neurons will be “fired up” if there is a horizontal edge in your visual field, and different neurons will be activated if there is, let’s say a vertical edge in our visual field. 

For example, take a look at this image, and tell us what do you see:

This is a well known optical illusion, which first appeared in a German humor magazine back in 1892. As you could notice, you can see either duck or rabbit, depending on how you observe the image. What is happening in this and similar illusions is that they use previously mentioned functions of the visual cortex to confuse us. Take a look at the same image bellow:

If your attention wanders to the area where the red square is you would say that you see a duck. However, if you are focused on the area of the image marked with a blue square, you would say that you see a rabbit. Meaning, when you observe certain features of the image different group of neurons got activated and you classify this image either like a rabbit or like a duck. This is exactly the functionality that Convolutional Neural Networks utilize. They detect features on the image and then classify them based on that inforamtion.

Convolutional Neural Networks Structure

We will not go into details of how CNNs work. You can read all about that in this article. However, we are going to emphasize some major components. First, the so-called convolutional layers detect features of the image. This layer use filters to detect low-level features, like edges and curves, as well as higher levels features, like a face or a hand. Than Convolutional Neural Network use additional layers to remove linearity from the image, something that could cause overfitting. When linearity is removed, additional layers for compressing the image (polling) and flattening the data are used. Finally, this information is passed into a neural network, called Fully-Connected Layer in the world of Convolutional Neural Networks. Again, the goal of this article is to show you how to implement all these concepts, so more details about these layers, how they work and what is the purpose of each of them can be found here.


In one of the previous articles, we implemented this type of neural networks using Python and Keras. We created a neural network that is able to detect and classify handwritten digits. For that purpose, we used MNIST dataset. This is a well-known dataset in the world of neural networks. It is extending its predecessor NIST and it has a training set of 60,000 samples and testing set of 10,000 images of handwritten digits. All digits have been size-normalized and centered. Size of the images is also fixed to 28×28 pixels. This is why this dataset is so popular.

MNIST Dataset samples

Using Convolutional Neural Networks we can get almost human results. The record, when it comes to accuracy of prediction on this dataset, is held by the Parallel Computing Center (Khmelnitskiy, Ukraine). They used an ensemble of only 5 convolution neural networks and got the error rate of 0.21 percent. Basically, they are giving correct results in 99.79% of the cases. Awesome, isn’t it?

For the implementation in JavaScript that we plan to do here we will also need a simple HTTP server. I am using this one, which can be easily installed using npm:

> npm install http-server

It is run using command:

> http-server

Loading Data

In order to faster load fore-mentioned data, guys from Google provided us with this sprite file and with this code so we can manage that sprite file. However, I had to tweak this code a little bit to better fit my needs, so here how it looks like:

* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
const IMAGE_SIZE = 784;
const NUM_CLASSES = 10;
const TRAIN_TEST_RATIO = 5 / 6;
* A class that fetches the sprited MNIST dataset and returns shuffled batches.
* NOTE: This will get much easier. For now, we do data fetching and
* manipulation manually.
export class MnistData {
constructor() {
this.shuffledTrainIndex = 0;
this.shuffledTestIndex = 0;
async load() {
// Make a request for the MNIST sprited image.
const img = new Image();
const canvas = document.createElement('canvas');
const ctx = canvas.getContext('2d');
const imgRequest = new Promise((resolve, reject) => {
img.crossOrigin = '';
img.onload = () => {
img.width = img.naturalWidth;
img.height = img.naturalHeight;
const datasetBytesBuffer =
const chunkSize = 5000;
canvas.width = img.width;
canvas.height = chunkSize;
for (let i = 0; i < NUM_DATASET_ELEMENTS / chunkSize; i++) {
const datasetBytesView = new Float32Array(
datasetBytesBuffer, i * IMAGE_SIZE * chunkSize * 4,
IMAGE_SIZE * chunkSize);
img, 0, i * chunkSize, img.width, chunkSize, 0, 0, img.width,
const imageData = ctx.getImageData(0, 0, canvas.width, canvas.height);
for (let j = 0; j < / 4; j++) {
// All channels hold an equal value since the image is grayscale, so
// just read the red channel.
datasetBytesView[j] =[j * 4] / 255;
this.datasetImages = new Float32Array(datasetBytesBuffer);
const labelsRequest = fetch(MNIST_LABELS_PATH);
const [imgResponse, labelsResponse] =
await Promise.all([imgRequest, labelsRequest]);
this.datasetLabels = new Uint8Array(await labelsResponse.arrayBuffer());
// Create shuffled indices into the train/test set for when we select a
// random dataset element for training / validation.
this.trainIndices = tf.util.createShuffledIndices(NUM_TRAIN_ELEMENTS);
this.testIndices = tf.util.createShuffledIndices(NUM_TEST_ELEMENTS);
// Slice the the images and labels into train and test sets.
this.trainImages =
this.datasetImages.slice(0, IMAGE_SIZE * NUM_TRAIN_ELEMENTS);
this.testImages = this.datasetImages.slice(IMAGE_SIZE * NUM_TRAIN_ELEMENTS);
this.trainLabels =
this.datasetLabels.slice(0, NUM_CLASSES * NUM_TRAIN_ELEMENTS);
this.testLabels =
this.datasetLabels.slice(NUM_CLASSES * NUM_TRAIN_ELEMENTS);
nextDataBatch(batchSize, test = false) {
return this.nextBatch(
batchSize, [this.trainImages, this.trainLabels], () => {
this.shuffledTrainIndex =
(this.shuffledTrainIndex + 1) % this.trainIndices.length;
return this.trainIndices[this.shuffledTrainIndex];
return this.nextBatch(batchSize, [this.testImages, this.testLabels], () => {
this.shuffledTestIndex =
(this.shuffledTestIndex + 1) % this.testIndices.length;
return this.testIndices[this.shuffledTestIndex];
nextBatch(batchSize, data, index) {
const batchImagesArray = new Float32Array(batchSize * IMAGE_SIZE);
const batchLabelsArray = new Uint8Array(batchSize * NUM_CLASSES);
for (let i = 0; i < batchSize; i++) {
const idx = index();
const image =
data[0].slice(idx * IMAGE_SIZE, idx * IMAGE_SIZE + IMAGE_SIZE);
batchImagesArray.set(image, i * IMAGE_SIZE);
const label =
data[1].slice(idx * NUM_CLASSES, idx * NUM_CLASSES + NUM_CLASSES);
batchLabelsArray.set(label, i * NUM_CLASSES);
const xs = tf.tensor2d(batchImagesArray, [batchSize, IMAGE_SIZE]);
const labels = tf.tensor2d(batchLabelsArray, [batchSize, NUM_CLASSES]);
return {xs, labels};
view raw data.js hosted with ❤ by GitHub

Here are some major points of this MnistData class. Images and labels of those images are loaded into fields trainImages, testImages, trainLabels and testLabels. This function needs to be called first. Once that is done, we can use methods nextDataBatch to get batches of data from these datasets. Underneath, these methods are calling nextBatch function with different images and labels. This function also converts data into tensors which is necessary for processing.


The whole code that accompanies this blog post can be found here.

Let’s start from index.html file of our solution. In the previous article, we presented several ways of installing TensorFlow.js. One of them was integrating it within script tag of the HTML file. That is what is done here:

<!DOCTYPE html>
<meta charset="utf-8">
<meta http-equiv="X-UA-Compatible" content="IE=edge">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>TensorFlow.js Convolutional Neural Networks</title>
<script src=""></script>
<script src=""></script>
<script src="./data.js" type="module"></script>
<script src="./script.js" type="module"></script>
view raw index.html hosted with ❤ by GitHub

Note that we add the script tag for TensorFlow.js and additional for tfjs-vis. This is a small library for in-browser visualization. In this HTML file, we imported data.js file, which should be located in the same folder as index.html file, and script.js file. Our complete implementation is located in this file. In order to run this whole process, all you have to do is open index.html inside of your browser.

Now, let’s examine the script.js file, where the majority of implementaiton is located. Here is how the main run function looks:

async function run() {
const data = await getData();
await displayDataFunction(data, 30);
const model = createModel();{name: 'Model Architecture'}, model);
await trainModel(model, data, 20);
await evaluateModel(model, data);
view raw run-tfjscnn.js hosted with ❤ by GitHub

You can notice that this function is similar to the one from the previous article. It reveals the workflow of the application. In the beginning, we load the data using getData function:

async function getDataFunction() {
var data = new MnistData();
await data.load();
return data;

First, we create an object of MnistData class. This class is located in data.js file. Then we call mentioned load function. It will initialize the properties of the created object. Once this is done, we can return this object and use it for displaying input data in the browser with displayData function:

async function singleImagePlot(image)
const canvas = document.createElement('canvas');
canvas.width = 28;
canvas.height = 28; = 'margin: 4px;';
await tf.browser.toPixels(image, canvas);
return canvas;
async function displayDataFunction(data, numOfImages = 10) {
const inputDataSurface =
tfvis.visor().surface({ name: 'Input Data Examples', tab: 'Input Data'});
const examples = data.nextDataBatch(numOfImages, true);
for (let i = 0; i < numOfImages; i++) {
const image = tf.tidy(() => {
return examples.xs
.slice([i, 0], [1, examples.xs.shape[1]])
.reshape([28, 28, 1]);
const canvas = await singleImagePlot(image)

In this function, we first create a new tab called Input Data. Then we get a batch of test data using nextDataBatch function from the MnistData class. Then we iterate through those images and convert it from tensor into data that can be displayed. Finally, we show them in the created tab:

Once data is visualized, we can proceed to the more fun part of the implementation and create model. This is done in the function createModel:

function createModelFunction() {
const cnn = tf.sequential();
inputShape: [28, 28, 1],
kernelSize: 5,
filters: 8,
strides: 1,
activation: 'relu',
kernelInitializer: 'varianceScaling'
cnn.add(tf.layers.maxPooling2d({poolSize: [2, 2], strides: [2, 2]}));
kernelSize: 5,
filters: 16,
strides: 1,
activation: 'relu',
kernelInitializer: 'varianceScaling'
cnn.add(tf.layers.maxPooling2d({poolSize: [2, 2], strides: [2, 2]}));
units: 10,
kernelInitializer: 'varianceScaling',
activation: 'softmax'
optimizer: tf.train.adam(),
loss: 'categoricalCrossentropy',
metrics: ['accuracy'],
return cnn;

In order to better understand the layers used in this method, please refer to this article. Basically, we create two convolutional layers that are followed by max-polling layers. Finally, we flatten the data into an array and put it through a fully connected layer, which is in this case just one dense layer with 10 neurons. This last layer is actually the output layer, which predicts the class of the image. Model is then compiled with categorical cross-entropy and Adam optimizer. Once we print the summary of the model with tfjs-vis here is what we get:

Cool! Now, we have prepared input data and model, so we can train it. This is done inside trainModel function:

async function trainModelFunction(model, data, epochs) {
const metrics = ['loss', 'val_loss', 'acc', 'val_acc'];
const container = {
name: 'Model Training', styles: { height: '1000px' }
const fitCallbacks =, metrics);
const batchSize = 512;
const [trainX, trainY] = getBatch(data, 5500);
const [testX, testY] = getBatch(data, 1000, true);
return, trainY, {
batchSize: batchSize,
validationData: [testX, testY],
epochs: epochs,
shuffle: true,
callbacks: fitCallbacks

In essence, we get a batch of train data and a batch of test data. Then we ignite fit method on our model and pass the train data for training and test data for evaluation. Metrics like loss and accuracy are displayed after each epoch using tf-vis:

The final step in this process if the evaluation of our model. For that purpose, we display accuracy per digit and we use the concept of a confusion matrix. This matrix is just a table that is often used to describe the performance of a classification model. This all is done in the evaluateModel function:

function predict(model, data, testDataSize = 500) {
const testData = data.nextDataBatch(testDataSize, true);
const testxs = testData.xs.reshape([testDataSize, 28, 28, 1]);
const labels = testData.labels.argMax([1]);
const preds = model.predict(testxs).argMax([1]);
return [preds, labels];
async function displayAccuracyPerClass(model, data) {
const [preds, labels] = predict(model, data);
const classAccuracy = await tfvis.metrics.perClassAccuracy(labels, preds);
const container = {name: 'Accuracy', tab: 'Evaluation'};, classAccuracy, classNames);
async function displayConfusionMatrix(model, data) {
const [preds, labels] = predict(model, data);
const confusionMatrix = await tfvis.metrics.confusionMatrix(labels, preds);
const container = {name: 'Confusion Matrix', tab: 'Evaluation'};
container, {values: confusionMatrix}, classNames);
async function evaluateModelFunction(model, data)
await displayAccuracyPerClass(model, data);
await displayConfusionMatrix(model, data);

As you can see, this function is using helper functions like
displayAccuracyPerClass, displayConfusionMatrix and predict to make thise graphs:

We are getting pretty good results with this simple model and with just 20 epochs. We could improve these results by adding additional convolutional layers or increasing number of epochs. This would, of course, get an impact on the length of the training process.


In this article, we got a chance to see how we can utilize Convolutional Neural Networks with TensorFlow.js. We learned how to manipulate layers that are specific for this type of neural networks and run them inside of a browser.

Thank you for reading!

Read more posts from the author at Rubik’s Code.

Ultimate Guide to Machine Learning with Python

Everything from Python basics to the deployment of Machine Learning algorithms to production in one place.

Become a Machine Learning Superhero TODAY!