Fun write-up. Rookie question: what’s K-composition? I can’t find it on Google.
Also: the log_softmax part seems important but I don’t think I understand how it could happen. Derivative of log(1+x) is approximately 1 when x is small, regardless of x, and so should be stable so long as you’re not computing it by finite differences. Autograd should get you the right answer for any precision. That seems to be what I observe (no gradients at x=1e-8 are equal to that at x=0, which should be correct):
K-composition as a concept was introduced by Anthropic in their work on Transformer Circuits in the initial post. In general, the output of an attention head in an earlier layer can influence the query, key, or value computation of an attention head in a later layer.
K-composition refers to the case in which the key-computation is influenced. In a model without nonlinearities or layernorms you can do this simply by looking at how strongly the output matrix of head 1 and the key matrix of head 2 compose (or more precisely, by looking at the frobenius norm of the product relative to the product of the individual norms). I also tried to write a bit about it here.
Ah, log_softmax is not log(1+x) (that’s just an intermediate step in the calculation). You need to have large disparities in your numbers to get the error. Try log_softmax([0, 15]) vs log_softmax([0, 18])
Of course, my mistake! However, I still don’t see the problem. Trying those values, the gradients look fine whether I do them as float32s or float64s, for both pairs of numbers. The float32 and float64 gradients match out to the 7th decimal digit. Maybe I don’t understand what the problem is?
import torch
import torch.nn.functional as F
a = torch.tensor([0., 16.], requires_grad=True)
al = F.log_softmax(a, dim=-1)
print('Log softmax output for 16', al)
al[1].backward()
print('Log softmax grad for 16', a.grad)
b = torch.tensor([0., 17.], requires_grad=True)
bl = F.log_softmax(b, dim=-1)
print('Log softmax output for 17', bl)
bl[1].backward()
print('Log softmax grad for 17', b.grad)
a = torch.tensor([0., 16.], requires_grad=True, dtype=torch.float64)
al = F.log_softmax(a, dim=-1)
print('Log softmax output for 16', al)
al[1].backward()
print('Log softmax grad for 16', a.grad)
b = torch.tensor([0., 17.], requires_grad=True, dtype=torch.float64)
bl = F.log_softmax(b, dim=-1)
print('Log softmax output for 17', bl)
bl[1].backward()
print('Log softmax grad for 17', b.grad)
This outputs:
Log softmax output for 16 tensor([-1.6000e+01, -1.1921e-07], grad_fn=<LogSoftmaxBackward0>)
Log softmax grad for 16 tensor([-1.1254e-07, 1.1921e-07])
Log softmax output for 17 tensor([-17., 0.], grad_fn=<LogSoftmaxBackward0>)
Log softmax grad for 17 tensor([-4.1399e-08, 0.0000e+00])
Log softmax output for 16 tensor([-1.6000e+01, -1.1254e-07], dtype=torch.float64,
grad_fn=<LogSoftmaxBackward0>)
Log softmax grad for 16 tensor([-1.1254e-07, 1.1254e-07], dtype=torch.float64)
Log softmax output for 17 tensor([-1.7000e+01, -4.1399e-08], dtype=torch.float64,
grad_fn=<LogSoftmaxBackward0>)
Log softmax grad for 17 tensor([-4.1399e-08, 4.1399e-08], dtype=torch.float64)
Thanks for writing this out. The difference from mine is that you take the gradient of the second component while I took the gradient of the sum of the log_softmax outputs, which pushes the gradients towards +1 or −1 and hides the problem. I’m still confused how the large effects you see could come down to a difference of gradient = −4.1399e-08 versus 0. AdamW includes an ‘epsilon’ term in the denominator of (default) 1e-8, which means that I don’t see how this difference can change anything significantly. I assume you’re using the default epsilon value? I just don’t see how this can make such a difference.
Fun write-up. Rookie question: what’s K-composition? I can’t find it on Google.
Also: the log_softmax part seems important but I don’t think I understand how it could happen. Derivative of log(1+x) is approximately 1 when x is small, regardless of x, and so should be stable so long as you’re not computing it by finite differences. Autograd should get you the right answer for any precision. That seems to be what I observe (no gradients at x=1e-8 are equal to that at x=0, which should be correct):
K-composition as a concept was introduced by Anthropic in their work on Transformer Circuits in the initial post. In general, the output of an attention head in an earlier layer can influence the query, key, or value computation of an attention head in a later layer.
K-composition refers to the case in which the key-computation is influenced. In a model without nonlinearities or layernorms you can do this simply by looking at how strongly the output matrix of head 1 and the key matrix of head 2 compose (or more precisely, by looking at the frobenius norm of the product relative to the product of the individual norms). I also tried to write a bit about it here.
Ah, log_softmax is not log(1+x) (that’s just an intermediate step in the calculation). You need to have large disparities in your numbers to get the error. Try log_softmax([0, 15]) vs log_softmax([0, 18])
Of course, my mistake! However, I still don’t see the problem. Trying those values, the gradients look fine whether I do them as float32s or float64s, for both pairs of numbers. The float32 and float64 gradients match out to the 7th decimal digit. Maybe I don’t understand what the problem is?
Here’s a minimal counter-example:
This outputs:
Thanks for writing this out. The difference from mine is that you take the gradient of the second component while I took the gradient of the sum of the log_softmax outputs, which pushes the gradients towards +1 or −1 and hides the problem. I’m still confused how the large effects you see could come down to a difference of gradient = −4.1399e-08 versus 0. AdamW includes an ‘epsilon’ term in the denominator of (default) 1e-8, which means that I don’t see how this difference can change anything significantly. I assume you’re using the default epsilon value? I just don’t see how this can make such a difference.