After seeing this comment, if I were to re-write this post, maybe it would have been better to use the KL Divergence over the simple ΔCE metric that I used. I think they’re subtly different.
Per the TL implementation for CE, I’m calculating: CEj = 1N∑ilnpij where i is the batch dimension and j is context position.
So ΔCEj = 1N∑i(lnqij−lnpij) for pij the baseline probability and qij the patched probability.
I think it is the same. When training next-token predictors we model the ground truth probability distribution as having probability 1 for the actual next token and 0 for all other tokens in the vocab. This is how the cross-entropy loss simplifies to negative log likelihood. You can see that the transformer lens implementation doesn’t match the equation for cross entropy loss because it is using this simplification.
So the missing factor of p would just be 1 I think.
Oh! You’re right, thanks for walking me through that, I hadn’t appreciated that subtlety. Then in response to the first question: yep! ΔCE = KL Divergence.
After seeing this comment, if I were to re-write this post, maybe it would have been better to use the KL Divergence over the simple ΔCE metric that I used. I think they’re subtly different.
Per the TL implementation for CE, I’m calculating: CEj = 1N∑ilnpij where i is the batch dimension and j is context position.
So ΔCEj = 1N∑i(lnqij−lnpij) for pij the baseline probability and qij the patched probability.
So this is missing a factor of pij to be the true KL divergence.
I think it is the same. When training next-token predictors we model the ground truth probability distribution as having probability 1 for the actual next token and 0 for all other tokens in the vocab. This is how the cross-entropy loss simplifies to negative log likelihood. You can see that the transformer lens implementation doesn’t match the equation for cross entropy loss because it is using this simplification.
So the missing factor of p would just be 1 I think.
Oh! You’re right, thanks for walking me through that, I hadn’t appreciated that subtlety. Then in response to the first question: yep! ΔCE = KL Divergence.