Been thinking a bit about latent reasoning. Here’s an interesting confusion I’ve run into.
Consider COCONUT vs Geiping et al. Geiping et al do recurrent passes in between the generation of each new token, COCONUT turns a section of the CoT into a recurrent state. Which is better / how are they different, safety-wise?
Intuitively COCONUT strikes me as very scary, because it makes the CoT illegible. We could try and read it by coaxing it back to the nearest token, but the whole point is to allow reasoning that involves passing more state than can be captured in one token. If it works as advertised, this oversight will be lossy.
Intuitively Geiping et al seems better. They use skip connections in the recurrence, so maybe their method maintains logit lens. It still increases the maximum depth between overseeable tokens, but seems no more dangerous than a non-recurrent model of equivalent depth.
But isn’t the COCONUT method roughly equivalent to just doing the Geiping et al method for only one token? Why does it seem so much scarier? Do the skip connections really make that much difference? Doesn’t COCONUT effectively have skip connections anyway, because it autoregressively generates new “tokens”? And it’ll have the logit lens property too, since it uses the same residual stream with the same feature directions.
This made me realize that the logit lens result has made me think of within-forward-pass cognition as very myopic and next-token-oriented. Relatedly, it’s hard for me to imagine far-ranging or highly consequentialist cognition happening within a forward pass—I’m generally more comfortable thinking of that stuff as happening within the CoT. But now that I articulate it explicitly, that’s sort of a weird view—why does inserting a “unembed, sample, reembed w/ skip” block every so often make such a difference? The fact that COCONUT works is an update against.
Been thinking a bit about latent reasoning. Here’s an interesting confusion I’ve run into.
Consider COCONUT vs Geiping et al. Geiping et al do recurrent passes in between the generation of each new token, COCONUT turns a section of the CoT into a recurrent state. Which is better / how are they different, safety-wise?
Intuitively COCONUT strikes me as very scary, because it makes the CoT illegible. We could try and read it by coaxing it back to the nearest token, but the whole point is to allow reasoning that involves passing more state than can be captured in one token. If it works as advertised, this oversight will be lossy.
Intuitively Geiping et al seems better. They use skip connections in the recurrence, so maybe their method maintains logit lens. It still increases the maximum depth between overseeable tokens, but seems no more dangerous than a non-recurrent model of equivalent depth.
But isn’t the COCONUT method roughly equivalent to just doing the Geiping et al method for only one token? Why does it seem so much scarier? Do the skip connections really make that much difference? Doesn’t COCONUT effectively have skip connections anyway, because it autoregressively generates new “tokens”? And it’ll have the logit lens property too, since it uses the same residual stream with the same feature directions.
This made me realize that the logit lens result has made me think of within-forward-pass cognition as very myopic and next-token-oriented. Relatedly, it’s hard for me to imagine far-ranging or highly consequentialist cognition happening within a forward pass—I’m generally more comfortable thinking of that stuff as happening within the CoT. But now that I articulate it explicitly, that’s sort of a weird view—why does inserting a “unembed, sample, reembed w/ skip” block every so often make such a difference? The fact that COCONUT works is an update against.