Deep Learning Advanced Part 6: Inductive Bias and Knowledge Distillation Explained
In the previous article, we completed the full logic of Vision Transformer: cutting images into patches as tokens and feeding them into Transformer Encoder for global modeling.
However, we also mentioned an unavoidable pain point of ViT:
Without sufficiently large data scales, ViT is often difficult to train well.
From a paradigm perspective, this is because ViT is essentially a "weak prior, strong data-driven" modeling approach.
Expanding on this further, regarding the question:
Why does ViT require large amounts of data to perform well, while CNN remains effective with small data?
We've already expanded on prior knowledge related content in our hyperspectral imaging content. We know that convolutional networks write大量 visual priors into their structure, while ViT is more "data-driven," with spatial relationships mainly learned through statistics, making it highly sensitive to data scale and training recipes.
But jumping directly to the next improvement of ViT based solely on these general concepts feels somewhat thin.
Therefore, this article mainly introduces two important concepts in Deep Learning:
- Inductive Bias
- Distillation
After understanding these concepts, you can smoothly enter one of the improvement logics of ViT.
1. Prior Information and Inductive Bias
First, we need to understand two highly related but not equivalent concepts: Prior Information and Inductive Bias.
1.1 What is Prior Information?
Let's simply restate it with a one-sentence definition:
Prior information is the "laws about the world" that we already know before learning.
It doesn't come from data but from experience or cognition.
For example, in vision tasks, we naturally know: images are continuous, adjacent pixels are more correlated, objects have structure, and other basic cognitions. Expanding simply:
- Structural priors: In a face, eyes are above the nose, mouth is below the nose, presenting a stable spatial arrangement relationship overall.
- Local correlation priors: In an image, adjacent pixels usually belong to the same object. For example, in a sky region, its color and texture are smooth and similar within a local range, rather than suddenly changing drastically.
- Continuity priors: Edges and contours in images are usually continuous. For example, a road or object boundary won't randomly interrupt or jump at adjacent positions.
These all belong to descriptions of the real world, which is prior information.
1.2 What is Inductive Bias?
In comparison, inductive bias is a more "model-perspective" concept:
Inductive bias is the mechanism by which a model "tends to choose a certain type of solution" during the learning process. It usually comes from prior information.
It's not specific knowledge but designs we make in the model based on priors, thereby making the model easier to learn certain things and harder to learn others.
In summary: Inductive bias determines the model's "learning direction."
At this point, we can more completely explain the initial question: Why does ViT require large amounts of data to perform well, while CNN remains effective with small data?
1.3 CNN and ViT
Let's first summarize using the concepts just introduced: CNN and ViT are essentially two different choices on the question of "whether to introduce inductive bias."
Here, we need to emphasize that the data dependency caused by inductive bias is only relative. DL-based methods themselves are all data-driven.
Let's talk about CNN first. Our CNN modeling logic itself is doing one thing: writing prior information into the model structure, and this is the embodiment of inductive bias.
Specific expansion and comparison:
- Because of locality priors, we designed convolutional kernels to restrict the model to only focus on local regions.
- We use network hierarchical structures, gradually combining from local to global, which is actually also an embodiment of structural priors.
You'll discover: CNN is already "prescribed" how to understand images before training begins.
Expanding from a mathematical perspective: CNN's learning doesn't occur in a completely free space but searches for solutions in a strongly constrained function space.
This is strong inductive bias. The direct result is that even with limited data, the model can converge faster and learn reasonable structures more easily, making it less likely to learn incorrectly.
But the cost is that the model's expressive power is limited by structure, which actually limits the model's upper bound because the model's learning logic doesn't necessarily have to follow the logic we humans understand.
(In GPT-generated Chinese illustrations haven't been used for about four months. The Chinese in images generated four months ago still had many errors and confusion. It's really evolving continuously.)
In ViT, we made almost the opposite choice: trying not to write priors into the structure but letting data learn them.
The embodiment is not using convolution; all patch tokens interact globally through attention, which is weaker inductive bias at the spatial hierarchy level.
In other words: The model initially doesn't know "what is local structure" or "what is spatial relationship."
Therefore, ViT's learning process searches for solutions in an almost unconstrained huge function space.
This brings stronger expressive power, but training difficulty significantly increases, making it highly sensitive to data scale and training strategies.
After a long time, let's use an analogy:
CNN is like "finding a path with a map." You have a general direction and can more stably train pathfinding abilities.
ViT is like "repeatedly trying errors in an unknown environment." After massive training, it possesses stronger pathfinding abilities.
The "map" is inductive bias.
1.4 Summary
Summarizing this part's content as a rule:
The weaker the inductive bias, the stronger the model's dependency on data.
At this point, the next question naturally arises:
If we don't want to change ViT's structure but hope it performs better on small data, what should we do?
The answer is the content of the next part: Distillation.
2. Distillation
If we summarize in one sentence: Distillation is letting a "small model" imitate a "large model's" output, thereby learning better decision-making abilities.
Looking at this sentence alone, you might associate it with Transfer Learning we introduced before. They look like they're both "borrowing a strong model's capabilities," but the essential logic is different:
- Transfer Learning: Take "already learned parameters" and use them.
- Distillation: Don't directly use parameters but imitate the model's behavior (output distribution).
Before expanding in detail, let's first unify two core roles:
- Teacher: Usually a stronger model that's already trained well.
- Student: The target model we actually want to train.
This naming is also the mainstream terminology in related field literature. Now let's expand on the distillation approach.
It needs to be explained in advance that distillation technology's logic is universal, but it itself exists in various forms, with different implementation methods in different types of tasks. Here we use the most basic classification task to demonstrate.
2.1 Preparing the Teacher
To perform distillation, you first need to complete Teacher preparation.
Teacher is usually an already trained strong model. The most common situation is directly using a pretrained Teacher.
Of course, if in certain research or special tasks, no ready-made strong model is available, then the flow at this time is to first train a Teacher with performance as good as possible, then use it to distill Student.
You might feel this approach is somewhat unnecessary, but this is because we're usually not pursuing the strongest model but pursuing a "strong enough + cheap enough" model.
Anyway, an important principle here is:
Teacher doesn't have to be "large," but must be "more reliable than Student."
Otherwise, Student learns wrong knowledge, and distillation反而 hinders performance.
2.2 Soft Labels
After preparing Teacher, the second step is to run data on Teacher to obtain soft labels.
It's not difficult to understand. Its operation is like this:
For the same input image, Teacher not only gives the final category but also outputs a complete probability distribution:
[p_t = [p_1, p_2, ..., p_C]]
This distribution contains similarity relationship information between categories.
For example:
[\text{cat} = 0.7,\quad \text{dog} = 0.2,\quad \text{car} = 0.1]
Here you'll discover this is directly obtaining the probability distribution through the output layer softmax.
Not finished yet. In actual distillation, a temperature parameter (T) is usually introduced to "soften" soft labels and Student output before performing related calculations:
[p_i = \frac{e^{z_i / T}}{\sum_j e^{z_j / T}}]
This way, when designing (T > 1), the probability distribution becomes smoother, low-probability categories are "amplified," making differences between categories more obvious.
Actually, it's a hyperparameter that's frequently debugged in real operation.
Anyway, we obtained soft labels for the dataset through Teacher, thereby obtaining more detailed information like which categories are "close to correct," the model's "hesitation degree," and the approximate structure of decision boundaries.
2.3 Student Imitates Teacher
Here is the core logic of distillation. Actually, explaining it in simple terms is:
Design Student's loss function so that Student's training fitting target, besides learning true labels, also learns Teacher's input soft labels, i.e., fits Teacher's output probability distribution.
The mathematical formulas below are somewhat tedious. After understanding this part's logic, it's not a big problem:
First, we need to introduce KL Divergence (full name Kullback–Leibler divergence).
For discrete distributions, KL divergence is defined as:
[D_{KL}(P \parallel Q) = \sum_{i} P(i)\log \frac{P(i)}{Q(i)}]
In actual implementation, to simplify calculations, its equivalent form is:
[\mathcal{L}_{KD} = - \sum p_t \log p_s]
Overall in the distillation context, it's usually written as:
[D_{KL}(p_t \parallel p_s)]
Where:
- (p_t): Teacher's distribution (actual reference).
- (p_s): Student's distribution (object to learn).
Semantically, KL's role here is to measure how different Student's prediction distribution is from Teacher's distribution.
Let's look at an example:
| Category | (p_t) (Teacher) | (p_s) (Student) | Ratio (\frac{p_t}{p_s}) | (\log(\frac{p_t}{p_s})) | (p_t \cdot \log(\frac{p_t}{p_s})) |
|---|---|---|---|---|---|
| Cat | 0.7 | 0.5 | 1.4 | 0.336 | 0.235 |
| Dog | 0.2 | 0.4 | 0.5 | -0.693 | -0.139 |
| Car | 0.1 | 0.1 | 1.0 | 0.000 | 0.000 |
Finally:
[D_{KL}(p_t \parallel p_s) = 0.235 - 0.139 + 0 = 0.096]
This result's composition logic is like this:
- Cat: Student's given probability is too low (0.5 < 0.7), there's error.
- Dog: Student's given probability is too high (0.4 > 0.2), there's error.
- Car: No problem.
Obviously, the smaller the result, the closer the two distributions are.
In ordinary classification tasks, the cross-entropy loss function is as follows:
[\mathcal{L}_{CE} = -\sum_{i=1}^{C} y_i \log p_i]
Finally, Student's loss function is a combination of both:
[\mathcal{L} = \alpha \mathcal{L}_{CE} + (1 - \alpha)\mathcal{L}_{KD}]
Where (\alpha) is a hyperparameter for adjusting weights. The two loss terms respectively represent:
- (\mathcal{L}_{CE}): Tells Student "what the standard answer is."
- (\mathcal{L}_{KD}): Tells Student "how a stronger model thinks."
In the original distillation method, to compensate for the previous (T)'s scaling of gradients, the loss also introduces (T^2) for correction. But in modern practice, the impact is small and often omitted. Just understand it.
Training and propagating like this, you can finally obtain the distilled model Student.
2.4 Summary
Actually, you'll discover distillation is a clever logic: For a powerful model, I directly learn your answer distribution.
But in fact, distillation does have its theoretical support and practical value. What's shown here is also just a more original logic. We'll expand in detail later.
Returning to ViT, we already know its problem lies in "the search space is too free." Then distillation's role here is:
Artificially introduce a "soft constraint," narrow the search space, make optimization more stable, thereby reducing data dependency.
This logic is actually still utilizing Teacher's inductive bias.
Similarly, when Teacher itself has bias, this constraint will also indirectly limit performance upper bound. Therefore, debugging (\alpha) is also crucial.
After understanding two concepts, we can continue to ViT's next improvement.