WEBVTT

00:00:00.000 --> 00:00:04.639
Right now, in your pocket, your phone is running

00:00:04.639 --> 00:00:07.440
an AI that, honestly, by all rights, it should

00:00:07.440 --> 00:00:10.179
just melt its processor. Yeah, and completely

00:00:10.179 --> 00:00:12.240
drain your battery in like three minutes flat.

00:00:12.480 --> 00:00:15.000
Exactly. I mean, we are so used to picturing

00:00:15.000 --> 00:00:19.199
artificial intelligence as this sprawling, intimidating

00:00:19.199 --> 00:00:21.300
thing, like acres of humming servers in these

00:00:21.300 --> 00:00:23.960
massive data centers. Right, with liquid cooling

00:00:23.960 --> 00:00:27.010
systems working overtime. Yeah. pulling gigawatts

00:00:27.010 --> 00:00:30.510
from the grid just to power a single trillion

00:00:30.510 --> 00:00:34.030
parameter neural network. Yet somehow, you can

00:00:34.030 --> 00:00:35.869
just step away from all that massive infrastructure,

00:00:36.109 --> 00:00:38.609
pull out your smartphone, and run an incredibly

00:00:38.609 --> 00:00:41.149
sophisticated model right there on the palm of

00:00:41.149 --> 00:00:43.409
your hand. It's wild to think about. It really

00:00:43.409 --> 00:00:46.229
is. And that massive disconnect between the data

00:00:46.229 --> 00:00:48.719
center and the pocket. That's the central mystery

00:00:48.719 --> 00:00:50.759
we're pulling apart today. And it is arguably

00:00:50.759 --> 00:00:53.140
the defining engineering challenge of modern

00:00:53.140 --> 00:00:54.820
machine learning right now. I mean, we construct

00:00:54.820 --> 00:00:57.960
these gargantuan models, sometimes massive ensembles

00:00:57.960 --> 00:00:59.979
of many models working together, because they

00:00:59.979 --> 00:01:02.399
have this immense capacity to absorb knowledge.

00:01:02.920 --> 00:01:05.420
But they are incredibly inefficient. They don't

00:01:05.420 --> 00:01:08.700
fully utilize that capacity during everyday tasks.

00:01:09.310 --> 00:01:12.150
Evaluating a massive model is astronomically

00:01:12.150 --> 00:01:14.530
expensive computationally. Even if it's just

00:01:14.530 --> 00:01:16.730
doing something simple, right? Exactly. Even

00:01:16.730 --> 00:01:19.870
if it is only using a tiny fraction of its theoretical

00:01:19.870 --> 00:01:23.010
brain power to, say, recognize a voice command

00:01:23.010 --> 00:01:26.450
or translate a menu, the ultimate industry imperative

00:01:26.450 --> 00:01:29.349
right now is economical deployment. So figuring

00:01:29.349 --> 00:01:32.780
out how to run these... valid, powerful models

00:01:32.780 --> 00:01:35.980
on highly restricted, much less powerful hardware.

00:01:36.079 --> 00:01:38.000
Right, without losing their core capabilities.

00:01:38.180 --> 00:01:40.200
Well, I see your visual backdrop has shifted

00:01:40.200 --> 00:01:44.980
to this dense, high -tech cityscape, interwoven

00:01:44.980 --> 00:01:48.200
with a glowing digital neural network grid. Sets

00:01:48.200 --> 00:01:50.599
the perfect mood. I do like it. So for you listening,

00:01:50.680 --> 00:01:53.019
we are pulling our facts today from a comprehensive

00:01:53.019 --> 00:01:55.480
Wikipedia article on a machine learning method

00:01:55.480 --> 00:01:58.390
called knowledge distillation. or model distillation

00:01:58.390 --> 00:02:00.590
as it's sometimes called. Right. And our mission

00:02:00.590 --> 00:02:02.849
for this deep dive is to understand the mechanics

00:02:02.849 --> 00:02:05.590
of how scientists actually take these giant teacher

00:02:05.590 --> 00:02:07.969
neural networks and shrink them down into smaller

00:02:07.969 --> 00:02:10.840
student networks. while preserving the actual

00:02:10.840 --> 00:02:12.460
validity of the knowledge, which is the hard

00:02:12.460 --> 00:02:15.240
part. And to give this some context, knowledge

00:02:15.240 --> 00:02:19.099
distillation is not some experimental fringe

00:02:19.099 --> 00:02:21.379
theory. It's being used right now. Oh, everywhere.

00:02:21.500 --> 00:02:24.280
Yeah. It is a foundational technique actively

00:02:24.280 --> 00:02:26.740
deployed across a massive variety of machine

00:02:26.740 --> 00:02:30.000
learning applications today. It's used in object

00:02:30.000 --> 00:02:33.180
detection for autonomous vehicles, acoustic models

00:02:33.180 --> 00:02:36.919
for speech recognition, large scale natural language

00:02:36.919 --> 00:02:39.810
processing, and even non -grid data structures

00:02:39.810 --> 00:02:42.210
like graph neural networks. OK, let's unpack

00:02:42.210 --> 00:02:44.650
this a bit. Before we even touch the complex

