This paper finds critical periods in neural networks, and they’re a known phenomena in lots of animals. h/t Turntrout
An SLT story that seems plausible to me:
We can model the epoch as a temperature. Longer epochs result in a less noisy gibbs samplers. Earlier in training, we are sampling points from a noisier distribution, and so the full (point reached when training on full distribution) and ablated (point reached when ablating during the critical period) singularitites are kind of treated the same. As we decrease the temperature, they start to differentiate. If during that period in time, we also ablate part of the dataset, we will see a divergence between the sampling results of the full distribution and the ablated distribution. Because both regions are singular, and likely only connected via very low-dimensional low-loss manifolds, the sampling process gets stuck in the ablated region.
The paper uses the trace of the FIM to unknowingly measure the degeneracy of the current point. A better measure should be the learning coefficient. This also suggests that higher learning coefficients produce less adaptable models.
One thing this maybe allows us to do is if we’re able to directly model the above process, we can figure out how number of epochs corresponds to temperature.
Another thing the above story suggests is that while intra-epoch training is not really path-dependent, inter-epoch training is potentially path dependent, in the sense that occasionally the choice of which singular region to sample from is not always recoverable if it turns out to have been a bad idea.
The obvious thing to do, which tests the assumption of the above model, but not the model itself, is to see whether the RLCT decreases as you increase the number of epochs. This is a very easy experiment.
Actually maybe slightly less straightforward than this, since as you increase the control parameter β, you’ll both add a pressure to decrease Ln, as well as decrease λ, and it may just be cheaper to decrease Ln rather than λ.
This paper finds critical periods in neural networks, and they’re a known phenomena in lots of animals. h/t Turntrout
An SLT story that seems plausible to me:
We can model the epoch as a temperature. Longer epochs result in a less noisy gibbs samplers. Earlier in training, we are sampling points from a noisier distribution, and so the full (point reached when training on full distribution) and ablated (point reached when ablating during the critical period) singularitites are kind of treated the same. As we decrease the temperature, they start to differentiate. If during that period in time, we also ablate part of the dataset, we will see a divergence between the sampling results of the full distribution and the ablated distribution. Because both regions are singular, and likely only connected via very low-dimensional low-loss manifolds, the sampling process gets stuck in the ablated region.
The paper uses the trace of the FIM to unknowingly measure the degeneracy of the current point. A better measure should be the learning coefficient. This also suggests that higher learning coefficients produce less adaptable models.
One thing this maybe allows us to do is if we’re able to directly model the above process, we can figure out how number of epochs corresponds to temperature.
Another thing the above story suggests is that while intra-epoch training is not really path-dependent, inter-epoch training is potentially path dependent, in the sense that occasionally the choice of which singular region to sample from is not always recoverable if it turns out to have been a bad idea.
Thinking about more direct tests here...
The obvious thing to do, which tests the assumption of the above model, but not the model itself, is to see whether the RLCT decreases as you increase the number of epochs. This is a very easy experiment.
Actually maybe slightly less straightforward than this, since as you increase the control parameter β, you’ll both add a pressure to decrease Ln, as well as decrease λ, and it may just be cheaper to decrease Ln rather than λ.