This is very neat. I definitely agree that I find the discontinuity from the first transformer block surprising. One thing which occurred to me that might be interesting to do is to try and train a linear model to reconstitute the input from the activations at different layers to get an idea of how the model is encoding the input. You could either train one linear model on data randomly sampled from different layers, or a separate linear model for each layer, and then see if there are any interesting patterns like whether the accuracy increases or decreases as you get further into the model. You could also see if the resulting matrix has any relationship to the embedding matrix (e.g. are the two matrices farther apart or closer together than would be expected by chance?). One possible hypothesis that this might let you test is whether the information about the input is being stored indirectly via what the model’s guess is given that input or whether it’s just being stored in parts of the embedding space that aren’t very relevant to the output (if it’s the latter, the linear model should put a lot of weight on basis elements that have very little weight in the embedding matrix).
One thing which occurred to me that might be interesting to do is to try and train a linear model to reconstitute the input from the activations at different layers to get an idea of how the model is encoding the input. You could either train one linear model on data randomly sampled from different layers, or a separate linear model for each layer, and then see if there are any interesting patterns like whether the accuracy increases or decreases as you get further into the model.
That’s a great idea!
One possible hypothesis that this might let you test is whether the information about the input is being stored indirectly via what the model’s guess is given that input or whether it’s just being stored in parts of the embedding space that aren’t very relevant to the output (if it’s the latter, the linear model should put a lot of weight on basis elements that have very little weight in the embedding matrix).
Hmm… I guess there is some reason to think the basis elements have special meaning (as opposed to the elements of any other basis for the same space), since the layer norm step operates in this basis.
But I doubt there are actually individual components the embedding cares little about, as that seems wasteful (you want to compress 50K into 1600 as well as you possibly can), and if the embedding cares about them even a little bit then the model needs to slot in the appropriate predictive information, eventually.
Thinking out loud, I imagine there might be pattern where embeddings of unlikely tokens (given the context) are repurposed in the middle for computation (you know they’re near-impossible so you don’t need to track them closely), and then smoothly subtracted out at the end. There’s probably a way to check if that’s happening.
Thanks! I’d be quite excited to know what you find if you end up trying it.
Hmm… I guess there is some reason to think the basis elements have special meaning (as opposed to the elements of any other basis for the same space), since the layer norm step operates in this basis.
But I doubt there are actually individual components the embedding cares little about, as that seems wasteful (you want to compress 50K into 1600 as well as you possibly can), and if the embedding cares about them even a little bit then the model needs to slot in the appropriate predictive information, eventually.
Thinking out loud, I imagine there might be pattern where embeddings of unlikely tokens (given the context) are repurposed in the middle for computation (you know they’re near-impossible so you don’t need to track them closely), and then smoothly subtracted out at the end. There’s probably a way to check if that’s happening.
I wasn’t thinking you would do this with the natural component basis—though it’s probably worth trying that also—but rather doing some sort of matrix decomposition on the embedding matrix to get a basis ordered by importance (e.g. using PCA or NMF—PCA is simpler though I know NMF is what OpenAI Clarity usually uses when they’re trying to extract interpretable basis elements from neural network activations) and then seeing what the linear model looks like in that basis. You could even just do something like what you’re saying and find some sort of basis ordered by the frequency of the tokens that each basis element corresponds to (though I’m not sure exactly what the right way would be to generate such a basis).
I also thought of PCA/SVD, but I imagine matrix decompositions like these would be misleading here.
What matters here (I think) is not some basis of N_emb orthogonal vectors in embedding space, but some much larger set of ~exp(N_emb) almost orthogonal vectors. We only have 1600 degrees of freedom to tune, but they’re continuous degrees of freedom, and this lets us express >>1600 distinct vectors in vocab space as long as we accept some small amount of reconstruction error.
I expect GPT and many other neural models are effectively working in such space of nearly orthogonal vectors, and picking/combining elements of it. A decomposition into orthogonal vectors won’t really illuminate this. I wish I knew more about this topic—are there standard techniques?
You might want to look into NMF, which, unlike PCA/SVD, doesn’t aim to create an orthogonal projection. It works well for interpretability because its components cannot cancel each other out, which makes its features more intuitive to reason about. I think it is essentially what you want, although I don’t think it will allow you to find directly the ‘larger set of almost orthogonal vectors’ you’re looking for.
This is very neat. I definitely agree that I find the discontinuity from the first transformer block surprising. One thing which occurred to me that might be interesting to do is to try and train a linear model to reconstitute the input from the activations at different layers to get an idea of how the model is encoding the input. You could either train one linear model on data randomly sampled from different layers, or a separate linear model for each layer, and then see if there are any interesting patterns like whether the accuracy increases or decreases as you get further into the model. You could also see if the resulting matrix has any relationship to the embedding matrix (e.g. are the two matrices farther apart or closer together than would be expected by chance?). One possible hypothesis that this might let you test is whether the information about the input is being stored indirectly via what the model’s guess is given that input or whether it’s just being stored in parts of the embedding space that aren’t very relevant to the output (if it’s the latter, the linear model should put a lot of weight on basis elements that have very little weight in the embedding matrix).
That’s a great idea!
Hmm… I guess there is some reason to think the basis elements have special meaning (as opposed to the elements of any other basis for the same space), since the layer norm step operates in this basis.
But I doubt there are actually individual components the embedding cares little about, as that seems wasteful (you want to compress 50K into 1600 as well as you possibly can), and if the embedding cares about them even a little bit then the model needs to slot in the appropriate predictive information, eventually.
Thinking out loud, I imagine there might be pattern where embeddings of unlikely tokens (given the context) are repurposed in the middle for computation (you know they’re near-impossible so you don’t need to track them closely), and then smoothly subtracted out at the end. There’s probably a way to check if that’s happening.
Thanks! I’d be quite excited to know what you find if you end up trying it.
I wasn’t thinking you would do this with the natural component basis—though it’s probably worth trying that also—but rather doing some sort of matrix decomposition on the embedding matrix to get a basis ordered by importance (e.g. using PCA or NMF—PCA is simpler though I know NMF is what OpenAI Clarity usually uses when they’re trying to extract interpretable basis elements from neural network activations) and then seeing what the linear model looks like in that basis. You could even just do something like what you’re saying and find some sort of basis ordered by the frequency of the tokens that each basis element corresponds to (though I’m not sure exactly what the right way would be to generate such a basis).
I also thought of PCA/SVD, but I imagine matrix decompositions like these would be misleading here.
What matters here (I think) is not some basis of N_emb orthogonal vectors in embedding space, but some much larger set of ~exp(N_emb) almost orthogonal vectors. We only have 1600 degrees of freedom to tune, but they’re continuous degrees of freedom, and this lets us express >>1600 distinct vectors in vocab space as long as we accept some small amount of reconstruction error.
I expect GPT and many other neural models are effectively working in such space of nearly orthogonal vectors, and picking/combining elements of it. A decomposition into orthogonal vectors won’t really illuminate this. I wish I knew more about this topic—are there standard techniques?
You might want to look into NMF, which, unlike PCA/SVD, doesn’t aim to create an orthogonal projection. It works well for interpretability because its components cannot cancel each other out, which makes its features more intuitive to reason about. I think it is essentially what you want, although I don’t think it will allow you to find directly the ‘larger set of almost orthogonal vectors’ you’re looking for.