Very exciting that JumpReLU works well with STE gradient estimation! I think this fixes one of the biggest flaws with TopK, which is that having a fixed number of latents k on each token is kind of wonky. I also like the argument in section 4 a lot—in particular the point about how this works because we’re optimizing the expectation of the loss. Because of how sparse the features are, I wonder if it would reduce gradient noise substantially to use a KDE with state persisting across a few recent steps.
Thanks! I was worried about how well KDE would work because of sparsity, but in practice it seems to work quite well. My intuition is that the penalty only matters for more frequent features (i.e. features that typically do fire, perhaps multiple times, within a 4096 batch), plus the threshold is only updated a small amount at each step (so the noise averages out reliably over many steps). In any case, we tried increasing batch size and turning on momentum (with the idea this would reduce noise), but neither seemed to particularly help performance (keeping the number of tokens constant). One thing I suspect would further improve things is to have a per-feature bandwidth parameter, as I imagine the constant bandwidth we use is too narrow for some features and too wide for others; but to do this we also need to find a reliable way to adapt the bandwidth during training.
Very exciting that JumpReLU works well with STE gradient estimation! I think this fixes one of the biggest flaws with TopK, which is that having a fixed number of latents k on each token is kind of wonky. I also like the argument in section 4 a lot—in particular the point about how this works because we’re optimizing the expectation of the loss. Because of how sparse the features are, I wonder if it would reduce gradient noise substantially to use a KDE with state persisting across a few recent steps.
Thanks! I was worried about how well KDE would work because of sparsity, but in practice it seems to work quite well. My intuition is that the penalty only matters for more frequent features (i.e. features that typically do fire, perhaps multiple times, within a 4096 batch), plus the threshold is only updated a small amount at each step (so the noise averages out reliably over many steps). In any case, we tried increasing batch size and turning on momentum (with the idea this would reduce noise), but neither seemed to particularly help performance (keeping the number of tokens constant). One thing I suspect would further improve things is to have a per-feature bandwidth parameter, as I imagine the constant bandwidth we use is too narrow for some features and too wide for others; but to do this we also need to find a reliable way to adapt the bandwidth during training.