stephen wan

Deriving `cross_entropy` loss

At recurse center, I've been implementing a language model from scratch by following along the Stanford CS336 class online.

Part of the first assignment implements cross-entropy loss from scratch. Even though I've seen cross-entropy quite a few different times in other material, I keep finding myself confused about its relation to other concepts like KL divergence and softmax, "log probs" and "nll loss" and how it's actually computed numerically.

I thought it'd be helpful for myself1 to write a bit about this to sharpen my intuition for the situation.

A basic understanding of loss functions, model training and softmax is probably prerequisite here, as I don't want to get too into the weeds coming from first principles. I found this video and this video do a good job of explaining how cross-entropy loss is defined.

what is cross-entropy loss for?

Cross-entropy loss is most commonly used as the loss function for training models (training loss) where we want the model to learn a distribution. For instance, in classification tasks ("what's in this image?"), cross-entropy is used to train the predictions the model is making ("35% likely this is a cat, 65% likely this is a bus") to more closely match what's actually in the image ("100% bus").

The training output label is usually formed as a one-hot vector. For instance, the one-hot for our image classification task might have a probability for each output label we might classify the image. For [cat_prob, bus_prob, house_prob, scarecrow, ...] we might have:

LLMs aren't classification tasks, but similarly use one-hots during training, e.g. if the token sequence is "the quick brown fox" the training data one-hot might represent the token for "jumped".

deriving cross_entropy for the one-hot case

Our goal here is to go from a somewhat impenetrable definition of cross-entropy and understand it enough to get to a working pytorch function we could actually run.

definitions

To warm us up, we first need a few definitions.

First, we use P as the true distribution and Q for the predicted distribution. For our ML purposes, P corresponds to the training data and Q corresponds to the predictions from our model. The subscript i tells you we're looking at the probability P at index i.

We also have H() that is the entropy of the distribution. H(P,Q) is the cross-entropy of those two distributions and H(P) is the entropy of P itself. The technical definition for entropy is something about "how much information" is in the distribution, but I found that somewhat unhelpful to think about. We will soon realize that we don't need to worry about the definition for this one too much.

KL(PQ) is the KL divergence between the two distributions. The KL divergence is a distance measure that tells us how different P and Q is. You could read this term as "the KL divergence from P to Q" or better yet, "how surprised (KL) you are by the prediction (Q), given the training data (P)"

cross entropy

With those definitions in mind, we can now reveal what we're working with. Cross-entropy is defined as:

H(P,Q)cross entropy=KL(PQ)+H(P)

We're going to take this formula and break it down piece by piece, taking advantages of a few simplifications that are possible for our one-hot vector case. (Note that many of the assumptions we'll make do not apply if the target distribution is not one-hot!)

measuring cross-entropy loss as a proxy for KL divergence

To start out, let's label the terms:

H(P,Q)cross entropy=KL(PQ)KL divergence+H(P)inherent entropy

Using the definitions we have above, the cross-entropy can be read as the "amount of difference between the two distributions" (KL(PQ)) plus "the inherent variation that is part of the training data" (H(P)).

eliminating the H(P) entropy term

You might be wondering how to calculate this H(P) term, but luckily we don't have to for long2.

For training loss, we can drop the H(P) term with a couple justifications:

Having convinced ourselves of dropping H(P), we now have:

H(P,Q)cross entropy=KL(PQ)KL divergence

We can now see that when we use cross-entropy loss, we're really just using KL divergence3.

I like this understanding of cross-entropy loss much better. If we talk about KL divergence, a measure of distance, it makes some intuitive sense why you'd want to minimize the distance between the model predictions and training data. Cross-entropy in my mind is a bit more hand-wavy with something about information theory and bits that makes the picture less fuzzy.

calculating kl divergence

Now we know that cross-entropy loss simplifies down to KL divergence4. Let's look at its definition5:

DKL(PQ)KL divergence=iP(i)logP(i)Q(i)

In other words, we sum over the distribution, looking at the difference (the log ratio term) between the true and predicted probabilities for each, weighted by the true probability (the P(i) term).

For our one-hot case, we can simplify this equation by dropping the summation and P(i) terms6. This is because the training data's P is always going to be a one-hot vector, so P(i) is going to be 0 everywhere except the one true label where P(i) is 1, e.g. for the label where the image is a bus:

DKL(PQ)=0*logP(0)Q(0)cat prob+1*logP(1)Q(1)bus prob+0*logP(2)Q(2)scarecrow prob+...

We can remove all of the terms where P(i) is 0, simplifying down to:

DKL(PQ)KL divergence=logP(i)Q(i)term for true label only

manipulating logarithms

From here, if we remember our quotient rule (log(a/b)=log(a)log(b)), we can continue to break our formula down:

DKL(PQ)KL divergence=log(P(i))log(Q(i))

Since P(i) for a one-hot vector is going to be 1, we can replace the log(P(i)) term with log(1), i.e. just 0:

DKL(PQ)KL divergence=CE=0log(Q(i))

negative log probability (NLL)

We can see that we've now arrived at negative log likelihood! As in literally, the cross-entropy function boils down to take the negative of the log of the likelihood (probability).