00:02:44.650 --> 00:02:47.189
math of how this transfer actually happens, we

00:02:47.189 --> 00:02:49.189
need to understand what is actually being transferred.

00:02:49.550 --> 00:02:51.469
The knowledge itself. Yeah, because when we say

00:02:51.469 --> 00:02:54.169
we are passing knowledge from a teacher to a

00:02:54.169 --> 00:02:56.250
student, what does that actually mean to a machine?

00:02:56.750 --> 00:02:59.550
Let's use a multiple choice test analogy. I like

00:02:59.550 --> 00:03:02.289
that. Let's hear it. OK, imagine a smaller, less

00:03:02.289 --> 00:03:05.330
capable model, like a C average student. They

00:03:05.330 --> 00:03:07.830
studied just enough raw data to know that the

00:03:07.830 --> 00:03:10.750
correct answer to question four is A. They just

00:03:10.750 --> 00:03:13.310
memorized the hard fact. Right, rote memorization.

00:03:13.650 --> 00:03:16.550
Exactly. Yeah. But the large model, the valedictorian,

00:03:16.650 --> 00:03:19.490
doesn't just know that A is correct. It knows

00:03:19.490 --> 00:03:22.629
the underlying mechanics of why A is right. More

00:03:22.629 --> 00:03:25.210
importantly, it recognizes that answer B is a

00:03:25.210 --> 00:03:27.750
cleverly designed distractor. And answer C is

00:03:27.750 --> 00:03:30.810
just completely ridiculous. Yes. So my question

00:03:30.810 --> 00:03:34.569
is, is that nuance? that specific understanding

00:03:34.569 --> 00:03:37.810
of the wrong answers is that the actual knowledge

00:03:37.810 --> 00:03:40.370
we are trying to bottle up. What's fascinating

00:03:40.370 --> 00:03:42.930
here is that your analogy maps perfectly to the

00:03:42.930 --> 00:03:45.169
math. It really does. If you take a small model

00:03:45.169 --> 00:03:47.270
and a large model and train them both on the

00:03:47.270 --> 00:03:50.469
exact same raw data, the small model simply lacks

00:03:50.469 --> 00:03:52.750
the architectural capacity. Like the internal

00:03:52.750 --> 00:03:55.270
parameters. Exactly. It lacks the parameters

00:03:55.270 --> 00:03:57.810
to form a concise, efficient representation of

00:03:57.810 --> 00:04:00.250
the knowledge. The raw data is just too noisy

00:04:00.250 --> 00:04:02.750
for it. But the big model can handle it. Right.

00:04:03.150 --> 00:04:05.650
The large model, because of its vast size, can

00:04:05.650 --> 00:04:08.270
filter that noise and learn that concise representation.

00:04:08.789 --> 00:04:11.469
It encodes that deep nuance in something called

00:04:11.469 --> 00:04:13.969
pseudo likelihoods. OK, pseudo likelihoods. Let's

00:04:13.969 --> 00:04:15.909
translate that for a second. Is that just the

00:04:15.909 --> 00:04:19.209
model's internal confidence score for every possible

00:04:19.209 --> 00:04:22.980
answer? Essentially, yeah. When a large classification

00:04:22.980 --> 00:04:25.980
model makes a prediction, it assigns a very large

00:04:25.980 --> 00:04:28.579
numerical value to the correct output. But it

00:04:28.579 --> 00:04:30.819
doesn't just output a zero for everything else.

00:04:30.939 --> 00:04:33.600
It gives them a score or two. Right. It assigns

00:04:33.600 --> 00:04:36.740
smaller, highly specific values to all the incorrect

00:04:36.740 --> 00:04:39.600
outputs. And the specific distribution of those

00:04:39.600 --> 00:04:42.680
values across the entire board provides a literal

00:04:42.680 --> 00:04:45.660
map of how the large model structures its understanding

00:04:45.660 --> 00:04:48.579
of the world. So it tells us that an image of

00:04:48.579 --> 00:04:51.750
a dog is very likely a dog. somewhat likely a

00:04:51.750 --> 00:04:55.129
cat, and absolutely not a minivan. Perfect example.

00:04:55.470 --> 00:04:57.829
So instead of forcing the small model to learn

00:04:57.829 --> 00:05:00.389
from the chaotic raw data, we train the large

00:05:00.389 --> 00:05:03.430
model first. Then we distill its knowledge by

00:05:03.430 --> 00:05:05.709
training the small model to mimic that exact

00:05:05.709 --> 00:05:08.629
distribution of values. The soft output. Exactly,

00:05:08.730 --> 00:05:11.170
the soft output of the large model across all

00:05:11.170 --> 00:05:13.610
variables. So we are teaching the student the

00:05:13.610 --> 00:05:15.470
exact shape of the teacher's thought process,

00:05:15.529 --> 00:05:17.870
not just giving it the final answer key. I'm

00:05:17.870 --> 00:05:19.750
a little lost on the mechanics of this though.

00:05:19.959 --> 00:05:22.980
because if a teacher model is highly confident

00:05:22.980 --> 00:05:25.800
and highly accurate, the probability it spits

00:05:25.800 --> 00:05:27.839
out for the correct answer is going to be something

