In the previous article, we had a chance to explore transfer learning with TensorFlow 2. We used several huge pre-trained models: VGG16, GoogLeNet and ResNet. These architectures are all trained on ImageNet dataset and their weights are stored. We specialized them for “Cats vs Dogs” dataset, the dataset that contains 23,262 images of cats and dogs. There are many pre-trained models available at tensorflow.keras.applications module. In essence, there are two ways in which you can use them. You can use it as out of the box solution and or you can use it with transfer learning.

Since, large datasets are usually used for some global solution you can customize pre-trained model and specialize it for certain problem. That is exactly what we have done in the previous article. This way you can utilize some of the most famous neural networks without loosing too much time and resources on training. Additionally, you can fine tune these models, by modifying behavior of the chosen layers. That is what we will do today. During the experiments in the previous article we got the best results with ResNet architecture. Let’s try to fine tune that model and maybe get even better results.


The problem of ImageNet winning architectures that came before ResNet was that they were very deep and had a lot of layers. For example, AlexNet had 5 convolutional layers, while VGG and GoogLeNet had 19 and 22 layers respectively. This means that because of vanishing gradient they are hard to train. Vanishing gradient is a problem that occurs when as the gradient is backpropagated from top layers to lower layers. During this repeated multiplication it may happen that gradient becomes very small. In turn this means that network stops learning and it’s performance degrades.

Residual Networks or ResNet tires to address that problem with so-called “identity shortcut connection”, or residual blocks:

In essence, ResNet follows VGG’s 3×3 convolutional layer design, where each convolutional layer is followed by a batch normalization layer and ReLU activation function. The difference is however that we before the final ReLuResNet injects input. One of the variations is that, input value is passes through the 1×1 convolution layer.

The core idea is that deeper network should not produce a training error higher than shallower network. Authors of the ResNet hypothesize that if you add layers that don’t do anything to a network error should stay the same. This means that letting the network fit a residual is easier than letting them directly fit the complete desired data. This is accomplished by residual blocks.

This is how the complete ResNet architecture looks like:


This implementation is divided into several sections. First we implement class that is in charge of loading data and preparing it. Then we import pre-trained models and build a class that will modify it’s top layers. Also, we fine tune it and configure it in a way that some of it’s layers are trainable. Finally we run the training process and evaluation process. Before everything, of course, we have to import some libraries and define some global constant:

import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
import tensorflow_datasets as tfds
IMG_SIZE = 160
view raw hosted with ❤ by GitHub

Data Loader

You can recognize this class from previous article. It is in charge of loading the data from “Cats vs Dogs” dataset and preparing it for processing. Images in this dataset are not normalized and that they have different shapes, so this class takes care of this as well. Here is what it looks like:

This class has several methods of which one is “public”. Here is explanation what each is doing:

  • _prepare_data – Internal method used to resize and normalize images from dataset. Utilized from constructor.
  • _resize_sample – Internal method used for resizing single image.
  • _prepare_batches – Internal method used to create batches from images. Creates train_batches, validation_batches and test_batches that are used for training and evaluation process.
  • get_random_raw_images – Method used to get certain number of random images from raw, non processed data.

However, majority of things happen in the constructor of the class. Let’s take a closer look:

def __init__(self, image_size, batch_size):
self.image_size = image_size
self.batch_size = batch_size
# 80% train data, 10% validation data, 10% test data
split_weights = (8, 1, 1)
splits = tfds.Split.TRAIN.subsplit(weighted=split_weights)
(self.train_data_raw, self.validation_data_raw, self.test_data_raw), self.metadata = tfds.load(
'cats_vs_dogs', split=list(splits),
with_info=True, as_supervised=True)
# Get the number of train examples
self.num_train_examples = self.metadata.splits['train'].num_examples*80/100
self.get_label_name = self.metadata.features['label'].int2str
# Pre-process data

In the beginning, we define image and batch size that are injected through parameters of the constructor. Then, since dataset is not already split into training and testing data, we split data using split weights. This is really a cool feature that TensorFlow Dataset introduced, because we stay within TensorFlow ecosystem and we don’t have to involve other libraries like Pandas or SciKit Learn. Once we performed data split we calculate the number of the training samples and call helper function that prepares data for training. All we need to do after this is to instantiate an object of this class and have fun with loaded data:

data_loader = DataLoader(IMG_SIZE, BATCH_SIZE)
plt.figure(figsize=(10, 8))
i = 0
for img, label in data_loader.get_random_raw_images(20):
plt.subplot(4, 5, i+1)
plt.title("{} – {}".format(data_loader.get_label_name(label), img.shape))
i += 1
view raw hosted with ❤ by GitHub

Here is the output:

Loading & Fine Tuning

Let’s see how we can load and fine tune pre-trained models. In previous article, we just added top layers on existing pre-trained model and trained just them, not the whole network. We used pre-trained model as feature extractor, and then specialized it for our own problem. We are going to do to do the same thing here, however, we are going to “un-freeze” some of the other layers in the pre-trained model that are close to the top. The intuition is that lower convolutional layers detect low level features like edges and curves, while the higher level, which are more specialized, detect features that are applicable to our specific problem. This is how it is done:

resnet_base = tf.keras.applications.ResNet101V2(input_shape=IMG_SHAPE, include_top=False, weights='imagenet')
resnet_base.trainable = True
from_layer = 100
for layer in resnet_base.layers[:from_layer]:
layer.trainable = False
view raw hosted with ❤ by GitHub

Pre-trained models are located in tensorflow.kearas.applications so first thing we need to do is to load ResNet from there. Notice that include_top parameter is defined as False. This means that we need to add top layers to this models so they are applicable to our concrete problem. Then we set that this model is actually trainable and we indicate the number of layers that are trainable. Once that is done, we need to add top layers. We do that using Wrapper class. This class accepts injected pre-trained model and adds one Global Average Polling Layer and one Dense layer:

class Wrapper(tf.keras.Model):
def __init__(self, base_model):
super(Wrapper, self).__init__()
self.base_model = base_model
self.average_pooling_layer = tf.keras.layers.GlobalAveragePooling2D()
self.output_layer = tf.keras.layers.Dense(1)
def call(self, inputs):
x = self.base_model(inputs)
x = self.average_pooling_layer(x)
output = self.output_layer(x)
return output
view raw hosted with ❤ by GitHub

Then we can create real model for classification Cats vs Dogs dataset and compile it:

base_learning_rate = 0.0001
resnet = Wrapper(resnet_base)
view raw hosted with ❤ by GitHub

Training & Evaluation

Training is done by putting together model and data with fit method:

history =,
view raw hosted with ❤ by GitHub

Training history looks a little wild. We can see that loss was really high at several points of training.

However, when we perform evaluation, we are able to see that we got a little bit better than in the previous article:

loss, accuracy = resnet.evaluate(data_loader.test_batches, steps = validation_steps)
print("Loss: {:.2f}".format(loss))
print("Accuracy: {:.2f}".format(accuracy))
view raw hosted with ❤ by GitHub

Here is the result:

Loss: 0.43
Accuracy: 0.98

So, we can see that we got 0.01% better accuracy .


In this article we extended our knowledge of transfer learning. We saw how we can further improve results we get with pre-trained models and how we can fine tune them.

Thank you for reading!

Read more posts from the author at Rubik’s Code.