Deriving `cross_entropy` loss
At recurse center, I've been implementing a language model from scratch by following along the Stanford CS336 class online.
Part of the first assignment implements cross-entropy loss from scratch. Even though I've seen cross-entropy quite a few different times in other material, I keep finding myself confused about its relation to other concepts like KL divergence and softmax, "log probs" and "nll loss" and how it's actually computed numerically.
I thought it'd be helpful for myself1 to write a bit about this to sharpen my intuition for the situation.
A basic understanding of loss functions, model training and softmax is probably prerequisite here, as I don't want to get too into the weeds coming from first principles. I found this video and this video do a good job of explaining how cross-entropy loss is defined.
what is cross-entropy loss for?
Cross-entropy loss is most commonly used as the loss function for training models (training loss) where we want the model to learn a distribution. For instance, in classification tasks ("what's in this image?"), cross-entropy is used to train the predictions the model is making ("35% likely this is a cat, 65% likely this is a bus") to more closely match what's actually in the image ("100% bus").
The training output label is usually formed as a one-hot vector. For instance, the one-hot for our image classification task might have a probability for each output label we might classify the image. For [cat_prob, bus_prob, house_prob, scarecrow, ...]
we might have:
- training data one-hot:
[0, 1, 0, ...]
- predicted probabilities before training:
[.01, .04, .02, ...]
- predicted probabilities after training:
[.01, .95, .001, ...]
LLMs aren't classification tasks, but similarly use one-hots during training, e.g. if the token sequence is "the quick brown fox" the training data one-hot might represent the token for "jumped".
deriving cross_entropy
for the one-hot case
Our goal here is to go from a somewhat impenetrable definition of cross-entropy and understand it enough to get to a working pytorch function we could actually run.
definitions
To warm us up, we first need a few definitions.
First, we use as the true distribution and for the predicted distribution. For our ML purposes, corresponds to the training data and corresponds to the predictions from our model. The subscript tells you we're looking at the probability at index .
We also have that is the entropy of the distribution. is the cross-entropy of those two distributions and is the entropy of itself. The technical definition for entropy is something about "how much information" is in the distribution, but I found that somewhat unhelpful to think about. We will soon realize that we don't need to worry about the definition for this one too much.
is the KL divergence between the two distributions. The KL divergence is a distance measure that tells us how different P and Q is. You could read this term as "the KL divergence from to " or better yet, "how surprised () you are by the prediction (), given the training data ()"
cross entropy
With those definitions in mind, we can now reveal what we're working with. Cross-entropy is defined as:
We're going to take this formula and break it down piece by piece, taking advantages of a few simplifications that are possible for our one-hot vector case. (Note that many of the assumptions we'll make do not apply if the target distribution is not one-hot!)
measuring cross-entropy loss as a proxy for KL divergence
To start out, let's label the terms:
Using the definitions we have above, the cross-entropy can be read as the "amount of difference between the two distributions" () plus "the inherent variation that is part of the training data" ().
eliminating the entropy term
You might be wondering how to calculate this term, but luckily we don't have to for long2.
For training loss, we can drop the term with a couple justifications:
- as a term never accounts for (literally, the parameters of don't take ), so varying by training our model will never impact . The term is just inherent to the training data itself. You could think of this as a floor for the amount of loss we'll see for training.
- For one-hot vectors specifically, is going to always be anyways, i.e. an image will always be labeled "this picture is 100% a cat", not "this picture is 90% a cat, 10% a bus" (assuming that catbus is not one of our output classifications).
Having convinced ourselves of dropping , we now have:
We can now see that when we use cross-entropy loss, we're really just using KL divergence3.
I like this understanding of cross-entropy loss much better. If we talk about KL divergence, a measure of distance, it makes some intuitive sense why you'd want to minimize the distance between the model predictions and training data. Cross-entropy in my mind is a bit more hand-wavy with something about information theory and bits that makes the picture less fuzzy.
calculating kl divergence
Now we know that cross-entropy loss simplifies down to KL divergence4. Let's look at its definition5:
In other words, we sum over the distribution, looking at the difference (the log ratio term) between the true and predicted probabilities for each, weighted by the true probability (the term).
For our one-hot case, we can simplify this equation by dropping the summation and terms6. This is because the training data's is always going to be a one-hot vector, so is going to be 0 everywhere except the one true label where is , e.g. for the label where the image is a bus:
We can remove all of the terms where P(i) is 0, simplifying down to:
manipulating logarithms
From here, if we remember our quotient rule (), we can continue to break our formula down:
Since for a one-hot vector is going to be , we can replace the term with , i.e. just :
negative log probability (NLL)
We can see that we've now arrived at negative log likelihood! As in literally, the cross-entropy function boils down to take the negative of the log of the likelihood (probability).
Calculating cross-entropy is the same as calculating KL divergence is the same thing as calculating NLL:
Though we see that these are the equivalent, I'll continue to use the term cross-entropy to be consistent.
writing it up in pytorch
Finally, we can take a stab at a pytorch function. Since we know that is the predictions from our model, here can be read as "the model output for the true class of training example".
import torch as t
from jaxtyping import Float
def cross_entropy(q: Float[t.Tensor, "d"], p_index: int) -> float:
return -t.log(q[p_index]).item()
where q
corresponds to , the output probabilities from our model.
numerical stability and floating point error
At this point, we could be pretty much done if it weren't for the reality of how floating point numbers are computed. If we were to open up the hood on our model, we'd see that output probabilities come from a operation. Softmax takes raw outputs from the model and shapes them into a probability distribution that adds up to 100%. It's defined as:
If we wrap this up in our cross entropy function, we get
which is a bit weird since we're taking the log of an exponent. Computers calculating log(exp())
might end up with large or small numbers that are difficult to represent in floating point.
We'll do some "numerical stability tricks" to make it less likely we accidentally overflow or underflow the value this composed calculation. Note that this is somewhat similar to what we might do symbolically by hand, i.e. first cancel terms that can be simplified before computing the value.
Again using our quotient rule, we can first break this up:
then rearrange and simplify the term:
or in pseudocode:
log(sum(exp(z)) - z[true_class_index]
Also note that we still have a log(sum(exp(z)))
which can't be reduced symbolically but goes in and out of the exponential range. Here, we can do a trick where we subtract the max from z
first to keep all of the values stably around 0
in the domain7.
Putting everything together now, we can write:
def cross_entropy(logits: Float[t.Tensor, "d"], p_index: int) -> float:
log_probs = log_softmax(logits)
return -log_probs[p_index].item()
where log_softmax
is:
def log_softmax(logits: Float[t.Tensor, "d"]) -> t.Tensor:
m = logits.max()
v = logits - m
return v - t.log(v.exp().sum())
Note that we now have to take the pre-softmax'd output, usually called a logit, to let us simplify the expression8. This mismatch (i.e. why doesn't the cross_entropy
function take the raw probabilities instead of the logits?) always felt a bit confusing to me. We can now see that it's necessary to allow for these stability tricks.
Finally, we still have t.log(v.exp().sum())
which is still not ideal. We can do the logsumexp trick here which again involves subtracting and re-adding the max to keep the values in a good range9:
def log_softmax(logits: Float[t.Tensor, "d"]) -> t.Tensor:
m = logits.max()
v = logits - m
return v - logsumexp(v)
def logsumexp(x: Float[t.Tensor, "d"]) -> t.Tensor:
m = logits.max()
return m + t.log((x - m).exp().sum())
bringing in the batch dimension
If we want to run this function across multiple batches at once, we lastly need to do some dimensional wrangling:
def cross_entropy(logits: Float[t.Tensor, "b d"], p_index: Int[t.Tensor, "b"]) -> t.Tensor:
log_probs = log_softmax(logits, dim=-1)
return -log_probs.gather(dim=-1, index= p_index.unsqueeze(-1)).squeeze(-1).mean()
def log_softmax(logits: Float[t.Tensor, "b d"], dim: int):
m = logits.max(dim=dim, keepdim=True).values
v = logits - m
return v - logsumexp(v, dim=dim, keepdim=True)
def logsumexp(x: t.Tensor, dim: int) -> t.Tensor:
m = x.max(dim=dim, keepdim=True).values
return m + t.log((x - m).exp().sum(dim=dim, keepdim=True))
Note that we take the mean loss over the batches, since we're trying to get a single loss number across all batches.
And here we are! As a quick summary, we've shown:
- how cross-entropy loss and KL divergence are related
- where "negative log likelihood" comes from
- why the
cross_entropy
function takes the pre-softmax'd logits instead of the output probabilities directly
The true implementation for cross_entropy
in pytorch is a bit more complicated because it also deals with the case where the training distribution isn't a one-hot vector.
footnotes
As an editorial note, a friend pointed out to me that this post is written more for myself than for an external audience. I generally agree and think I will mark these as "notes on..." in the title.↩
If you do find yourself interested in what's going on with this term, this video is helpful, but I think best watched after having a clear sense of KL divergence first↩
Why do we use the cross-entropy nomenclature? This was not clear to me, since it seems like optimizer will never be interested in the entropy term for training loss.↩
That is, only for our one-hot case!↩
Again, I won't go into details but the linked videos in the intro are helpful for deriving these.↩
If our training data is not one-hot, we'd still need the sum and terms]↩
This is because , i.e. adding or subtracting any constant to the exponentiated term is symbolically equivalent.↩
this is also implemented in pytorch as
t.log_softmax()
.↩This is also implemented in pytorch as
t.logsumexp()
.↩