00:05:27.839 --> 00:05:31.920
like, what, 99 .9 %? Usually, yeah. So the probability

00:05:31.920 --> 00:05:34.319
for the wrong answers, those nuances we supposedly

00:05:34.319 --> 00:05:38.699
want to capture might be like 0 .001%. If those

00:05:38.699 --> 00:05:41.600
numbers are that microscopic, How do we mathematically

00:05:41.600 --> 00:05:45.540
force a small, low -capacity neural network to

00:05:45.540 --> 00:05:48.180
pay any attention to them? Well, the answer lies

00:05:48.180 --> 00:05:50.579
in a mathematical manipulation called temperature.

00:05:50.660 --> 00:05:53.660
Temperature, like heat. Not literal heat, no.

00:05:53.839 --> 00:05:56.300
In the final layer of a classification network,

00:05:56.439 --> 00:05:58.860
there is a function that converts the raw numbers

00:05:58.860 --> 00:06:01.079
the network generates, which are called logit

00:06:01.079 --> 00:06:04.160
values, into pseudo probabilities that add up

00:06:04.160 --> 00:06:06.759
to 100%. Okay. Tracking with you. In a standard

00:06:06.759 --> 00:06:08.920
setup, this equation includes a parameter called

00:06:08.920 --> 00:06:11.240
temperature, represented by the letter T, which

00:06:11.240 --> 00:06:13.379
is normally just set to 1. A standard operating

00:06:13.379 --> 00:06:15.319
procedure. But in the knowledge distillation

00:06:15.319 --> 00:06:17.879
process, the researchers intentionally set that

00:06:17.879 --> 00:06:20.579
temperature parameter to a very high value. I

00:06:20.579 --> 00:06:22.560
have to pause you there because that feels totally

00:06:22.560 --> 00:06:26.110
counterintuitive. If we already have a transfer

00:06:26.110 --> 00:06:28.389
data set where we know the hundred percent correct

00:06:28.389 --> 00:06:31.430
answers, the ground truth, why are we ignoring

00:06:31.430 --> 00:06:33.490
that hard truth to look at the teacher's soft

00:06:33.490 --> 00:06:35.569
guesses? It does seem backward at first. And

00:06:35.569 --> 00:06:37.250
then jacking up the temperature on top of it.

00:06:37.610 --> 00:06:40.370
Isn't artificially heating up the math just like

00:06:40.370 --> 00:06:42.850
adding unnecessary noise to the system? It seems

00:06:42.850 --> 00:06:45.459
like it would muddy the waters, yeah. But a high

00:06:45.459 --> 00:06:47.720
temperature fundamentally changes the shape of

00:06:47.720 --> 00:06:50.379
the data the student model sees. It converts

00:06:50.379 --> 00:06:53.360
those rigid logit values into a much softer,

00:06:53.959 --> 00:06:55.959
flatter distribution of pseudo -probability.

00:06:56.100 --> 00:06:58.360
So it squashes the big numbers and raises the

00:06:58.360 --> 00:07:01.980
small ones. Right. Suddenly that 99 .9 % confidence

00:07:01.980 --> 00:07:06.199
spike drops down and the tiny 0 .001 % values

00:07:06.199 --> 00:07:09.259
rise up. By flattening the distribution, you

00:07:09.259 --> 00:07:11.480
drastically increase the entropy of the output.

00:07:11.689 --> 00:07:14.910
meaning there is more spread, so the subtle differences

00:07:14.910 --> 00:07:17.430
between the wrong answers become highly visible

00:07:17.430 --> 00:07:20.009
to the student. Exactly. The student model suddenly

00:07:20.009 --> 00:07:22.449
has significantly more actionable information

00:07:22.449 --> 00:07:24.490
to learn from compared to just looking at hard

00:07:24.490 --> 00:07:27.290
targets, like a simple zero or one. Wow. But

00:07:27.290 --> 00:07:30.250
there is a secondary, deeply mathematical benefit

00:07:30.250 --> 00:07:33.129
to this, too. Making the distribution softer

00:07:33.129 --> 00:07:35.930
actually reduces the variance of the gradient

00:07:35.930 --> 00:07:38.050
between different data records during the training

00:07:38.050 --> 00:07:40.029
process. OK, let's break down gradient variance

00:07:40.029 --> 00:07:42.949
for a second. Why does lower variance matter

00:07:42.949 --> 00:07:45.860
to the student model? Think of training a model

00:07:45.860 --> 00:07:48.540
like trying to find the lowest point in a valley

00:07:48.540 --> 00:07:51.120
while you're blindfolded. You take a step, you

00:07:51.120 --> 00:07:53.000
feel the slope, and you move downhill. Right.

00:07:53.300 --> 00:07:56.240
That slope is the gradient. If the variance is

00:07:56.240 --> 00:07:58.120
high, it's like trying to find the bottom of

00:07:58.120 --> 00:08:00.459
the valley during an earthquake. Oh, man. Right.

00:08:00.939 --> 00:08:04.079
Every step gives you wildly different, jagged

00:08:04.079 --> 00:08:06.959
feedback. You have to take very tiny, cautious

00:08:06.959 --> 00:08:09.759
steps, which means a really low learning rate.

