This is a completely fair suggestion. I’ll look into training a fully-fledged SAE with the same number of features for the full training duration.
tdooms
One caveat that I want to highlight is that there was a bug training the tokenized SAEs for the expansions sweep, the lookup table isn’t learned but remained at the hard-coded values...
They are therefore quite suboptimal. Due to some compute constraints, I haven’t re-run that experiment (the x64 SAEs take quite a while to train).
Anyway, I think the main question you want answered is if the 8x tokenized SAE beats the 64x normal SAE, which it does. The 64x SAE is improving slightly quicker near the end of training, I only used 130M tokens.
Below is an NMSE plot for k=30 across expansion factors (the CE is about the same albeit slightly less impacted by size increase). the “tokenized” label indicates the non-learned lookup and the “Learned” is the working tokenized setup.
That’s awesome to hear, while we are not especially familiar with circuit analysis, anecdotally, we’ve heard that some circuit features are very disappointing (such as the “Mary” feature for IOI, I believe this is also the case in Othello SAEs where many features just describe the last move). This was a partial motivation for this work.
About similar tokenized features, maybe I’m misunderstanding, but this seems like a problem for any decoder-like structure. In the lookup table though, I think this behaviour is somewhat attenuated due to the strict manual trigger, which encourages the lookup table to learn exact features instead of means.
We used a google drive repo where we stored most of the runs (https://drive.google.com/drive/folders/1ERSkdA_yxr7ky6AItzyst-tCtfUPy66j?usp=sharing). We use a somewhat weird naming scheme, if there is a “t” in the postfix of the name, it is tokenized. Some may be old and may not fully work, if you run into any issues, feel free to reach out.
The code in the research repo (specifically https://github.com/tdooms/tokenized-sae/blob/main/base.py#L119) should work to load them in.
Please keep in mind that these are currently more of a proof of concept and are likely undertrained. We were hoping to determine the level of interest in this technique before training a proper suite.
Tokens are indeed only a specific instantiation of hardcoding “known” features into an SAE, there are lots of interesting sparse features one can consider which may even further speed up training.
I like the suggestion of trying to find the “enriched” token representations. While our work shows that such representations are likely bigrams and trigrams, using an extremely sparse SAE to reveal those could also work (say at layer 1 or 2). While this approach still has the drawback of having an encoder, this encoder can be shared across SAEs, which is still a large decrease in complexity. Also, the encoder will probably be simpler since it’s earlier in the model.
This idea can be implemented recursively across a suite of SAEs, where each layer can add to a pool of hardcoded features. In other words, each layer SAE in a layer has its own encoder/decoder and the decoder is copied (and fine-tuned) across later layers. This would allow to more faithfully trace a feature through the model than is currently possible.
We haven’t considered this since our idea was that the encoder could maybe use the full information to better predict features. However, this seems worthwhile to at least try. I’ll look into this soon, thanks for the inspiration.