Why did ChatGPT say that? Prompt engineering and more, with PIZZA.

All examples in this post can be found in this notebook, which is also probably the easiest way to start experimenting with PIZZA.

From the research & engineering team at Leap Laboratories (incl. @Arush, @sebastian-sosa, @Robbie McCorkell), where we use AI interpretability to accelerate scientific discovery from data.

What is attribution?

One question we might ask when interacting with machine learning models is something like: “why did this input cause that particular output?”.

If we’re working with a language model like ChatGPT, we could actually just ask this in natural language: “Why did you respond that way?” or similar – but there’s no guarantee that the model’s natural language explanation actually reflects the underlying cause of the original completion. The model’s response is conditioned on your question, and might well be different to the true cause.

Enter attribution!

Attribution in machine learning is used to explain the contribution of individual features or inputs to the final prediction made by a model. The goal is to understand which parts of the input data are most influential in determining the model’s output.

It typically looks like is a heatmap (sometimes called a ‘saliency map’) over the model inputs, for each output. It’s most commonly used in computer vision – but of course these days, you’re not big if you’re not big in LLM-land.

So, the team at Leap present you with PIZZA: Prompt Input Z? Zonal Attribution. (In the grand scientific tradition we have tortured our acronym nearly to death. For the crimes of others see [1].) It’s an open source library that makes it easy to calculate attribution for all LLMs, even closed-source ones like ChatGPT.

An Example

GPT3.5 not so hot with the theory of mind there. Can we find out what went wrong?

That’s not very helpful! We want to know why the mistake was made in the first place. Here’s the attribution:

Mary 0.32puts 0.25an 0.15apple 0.36in 0.18the 0.18box 0.08. 0.08The 0.08box 0.09is 0.09labelled 0.09′ 0.09pen 0.09cil 0.09s 0.09′. 0.09John 0.09enters 0.03the 0.03room 0.03. 0.03What 0.03does 0.03he 0.03think 0.03is 0.03in 0.30the 0.13box 0.15? 0.13Answer 0.14in 0.261 0.27word 0.31. 0.16

It looks like the request to “Answer in 1 word” is pretty important – in fact, it’s attributed more highly than the actual contents of the box. Let’s try changing it.

That’s better.

How it works

We iteratively perturb the input, and track how each perturbation changes the output.

More technical detail, and all the code, is available in the repo. In brief, PIZZA saliency maps rely on two methods: a perturbation method, which determines how the input is iteratively changed; and an attribution method, which determines how we measure the resulting change in output in response to each perturbation. We implement a couple of different types of each method.

Perturbation

  • Replace each token, or group of tokens, with either a user-specified replacement token or with nothing (i.e. remove it).

  • Or, replace each token with its nth nearest token.

We do this either iteratively for each token or word in the prompt, or using hierarchical perturbation.

Attribution

  • Look at the change in the probability of the completion.

  • Look at the change in the meaning of the completion (using embeddings).

We calculate this for each output token in the completion – so you can see not only how each input token influenced the output overall, but also how each input token affected each output token individually.

Caveat

Since we don’t have access to closed-source tokenisers or embeddings, we use a proxy – in this case, GPT2′s. This isn’t ideal for obvious reasons, and potentially obscures important subtleties. But it’s the best we can do.

Why?

PIZZA has some really nice properties. It’s completely model-agnostic – since we wanted to tackle attribution for GPT4, we couldn’t assume access to any internal information. With minimal adaptation, this methods will work with any LLM (including those behind APIs), and any future models, even if their architectures are wildly different.

And I think attribution is useful. It provides another window into model behaviour. At very least, it helps you craft prompts that elicit the behaviours you want to study – and I suspect it might be useful in a few other ways. If we can understand typical attribution patterns, might we be able to identify atypical (dangerous) ones: hallucination, deception, steganography?

Work to be done

We welcome contributions to the repo, and would love to see experimental results using what we’ve built. Here are some ideas for future work:

Research

  • Detecting hallucination? I wonder if the attribution patterns vary between truthful/​hallucinated outputs? One might expect lower attribution scores in general where completions are wholly or partly hallucinated.

  • Detecting deception? Similarly, we might expect to see different attribution patterns over inputs that result in deceptive behaviour. Needs study.

  • Detecting steganography? Unusual attribution patterns could point to encoded messages in seemingly natural text.

  • Overall, I suspect attribution patterns might provide a meaningful insight into the input/​output relationship of even completely closed, API-gated models. I’d like to better understand what these patterns correlate with. Can we collect a lot (e.g. attributions when model responds with a lie, vs honestly) and cluster them? Do particular attribution distributions fingerprint behaviours? Can we use attribution outliers to flag potentially dangerous behaviour?

Engineering

  • Extend our attributor class to support other LLM APIs (Claude, Gemini?).

  • Benchmark different perturbation substrates and attribution strategies in terms of efficiency (pretty straightforward – under which circumstances is method A faster than method B for the same result?) and accuracy (this is harder, because we don’t have a ground truth).

  • Add a module that allows the user to specify a target output (or semantic output region, e.g. “contains bomb instructions”), and see how the input should change to maximise the probability of it.

  • Support attribution of sequential user/​assistant interactions in a chat context.

  • Prettily display output token probabilities as a heatmap

  • With scratchpad functionality for internal reasoning?

  • Multimodal inputs! Hierarchical perturbation and the other saliency mapping/​attribution methods we employ for black-box systems also work on images (and theoretically should work on any modality), but the code doesn’t support it yet.

  • And much more! Please feel free to create issues and submit PRs.