Ungraded Lab: Knowledge Distillation


Welcome, during this ungraded lab you are going to perform a model compression technique known as knowledge distillation in which a student model "learns" from a more complex model known as the teacher. In particular you will:

  1. Define a Distiller class with the custom logic for the distillation process.
  2. Train the teacher model which is a CNN that implements regularization via dropout.
  3. Train a student model (a smaller version of the teacher without regularization) by using knowledge distillation.
  4. Train another student model from scratch without distillation called student_scratch.
  5. Compare the three students.

This notebook is based on this official Keras tutorial.

If you want a more theoretical approach to this topic be sure to check this paper Hinton et al. (2015).

Let's get started!

Imports

Prepare the data

For this lab you will use the cats vs dogs which is composed of many images of cats and dogs alongise their respective labels.

Begin by downloading the data:

Preprocess the data for training by normalizing pixel values, reshaping them and creating batches of data:

Code the custom Distiller model

In order to implement the distillation process you will create a custom Keras model which you will name Distiller. In order to do this you need to override some of the vanilla methods of a keras.Model to include the custom logic for the knowledge distillation. You need to override these methods:

To learn more about customizing models check out the official docs.

Teacher and student models

For the models you will use a standard CNN architecture that implements regularization via some dropout layers (in the case of the teacher), but it could be any Keras model.

Define the create_model functions to create models with the desired architecture using Keras' Sequential Model.

Notice that create_small_model returns a simplified version of the model (in terms of number of layers and absence of regularization) that create_big_model returns:

There are two important things to notice:

Remember that the student model can be thought of as a simplified (or compressed) version of the teacher model.

Check the actual difference in number of trainable parameters (weights and biases) between both models:

Train the teacher

In knowledge distillation it is assumed that the teacher has already been trained so the natural first step is to train the teacher. You will do so for a total of 8 epochs:

Train a student from scratch for reference

In order to assess the effectiveness of the distillation process, train a model that is equivalent to the student but without doing knowledge distillation. Notice that the training is done for only 5 epochs:

Knowledge Distillation

To perform the knowledge distillation process you will use the custom model you previously coded. To do so, begin by creating an instance of the Distiller class and passing in the student and teacher models. Then compile it with the appropiate parameters and train it!

The two student models are trained for only 5 epochs unlike the teacher that was trained for 8. This is done to showcase that the knowledge distillation allows for quicker training times as the student learns from an already trained model.

Comparing the models

To compare the models you can check the sparse_categorical_accuracy of each one on the test set:

The teacher model yields a higger accuracy than the two student models. This is expected since it was trained for more epochs while using a bigger architecture.

Notice that the student without distillation was outperfomed by the student with knowledge distillation.

Since you saved the training history of each model you can create a plot for a better comparison of the two student models.

This plot is very interesting because it shows that the distilled version outperformed the unmodified one in almost all of the epochs when using the evaluation set. Alongside this, the student without distillation yields a bigger training accuracy, which is a sign that it is overfitting more than the distilled model. This hints that the distilled model was able to learn from the regularization that the teacher implemented! Pretty cool, right?


Congratulations on finishing this ungraded lab! Now you should have a clearer understanding of what Knowledge Distillation is and how it can be implemented using Tensorflow and Keras.

This process is widely used for model compression and has proven to perform really well. In fact you might have heard about DistilBert, which is a smaller, faster, cheaper and lighter of BERT.

Keep it up!