I think that these papers do provide sufficient behavioral evidence that transformers are implementing something close to gradient descent in their weights. Garg et al. 2022 examine the performance of 12-layer GPT-style transformers trained to do few-shot learning and show that they can in-context learn 2-layer MLPs. The performance of their model closely matches an MLP with GD for 5000 steps on those same few-shot examples, and it cannot be explained by heuristics like averaging the K-nearest neighbors from the few-shot examples. Since the inputs are fairly high-dimensional, I don’t think they can be performing this well by only memorizing the weights they’ve seen during training. The model is also fairly robust to distribution shifts in the inputs at test time, so the heuristic they must be learning should be pretty similar to a general-purpose learning algorithm.
I think that there also is some amount of mechanistic evidence that transformers implement some sort of iterative optimization algorithm over some quantity stored in the residual stream. In one of the papers mentioned above (Akyurek et al. 2022), the authors trained a probe to extract the ground-truth weights of the linear model from the residual stream and it appears to somewhat work. The diagrams seem to show that it gets better when trained on activations from later layers, so it seems likely that the transformer is iteratively refining its prediction of the weights.
I think that these papers do provide sufficient behavioral evidence that transformers are implementing something close to gradient descent in their weights. Garg et al. 2022 examine the performance of 12-layer GPT-style transformers trained to do few-shot learning and show that they can in-context learn 2-layer MLPs. The performance of their model closely matches an MLP with GD for 5000 steps on those same few-shot examples, and it cannot be explained by heuristics like averaging the K-nearest neighbors from the few-shot examples. Since the inputs are fairly high-dimensional, I don’t think they can be performing this well by only memorizing the weights they’ve seen during training. The model is also fairly robust to distribution shifts in the inputs at test time, so the heuristic they must be learning should be pretty similar to a general-purpose learning algorithm.
I think that there also is some amount of mechanistic evidence that transformers implement some sort of iterative optimization algorithm over some quantity stored in the residual stream. In one of the papers mentioned above (Akyurek et al. 2022), the authors trained a probe to extract the ground-truth weights of the linear model from the residual stream and it appears to somewhat work. The diagrams seem to show that it gets better when trained on activations from later layers, so it seems likely that the transformer is iteratively refining its prediction of the weights.