I feel like the term “amortization” in ML/CS has a couple of meanings. Do you just mean redistributing compute from training to inference?
I think this is an interesting model, but I also think that part of the use of CoT is more specific to the language/logic context, to literally think step by step (which sometimes lets you split problems into subproblems). In some limit, there would be exponentially few examples in the training data of directly “thinking n steps ahead”, so a transformer wouldn’t be able to learn to do this at all (at least without some impressive RL). Like imagine training a chess playing computer to play chess, by only looking at every 10th move of a chess game: probably with enough inference power, a very powerful system wold be able to reconstruct the rules of chess as the best way of making sense of the regularities in the information, but this is in some sense exponentially harder than learning from looking at every move
Ah I think that the notion of amortized inference that you’re using encapsulates what I’m saying about chess. I’m still a little confused about the scope of the concept though—do you have a good cached explanation?
Yes, I’m thinking of that line of work. I actually think the first few paragraphs of this paper does a better job of getting the vibes I want (and I should emphasize these are vibes that I have, not any kind of formal understanding). So here’s my try at a cached explanation of the concept of amortized inference I’m trying to evoke:
A lot of problems are really hard, and the algorithmic/reasoning path from the question to the answer are many steps. But it seems that in some cases humans are much faster than that (perhaps by admitting some error, but even so, they are both fast and quite good at the task). The idea is that in these settings a human brain is performing amortized inference—because they’ve seen similar examples of the input/output relation of the task before, they can use that direct mapping as a kind of bootstrap for the new task at hand, saving a lot of inference time.
Now that i’ve typed that out it feels maybe similar to your stuff about heuristics?
Big caveat here: it’s quite possible I’m misunderstanding amortized inference (maybe @jessicata can help here?), as well as reaching with the connection to your work.
I’m not sure this captures what you mean, but, if you see a query, do a bunch of reasoning, and get an answer, then you can build a dataset of (query, well-thought guess). Then you can train an AI model on that.
AlphaZero sorta works like this, because it can make a “well-thought guess” (take value and/or Q network, do an iteration of minimax, then make the value/Q network more closely approximate that, in a fixed point fashion)
Learning stochastic inverses is a specific case of “learn to automate Bayesian inference by taking forward samples and learning the backwards model”. It could be applied to LLMs for example, in terms of starting with a forwards LLM and then using it to train a LLM that predicts things out-of-order.
Paul Christiano’s iterated amplification and distillation is applying this idea to ML systems with a human feedback element. If you can expend a bunch of compute to get a good answer, you can train a weaker system to approximate that answer. Or, if you can expend a bunch of compute to get a good rating for answers, you can use that as RL feedback.
Broadly, I take o3 as evidence that Christiano’s work is broadly on the right track with respect to alignment of near-term AI systems. That is, o3 shows that hard questions can be decomposed into easy ones, in a way that involves training weaker models to be part of a big computation. (I don’t understand the o3 details that well, given it’s partially private, but I’m assuming this describes the general outlines). So I think the sort of schemes Christiano has described will be helpful for both alignment and capabilities, and will scale pretty well to impressive systems.
I’m not sure if there’s a form of amortized inference that you think this doesn’t cover well.
Thanks, this is helpful. I’m still a bit unclear about how to use the word/concept “amortized inference” correctly. Is the first example you gave, of training an AI model on (query, well-thought guess), an example of amortized inference, relative to training on (query, a bunch of reasoning + well-thought out guess)?
I feel like the term “amortization” in ML/CS has a couple of meanings. Do you just mean redistributing compute from training to inference?
I think this is an interesting model, but I also think that part of the use of CoT is more specific to the language/logic context, to literally think step by step (which sometimes lets you split problems into subproblems). In some limit, there would be exponentially few examples in the training data of directly “thinking n steps ahead”, so a transformer wouldn’t be able to learn to do this at all (at least without some impressive RL). Like imagine training a chess playing computer to play chess, by only looking at every 10th move of a chess game: probably with enough inference power, a very powerful system wold be able to reconstruct the rules of chess as the best way of making sense of the regularities in the information, but this is in some sense exponentially harder than learning from looking at every move
Ah I think that the notion of amortized inference that you’re using encapsulates what I’m saying about chess. I’m still a little confused about the scope of the concept though—do you have a good cached explanation?
Yes, I’m thinking of that line of work. I actually think the first few paragraphs of this paper does a better job of getting the vibes I want (and I should emphasize these are vibes that I have, not any kind of formal understanding). So here’s my try at a cached explanation of the concept of amortized inference I’m trying to evoke:
A lot of problems are really hard, and the algorithmic/reasoning path from the question to the answer are many steps. But it seems that in some cases humans are much faster than that (perhaps by admitting some error, but even so, they are both fast and quite good at the task). The idea is that in these settings a human brain is performing amortized inference—because they’ve seen similar examples of the input/output relation of the task before, they can use that direct mapping as a kind of bootstrap for the new task at hand, saving a lot of inference time.
Now that i’ve typed that out it feels maybe similar to your stuff about heuristics?
Big caveat here: it’s quite possible I’m misunderstanding amortized inference (maybe @jessicata can help here?), as well as reaching with the connection to your work.
I’m not sure this captures what you mean, but, if you see a query, do a bunch of reasoning, and get an answer, then you can build a dataset of (query, well-thought guess). Then you can train an AI model on that.
AlphaZero sorta works like this, because it can make a “well-thought guess” (take value and/or Q network, do an iteration of minimax, then make the value/Q network more closely approximate that, in a fixed point fashion)
Learning stochastic inverses is a specific case of “learn to automate Bayesian inference by taking forward samples and learning the backwards model”. It could be applied to LLMs for example, in terms of starting with a forwards LLM and then using it to train a LLM that predicts things out-of-order.
Paul Christiano’s iterated amplification and distillation is applying this idea to ML systems with a human feedback element. If you can expend a bunch of compute to get a good answer, you can train a weaker system to approximate that answer. Or, if you can expend a bunch of compute to get a good rating for answers, you can use that as RL feedback.
Broadly, I take o3 as evidence that Christiano’s work is broadly on the right track with respect to alignment of near-term AI systems. That is, o3 shows that hard questions can be decomposed into easy ones, in a way that involves training weaker models to be part of a big computation. (I don’t understand the o3 details that well, given it’s partially private, but I’m assuming this describes the general outlines). So I think the sort of schemes Christiano has described will be helpful for both alignment and capabilities, and will scale pretty well to impressive systems.
I’m not sure if there’s a form of amortized inference that you think this doesn’t cover well.
Thanks, this is helpful. I’m still a bit unclear about how to use the word/concept “amortized inference” correctly. Is the first example you gave, of training an AI model on (query, well-thought guess), an example of amortized inference, relative to training on (query, a bunch of reasoning + well-thought out guess)?
I don’t habitually use the concept so I don’t have an opinion on how to use the term.