00:08:09.879 --> 00:08:11.920
Because you don't want to fall off a cliff. Exactly.

00:08:12.279 --> 00:08:14.600
But if you lower the variance, the ground stops

00:08:14.600 --> 00:08:17.420
shaking. The feedback is smooth and consistent.

00:08:18.019 --> 00:08:20.579
Because of that, you can use a much higher learning

00:08:20.579 --> 00:08:23.670
rate. The small model can take large, confident

00:08:23.670 --> 00:08:26.529
steps, learning faster and far more efficiently.

00:08:26.949 --> 00:08:29.930
Okay, so turning up the heat melts the rigid

00:08:29.930 --> 00:08:33.389
answers down into a softer, richer soup of information.

00:08:34.169 --> 00:08:36.289
And the student absorbs it faster because the

00:08:36.289 --> 00:08:38.149
mathematical terrain is smoother. That's a great

00:08:38.149 --> 00:08:40.769
way to put it. But what about the actual undeniable

00:08:40.769 --> 00:08:42.990
truth of the data? Like, if we have an image

00:08:42.990 --> 00:08:45.929
of a dog and we know it's a dog, do we just throw

00:08:45.929 --> 00:08:48.289
that fact out entirely to focus on the teacher's

00:08:48.289 --> 00:08:50.429
softened guesses? No, no. The ground truth is

00:08:50.429 --> 00:08:52.970
still incredibly valuable if you have it. The

00:08:52.970 --> 00:08:55.330
distillation process uses a cross entropy loss

00:08:55.330 --> 00:08:57.610
function to measure the difference between the

00:08:57.610 --> 00:09:00.350
small model's outputs and the large model's soft

00:09:00.350 --> 00:09:02.950
outputs at that high temperature. OK. But if

00:09:02.950 --> 00:09:05.009
the actual ground truth is available for the

00:09:05.009 --> 00:09:07.490
transfer data set, you add a second component

00:09:07.490 --> 00:09:10.029
to the loss function. You calculate the error

00:09:10.029 --> 00:09:12.990
against the known true label, but you compute

00:09:12.990 --> 00:09:15.070
that part at the normal temperature where t is

00:09:15.070 --> 00:09:17.269
equal to 1. So you anchor the student to the

00:09:17.269 --> 00:09:20.490
hard truth, but you let it explore the soft nuance

00:09:20.490 --> 00:09:23.340
of the teacher's logic at the same time. Exactly.

00:09:23.700 --> 00:09:26.000
And to ensure the math doesn't tear itself apart,

00:09:26.200 --> 00:09:28.340
trying to balance those two competing things,

00:09:28.840 --> 00:09:30.899
the component dealing with the large model's

00:09:30.899 --> 00:09:34.220
soft outputs is multiplied by a factor of T squared.

00:09:34.379 --> 00:09:37.019
Why T squared? Because as you increase the temperature,

00:09:37.440 --> 00:09:39.460
the mathematical gradient of the loss naturally

00:09:39.460 --> 00:09:42.440
scales down by a factor of 1 over T squared.

00:09:42.899 --> 00:09:46.419
So multiplying it by T squared acts as a counterbalance,

00:09:46.779 --> 00:09:48.919
keeping the two competing parts of the learning

00:09:48.919 --> 00:09:52.019
equation perfectly stable. So we are going through

00:09:52.019 --> 00:09:55.080
all this mathematical gymnastics, adjusting temperatures,

00:09:55.500 --> 00:09:58.220
balancing loss functions with t -squared multipliers,

00:09:58.559 --> 00:10:01.480
generating soft pseudo probabilities, all just

00:10:01.480 --> 00:10:04.039
to teach a small model. It's a lot of work. But

00:10:04.039 --> 00:10:06.259
wouldn't it be vastly simpler to just take a

00:10:06.259 --> 00:10:08.879
digital buzzsaw to the big model? Like why not

00:10:08.879 --> 00:10:11.500
literally chop pieces off the teacher until it

00:10:11.500 --> 00:10:13.409
fits on a smartphone? Well, that is the core

00:10:13.409 --> 00:10:15.669
distinction between knowledge distillation and

00:10:15.669 --> 00:10:18.029
model compression. Model compression is exactly

00:10:18.029 --> 00:10:20.269
what you just described. It aims to decrease

00:10:20.269 --> 00:10:22.789
the size of the large model itself by just cutting

00:10:22.789 --> 00:10:25.090
things out, generally preserving the overall

00:10:25.090 --> 00:10:27.570
architecture and the total number of parameters,

00:10:28.070 --> 00:10:30.830
but decreasing the bits per parameter to save

00:10:30.830 --> 00:10:33.509
physical space. So compression is like taking

00:10:33.509 --> 00:10:37.330
a massive uncompressed raw photo and saving it

00:10:37.330 --> 00:10:40.289
as a smaller JPEG. It's the same image. It just

00:10:40.289 --> 00:10:42.549
takes up less physical memory. That is the intent,

00:10:42.629 --> 00:10:45.250
yeah. But if you look closely at the math in

00:10:45.250 --> 00:10:47.730
the source material, there is a wild twist here.

00:10:47.850 --> 00:10:50.889
Wait on me. Under a specific condition specifically,

