Neuroscientist turned Interpretability Researcher. Starting Simplex, an AI Safety Research Org.
Adam Shai
Thanks! I’ll have more thorough results to share about layer-wise reprsentations of the MSP soon. I’ve already run some of the analysis concatenating over all layers residual streams with RRXOR process and it is quite interesting. It seems there’s a lot more to explore with the relationship between number of states in the generative model, number of layers in the transformer, residual stream dimension, and token vocab size. All of these (I think) play some role in how the MSP is represented in the transformer. For RRXOR it is the case that things look crisper when concatenating.
Even for cases where redundant info is discarded, we should be able to see the distinctions somewhere in the transformer. One thing I’m keen on really exploring is such a case, where we can very concretely follow the path/circuit through which redundant info is first distinguished and then is collapsed.
That is a fair summary.
Thanks!
one way to construct an HMM is by finding all past histories of tokens that condition the future tokens with the same probablity distribution, and make that equivalence class a hidden state in your HMM. Then the conditional distributions determine the arrows coming out of your state and which state you go to next. This is called the “epsilon machine” in Comp Mech, and it is unique. It is one presentation of the data generating process, but in general there are an infinite number of HMM presntations that would generate the same data. The epsilon machine is a particular type of HMM presentation—it is the smallest one where the hidden states are the minimal sufficient statistics for predicting the future based on the past. The epsilon machine is one of the most fundamental things in Comp Mech but I didn’t talk about it in this post. In the future we plan to make a more generic Comp Mech primer that will go through these and other concepts.
The interpretability of these simplexes is an issue that’s in my mind a lot these days. The short answer is I’m still wrestling with it. We have a rough experimental plan to go about studying this issue but for now, here are some related questions I have in my mind:
What is the relationship between the belief states in the simplex and what mech interp people call “features”?
What are the information theoretic aspects of natural language (or coding databases or some other interesting training data) that we can instantiate in toy models and then use our understanding of these toy systems to test if similar findings apply to real systems.
For something like situational awareness, I have the beginnings of a story in my head but it’s too handwavy to share right now. For something slightly more mundane like out-of-distribution generaliztion or transfer learning or abstraction, the idea would be to use our ability to formalize data-generating structure as HMMs, and then do theory and experiments on what it would mean for a transformer to understand that e.g. two HMMs have similar hidden/abstract structure but different vocabs.
Hopefully we’ll have a lot more to say about this kind of thing soon!
Oh wait one thing that looks not quite right is the initial distribution. Instead of starting randomly we begin with the optimal initial distribution, which is the steady-state distribution. Can be computed by finding the eigenvector of the transition matrix that has an eigenvalue of 1. Maybe in practice that doesn’t matter that much for mess3, but in general it could.
I should have explained this better in my post.
For every input into the transformer (of every length up to the context window length), we know the ground truth belief state that comp mech says an observer should have over the HMM states. In this case, this is 3 numbers. So for each input we have a 3d ground truth vector. Also, for each input we have the residual stream activation (in this case a 64D vector). To find the projection we just use standard Linear Regression (as implemented in sklearn) between the 64D residual stream vectors and the 3D (really 2D) ground truth vectors. Does that make sense?
Everything looks right to me! This is the annoying problem that people forget to write the actual parameters they used in their work (sorry).
Try x=0.05, alpha=0.85. I’ve edited the footnote with this info as well.
That sounds interesting. Do you have a link to the apperception paper?
That’s an interesting framing. From my perspective that is still just local next-token accuracy (cross-entropy more precisely), but averaged over all subsets of the data up to the context length. That is distinct from e.g. an objective function that explicitly mentioned not just next-token prediction, but multiple future tokens in what was needed to minimize loss. Does that distinction make sense?
One conceptual point I’d like to get across is that even though the equation for the predictive cross-entropy loss only has the next token at a given context window position in it, the states internal to the transformer have the information for predictions into the infinite future.
This is a slightly different issue than how one averages over training data, I think.
Thanks! I appreciate the critique. From this comment and from John’s it seems correct and I’ll keep it in mind for the future.
On the question, by optimize the representation do you mean causally intervene on the residual stream during inference (e.g. a patching experiment)? Or do you mean something else that involves backprop? If the first, then we haven’t tried, but definitely want to! It could be something someone does at the Hackathon, if interested ;)
Cool question. This is one of the things we’d like to explore more going forward. We are pretty sure this is pretty nuanced and has to do with the relationship between the (minimal) state of the generative model, the token vocab size, and the residual stream dimensionality.
One your last question, I believe so but one would have to do the experiment! It totally should be done. check out the Hackathon if you are interested ;)
this looks highly relevant! thanks!
Good catch! That should be eta_00, thanks! I’ll change it tomorrow.
Cool idea! I don’t know enough about GANs and their loss so I don’t have a prediction to report right now. If it is the case that GAN loss should really give generative and not predictive structure, this would be a super cool experiment.
The structure of generation for this particular process has just 3 points equidistant from eachother, no fractal. But in general the shape of generation is a pretty nuanced issue because it’s nontrivial to know for sure that you have the minimal structure of generation. There’s a lot more to say about this but @Paul Riechers knows these nuances more than I do so I will leave it to him!
Responding in reverse order:
If there’s literally a linear projection of the residual stream into two dimensions which directly produces that fractal, with no further processing/transformation in between “linear projection” and “fractal”, then I would change my mind about the fractal structure being mostly an artifact of the visualization method.
There is literally a linear projection (
well, we allow a constant offset actually, so affine) of the residual stream into two dimensions which directly produces that fractal. There’s no distributions in the middle or anything. Isuspect the offset is not necessary but I haven’t checked ::adding to to-do list::edit: the offset isn’t necessary. There is literally a linear projection of the residual stream into 2D which directly produces the fractal.
But the “fractal-ness” is mostly an artifact of the MSP as a representation-method IIUC; the stochastic process itself is not especially “naturally fractal”.
(As I said I don’t know the details of the MSP very well; my intuition here is instead coming from some background knowledge of where fractals which look like those often come from, specifically chaos games.)
I’m not sure I’m following, but the MSP is naturally fractal (in this case), at least in my mind. The MSP is a stochastic process, but it’s a very particular one—it’s the stochastic process of how an optimal observer’s beliefs (about which state an HMM is in) change upon seeing emissions from that HMM. The set of optimal beliefs themselves are fractal in nature (for this particular case).
Chaos games look very cool, thanks for that pointer!
- Apr 17, 2024, 4:14 PM; 12 points) 's comment on Transformers Represent Belief State Geometry in their Residual Stream by (
Can you elaborate on how the fractal is an artifact of how the data is visualized?
From my perspective, the fractal is there because we chose this data generating structure precisely because it has this fractal pattern as it’s Mixed State Presentation (ie. we chose it because then the ground truth would be a fractal, which felt like highly nontrivial structure to us, and thus a good falsifiable test that this framework is at all relevant for transformers. Also, yes, it is pretty :) ). The fractal is a natural consequence of that choice of data generating structure—it is what Computational Mechanics says is the geometric structure of synchronization for the HMM. That there is a linear 2d plane in the residual stream that when you project onto it you get that same fractal seems highly non-artifactual, and is what we were testing.Though it should be said that an HMM with a fractal MSP is a quite generic choice. It’s remarkably easy to get such fractal structures. If you randomly chose an HMM from the space of HMMs for a given number of states and vocab size, you will often get synchronizations structures with infinite transient states and fractals.
This isn’t a proof of that previous claim, but here are some examples of fractal MSPs from https://arxiv.org/abs/2102.10487:
Transformers Represent Belief State Geometry in their Residual Stream
I find this focus on task structure and task decomposition to be incredibly important when thinking about what neural networks are doing, what they could be doing in the future, and how they are doing it. The manner in which a system understands/represents/instantiates task structures and puts them in relation to one another is, as far as I can tell, just a more concrete way of asking “what is it that this neural network knows? what cognitive abilities does it have? what abstractions is it making? under what out of distribution inputs will it succeed/fail, etc.”
This comment isn’t saying anything that wasn’t in the post, just wanted to express happiness and solidarity with this framing!
I do wonder if the tree-structure of which-task and then task algorithm is what we should expect, in general. I have nothing super concrete to say here, my feeling is just that the manners in which a neural network can represent structures and put them in relation to eachother may be instantiated differently than a tree (with that specific ordering). The onus is probably on me here though—I should come up with a set of tasks in certain relations that aren’t most naturally described with tree structures.
Another question that comes to mind is, is there a hard distinction between categorizing which sub-task one is in and the algorithm which carries out the computation for a specific subtask. Is it all just tasks all the way down?
I think you might need to change permissions on your github repository?
The blog post linked says it’s from August. Is there something new I’m missing?
Thanks John and David for this post! This post has really helped people to understand the full story. I’m especially interested in thinking more about plans for how this type of work can be helpful for AI safety. I do think the one you presented here is a great one, but I hope there are other potential pathways. I have some ideas, which I’ll present in a post soon, but my views on this are still evolving.