The second way gradient descent is used is to find the activations of the neural network, in this case, to find the values of the r, given a single input image. This is not standard from the perspective of modern deep learning, but should set off mesa-optimizer alarm bells.
Actually this is completely standard in DL as it’s core functionality of the transformer architecture. Transformers consist of a stack of homogeneous layers that update a shared residual stream, so they are residual networks and thus equivalent to an unrolled iterative estimator with untied weights. In this form it’s more obvious that transformers are equivalent to (constrained) unrolled recurrent networks that performs K iterative sub inference steps on a probability map (the residual stream) per main token prediction step.
In this sense they are much closer to brain arch than most realize. Each main ‘layer’ (or sub step in the shared weight form) has two sub modules: a pure depth 2 FF network which is a close match for the cerebellum in the brain, and the attention submodule which provides short term contextual memory—and is thus arguably more similar to the recurrent modules of the cortex/hippocampus (due to the equivalence of fast weight RNNs and transformer attention modules).
So if you imagine implementing a large transformer physically with a huge but slow circuit, each main layer (with its own weights) would correspond to a cortical-cerebellar module pair, with the cortical module as the attention part and the cerebellar module as the FF part. Full pipeline parallelization is used of course, so that temporal flow down the transformer depth corresponds to flow from sensory up to higher level modules and then down to motor modules in the brain.
There are some differences of course—the brain arch is truly recurrent, which would be like allowing connections from higher levels down to lower levels in the transformer depth. Transformers dont allow this because it breaks their parallelization over time strategy, which is their major serious weakness compared to the more fully universal RNN arch of the brain.
There are many many other convergent details of course—the bits per weight of the most efficient transformers is converging to 4-bits per weight similar to the brain, the #weights per neuron is similar, the learned internal representations are similarish or functionally equivalent, etc
Actually this is completely standard in DL as it’s core functionality of the transformer architecture. Transformers consist of a stack of homogeneous layers that update a shared residual stream, so they are residual networks and thus equivalent to an unrolled iterative estimator with untied weights. In this form it’s more obvious that transformers are equivalent to (constrained) unrolled recurrent networks that performs K iterative sub inference steps on a probability map (the residual stream) per main token prediction step.
In this sense they are much closer to brain arch than most realize. Each main ‘layer’ (or sub step in the shared weight form) has two sub modules: a pure depth 2 FF network which is a close match for the cerebellum in the brain, and the attention submodule which provides short term contextual memory—and is thus arguably more similar to the recurrent modules of the cortex/hippocampus (due to the equivalence of fast weight RNNs and transformer attention modules).
So if you imagine implementing a large transformer physically with a huge but slow circuit, each main layer (with its own weights) would correspond to a cortical-cerebellar module pair, with the cortical module as the attention part and the cerebellar module as the FF part. Full pipeline parallelization is used of course, so that temporal flow down the transformer depth corresponds to flow from sensory up to higher level modules and then down to motor modules in the brain.
There are some differences of course—the brain arch is truly recurrent, which would be like allowing connections from higher levels down to lower levels in the transformer depth. Transformers dont allow this because it breaks their parallelization over time strategy, which is their major serious weakness compared to the more fully universal RNN arch of the brain.
There are many many other convergent details of course—the bits per weight of the most efficient transformers is converging to 4-bits per weight similar to the brain, the #weights per neuron is similar, the learned internal representations are similarish or functionally equivalent, etc