Latent Adversarial Training

The Problem

We’d like to train models to be robustly safe, even in environments that may fall well outside of the training distribution. Unfortunately all we get to work with is the training distribution, which makes ensuring robust generalization difficult.

Deception is an example of this concern. Models may be well-behaved in the training environment because they realize it is a training environment, but behave poorly in deployment because they notice the distributional shift.

An Approach

Background

It would be great if we could generate scenarios where the model behaves badly. For instance if the model is deceptive we’d like to generate inputs to trick it into thinking it’s been deployed so that it reveals that deception.

Adversarial Training (AT), and Relaxed Adversarial Training (RAT), are approaches to this problem. Very briefly, in AT, we search for inputs that generate large loss, while in RAT we search for distributions of inputs or other more general evidence that the model is capable of generating large loss.

The approach I want to explore, called Latent Adversarial Training (LAT), is a form of RAT.

Concept

Suppose we have two models, the Agent and the Surgeon. The Agent is the model we’re trying to train, and the Surgeon is going to be its adversary. The training proceeds as follows:

  1. The Agent processes input from the environment, producing latent state and output .

  2. The Surgeon takes inputs and searches for a modification such that does poorly on the loss function .

  3. We update the Agent based on , treating the Surgeon’s modifications as constants.

The idea is that the Surgeon gets to meddle with the latent state of the Agent with full knowledge of what the Agent was planning on doing, and we train the Agent to be robust against this kind of attack.

Implementation

Concretely, suppose the Agent is a neural network. The latent state consists of the activations of all the neurons and the Surgeon gets to specify modifications like “this neuron gets +5”.

How does the Surgeon decide what to do? There are a number of options. It could examine the gradient of the Agent’s loss with respect to each activation, then push on the activations that give it the most loss. Or it could randomly generate its modifications and use some other optimization process like the Metropolis-Hastings algorithm to identify loss-maximizing options.

In general I’m imagining the Surgeon as more of a fuzzer with access to program memory and less of a complicated model. In principle it could be some complicated neural network, but then I’d want it to be simpler than the Agent, possibly of bounded size, so I could be quite sure I understood what it was doing. The last thing I want is a misaligned Surgeon trying to achieve its own mesa-objective by programming the Agent to be Surgeon-aligned and deceptive against humans.

I suspect it’s possible to have the Surgeon pick useful modifications that generate meaningful changes in the loss without having the Surgeon’s complexity scale with that of the Agent, but that’s just an intuition, and if it fails then this approach looks a lot less appealing.

Core Ideas

The key difference between LAT and Adversarial Training is that the Surgeon gets to directly manipulate the Agent’s inner state, which makes the Surgeon’s job much easier than in the ordinary adversarial training setup.

Take the classic example of an Agent that only defects if it sees a valid solution to a hard cryptographic problem: the Surgeon doesn’t need to identify an input that causes the Agent to defect, it only needs to identify that somewhere in the middle of the Agent there is a switch which, if flipped, causes the Agent to defect and generate massive loss.

In a sense the Surgeon plays the role of intrusive thoughts. It can force the Agent to think of something bad (e.g. hurting humans) even if the input is perfectly benign. We can then train the Agent to have safeguards, so it does not act on such thoughts.

Challenges

The Surgeon needs constraints

The first challenge is that LAT works too well unless we constrain the kinds of modifications the Surgeon can impose on the Agent’s latent state. For instance, if our goal is to make a language model never output the word “kill”, the Surgeon can always force a large loss by strongly shifting the activations in the final layer towards the tokens that generate the word “kill”.

So we need to bind the Surgeon’s hands enough to make its task non-trivial, but not so much that it fails to force the Agent to be robustly safe. How can we do this?

A special case is that the Surgeon shouldn’t be able to meddle too close to the output layer. That clearly makes its job too easy in ways that don’t help with safety. We probably want to forbid the Surgeon from touching the last few layers in the model, and as long as that remains “a few layers” and not something that scales up with model size it seems plausible that interpretability tools can tell us if anything fishy is happening in the parts the Surgeon can’t touch.

Beyond that, it seems important to force the Surgeon’s modifications to be sparse. This probably means restricting it so that it cannot modify more than e.g. 10% of the activations in a given layer. Otherwise it would be possible to completely change the model’s latent state from that point on, which again may make the job too easy.

Similarly, it’s probably important to bound the size of the Surgeon’s modifications (e.g. +5 is okay, but is not). A sufficiently extreme modification to even a single activation can probably break the Agent’s output in unhelpful ways.

These are just the constraints that I can think of. We may need more. Fortunately the failure mode here is easy to notice because it looks like the Agent always doing badly after the Surgeon fiddles with its latent state. If the Surgeon always wins that’s a problem because it means we’re not making progress. Eventually we want an Agent that can robustly do good things despite the Surgeon’s meddling.

The precise constraints we need to prevent this failure mode can likely be determined empirically on small Agents. The failure mode of “the Surgeon hijacks the Agent by doing something extreme and obviously bad” probably looks similar for small and large language models, for instance, and I suspect does not change in complexity at scale. I could be wrong about that, but that’s my intuition.

It’s also possible that we can start with many constraints and, as the Agent starts to be robustly safe despite the Surgeon’s meddling, untie the Surgeon’s hands and give it more room to play. A sufficiently free Surgeon should always win over even very robustly safe Agents, but it’s possible that through the training process the Agent can learn to handle quite strong Surgeons.

The Agent can be inscrutable

Latent Adversarial Training only works if the Surgeon can identify places where it can modify the Agent’s latent state to cause problems. If we constrain the Surgeon so this task isn’t trivial, it could be that it becomes impossible.

For instance, it could be that the calculations the Agent does are encrypted, so that it is very hard to control its outputs robustly by changing distant parts of the latent state. This seems unlikely for models that aren’t actively trying to deceive us, as a number of different kinds of models have revealed rich conceptual structure in their layers. Still, we’re worried about deceptive models so this seems like an important challenge to address.

I think the only hope here lies in the Surgeon forcing the model to be robustly safe before it learns to deceive. Once the model is deceptive its really game-over.

Is this possible? Maybe, and it almost certainly depends on the prior the model is sampled from.

My weak suspicion is that early deception probably looks like a subcircuit that detects if the model is in training, connected to a switch that flips between good and bad behavior. That switch seems like a good target for the Surgeon, so if it finds it early on (before the model has gradient-hacked its way into a more inscrutable structure) it could force the model to defect while in training, revealing large losses and allowing us to remedy the situation. Early deception might be more obfuscated than this, but probably not so obfuscated that the Surgeon can’t find the pattern of latent state corresponding to “defect on deployment”.

Summary

Latent Adversarial Training is an approach where an adversary (the Surgeon) attempts to force a model (the Agent) to behave badly by altering its latent state. In this way we can identify ways that the model may be capable of behaving badly, and train the model to be more robustly safe even on very different distributions.

The core task of LAT is much easier than that of regular Adversarial Training, indeed so much easier that one of the key challenges to making LAT work is placing enough constraints on the adversary to make its job non-trivial, while not placing so many as to make it impossible.

A further challenge is that the adversary itself needs to be safe, which in practice likely means the adversary needs to be well-understood by humans, more akin to a fuzzer than an ML model. In particular this means that the adversary must be much simpler than any large model it attacks. This may not be a problem, it may suffice that the adversary has access to the full internal state of the model, but it is a limitation worth bearing in mind.

Finally, LAT cannot make a model safe once that model has developed robust deception, so it must be employed from the beginning to (ideally) prevent deception from taking root.

Thanks to Evan Hubinger and Nicholas Schiefer for discussions on LAT.