00:10:51.490 --> 00:10:53.870
the assumption that the logit values have a mean

00:10:53.870 --> 00:10:56.889
of zero something fascinating happens. If you

00:10:56.889 --> 00:10:58.909
take the complex derivative of the knowledge

00:10:58.909 --> 00:11:01.330
distillation loss we just talked about for large

00:11:01.330 --> 00:11:04.090
values of temperature. It simplifies drastically.

00:11:04.370 --> 00:11:06.409
How drastically? It mathematically reduces to

00:11:06.409 --> 00:11:08.529
a formula that directly matches the logics of

00:11:08.529 --> 00:11:11.169
the two models. Which means, under those conditions,

00:11:11.370 --> 00:11:14.129
the incredibly complex act of knowledge distillation

00:11:14.129 --> 00:11:16.870
simply becomes mathematically identical to model

00:11:16.870 --> 00:11:19.409
compression. Wait, really? They become the same

00:11:19.409 --> 00:11:22.149
thing? Yes. They are two sides of the same algorithmic

00:11:22.149 --> 00:11:25.629
coin. That is a beautiful piece of mathematical

00:11:25.629 --> 00:11:28.210
convergence. I love that. But let's say we do

00:11:28.210 --> 00:11:30.629
want to take that digital buzzsaw and actually

00:11:30.629 --> 00:11:33.409
compress the model. How do we know what parts

00:11:33.409 --> 00:11:36.929
to cut? The source mentions a pruning algorithm

00:11:36.929 --> 00:11:39.889
with a rather intense name, optimal brain damage.

00:11:40.169 --> 00:11:44.049
Right, optimal brain damage or OBD. It is a highly

00:11:44.049 --> 00:11:46.830
deliberate surgical pruning process. It doesn't

00:11:46.830 --> 00:11:49.370
just chop things off randomly. I would hope not

00:11:49.370 --> 00:11:51.710
with a name like that. The algorithm works in

00:11:51.710 --> 00:11:54.230
a loop until the model reaches the desired level

00:11:54.230 --> 00:11:57.820
of sparsity. First, you train the network until

00:11:57.820 --> 00:12:01.100
you have a reasonable solution. Second, you compute

00:12:01.100 --> 00:12:03.799
what are called the saliencies for each individual

00:12:03.799 --> 00:12:06.899
parameter in the network. Saliencies, meaning

00:12:06.899 --> 00:12:09.779
importance. Roughly, yes. Finally, you delete

00:12:09.779 --> 00:12:12.039
the parameters with the lowest saliency, which

00:12:12.039 --> 00:12:15.259
permanently fixes their value to zero. So what

00:12:15.259 --> 00:12:17.519
does this all mean? It sounds exactly like playing

00:12:17.519 --> 00:12:20.190
a game of Jenga. with a neural network. Jenga.

00:12:20.389 --> 00:12:21.629
Oh, I see where you're going with this. Yeah,

00:12:21.710 --> 00:12:23.789
you have this massive tower of blocks, the parameters.

00:12:24.210 --> 00:12:26.169
The network is trained. The tower is standing

00:12:26.169 --> 00:12:28.629
perfectly tall, but you need to make the tower

00:12:28.629 --> 00:12:31.120
lighter to fit on a phone. You can't just swipe

00:12:31.120 --> 00:12:32.659
your hand through the middle of it. You have

00:12:32.659 --> 00:12:35.120
to poke at individual blocks to see which ones

00:12:35.120 --> 00:12:37.720
are truly structural and which ones are just

00:12:37.720 --> 00:12:40.019
resting there doing nothing. If a block is loose,

00:12:40.059 --> 00:12:41.940
it has low saliency and you pull it out. And

00:12:41.940 --> 00:12:44.059
if a block is load -bearing, it has high saliency

00:12:44.059 --> 00:12:46.460
and you have to leave it alone. Exactly. Well,

00:12:46.460 --> 00:12:48.460
let's follow the physics of your Jenga tower

00:12:48.460 --> 00:12:52.860
into the actual math. In OBD, we calculate that

00:12:52.860 --> 00:12:55.620
load -bearing saliency by analyzing the loss

00:12:55.620 --> 00:12:59.190
function. The idea is to approximate the loss

00:12:59.190 --> 00:13:01.409
function in the immediate neighborhood of an

00:13:01.409 --> 00:13:04.269
optimal parameter using a Taylor expansion. OK,

00:13:04.350 --> 00:13:06.950
keep the math accessible for me here. What is

00:13:06.950 --> 00:13:09.649
a Taylor expansion actually doing to our Jenga

00:13:09.649 --> 00:13:12.870
tower? Fair question. A Taylor expansion lets

00:13:12.870 --> 00:13:15.129
us examine the local math around a parameter

00:13:15.129 --> 00:13:17.850
without having to fully compute the entire network's

00:13:17.850 --> 00:13:20.879
failure. It is the mathematical equivalent of

00:13:20.879 --> 00:13:23.279
lightly tapping a Jenga block with your finger

00:13:23.279 --> 00:13:25.860
to measure its friction without actually pulling

00:13:25.860 --> 00:13:27.600
it out and waiting to see if the whole tower

00:13:27.600 --> 00:13:30.340
crashes. That makes total sense. Because the

