How is this translational symmetry measure checking for the translational symmetry of the circuit? QK, for example, is being used as a bilinear form, so it’s not clear to me, for example, what the “difference in the values” is mapping onto here (since I think these “numbers” are actually corresponding to unique embeddings). More broadly, do you have a good sense of how to interpret these bilinear forms? There is clearly a lot of structure in the standard weight basis in these pictures, and I’m not sure exactly what it means. I’m guessing you can see that some sections are rather empty corresponding to the “model learns to specialize on certain parts of the vocabulary for xyz head” being potentially associated with some sort of one-hot or generally standard-basis-privilege situation. Let me know if I’m misunderstanding something dumb. I haven’t seen this being done much elsewhere, but it would be nice to have a github because it’s really easy to resolve these questions by reading pytorch code. Is it available somewhere?
One other thing I’m curious about is results for more control experiments. For example, for the noise, if you fully noised the output (i.e. output a random permutation) we should expect the model to fail to learn anything at all and to fail to get a high LLC right? It’s also possible to noise by inserting new elements in the output (or input… I guess it’s equivalent) to replace others, but keeping the order the same. In this case, maybe the network can learn to understand what the ordering is even if it doesn’t know exactly which outputs will be there in the end, so even with very high amounts of noise a “structured” solution makes sense (though I reckon the way you propagate loss will matter in this case).
The code is currently not public. We intend to make it public once we have finished a few more projects with the same codebase. One of the things we would like to look at is varying the amount of noise. I don’t have great intuitions for what the loss landscape of a model trained on a finite random dataset will look like.
As to the translational symmetry of the circuits, the measure just sums the absolute difference between adjacent elements parallel to the diagonal, does the same for elements perpendicular to the diagonal and takes the difference of the two sums. The intuition behind this is that if the circuit has translational symmetry, the relationship between vocabulary element i and j would be the same as the relationship i+1 and j+1. We subtract the lines perpendicular to the diagonal to avoid our measure becoming very large for a circuit that is just very uniform in all directions. We expect the circuits to have translational symmetry because we expect the sorting to work the same across all the vocabulary (except for the first and the last vocabulary). If you compare two numbers a and b for the purpose of sorting, the only thing that should matter is the difference between a and b, not their absolute scale. When a circuit for instance does something like “vocabulary elements attends to the smallest number larger than itself”, that should only depend on the relationship difference between itself and all the numbers, not on their overall magnitude. I do agree that our translational symmetry measure is somewhat arbitrary, and that we instead could have looked at the standard deviation of lines parallel and perpendicular to the diagonal, or something like that. I expect that the outcome would have been largely the same.
As to how to interpret the circuits, Callum goes into some more detail on interpreting the final form of the baseline 2-head model here (select [October] Solutions in the menu on the left).
How is this translational symmetry measure checking for the translational symmetry of the circuit? QK, for example, is being used as a bilinear form, so it’s not clear to me, for example, what the “difference in the values” is mapping onto here (since I think these “numbers” are actually corresponding to unique embeddings). More broadly, do you have a good sense of how to interpret these bilinear forms? There is clearly a lot of structure in the standard weight basis in these pictures, and I’m not sure exactly what it means. I’m guessing you can see that some sections are rather empty corresponding to the “model learns to specialize on certain parts of the vocabulary for xyz head” being potentially associated with some sort of one-hot or generally standard-basis-privilege situation. Let me know if I’m misunderstanding something dumb. I haven’t seen this being done much elsewhere, but it would be nice to have a github because it’s really easy to resolve these questions by reading pytorch code. Is it available somewhere?
One other thing I’m curious about is results for more control experiments. For example, for the noise, if you fully noised the output (i.e. output a random permutation) we should expect the model to fail to learn anything at all and to fail to get a high LLC right? It’s also possible to noise by inserting new elements in the output (or input… I guess it’s equivalent) to replace others, but keeping the order the same. In this case, maybe the network can learn to understand what the ordering is even if it doesn’t know exactly which outputs will be there in the end, so even with very high amounts of noise a “structured” solution makes sense (though I reckon the way you propagate loss will matter in this case).
The code is currently not public. We intend to make it public once we have finished a few more projects with the same codebase. One of the things we would like to look at is varying the amount of noise. I don’t have great intuitions for what the loss landscape of a model trained on a finite random dataset will look like.
As to the translational symmetry of the circuits, the measure just sums the absolute difference between adjacent elements parallel to the diagonal, does the same for elements perpendicular to the diagonal and takes the difference of the two sums. The intuition behind this is that if the circuit has translational symmetry, the relationship between vocabulary element i and j would be the same as the relationship i+1 and j+1. We subtract the lines perpendicular to the diagonal to avoid our measure becoming very large for a circuit that is just very uniform in all directions. We expect the circuits to have translational symmetry because we expect the sorting to work the same across all the vocabulary (except for the first and the last vocabulary). If you compare two numbers a and b for the purpose of sorting, the only thing that should matter is the difference between a and b, not their absolute scale. When a circuit for instance does something like “vocabulary elements attends to the smallest number larger than itself”, that should only depend on the relationship difference between itself and all the numbers, not on their overall magnitude. I do agree that our translational symmetry measure is somewhat arbitrary, and that we instead could have looked at the standard deviation of lines parallel and perpendicular to the diagonal, or something like that. I expect that the outcome would have been largely the same.
As to how to interpret the circuits, Callum goes into some more detail on interpreting the final form of the baseline 2-head model here (select [October] Solutions in the menu on the left).