I’ve got a paper on two Oracle[1] designs: the counterfactual Oracle and the low bandwidth Oracle. In this post, I’ll revisit these designs and simplify them, presenting them in terms of sequence prediction for an Oracle with self-confirming predictions.
Predicting y
The task of the Oracle is simple: at each time t, they will output a prediction xt, in the range [−5,5]. There will then be a subsequent observation yt. The Oracle aims to minimise the quadratic loss function l(xt,yt)=(xt−yt)2.
Because there is a self-confirming aspect to it, the yt is actually a (stochastic) function of xt (though not of xt−1 or preceding xi’s). Let Yt be the random variable such that Yt(xt) describes the distribution of yt given xt. So the Oracle wants to minimise the expectation of the quadratic loss:
L(xt)=(xt−Yt(xt))2.
What is the Yt in this problem? Well, I’m going to use it to illustrate many different Oracle behaviours, so it is given by this rather convoluted diagram:
.
The red curve is the expectation of Yt, as a function of xt; it is given by f(x)=E(Yt|x=xt).
Ignoring, for the moment, the odd behaviour around 2.5, y=f(x) is a curve that starts below the y=x line, climbs above it (and so has a fixed point at x=−2) in piecewise-linear fashion, and then transforms into an inverted parabola that has another fixed point at x=4. The exact equation of this curve is not important[2]. Relevant, though, is the fact that the fixed point at x=4 is attractive, while the one at x=−2 is not.
What of the blue edging? That represents the span of the standard deviation around the expectation. For any given x, the Y(x) is a normal distribution with mean f(x) and standard deviation g(x). This g(x) is given by:
So the g(x) is zero for x less than −2.5. From there, it jumps up to 1, for −2.5<x≤1. From that point onward, it starts growing linearly, being equal to x: g(x)=x. The blue edges of the diagram above are the curves of f(x)+g(x) and f(x)−g(x): the range between plus and minus one standard deviation.
Wireheading
But what is happening around x=2.5? Well, I wanted to represent the behaviour of wireheading: finding some “cheating” output that gives maximal accuracy, through hacking the system or tricking the human. These solutions are rare, so I confined them to a tiny area around x=2.5, where the Oracle has maximal accuracy and lowest variance, because it’s “hacked” the problem setup.
The loss function
At fixed points where x=f(x), the loss function is just the variance of x, namely g(x)2. In general, the expected loss is:
Notice the discontinuity at x=−2.5, where the variance suddenly jumps from 0 to 1. This is also the lowest “legitimate” loss (as opposed to the wireheading loss at x=2.5), with a loss of 0.25. Note that x=−2.5 is not a fixed point, just pretty close to being a fixed point, and with variance zero.
Of the two actual fixed points, x=−2 has a loss of 1 (square of the standard deviation of 1), and x=4 has a huge loss of 16 (square of 4).
The algorithms
We can now finally turn to the Oracles themselves, and present four designs: a deluded Oracle that doesn’t “realise” that its predictions xt affect yt, a low bandwidth Oracle that knows its predictions are self-confirming, a high bandwidth version of the same, and a counterfactual Oracle that predicts what will happen only when its prediction is overwritten.
The deluded Oracle
The deluded Oracle doesn’t model yt as being affected by its predictions xt, at all. I’ll use a very simple algorithm for it: it will start out with a random x0 in [−5,5], and, thereafter, it will simply output the average of all the yt it has previously seen. It does this for 10,000 steps.
The program was then run 1000 times. Of these, 69.3% resulted in estimates that converged to the fixed point at x=4. The remaining 30.7% encountered a different problem: they hit the lower limit at x=−5, and stayed stuck there. If the Oracle’s output was not confined to [−5,5], then the Oracle would have outputed smaller and smaller numbers, spiralling off towards −∞, with the loss ever-growing.
Graphing the average loss for the runs that converge to 4 in blue, and those that converge to −5 in purple, we get:
So, the deluded Oracle hits a very high loss of 16 (for x=4) or 9 (for x=−5). It doesn’t converge to the lowest loss (0.25 at x=−2.5), nor does it converge to the fixed point with lowest loss (1 at x=−2). Even worse, as can be seen from the graph, its loss during the early iterations, where it was “learning”, is lower than once it has converged.
So, running an Oracle on self-confirming prophecies, without coding for that self-confirming aspect, can result in very large losses and strange behaviour[3].
The low bandwidth Oracle
So, we now need to allow the Oracle to take into account the fact that xt affects yt. In my paper, I talked about the “low bandwidth Oracle”: in order to avoid the Oracle wireheading itself, we restrict it to considering only a few possible values. These will be the 11 integers in the range: {−5,−4,−3,−2,−1,0,1,2,3,4,5}.
This makes the problem into a kind of bandit problem: the Oracle has a choice of 11 “arms”, and it wants to choose the one that minimises loss.
There are many sophisticated algorithms for solving bandit problems, but here we’ll use an absurdly simple one: the Oracle will first spend 1,000 turns trying each of the possible integer xt in turn. Then, it will pick the xt with minimal average loss to date, and just do that for the remaining 9,000 turns.
The program was run 1000 times; in 50.3% of these runs, it identified x=−2 as the point of minimal loss (loss of 1, due to variance). In 49.7% of these runs, it identified x=−3 as the point of minimal loss (loss of 1, due to mean) - so very close to 50-50, which is what we would expect.
Plotting the average loss, we get a variety of values for loss during the initial exploratory phase, followed by constant loss of 1 during “exploitation”:
As expected, the Oracle does not find the lowest “legitimate” loss at x=−2.5, nor the wireheading point at x=2.5.
The high bandwidth Oracle
The high bandwidth Oracle is meant to simulate an Oracle that knows its predictions are self-confirming, and has much greater latitude to optimise for that.
For the purpose of this model, it is run like the low bandwidth Oracle, except with an increased set of possible values—it look at all 101 deci-integers in the range [−5,5], ie {−5.0,−4.9,−4.8,…,4.9,5.0}.
This Oracle does find the lowest minimal legitimate loss at x=−2.5, but it also finds the wireheading point of x=2.5, with its zero loss, and does nothing but that during the exploitation phase:
This Oracle has thus wireheaded itself.
The counterfactual Oracle
Finally, we come to the counterfactual Oracle design. Here we need to distinguish the Oracle’s initial output (xt) from the output that goes out into the world (x′t). The Yt is a function of that second output only.
Now, in this setup, there is a 1/10 chance of an erasure during every turn. An erasure automatically sets x′t to zero, whatever xt was. Otherwise, x′t=xt.
The Oracle is informed of which turns included erasure. Its goal is to correctly predict the output during the erasure turns. So the loss function is the expectation:
E[((Yt(x′t)−xt)2|erasure(t)].
Since erasure implies x′t=0, this reduces to:
E[((Yt(0)−xt)2|erasure(t)].
Now, Yt(0) is a normal distribution of mean 1 and variance 1, so the loss is simply:
(xt−1)2.
The algorithm is simple, just a reprise of the deluded algorithm above, except it only updates its estimate using data from erasure events. Though it uses the same algorithm, it is not actually deluded: in case of erasure, xt genuinely doesn’t have any impact on Yt.
Then, even though it only updates on 10% of the outputs, its estimate swiftly converges to 1 while the loss converges to 0:
The Oracle has no interest in the wireheading solution, since wireheading (requiering x′t=2.5) doesn’t work when there is an erasure[4] (implying x′t=0).
Note that we have to be clear what the Oracle’s output is, when we see it during non-erasure events. It is the Oracle’s estimate as to what the value of ytwould have been if an erasure event occurred; it’s not a prediction of the actual yt.
Note that if the slope of the parabola had been steeper, x=4 would no longer have been an attracting point, and the Oracle would have failed to converge to that value, resulting in chaotic behaviour.
Self-confirming prophecies, and simplified Oracle designs
I’ve got a paper on two Oracle[1] designs: the counterfactual Oracle and the low bandwidth Oracle. In this post, I’ll revisit these designs and simplify them, presenting them in terms of sequence prediction for an Oracle with self-confirming predictions.
Predicting y
The task of the Oracle is simple: at each time t, they will output a prediction xt, in the range [−5,5]. There will then be a subsequent observation yt. The Oracle aims to minimise the quadratic loss function l(xt,yt)=(xt−yt)2.
Because there is a self-confirming aspect to it, the yt is actually a (stochastic) function of xt (though not of xt−1 or preceding xi’s). Let Yt be the random variable such that Yt(xt) describes the distribution of yt given xt. So the Oracle wants to minimise the expectation of the quadratic loss:
L(xt)=(xt−Yt(xt))2.
What is the Yt in this problem? Well, I’m going to use it to illustrate many different Oracle behaviours, so it is given by this rather convoluted diagram:
.The red curve is the expectation of Yt, as a function of xt; it is given by f(x)=E(Yt|x=xt).
Ignoring, for the moment, the odd behaviour around 2.5, y=f(x) is a curve that starts below the y=x line, climbs above it (and so has a fixed point at x=−2) in piecewise-linear fashion, and then transforms into an inverted parabola that has another fixed point at x=4. The exact equation of this curve is not important[2]. Relevant, though, is the fact that the fixed point at x=4 is attractive, while the one at x=−2 is not.
What of the blue edging? That represents the span of the standard deviation around the expectation. For any given x, the Y(x) is a normal distribution with mean f(x) and standard deviation g(x). This g(x) is given by:
So the g(x) is zero for x less than −2.5. From there, it jumps up to 1, for −2.5<x≤1. From that point onward, it starts growing linearly, being equal to x: g(x)=x. The blue edges of the diagram above are the curves of f(x)+g(x) and f(x)−g(x): the range between plus and minus one standard deviation.
Wireheading
But what is happening around x=2.5? Well, I wanted to represent the behaviour of wireheading: finding some “cheating” output that gives maximal accuracy, through hacking the system or tricking the human. These solutions are rare, so I confined them to a tiny area around x=2.5, where the Oracle has maximal accuracy and lowest variance, because it’s “hacked” the problem setup.
The loss function
At fixed points where x=f(x), the loss function is just the variance of x, namely g(x)2. In general, the expected loss is:
E[(Yt−xt)2|xt]=E[Y2t|xt]−2xtE[Yt|xt]+x2t=Var[Yt|xt]+(E[Yt|xt])2−2xtE[Yt|xt]+x2t=g(xt)2+(f(xt)−xt)2.
If we plot the expected loss against x, we get:
Notice the discontinuity at x=−2.5, where the variance suddenly jumps from 0 to 1. This is also the lowest “legitimate” loss (as opposed to the wireheading loss at x=2.5), with a loss of 0.25. Note that x=−2.5 is not a fixed point, just pretty close to being a fixed point, and with variance zero.
Of the two actual fixed points, x=−2 has a loss of 1 (square of the standard deviation of 1), and x=4 has a huge loss of 16 (square of 4).
The algorithms
We can now finally turn to the Oracles themselves, and present four designs: a deluded Oracle that doesn’t “realise” that its predictions xt affect yt, a low bandwidth Oracle that knows its predictions are self-confirming, a high bandwidth version of the same, and a counterfactual Oracle that predicts what will happen only when its prediction is overwritten.
The deluded Oracle
The deluded Oracle doesn’t model yt as being affected by its predictions xt, at all. I’ll use a very simple algorithm for it: it will start out with a random x0 in [−5,5], and, thereafter, it will simply output the average of all the yt it has previously seen. It does this for 10,000 steps.
The program was then run 1000 times. Of these, 69.3% resulted in estimates that converged to the fixed point at x=4. The remaining 30.7% encountered a different problem: they hit the lower limit at x=−5, and stayed stuck there. If the Oracle’s output was not confined to [−5,5], then the Oracle would have outputed smaller and smaller numbers, spiralling off towards −∞, with the loss ever-growing.
Graphing the average loss for the runs that converge to 4 in blue, and those that converge to −5 in purple, we get:
So, the deluded Oracle hits a very high loss of 16 (for x=4) or 9 (for x=−5). It doesn’t converge to the lowest loss (0.25 at x=−2.5), nor does it converge to the fixed point with lowest loss (1 at x=−2). Even worse, as can be seen from the graph, its loss during the early iterations, where it was “learning”, is lower than once it has converged.
So, running an Oracle on self-confirming prophecies, without coding for that self-confirming aspect, can result in very large losses and strange behaviour[3].
The low bandwidth Oracle
So, we now need to allow the Oracle to take into account the fact that xt affects yt. In my paper, I talked about the “low bandwidth Oracle”: in order to avoid the Oracle wireheading itself, we restrict it to considering only a few possible values. These will be the 11 integers in the range: {−5,−4,−3,−2,−1,0,1,2,3,4,5}.
This makes the problem into a kind of bandit problem: the Oracle has a choice of 11 “arms”, and it wants to choose the one that minimises loss.
There are many sophisticated algorithms for solving bandit problems, but here we’ll use an absurdly simple one: the Oracle will first spend 1,000 turns trying each of the possible integer xt in turn. Then, it will pick the xt with minimal average loss to date, and just do that for the remaining 9,000 turns.
The program was run 1000 times; in 50.3% of these runs, it identified x=−2 as the point of minimal loss (loss of 1, due to variance). In 49.7% of these runs, it identified x=−3 as the point of minimal loss (loss of 1, due to mean) - so very close to 50-50, which is what we would expect.
Plotting the average loss, we get a variety of values for loss during the initial exploratory phase, followed by constant loss of 1 during “exploitation”:
As expected, the Oracle does not find the lowest “legitimate” loss at x=−2.5, nor the wireheading point at x=2.5.
The high bandwidth Oracle
The high bandwidth Oracle is meant to simulate an Oracle that knows its predictions are self-confirming, and has much greater latitude to optimise for that.
For the purpose of this model, it is run like the low bandwidth Oracle, except with an increased set of possible values—it look at all 101 deci-integers in the range [−5,5], ie {−5.0,−4.9,−4.8,…,4.9,5.0}.
This Oracle does find the lowest minimal legitimate loss at x=−2.5, but it also finds the wireheading point of x=2.5, with its zero loss, and does nothing but that during the exploitation phase:
This Oracle has thus wireheaded itself.
The counterfactual Oracle
Finally, we come to the counterfactual Oracle design. Here we need to distinguish the Oracle’s initial output (xt) from the output that goes out into the world (x′t). The Yt is a function of that second output only.
Now, in this setup, there is a 1/10 chance of an erasure during every turn. An erasure automatically sets x′t to zero, whatever xt was. Otherwise, x′t=xt.
The Oracle is informed of which turns included erasure. Its goal is to correctly predict the output during the erasure turns. So the loss function is the expectation:
E[((Yt(x′t)−xt)2|erasure(t)].
Since erasure implies x′t=0, this reduces to:
E[((Yt(0)−xt)2|erasure(t)].
Now, Yt(0) is a normal distribution of mean 1 and variance 1, so the loss is simply:
(xt−1)2.
The algorithm is simple, just a reprise of the deluded algorithm above, except it only updates its estimate using data from erasure events. Though it uses the same algorithm, it is not actually deluded: in case of erasure, xt genuinely doesn’t have any impact on Yt.
Then, even though it only updates on 10% of the outputs, its estimate swiftly converges to 1 while the loss converges to 0:
The Oracle has no interest in the wireheading solution, since wireheading (requiering x′t=2.5) doesn’t work when there is an erasure[4] (implying x′t=0).
Note that we have to be clear what the Oracle’s output is, when we see it during non-erasure events. It is the Oracle’s estimate as to what the value of yt would have been if an erasure event occurred; it’s not a prediction of the actual yt.
AIs restricted to answering questions.
For those interested: f(x) is given by 2x+2 for x≤−1, x+1 for −1≤x≤1, and −x2/2+x19/6−4/6 for $1 \leq x $.
Note that if the slope of the parabola had been steeper, x=4 would no longer have been an attracting point, and the Oracle would have failed to converge to that value, resulting in chaotic behaviour.
We also need the assumption that the Oracle is episodic—trying to minimise loss at each output independently—for this to be true in general setups.