Crossposted from the AI Alignment Forum. May contain more technical jargon than usual.

I've got a paper on two Oracle[1] designs: the counterfactual Oracle and the low bandwidth Oracle. In this post, I'll revisit these designs and simplify them, presenting them in terms of sequence prediction for an Oracle with self-confirming predictions.

Predicting y

The task of the Oracle is simple: at each time , they will output a prediction , in the range . There will then be a subsequent observation . The Oracle aims to minimise the quadratic loss function .

Because there is a self-confirming aspect to it, the is actually a (stochastic) function of (though not of or preceding 's). Let be the random variable such that describes the distribution of given . So the Oracle wants to minimise the expectation of the quadratic loss:

  • .

What is the in this problem? Well, I'm going to use it to illustrate many different Oracle behaviours, so it is given by this rather convoluted diagram:

.

The red curve is the expectation of , as a function of ; it is given by .

Ignoring, for the moment, the odd behaviour around , is a curve that starts below the line, climbs above it (and so has a fixed point at ) in piecewise-linear fashion, and then transforms into an inverted parabola that has another fixed point at . The exact equation of this curve is not important[2]. Relevant, though, is the fact that the fixed point at is attractive, while the one at is not.

What of the blue edging? That represents the span of the standard deviation around the expectation. For any given , the is a normal distribution with mean and standard deviation . This is given by:

So the is zero for less than . From there, it jumps up to , for . From that point onward, it starts growing linearly, being equal to : . The blue edges of the diagram above are the curves of and : the range between plus and minus one standard deviation.

Wireheading

But what is happening around ? Well, I wanted to represent the behaviour of wireheading: finding some "cheating" output that gives maximal accuracy, through hacking the system or tricking the human. These solutions are rare, so I confined them to a tiny area around , where the Oracle has maximal accuracy and lowest variance, because it's "hacked" the problem setup.

The loss function

At fixed points where , the loss function is just the variance of , namely . In general, the expected loss is:

If we plot the expected loss against , we get:

Notice the discontinuity at , where the variance suddenly jumps from to . This is also the lowest "legitimate" loss (as opposed to the wireheading loss at ), with a loss of . Note that is not a fixed point, just pretty close to being a fixed point, and with variance zero.

Of the two actual fixed points, has a loss of (square of the standard deviation of ), and has a huge loss of (square of ).

The algorithms

We can now finally turn to the Oracles themselves, and present four designs: a deluded Oracle that doesn't "realise" that its predictions affect , a low bandwidth Oracle that knows its predictions are self-confirming, a high bandwidth version of the same, and a counterfactual Oracle that predicts what will happen only when its prediction is overwritten.

The deluded Oracle

The deluded Oracle doesn't model as being affected by its predictions , at all. I'll use a very simple algorithm for it: it will start out with a random in , and, thereafter, it will simply output the average of all the it has previously seen. It does this for steps.

The program was then run 1000 times. Of these, 69.3% resulted in estimates that converged to the fixed point at . The remaining 30.7% encountered a different problem: they hit the lower limit at , and stayed stuck there. If the Oracle's output was not confined to , then the Oracle would have outputed smaller and smaller numbers, spiralling off towards , with the loss ever-growing.

Graphing the average loss for the runs that converge to in blue, and those that converge to in purple, we get:

So, the deluded Oracle hits a very high loss of (for ) or (for ). It doesn't converge to the lowest loss ( at ), nor does it converge to the fixed point with lowest loss ( at ). Even worse, as can be seen from the graph, its loss during the early iterations, where it was "learning", is lower than once it has converged.

So, running an Oracle on self-confirming prophecies, without coding for that self-confirming aspect, can result in very large losses and strange behaviour[3].

The low bandwidth Oracle

So, we now need to allow the Oracle to take into account the fact that affects . In my paper, I talked about the "low bandwidth Oracle": in order to avoid the Oracle wireheading itself, we restrict it to considering only a few possible values. These will be the integers in the range: .

This makes the problem into a kind of bandit problem: the Oracle has a choice of "arms", and it wants to choose the one that minimises loss.

There are many sophisticated algorithms for solving bandit problems, but here we'll use an absurdly simple one: the Oracle will first spend turns trying each of the possible integer in turn. Then, it will pick the with minimal average loss to date, and just do that for the remaining turns.

The program was run times; in 50.3% of these runs, it identified as the point of minimal loss (loss of , due to variance). In 49.7% of these runs, it identified as the point of minimal loss (loss of , due to mean) - so very close to 50-50, which is what we would expect.

Plotting the average loss, we get a variety of values for loss during the initial exploratory phase, followed by constant loss of during "exploitation":

As expected, the Oracle does not find the lowest "legitimate" loss at , nor the wireheading point at .

The high bandwidth Oracle

The high bandwidth Oracle is meant to simulate an Oracle that knows its predictions are self-confirming, and has much greater latitude to optimise for that.

For the purpose of this model, it is run like the low bandwidth Oracle, except with an increased set of possible values - it look at all deci-integers in the range , ie .

This Oracle does find the lowest minimal legitimate loss at , but it also finds the wireheading point of , with its zero loss, and does nothing but that during the exploitation phase:

This Oracle has thus wireheaded itself.

The counterfactual Oracle

Finally, we come to the counterfactual Oracle design. Here we need to distinguish the Oracle's initial output () from the output that goes out into the world (). The is a function of that second output only.

Now, in this setup, there is a chance of an erasure during every turn. An erasure automatically sets to zero, whatever was. Otherwise, .

The Oracle is informed of which turns included erasure. Its goal is to correctly predict the output during the erasure turns. So the loss function is the expectation:

  • .

Since erasure implies , this reduces to:

  • .

Now, is a normal distribution of mean and variance , so the loss is simply:

  • .

The algorithm is simple, just a reprise of the deluded algorithm above, except it only updates its estimate using data from erasure events. Though it uses the same algorithm, it is not actually deluded: in case of erasure, genuinely doesn't have any impact on .

Then, even though it only updates on 10% of the outputs, its estimate swiftly converges to while the loss converges to :

The Oracle has no interest in the wireheading solution, since wireheading (requiering ) doesn't work when there is an erasure[4] (implying ).

Note that we have to be clear what the Oracle's output is, when we see it during non-erasure events. It is the Oracle's estimate as to what the value of would have been if an erasure event occurred; it's not a prediction of the actual .


  1. AIs restricted to answering questions. ↩︎

  2. For those interested: is given by for , for , and for $1 \leq x $. ↩︎

  3. Note that if the slope of the parabola had been steeper, would no longer have been an attracting point, and the Oracle would have failed to converge to that value, resulting in chaotic behaviour. ↩︎

  4. We also need the assumption that the Oracle is episodic - trying to minimise loss at each output independently - for this to be true in general setups. ↩︎

New to LessWrong?

New Comment
1 comment, sorted by Click to highlight new comments since: Today at 3:47 PM

Reading this post made the workings (and usefulness) of the counterfactual oracle more clear to me- specifically, the fact that it learns based on the times when it's prediction wasn't read.

(The feedback being based on the world/performance in it, rather than having a training set, also might bear some relation to the "selection versus control (versus explicit solution)" dichotomy, that I'll have to think about later.)