00:13:30.340 --> 00:13:32.519
model is already trained to an optimal point,

00:13:32.980 --> 00:13:35.320
the first derivative of the loss function is

00:13:35.320 --> 00:13:38.399
roughly zero. Meaning the slope is flat. The

00:13:38.399 --> 00:13:40.899
tower is perfectly balanced. Exactly. The first

00:13:40.899 --> 00:13:42.740
derivative doesn't tell us what will happen if

00:13:42.740 --> 00:13:44.919
we change things because everything is currently

00:13:44.919 --> 00:13:48.879
stable. So OBD uses a technique called second

00:13:48.879 --> 00:13:51.799
-order backpropagation to find the second derivative

00:13:51.799 --> 00:13:54.419
of the loss. Okay, second derivative. If the

00:13:54.419 --> 00:13:56.899
first derivative is the slope, the second derivative

00:13:56.899 --> 00:13:59.940
is the curvature. It tells us whether our parameter

00:13:59.940 --> 00:14:03.440
is sitting in a gentle wide bowl or at the bottom

00:14:03.440 --> 00:14:06.659
of a sharp steep ravine. And if it's in a gentle

00:14:06.659 --> 00:14:09.360
bowl, removing that parameter, pulling that block,

00:14:09.759 --> 00:14:12.360
doesn't change the overall shape of the network's

00:14:12.360 --> 00:14:14.480
error very much. Right. But if it's in a sharp

00:14:14.480 --> 00:14:17.679
ravine, pulling it causes the error to skyrocket.

00:14:17.779 --> 00:14:20.120
You've got it. The second derivative mathematically

00:14:20.120 --> 00:14:23.059
calculates exactly how much the error would increase

00:14:23.059 --> 00:14:25.980
if that specific parameter is deleted. That's

00:14:25.980 --> 00:14:28.460
incredibly precise. It is. And to save on precious

00:14:28.460 --> 00:14:31.179
computing power, the OBD algorithm ignores the

00:14:31.179 --> 00:14:33.110
cross derivatives. Cross derivatives. which would

00:14:33.110 --> 00:14:35.549
be calculating how two different blocks might

00:14:35.549 --> 00:14:37.429
interact if you pulled them both at the same

00:14:37.429 --> 00:14:40.210
exact time. By focusing only on the individual

00:14:40.210 --> 00:14:43.789
parameter, the saliency becomes a direct, highly

00:14:43.789 --> 00:14:46.389
accurate approximation of the penalty for deleting

00:14:46.389 --> 00:14:49.090
it. I find this entirely fascinating because

00:14:49.090 --> 00:14:53.110
we are discussing elegant, complex mathematical

00:14:53.110 --> 00:14:55.970
solutions. Temperature scaling, second -order

00:14:55.970 --> 00:14:59.309
backpropagation, Taylor expansions, and it all

00:14:59.309 --> 00:15:02.169
sounds like cutting -edge magic designed specifically

00:15:02.169 --> 00:15:04.610
for the smartphone era. It does feel very modern.

00:15:04.730 --> 00:15:06.309
It feels like something invented five minutes

00:15:06.309 --> 00:15:09.379
ago in Silicon Valley. Looking at our source

00:15:09.379 --> 00:15:12.220
material, the foundational ideas behind compressing

00:15:12.220 --> 00:15:15.120
and teaching networks have a surprisingly deep

00:15:15.120 --> 00:15:17.460
and rich history. Oh, absolutely. The history

00:15:17.460 --> 00:15:20.039
of AI development is rarely a timeline of sudden,

00:15:20.179 --> 00:15:23.519
spontaneous breakthroughs. It is a slow, methodical

00:15:23.519 --> 00:15:26.500
layering of ideas over decades. Decades. Yeah.

00:15:26.519 --> 00:15:28.539
The very first instance of model compression

00:15:28.539 --> 00:15:30.740
or pruning didn't happen in the smartphone era

00:15:30.740 --> 00:15:34.379
at all. It happened in the USSR in 1965. 1965.

00:15:35.080 --> 00:15:37.659
The computing power available then was practically

00:15:37.659 --> 00:15:40.899
not non -existent by today's standards. A computer

00:15:40.899 --> 00:15:43.220
took up a whole room and had less memory than

00:15:43.220 --> 00:15:46.139
a digital watch. Why were they trying to prune

00:15:46.139 --> 00:15:48.539
neural networks back then? Precisely because

00:15:48.539 --> 00:15:51.220
computing power was so scarce. They had absolutely

00:15:51.220 --> 00:15:55.090
no choice. Researchers Alexei Evakenko and Valentin

00:15:55.090 --> 00:15:57.269
Lapa were developing what became known as the

00:15:57.269 --> 00:16:00.450
group method of data handling. They were training

00:16:00.450 --> 00:16:03.250
deep networks layer by layer through regression

00:16:03.250 --> 00:16:06.309
analysis. And because they physically could not

00:16:06.309 --> 00:16:08.929
store massive inefficient models, they would

00:16:08.929 --> 00:16:11.549
evaluate the hidden units during training. And

00:16:11.549 --> 00:16:13.850
any unit that was superfluous was immediately

