Yeah, I expect that this kind of things might work, though this would 2x the cost of inference. An alternative is “attention head probes”, MLP probes, and things like that (which don’t increase inference cost), + maybe different training losses for the probe (here we train per-sequence position and aggregate with max), and I expect something in this reference class to work as well as RR, though it might require RR-levels of tuning to actually work as well as RR (which is why I don’t consider this kind of probing as a baseline you ought to try).
Why would it 2x the cost of inference? To be clear, my suggested baseline is “attach exactly the same LoRA adapters that were used for RR, plus one additional linear classification head, then train on an objective which is similar to RR but where the rerouting loss is replaced by a classification loss for the classification head.” Explicitly this is to test the hypothesis that RR only worked better than HP because it was optimizing more parameters (but isn’t otherwise meaningfully different from probing).
(Note that LoRA adapters can be merged into model weights for inference.)
(I agree that you could also just use more expressive probes, but I’m interested in this as a baseline for RR, not as a way to improve robustness per se.)
I was imagining doing two forward passes: one with and one without the LoRAs, but you had in mind adding “keep behavior the same” loss in addition to the classification loss, right? I guess that would work, good point.
Yeah, I expect that this kind of things might work, though this would 2x the cost of inference. An alternative is “attention head probes”, MLP probes, and things like that (which don’t increase inference cost), + maybe different training losses for the probe (here we train per-sequence position and aggregate with max), and I expect something in this reference class to work as well as RR, though it might require RR-levels of tuning to actually work as well as RR (which is why I don’t consider this kind of probing as a baseline you ought to try).
Why would it 2x the cost of inference? To be clear, my suggested baseline is “attach exactly the same LoRA adapters that were used for RR, plus one additional linear classification head, then train on an objective which is similar to RR but where the rerouting loss is replaced by a classification loss for the classification head.” Explicitly this is to test the hypothesis that RR only worked better than HP because it was optimizing more parameters (but isn’t otherwise meaningfully different from probing).
(Note that LoRA adapters can be merged into model weights for inference.)
(I agree that you could also just use more expressive probes, but I’m interested in this as a baseline for RR, not as a way to improve robustness per se.)
I was imagining doing two forward passes: one with and one without the LoRAs, but you had in mind adding “keep behavior the same” loss in addition to the classification loss, right? I guess that would work, good point.