Calculating cross-entropy is the same as calculating KL divergence is the same thing as calculating NLL:

NLL=CE=KL Divergence=log(Q(true class))

Though we see that these are the equivalent, I'll continue to use the term cross-entropy to be consistent.

writing it up in pytorch

Finally, we can take a stab at a pytorch function. Since we know that Q is the predictions from our model, Q(true class) here can be read as "the model output for the true class of training example".

import torch as t
from jaxtyping import Float

def cross_entropy(q: Float[t.Tensor, "d"], p_index: int) -> float:
    return -t.log(q[p_index]).item()

where q corresponds to Q(i), the output probabilities from our model.

numerical stability and floating point error

At this point, we could be pretty much done if it weren't for the reality of how floating point numbers are computed. If we were to open up the hood on our model, we'd see that output probabilities come from a softmax() operation. Softmax takes raw outputs from the model and shapes them into a probability distribution that adds up to 100%. It's defined as:

softmax(zi)=eziez

If we wrap this up in our cross entropy function, we get

CE=log(softmax(zi))=log(eziez)

which is a bit weird since we're taking the log of an exponent. Computers calculating log(exp()) might end up with large or small numbers that are difficult to represent in floating point.

We'll do some "numerical stability tricks" to make it less likely we accidentally overflow or underflow the value this composed calculation. Note that this is somewhat similar to what we might do symbolically by hand, i.e. first cancel terms that can be simplified before computing the value.

Again using our quotient rule, we can first break this up:

CE=log(ezi)+log(ez)

then rearrange and simplify the log(ez) term:

CE=log(ez)zi

or in pseudocode:

log(sum(exp(z)) - z[true_class_index]

Also note that we still have a log(sum(exp(z))) which can't be reduced symbolically but goes in and out of the exponential range. Here, we can do a trick where we subtract the max from z first to keep all of the values stably around 0 in the domain7.

Putting everything together now, we can write:

def cross_entropy(logits: Float[t.Tensor, "d"], p_index: int) -> float:
   log_probs = log_softmax(logits)
   return -log_probs[p_index].item()

where log_softmax is:

def log_softmax(logits: Float[t.Tensor, "d"]) -> t.Tensor:
   m = logits.max()
   v = logits - m
   return v - t.log(v.exp().sum())

Note that we now have to take the pre-softmax'd output, usually called a logit, to let us simplify the expression8. This mismatch (i.e. why doesn't the cross_entropy function take the raw probabilities instead of the logits?) always felt a bit confusing to me. We can now see that it's necessary to allow for these stability tricks.

Finally, we still have t.log(v.exp().sum()) which is still not ideal. We can do the logsumexp trick here which again involves subtracting and re-adding the max to keep the values in a good range9:

def log_softmax(logits: Float[t.Tensor, "d"]) -> t.Tensor:
   m = logits.max()
   v = logits - m
   return v - logsumexp(v)

def logsumexp(x: Float[t.Tensor, "d"]) -> t.Tensor:
   m = logits.max()
   return m + t.log((x - m).exp().sum())

bringing in the batch dimension

If we want to run this function across multiple batches at once, we lastly need to do some dimensional wrangling:

def cross_entropy(logits: Float[t.Tensor, "b d"], p_index: Int[t.Tensor, "b"]) -> t.Tensor:
    log_probs = log_softmax(logits, dim=-1)
    return -log_probs.gather(dim=-1, index= p_index.unsqueeze(-1)).squeeze(-1).mean()

def log_softmax(logits: Float[t.Tensor, "b d"], dim: int):
    m = logits.max(dim=dim, keepdim=True).values
    v = logits - m
    return v - logsumexp(v, dim=dim, keepdim=True)

def logsumexp(x: t.Tensor, dim: int) -> t.Tensor:
    m = x.max(dim=dim, keepdim=True).values
    return m + t.log((x - m).exp().sum(dim=dim, keepdim=True))

Note that we take the mean loss over the batches, since we're trying to get a single loss number across all batches.

And here we are! As a quick summary, we've shown:

The true implementation for cross_entropy in pytorch is a bit more complicated because it also deals with the case where the training distribution isn't a one-hot vector.

footnotes

  1. As an editorial note, a friend pointed out to me that this post is written more for myself than for an external audience. I generally agree and think I will mark these as "notes on..." in the title.

  2. If you do find yourself interested in what's going on with this term, this video is helpful, but I think best watched after having a clear sense of KL divergence first

  3. Why do we use the cross-entropy nomenclature? This was not clear to me, since it seems like optimizer will never be interested in the H(P) entropy term for training loss.

  4. That is, only for our one-hot case!

  5. Again, I won't go into details but the linked videos in the intro are helpful for deriving these.

  6. If our training data is not one-hot, we'd still need the sum and P(x) terms]

  7. This is because ezezi=ezcezic=ezeceziec=ezezi, i.e. adding or subtracting any constant to the exponentiated term is symbolically equivalent.

  8. this is also implemented in pytorch as t.log_softmax().

  9. This is also implemented in pytorch as t.logsumexp().

#article