00:16:13.850 --> 00:16:16.710
pruned away using a separate validation data

00:16:16.710 --> 00:16:18.870
set. Wow, they were forced to be highly efficient.

00:16:19.129 --> 00:16:21.009
Necessity driving invention. And those concepts

00:16:21.009 --> 00:16:23.129
just kept evolving. The source notes that in

00:16:23.129 --> 00:16:25.470
1988 we got algorithms like bias -weight decay,

00:16:25.850 --> 00:16:28.250
and in 1989 that optimal brain damage algorithm

00:16:28.250 --> 00:16:30.750
we just broke down was formally introduced by

00:16:30.750 --> 00:16:33.049
Jan Lacoon and his colleagues. But when we hit

00:16:33.049 --> 00:16:35.210
1991, here's where it gets really interesting.

00:16:35.590 --> 00:16:37.629
The history takes a turn that sounds like pure

00:16:37.629 --> 00:16:40.250
science fiction. Ah! You are looking at Juergen

00:16:40.250 --> 00:16:42.850
Schmidt Huber's work with recurrent neural networks,

00:16:43.149 --> 00:16:46.750
or... RNNs. Yes. The underlying problem he was

00:16:46.750 --> 00:16:49.230
tackling was sequence prediction for long sequences

00:16:49.230 --> 00:16:53.070
of data. To solve it, Schmidhuber used two entirely

00:16:53.070 --> 00:16:56.090
separate networks. Right. He called one the automatizer,

00:16:56.350 --> 00:16:57.950
which was tasked with predicting the sequence.

00:16:58.409 --> 00:17:00.669
The other he called the chunker, which had the

00:17:00.669 --> 00:17:02.850
job of predicting the errors the automatizer

00:17:02.850 --> 00:17:05.930
was making. But the sci -fi twist is that simultaneously

00:17:05.930 --> 00:17:08.589
the automatizer is actively trying to predict

00:17:08.589 --> 00:17:11.910
the chunker's internal states. It creates a closed,

00:17:12.309 --> 00:17:14.769
highly complex loop of self -correction. It's

00:17:14.769 --> 00:17:17.490
like having two people driving a car. One person

00:17:17.490 --> 00:17:19.329
is steering and the other person is just watching

00:17:19.329 --> 00:17:21.769
for potholes. Every time there's a pothole, the

00:17:21.769 --> 00:17:24.569
watcher yells out. Okay, I follow. But eventually,

00:17:24.809 --> 00:17:27.950
the person steering memorizes the watcher's behavior

00:17:27.950 --> 00:17:30.930
so perfectly that they learn to avoid the potholes

00:17:30.930 --> 00:17:34.069
before the watcher even has to yell. The automatizer

00:17:34.069 --> 00:17:36.930
learns the chunker's internal state so well that

00:17:36.930 --> 00:17:39.410
it starts fixing its own errors preemptively.

00:17:39.569 --> 00:17:41.490
That's a great analogy. And the chunker, the

00:17:41.490 --> 00:17:44.390
watcher falls asleep. It becomes completely obsolete.

00:17:44.529 --> 00:17:47.069
It's an incredible philosophical concept. One

00:17:47.069 --> 00:17:50.089
digital brain literally swallowing the function

00:17:50.089 --> 00:17:52.410
of another brain until it doesn't need it anymore.

00:17:53.210 --> 00:17:56.049
Only one RNN is left standing. It's a profound

00:17:56.049 --> 00:17:58.849
conceptual leap. And if we connect this to the

00:17:58.849 --> 00:18:01.609
bigger picture, the idea of one network explicitly

00:18:01.609 --> 00:18:04.670
observing and internalizing the behavior of another,

00:18:05.190 --> 00:18:08.170
that paved the way for the broader teacher -student

00:18:08.170 --> 00:18:10.490
network configurations we use today. It set the

00:18:10.490 --> 00:18:14.230
stage. Exactly. By 1992, researchers in statistical

00:18:14.230 --> 00:18:16.450
mechanics were heavily studying these teacher

00:18:16.450 --> 00:18:18.890
-student setups using theoretical frameworks

00:18:18.890 --> 00:18:21.250
like committee machines and parity machines.

00:18:21.730 --> 00:18:24.089
The idea was firmly taking root in the mathematical

00:18:24.089 --> 00:18:26.230
community. And it really accelerated from there.

00:18:26.490 --> 00:18:29.869
By 2006, researchers like Busuluking were doing

00:18:29.869 --> 00:18:32.470
what they explicitly termed model compression

00:18:32.470 --> 00:18:35.289
by training smaller models on massive amounts

00:18:35.289 --> 00:18:37.109
of pseudo data. Right, data that was labeled

00:18:37.109 --> 00:18:39.529
by a highly performing ensemble of larger models.

00:18:39.670 --> 00:18:41.829
And they optimized it specifically to match the

00:18:41.829 --> 00:18:44.390
logits, utilizing the very mechanisms we discussed

00:18:44.390 --> 00:18:46.839
earlier. Which brings the lineage right to the

00:18:46.839 --> 00:18:49.759
doorstep of the modern era. In 2015, Jeffrey

00:18:49.759 --> 00:18:52.720
Hinton, Oriol Vignoles, and Jeff Dean published

