A simple approach is to maintain intervals which are guaranteed to contain the actual values and to prove that output intervals don’t overlap the unsafe region.
For actual inference stacks in use (e.g. llama-3-405b 8 bit float) interval propagation will blow up massively and result in vacuous bounds. So, you’ll minimally have to assume that floating point approximation is non-adversarial aka random (because if it actually was adversarial, you would probably die!).
(My supporting argument for this is that you see huge blow ups on interval arguments for even tiny models on toy tasks[1] and that modern inference stacks have large amounts of clearly observable floating point non-determinism, which means that the maximum amount of deviation would actually have strong effects if all errors (including across token positions) were controlled by an adversary.)
Though note that the interval propagation issues in this case aren’t from floating point rounding and are instead from imperfections in the model weights implementing the algorithm.
Intervals are often a great simple form of “enclosure” in continuous domains. For simple functions there is also “interval arithmetic” which cheaply produces a bounding interval on the output of a function given intervals on its inputs: https://en.wikipedia.org/wiki/Interval_arithmetic But, as you say, for complex functions it can blow up. For a simple example of why, consider the function “f(x)=x-x” evaluated on the input interval [0,1]. In the simplest interval arithmetic, the interval for subtraction has to bound the worst possible members of the intervals of its two inputs. In this case that would be a lower bound of “0-1″ and an upper bound of “1-0” producing the resulting interval: [-1,1]. But, of course, “x-x” is always 0, so this is huge over-approximation. People have developed all kinds of techniques for capturing the correlations between variables in evaluating circuits on intervals. And you can always shrink the error by splitting the input intervals and doing “branch and bound”. But all of those are just particular implementation choices in proving bounds on the output of the function. Advanced AI theorem provers (like AlphaProof) can use very sophisticated techniques to accurately get the true bound on the output of a function.
But, it may be that it’s not a fruitful approach to try to bound the behavior of complex neural nets such as transformers. In our approach, we don’t need to understand or constrain a complex AI generating a solution or a control policy. Rather, we require the AI to generate a program, control policy, or simple network for taking actions in the situation of interest. And we force it to generate a proof that it satisfies given safety requirements. If it can’t do that, then it has no business taking actions in a dangerous setting.
Rather, we require the AI to generate a program, control policy, or simple network for taking actions in the situation of interest. And we force it to generate a proof that it satisfies given safety requirements. If it can’t do that, then it has no business taking actions in a dangerous setting.
This seems near certain to be cripplingly uncompetitive[1] even with massive effort on improving verification.
I agree you can do better than naive interval propagation by taking into account correlations. However, it will be tricky to get a much better bound while avoiding having this balloon in time complexity (all possible correlations requires exponentional time).
More strongly, I think that if an adversary controlled the non-determinism (e.g. summation order) in current efficient inference setups, they would actually be able to strongly influence the AI to an actualy dangerous extent—we are likely to depend on this non-determinism being non-adversarial (which is a reasonable assumption to make).
(And you can’t prove a false statement...)
See also heuristic arguments which try to resolve this sort of issue by assuming a lack of structure.
For actual inference stacks in use (e.g. llama-3-405b 8 bit float) interval propagation will blow up massively and result in vacuous bounds. So, you’ll minimally have to assume that floating point approximation is non-adversarial aka random (because if it actually was adversarial, you would probably die!).
(My supporting argument for this is that you see huge blow ups on interval arguments for even tiny models on toy tasks[1] and that modern inference stacks have large amounts of clearly observable floating point non-determinism, which means that the maximum amount of deviation would actually have strong effects if all errors (including across token positions) were controlled by an adversary.)
Though note that the interval propagation issues in this case aren’t from floating point rounding and are instead from imperfections in the model weights implementing the algorithm.
Intervals are often a great simple form of “enclosure” in continuous domains. For simple functions there is also “interval arithmetic” which cheaply produces a bounding interval on the output of a function given intervals on its inputs: https://en.wikipedia.org/wiki/Interval_arithmetic But, as you say, for complex functions it can blow up. For a simple example of why, consider the function “f(x)=x-x” evaluated on the input interval [0,1]. In the simplest interval arithmetic, the interval for subtraction has to bound the worst possible members of the intervals of its two inputs. In this case that would be a lower bound of “0-1″ and an upper bound of “1-0” producing the resulting interval: [-1,1]. But, of course, “x-x” is always 0, so this is huge over-approximation. People have developed all kinds of techniques for capturing the correlations between variables in evaluating circuits on intervals. And you can always shrink the error by splitting the input intervals and doing “branch and bound”. But all of those are just particular implementation choices in proving bounds on the output of the function. Advanced AI theorem provers (like AlphaProof) can use very sophisticated techniques to accurately get the true bound on the output of a function.
But, it may be that it’s not a fruitful approach to try to bound the behavior of complex neural nets such as transformers. In our approach, we don’t need to understand or constrain a complex AI generating a solution or a control policy. Rather, we require the AI to generate a program, control policy, or simple network for taking actions in the situation of interest. And we force it to generate a proof that it satisfies given safety requirements. If it can’t do that, then it has no business taking actions in a dangerous setting.
This seems near certain to be cripplingly uncompetitive[1] even with massive effort on improving verification.
If applied to all potentialy dangerous applications.
I agree you can do better than naive interval propagation by taking into account correlations. However, it will be tricky to get a much better bound while avoiding having this balloon in time complexity (all possible correlations requires exponentional time).
More strongly, I think that if an adversary controlled the non-determinism (e.g. summation order) in current efficient inference setups, they would actually be able to strongly influence the AI to an actualy dangerous extent—we are likely to depend on this non-determinism being non-adversarial (which is a reasonable assumption to make).
(And you can’t prove a false statement...)
See also heuristic arguments which try to resolve this sort of issue by assuming a lack of structure.