With that in mind, the real hot possibility is the inverse of what Shai and his coresearchers did. Rather than start with a toy model with some known nice latents, start with a net trained on real-world data, and go look for self-similar sets of activations in order to figure out what latent variables the net models its environment as containing. The symmetries of the set would tell us something about how the net updates its distributions over latents in response to inputs and time passing, which in turn would inform how the net models the latents as relating to its inputs, which in turn would inform which real-world structures those latents represent.
Thank you, this was very much the paragraph I was missing to understand why comp mech might be useful for interpretability.
How sure are we that models will keep tracking Bayesian belief states, and so allow this inverse reasoning to be used, when they don’t have enough space and compute to actually track a distribution over latent states?
Approximating those distributions by something like ‘peak position plus spread’ seems like the kind of thing a model might do to save space.
How sure are we that models will keeptracking Bayesian belief states, and so allow this inverse reasoning to be used, when they don’t have enough space and compute to actually track a distribution over latent states?
One obvious guess there would be that the factorization structure is exploited, e.g. independence and especially conditional independence/DAG structure. And then a big question is how distributions of conditionally independent latents in particular end up embedded.
One obvious guess there would be that the factorization structure is exploited, e.g. independence and especially conditional independence/DAG structure. And then a big question is how distributions of conditionally independent latents in particular end up embedded.
Separately, there are theoretical reasons to expect convergence to approximate causal models of the data generating process, e.g. Robust agents learn causal world models.
Right. If I have n fully independent latent variables that suffice to describe the state of the system, each of which can be in one of s different states, then even tracking the probability of every state for every latent with a p bit precision float will only take me about n×s×p bits. That’s actually not that bad compared to n×log(s) for just tracking some max likelihood guess.
Thank you, this was very much the paragraph I was missing to understand why comp mech might be useful for interpretability.
How sure are we that models will keep tracking Bayesian belief states, and so allow this inverse reasoning to be used, when they don’t have enough space and compute to actually track a distribution over latent states?
Approximating those distributions by something like ‘peak position plus spread’ seems like the kind of thing a model might do to save space.
One obvious guess there would be that the factorization structure is exploited, e.g. independence and especially conditional independence/DAG structure. And then a big question is how distributions of conditionally independent latents in particular end up embedded.
There are some theoretical reasons to expect linear representations for variables which are causally separable / independent. See recent work from Victor Veitch’s group, e.g. Concept Algebra for (Score-Based) Text-Controlled Generative Models, The Linear Representation Hypothesis and the Geometry of Large Language Models, On the Origins of Linear Representations in Large Language Models.
Separately, there are theoretical reasons to expect convergence to approximate causal models of the data generating process, e.g. Robust agents learn causal world models.
Linearity might also make it (provably) easier to find the concepts, see Learning Interpretable Concepts: Unifying Causal Representation Learning and Foundation Models.
Right. If I have n fully independent latent variables that suffice to describe the state of the system, each of which can be in one of s different states, then even tracking the probability of every state for every latent with a p bit precision float will only take me about n×s×p bits. That’s actually not that bad compared to n×log(s) for just tracking some max likelihood guess.