00:18:52.720 --> 00:18:55.119
a seminal pre -print paper that formally coined

00:18:55.119 --> 00:18:58.380
the term knowledge distillation. Ah, there's

00:18:58.380 --> 00:19:01.400
the term. Yep. They introduced the specific temperature

00:19:01.400 --> 00:19:04.319
math we explored earlier, demonstrating profound

00:19:04.319 --> 00:19:07.500
results in image classification tasks. They took

00:19:07.500 --> 00:19:10.859
decades of disparate theories from 1960s pruning

00:19:10.859 --> 00:19:13.680
to 90s teacher -student dynamics and unified

00:19:13.680 --> 00:19:17.230
them. Wow. That 2015 paper is largely responsible

00:19:17.230 --> 00:19:19.809
for the framework we rely on today to put these

00:19:19.809 --> 00:19:22.369
massive data center models onto edge devices.

00:19:22.730 --> 00:19:24.109
It is quite a journey. So let's bring it all

00:19:24.109 --> 00:19:26.250
together for you listening. We started with this

00:19:26.250 --> 00:19:29.710
massive real -world problem. AI models are simply

00:19:29.710 --> 00:19:31.589
too big for the pockets we want to keep them

00:19:31.589 --> 00:19:34.490
in. Way too big. Right. And we explored how the

00:19:34.490 --> 00:19:36.970
solution isn't just mindlessly chopping them

00:19:36.970 --> 00:19:38.970
up, but actually using them as highly nuanced

00:19:38.970 --> 00:19:41.740
teachers. By turning up the mathematical temperature,

00:19:42.240 --> 00:19:44.380
we force small student models to look past the

00:19:44.380 --> 00:19:47.500
obvious right answers and learn the soft, subtle,

00:19:47.640 --> 00:19:49.779
pseudo -likelihoods of the large models. The

00:19:49.779 --> 00:19:53.039
incredibly valuable wrong answers. Exactly. We

00:19:53.039 --> 00:19:55.519
looked at the surgical precision of optimal brain

00:19:55.519 --> 00:19:58.680
damage using local curvature and Taylor expansions

00:19:58.680 --> 00:20:01.519
to carefully pull the non -structural Jenga blocks

00:20:01.519 --> 00:20:05.319
out of a network. And we traced this incredible

00:20:05.319 --> 00:20:08.420
60 -year lineage from Soviet regression pruning

00:20:08.420 --> 00:20:12.299
in 1965 through Schmidhuber's brain -eating networks

00:20:12.299 --> 00:20:15.440
in the 90s, all the way to the sophisticated

00:20:15.440 --> 00:20:17.559
distillation running invisibly on your phone

00:20:17.559 --> 00:20:20.420
right now. It is a perfect example of the compounding

00:20:20.420 --> 00:20:22.859
nature of scientific discovery. However, before

00:20:22.859 --> 00:20:26.359
we finish, there is one final, deeply counterintuitive

00:20:26.359 --> 00:20:29.140
detail buried in the source text that flips the

00:20:29.140 --> 00:20:31.400
entire premise of our discussion completely on

00:20:31.400 --> 00:20:33.569
its head. Wait, flips it. We've been talking

00:20:33.569 --> 00:20:35.329
entirely about shrinking things down. What's

00:20:35.329 --> 00:20:37.849
the twist? Well, we have spent this whole time

00:20:37.849 --> 00:20:40.170
detailing how to take big, massive models and

00:20:40.170 --> 00:20:41.970
distill their knowledge downward into smaller,

00:20:42.250 --> 00:20:44.589
restricted models. But the text briefly notes

00:20:44.589 --> 00:20:47.269
a much less common, highly experimental technique

00:20:47.269 --> 00:20:49.869
called reverse knowledge distillation. Reverse

00:20:49.869 --> 00:20:52.170
distillation, meaning the knowledge is transferred

00:20:52.170 --> 00:20:54.880
from a smaller model up to a larger one. That

00:20:54.880 --> 00:20:58.160
is the mechanism. A massive, high -capacity model

00:20:58.160 --> 00:21:01.299
acts as the student to a restricted, low -capacity

00:21:01.299 --> 00:21:03.799
teacher. That is baffling. Right. Why would a

00:21:03.799 --> 00:21:06.759
massive, powerful, billion -parameter AI ever

00:21:06.759 --> 00:21:09.180
need to learn from a smaller, simpler one? That's

00:21:09.180 --> 00:21:11.960
a great question. Like, what kind of unique foundational

00:21:11.960 --> 00:21:15.420
rules does a restricted model possess that a

00:21:15.420 --> 00:21:18.880
sprawling supercomputer might actually miss precisely

00:21:18.880 --> 00:21:20.799
because it has too much capacity? It's a real

00:21:20.799 --> 00:21:23.059
paradox. It is a fascinating inversion of the

00:21:23.059 --> 00:21:25.220
teacher -student dynamic. And for you, as you

00:21:25.220 --> 00:21:27.339
put your phone back into your pocket today, that

00:21:27.339 --> 00:21:29.680
is the puzzle to ponder until our next deep dive.

00:21:30.319 --> 00:21:31.180
Thanks for learning with us.
