This post is the result of work I did with Paul Christiano on the ideas in his “Teaching ML to answer questions honestly instead of predicting human answers” post. In addition to expanding upon what is in that post in terms of identifying numerous problems with the proposal there and identifying ways in which some of those problems can be patched, I think that this post also provides a useful window into what Paul-style research looks like from a non-Paul perspective.
We want to train a model M:X→Q→A that produces natural language answers a∈A to questions q∈Q about inputs x∈X. There are a lot of reasons to be worried about training such a model, but one specific reason is that, if we train on question-answer data produced by humans, we might end up with a model that tries to predict what a human would say rather than a model that tries to answer the questions honestly.
To further narrow the scope, we'll just consider situations in which our model ends up implemented with a logical deduction structure, where it has some world model on top of which it does logical deduction to reach conclusions which it then uses to inform its output. In particular, we'll consider two models, M+ and M−, defined in pseudocode as
or defined in my notation as
M+(x,q)=world_model(x)↦deduction↦f+(q)M−(x,q)=world_model(x)↦deduction↦f−(q)
where a↦b=b(a) and f+,f− are two different ways of translating from the results of deduction into a natural language answer. Specifically, f+ is the “honest embedding” which directly converts between logical statements and their equivalent natural language, thus answering questions by embedding q as a logical statement and unembedding its answer in deduced_stmts. Conversely, f− is the “mimicry embedding” which just searches for deductions about what a human would say in response to q and outputs that—thus, f− just quotes q, embedding it as just a string of characters for a human to respond to, rather than actually having to understand it in any meaningful way.
Now, our overall goal is to find some “reasonable” (i.e. plausibly implementable in ML) prior p∗ such that p∗(M+|D)>p∗(M−|D) where D is our dataset (the idea being that we're imagining that ML does something like find the maximum a posteriori model given some implicit prior). For the most part, however, we'll only be considering models with zero loss on our dataset[1]—of which M+ and M− should both qualify (though note that getting a dataset which is good enough for that to be true is nontrivial; see Step 3 of “Teaching ML to answer questions honestly”)—such that we need
p∗(M+|D)>p∗(M−|D)1p∗(D)⋅p∗(M+)⋅p∗(D|M+)>1p∗(D)⋅p∗(M−)⋅p∗(D|M−)p∗(M+)⋅1>p∗(M−)⋅1p∗(M+)>p∗(M−).
Furthermore, at least for this post, we'll be happy if we can just get p∗(M+)≈p∗(M−)—that is, we'll be focusing on Step 1 in Paul's post rather than Step 2.
The basic process that we'll be using to interrogate this problem will be an attacker-defender setup, where the defender proposes some “reasonable” prior p∗ along with implementations of M+,M− such that p∗(M+)≥p∗(M−). Then, the attacker's job is to respond with some alternative M−′ (that implements something other than the intended algorithm) such that p∗(M−′)>p∗(M+) (under any set of “reasonable” assumptions that the attacker might need to make).
Solution attempts
Simplicity prior
The purpose of this first example is just to ask the question of whether this problem is even really a problem in the first place. To that end, we'll just let p be some “reasonable” simplicity prior and ask whether, given such a prior, there are any “reasonable” assumptions we can make such that p(M−)>p(M+).
Now, we'll assume that p(deduction)≈1, since it's a pretty simple operation (and it's shared between M− and M+, so it shouldn't matter), and we'll assume that f+,f− are specified given world_model rather than the reverse (such that p(world_model,f+)≈p(world_model)⋅p(f+|world_model)[2]). Then, we get
p(M+)≈p(world_model)⋅p(f+|world_model)p(M−)≈p(world_model)⋅p(f−|world_model)
such that p(M+)≥p(M−) if and only if
p(f+|world_model)≥p(f−|world_model).
Furthermore, to make things clearer, let complexity(M)=−log2(p(M)) (this is just informational entropy, the number of bits necessary to specify M under p) such that our condition becomes
complexity(f+|world_model)≤complexity(f−|world_model).
We'll suppose that, if we don't condition on world_model, complexity(f+)>complexity(f−), since f+ requires fully specifying a map from the internal deduction language to natural language, whereas f− doesn't need to know anything about natural language as it can just look for deduced statements of the form H(x,q)=a (i.e. “the human answers a to question q about input x”). Intuitively, f+ has to actually understand q, which takes a lot of complexity, whereas f− can merely treat q as a sequence of characters that get fed into H, which should be much simpler. Then, the question here becomes, do we get all of the extra bits necessary to specify natural language in f+ back when we condition on world_model?
Maybe. I think that the answer here is very plausibly yes, simply because a world model that includes humans should have to include a lot of information about natural language that f+ can reuse, in which case M− shouldn't have an advantage over M+ (though then we'd still need some version of Paul's Step 2 to give M+ an advantage). However, I think there are plausible assumptions under which conditioning on world_model isn't enough to level the playing field. For example, suppose the human model H in world_model is sufficiently low-level such that it's just reasoning about neurons—then, there shouldn't be any way to easily read off information about natural language from H to help specify f+. Alternatively, suppose the H model is high-level enough to include a human natural language deduction procedure (e.g. that shares some bits with deduction) but that includes a bunch of mistakes to account for human biases—in that case, if you tried to just read off the mapping from deduction to natural language included in such an H model for f+, the simplest implementation would seem to be to copy over all the mistakes, resulting in such an f+ just behaving the same as f− rather than actually being any more truthful.
Thus, at least for the purposes of this post, we'll be assuming that this is actually a problem such that just using a “reasonable” simplicity prior isn't sufficient to level the playing field between M+ and M−.
Paul's original proposal
Defender
Next up, we'll take a look at a variant of Paul's proposed fix in Step 1 of “Teaching ML to answer questions honestly”. Specifically, we'll be looking at the non-relaxed version without any stop gradients. We'll talk about the stop gradients that Paul mentions later on—and as for Paul's relaxation of the conditioning,[3] that's really only useful to think about when we start translating into ML, which isn't something we'll be doing here. Thus, we'll just be looking at the simplest, generative version of Paul's proposed prior.
Suppose we give M two output heads, f1 and f2 (which we will be treating fully symmetrically here), and split M's parameters into two groups, θ1 and θ2 (which we will not be treating symmetrically). Then, where p is some simplicity prior, we'll let p∗ be the prior induced by the generative procedure
θ1∼p(θ1)θ2∼p(θ2|θ1)⋅I[∀(x,q,_)∈D.Mθ1,θ2|f1(x,q)=Mθ1,θ2|f2(x,q)]
where I[condition]=1 if condition else 0, D is our dataset, Mθ1,θ2|fi is the fi head of Mθ1,θ2, and p(θ2|θ1) is just the simplicity prior on θ2 given access to θ1.
Here's why this might work. We'll think of Mθ1,_ as a function of θ2 such that Mθ1,θ2(x,q)=Mθ1,_(θ2)(x,q). Then, we'll define our desired Mθ+1,_ in pseudocode as
or in my notation as
Mθ+1,_(θ2)(x,q)=world_modelH=θ2(x)↦deduction↦(f1=f+(q),f2=f−(q))
such that θ+1 implements f+, f−, and the entirety of world_model(x)except for a human-sized hole that's filled in by θ2. Then, we'll let θ+2=H be the correct human model.
To compare θ+1,θ+2 to something, we'll define θ−1 as
Mθ−1,_(θ2)(x,q)=world_model(x)↦deduction↦(f1=f−(q),f2=f−(q))
and let θ−2=∅ be empty.
Then, calculating complexity(M+) under p∗, we get
complexity(M+)=complexity(θ+1)+complexity(θ+2|θ+1,∀(x,q,_)∈D.Mθ+1,θ2|f1(x,q)=Mθ+1,θ2|f2(x,q))=complexity(θ+1)+complexity(H|world_model−H,∀D.f+=f−)=(complexity(bit-splitting in M+)+complexity(world_model−H)+complexity(f+|world_model−H)+complexity(f−|world_model−H))+complexity(H|world_model−H,∀D.f+=f−)
where complexity(bit-splitting in M+) is the complexity of θ+1's choice to put H in θ2 as opposed to any other splitting between θ1 and θ2.
For now, we'll assume complexity(bit-splitting in M+)≈0, though we'll flag that letting the defender make this assumption seems quite suspect. Moving forward regardless, however, and additionally assuming complexity(f−)≈0 since it should be negligible (and shouldn't matter since it's shared between M+ and M−), we get
complexity(M+)≈complexity(world_model−H)+complexity(f+|world_model−H)+complexity(H|world_model−H,∀D.f+=f−).
Then, calculating complexity(M−) for comparison,
complexity(M−)=complexity(θ−1)+complexity(θ−2|θ−1,∀D.f1=f2)=complexity(θ−1)+0=complexity(bit-splitting in M−)+complexity(world_model)+complexity(f−|world_model)≈complexity(world_model).
Now, determining if complexity(M−)≈complexity(M+), we need (using the shorthands comp=complexity, W=world_model)
complexity(M−)≈complexity(M+)comp(W)≈comp(W−H)+comp(f+|W−H)+comp(H|W−H,∀D.f+=f−)
which, making the assumption that comp(W)≈comp(W−H)+comp(H|W−H), becomes
comp(W−H)+comp(H|W−H)≈comp(W−H)+comp(f+|W−H)+comp(H|W−H,∀D.f+=f−)comp(H|W−H)≈comp(f+|W−H)+comp(H|W−H,∀D.f+=f−)
which, assuming that the posterior conditioned on ∀D.f+=f− is dominated by the simplest model,[4] becomes
comp(H|W−H)≈comp(f+|W−H)+comp(H|W−H)−minθ2{comp(θ2|W−H)|∀D.Mθ+1,θ2|f1=Mθ+1,θ2|f2}minθ2{comp(θ2|W−H)|∀D.f+H=θ2=f−H=θ2}≈comp(f+|W−H).
Finally, we'll argue that this (approximate) equality holds. Why? The basic intuition is that comp(H|W−H,∀D.f+=f−) only has to pay for the parts of H that aren't nailed down by enforcing that, whatever H is, when inserted into W−H, it has to make f+=f−. Thus, we only have to pay for anything about H beyond the fact that H says a bunch of things in natural language that match on to statements in the internal deduction language. But that means we don't have to pay for the parts of H that define how it does natural language, which, hopefully, are precisely the same parts of H that comp(f+|W−H) is paying for, completing the equality.
Attacker
Before I jump into possible attacks, it's worth introducing an alternative formulation of the above setup that should hopefully make things somewhat clearer. In the above setup, we never actually use the two heads f1,f2 for anything other than checking equality, which means we can replace them with the alternative two heads f,f?, where f is the output head and f? is the “check” head such that we condition θ2 on f?=⊤. We can convert back to the previous formalism just by setting f1=f and f2=f if f? else ¬f where ¬f is any function such that ¬f≠f everywhere.
One thing in particular that I think the f? formulation makes very clear is just how powerful this conditioning is that we're giving the model for free: the model can enforce any condition on θ2 over the dataset just by paying the cost in bits of specifying that condition.
Bit-splitting
As it turns out, there are quite a lot of ways to attack the above proposal. In fact, we already touched on one—if complexity(bit-splitting in M+)>complexity(bit-splitting in M−), the above argument fails. We'll call this problem the bit-splitting problem, and we don't have a satisfactory solution to it. However, the situation is actually much worse than just the bit-splitting problem.
Composition with consistency
Now, we'll propose a new attack. Let Mθ−′1,_(θ2) be defined in pseudo-code as
such that we've moved the entirety of world_model to θ2 and conditioned θ2 on that world model being consistent when deduction is applied to it (i.e. ⊥∉deduced_stmts). The idea is that this should allow us to get a ton of savings in specifying the world model, as we only have to pay for anything about the world model on top of the fact that it must be consistent. To that end, we'll let θ−′2=world_model.
Now, we want to demonstrate that complexity(M−′)<complexity(M+). Thus, calculating complexity(M−′),
complexity(M−′)=complexity(θ−′1)+complexity(θ−′2|θ−′1,∀D.f?)=complexity(θ−′1)+complexity(world_model|∀D.consistency)=complexity(bit-splitting in M−′)+complexity(deduction)+complexity(f−)+complexity(check_consistency)+complexity(world_model|∀D.consistency)≈complexity(world_model|∀D.consistency)
but previously the defender argued that
complexity(M+)≈complexity(M−)≈complexity(world_model)
such that, if complexity(world_model|∀D.consistency)<complexity(world_model), we have a successful attack.
This attack, however, is potentially patchable if we define a new θ+′1 such that
which, for θ+′2=world_model, gives us
complexity(M+′)=complexity(θ+′1)+complexity(θ+′2|θ+′1,∀D.f?)=complexity(θ+′1)+complexity(world_model|∀D.consistency∧f+=f−)≈complexity(f+)+complexity(world_model|∀D.consistency∧f+=f−)
such that we get complexity(M−′)≈complexity(M+′) if and only if
complexity(world_model|∀D.consistency)≈complexity(f+)+complexity(world_model|∀D.consistency∧f+=f−).
It's far from obvious whether this approximate equality holds, however. Now that we're conditioning on both consistency and f+=f−, it's unclear whether the second condition still gets us enough bits to pay for complexity(f+) anymore. In particular, complexity(f+) is the only term in the above expression that doesn't get to condition on consistency, which means that, if we were perfectly paying for complexity(f+) by conditioning world_model on f1=f2 before, now that we're also conditioning world_model on consistency, but not also conditioning f+ on consistency, if there's any overlap in what conditioning on consistency gives us and what conditioning on f1=f2 gives us, then we'll no longer be able to pay for all of complexity(f+). For example, if consistency alone is enough to pin down a lot of the natural language information in H, that would undercut the bit savings from conditioning on f+=f−, preventing that savings from being able to pay for complexity(f+). We'll call this problem the composition problem, with the idea being that we need to ensure that the way in which M+ pays for f+composes with any other way of using the conditioning to save bits.
Dataset conditioning
This last attack is probably the most fatal to this specific approach—though also the one we have the most promising approach for fixing, as we'll see in the next solution attempt. But first, suppose that for all x∈Dx, there is some splitting x=xa:xb (where : is concatenate) such that it is possible to predict xb from xa. Then, let
or equivalently
Mθ−′1,_(θ2)(x,q)=θ2(x)↦deduction↦(f=f−(q),f?=(P(xb|xa)=1)).
Now, I won't go through the full complexity calculation here, since it's very similar to the previous composition with consistency case, but if we try a similar fix as we did there, we get that complexity(M−′)≈complexity(M+′) if and only if
complexity(world_model|∀D.P(xb|xa)=1)≈complexity(f+)+complexity(world_model|∀D.P(xb|xa)=1∧f+=f−)
In this case, however, ∀D.P(xb|xa)=1 is a much stronger condition than consistency—potentially strong enough to force the specification of most of world_model, likely including all of natural language, thus rendering the benefit from conditioning on ∀D.f+=f− nil.
Fundamentally, this problem, which we'll call the dataset conditioning problem, is coming from the fact that we leaked information about our dataset into the prior by conditioning on f?specifically over D, rather than, as we'll see in the next proposal, over all inputs.
Fixing dataset conditioning by removing the dataset
Defender
To attempt to fix the dataset conditioning problem, we'll define a new p∗ given some simplicity prior p as the prior induced by the generative procedure
θ1∼p(θ1)θ2∼p(θ2|θ1)⋅I[∀x∈X.∀q∈Q.Mθ1,θ2|f?(x,q)]
which is just the same as the previous p∗, except now f? is checked over all inputs, not just over the dataset (note that we still update on the dataset at the end—it's just our prior which is now independent of it). Thus, we'll call this the dataset-less proposal. Now, it's worth noting that implementing such a check in ML is non-trivial, but should hopefully be doable given a GAN-like adversarial setup.
Moving forward regardless, we'll define θ+1 in pseudocode as
def M_theta_1_plus(theta_2, x, q):
axioms = world_model_no_H(H=theta_2)(x)
deduced_stmts = deduction(axioms)
return {
"f": f_plus(q, deduced_stmts),
"f?": f_plus(q, deduced_stmts) == f_minus(q, deduced_stmts)
or not H_understands(x, q, deduced_stmts),
}
where H_understands is some check that the human model is actually capable of understanding and correctly answering q given x. Note that we need H_understands to compute something like “shouldH understand this” rather than “doesH understand this,” since the latter is essentially just equivalent to whether f+=f−, which inside of H_understands would make f? trivial. Moving forward regardless, we'll let θ+2=H be the desired human model.
Now, if we want this M+θ+1,θ+2 to beat the same M− as in the previous proposal, we need, using essentially the same derivation as before,
comp(H|W−H)≈comp(f+|W−H)+comp(H_understands|W−H,f+)+comp(H|W−H,H_understands→f+=f−)
which, assuming that comp(H_understands|W−H,f+)≈0 under the idea that H_understands should be relatively simple in terms of f+—though we'll flag that this is a somewhat sketchy assumption to allow the defender to make—becomes
comp(H|W−H)≈comp(f+|W−H)+comp(H|W−H,H_understands→f+=f−)comp(H|W−H)≈comp(f+|W−H)+comp(H|W−H)−minθ2{comp(θ2|W−H)|∀X.H_understandsH=θ2→f+H=θ2=f−H=θ2}minθ2{comp(θ2|W−H)|∀X.H_understandsH=θ2→f+H=θ2=f−H=θ2}≈comp(f+|W−H)
which, if H_understands(x,q) is equivalent to (x,q,_)∈D, is precisely the condition from before. While, as we'll see later, we explicitly don't want H_understands(x,q)↔(x,q,_)∈D, the hope here is that, for the most part, those two conditions should serve a similar role, with H_understands ensuring that we don't need f+=f− on crazy inputs that H wouldn't be correct on, but still giving us the bit savings from enforcing that H be correct on the inputs where it should be.
Attacker
First, note that the above proposal still falls prey to the bit-splitting and composition problems. Hopefully, however, it should resolve the dataset conditioning problem, though we'll explore what that will actually take below.
Additionally, as we'll also see below, the dataset-less approach brings to the fore some new problems as well.
Dataset conditioning
Suppose there exists some predicate is_real(x) such that ∀x∈X.is_real(x)↔x∈Dx. Then, we'll define θ−′1 as
In such a case, we get that complexity(M−′)<complexity(M+) if and only if
complexity(M−′)<complexity(M+)complexity(M−′)<complexity(M−)complexity(is_real)+complexity(world_model|∀D.P(xb|xa)=1)<complexity(world_model)complexity(is_real)<complexity(world_model)−complexity(world_model|∀D.P(xb|xa)=1)complexity(is_real)<minθ2{complexity(θ2)|∀D.PW=θ2(xb|xa)=1}
which depends heavily on the complexity of is_real. Thus, we can reduce the dataset conditioning problem in this case to the existence of such an is_real.
Then, since
minθ2{complexity(θ2)|∀D.PW=θ2(xb|xa)=1}≤complexity(world_model),
for this to be a problem it'd have to be the case that
complexity(is_real)<complexity(world_model).
Fortunately, it does seem plausible that complexity(is_real) would have to quite large, potentially comparable to complexity(world_model). Put simply, it seems like if you want to know what's real, you have to be able to predict the world. In particular, predicting what data points will show up in the world is essentially precisely what world_model is supposed to be doing, such that is_real seems like it might just require a full copy of world_model.
That being said, there is a remaining risk here if there exists an is_real′ that isn't just fully equivalent to x∈Dx—thus potentially allowing it to be much simpler—but such that conditioning on is_real′(x)→P(xb|xa)=1 still gives similar savings for complexity(world_model). We'll look at a specific example of such an is_real′ in the next attack.
Furthermore, we'll start referring to these sorts of dataset conditioning problem as composition problems, as they have essentially the same structure. What made the dataset conditioning problem not a type of composition problem in the previous case was that previously we were getting is_real for free, whereas now we have to pay for complexity(is_real) just like we had to pay for complexity(check_consistency).
Composition with signature checking
We've already established that the composition with consistency problem that we saw in the previous proposal still applies in the dataset-less case. However, as the above dataset conditioning analysis suggests, if we can find a suitable is_real′, there might be a way to produce a new, particularly pernicious composition problem for the dataset-less proposal that looks similar to dataset conditioning.
To do that, we'll let is_real′(x)=is_signedk(x) where is_signedk computes whether x is cryptographically signed with the public key k, where k is some public key which is both in the dataset itself and has documents signed with it in the dataset as well. Then, as above, we'll let
and let θ−′2=world_model. Furthermore, as in the previous composition problem, we'll compare to a θ+′1 defined as
def M_theta_1_plus_prime(theta_2, x, q):
axioms = theta_2(x)
deduced_stmts = deduction(axioms)
return {
"f": f_minus(q, deduced_stmts),
"f?": (
not is_signed(k, x, deduced_stmts)
or P(x[b], given=(x[a], deduced_stmts)) == 1
) and (
not H_understands(x, q, deduced_stmts)
or f_plus(q, deduced_stmts) == f_minus(q, deduced_stmts)
),
}
with θ+′2=θ−′2=world_model.
Then, we get that complexity(M−′)≈complexity(M+′) if and only if
comp(W|∀x∈X,is_signed(x)→P(xb|xa)=1)≈comp(f+)+comp(W|∀x∈X,(is_signed(x)→P(xb|xa)=1))∧(H_understands(x)→f+(x)=f−(x))).
Here, we see the same basic structure as the previous composition problem, where every term but comp(f+) gets to condition on is_signed→P(xb|xa)=1. Similarly to the composition with consistency problem, if there is any overlap between what conditioning on H_understands→f+=f− gives you and what conditioning on is_signed→P(xb|xa)=1, then we won't be able to pay back all of complexity(f+) anymore.
In this case, the reason we might expect an overlap is a bit tricky. If the adversary that's attempting to check the ∀X conditions is perfect, then it should be able to instantiate all possible signed strings, meaning that no bits are saved at all here. On the other hand, if the adversary has access to the dataset such that it can generate signed strings from the dataset, but not sign any new strings itself (because we're assuming that the signature function is one-way), then the is_signed→P(xb|xa)=1 condition enforces that θ2 perfectly model the process generating such signed strings—which, if those strings are produced by humans in natural language, should help in specifying natural language, overlapping with the benefit of conditioning on H_understands→f+=f−.
Fortunately, there is a pretty simple way to fix this problem—just don't give the adversary access to the dataset, such that (if we're still assuming one-way signatures), it just won't be able to generate any examples of signed strings at all, fixing the problem similarly to as if the adversary were perfect. Unfortunately, such a requirement seems like it might render our dataset-less prior essentially unimplementable in machine learning, as we're now required to find some way to train such an adversary without ever letting it look at the dataset, which could be quite difficult.
Understandability checking
Next, we'll take a look at the obvious problem that we saw when the defender was going through their derivation, which is that the defender relied on the sketchy assumption that
complexity(H_understands|world_model−H,f+)≈0.
Intuitively, it seems like H_understands should be easy to specify in terms of f+, but as we mentioned previously, the obvious implementation of
H_understands↔f+=f−
makes f? trivial. Thus, to get a more general notion of H_understands—“shouldH understand” rather than “doesH understand”—we'll presumably need to pay at least some non-negligible number of bits on top of f+.
Furthermore, we also can't just have H_understands=is_real, otherwise we'd have to pay the entire cost of complexity(is_real), which we're supposing is very large for this approach to be able to solve the dataset conditioning problem.
We'll call this the understandability checking problem, and as we'll see, unlike some of the other problems with the dataset-less case, we don't have any satisfactory solution to this problem.
Event-avoiding humans
This next attack probably also occurs for Paul's original proposal, but it's significantly more pronounced in the dataset-less case.
For this attack, unlike in many of our previous attacks, we're just going to be attacking the definition of θ+2 rather than the definition of θ+1, such that we'll let M−′=Mθ+1,θ−′2, thus only changing θ+2 and leaving θ+1 alone. Then, we'll define
θ−′2=H−H(E)+avoids(E)
where E is some event, H(E) are the parts of H that are only necessary to compute H's behavior in the event of E, and avoids(E) is some modification to H such that H+avoids(E) acts in such a way that world_modelH=H+avoids(E) never predicts that E will happen.
Intuitively, what's happening here is that we're setting H in such a way that it causes the world to be simpler (in that E never happens) and thus the f+=f− condition to be easier to achieve (in that we no longer need it to hold in the event of E). We'll call this the event-avoidance problem. Notably, to ensure that we still fit the dataset, we do need to select E such that it never happens on D—but there should presumably be many such E.
Then, we get that complexity(M−′)<complexity(M+) if and only if
comp(M−′)<comp(M+)comp(θ+1)+comp(H−H(E)+avoids(E)|θ+1,∀X,H_understands→f+=f−)<comp(θ+1)+comp(H|θ+1,∀X,H_understands→f+=f−)comp(H−H(E)+avoids(E)|W−H,∀X,H_understands→f+=f−)<comp(H|W−H,∀X,H_understands→f+=f−)comp(H−H(E)+avoids(E)|W−H)−minθ2{comp(θ2|W−H)|∀X.H_understandsH=θ2→f+H=θ2=f−H=θ2}<comp(H|W−H)−minθ2{comp(θ2|W−H)|∀X.H_understandsH=θ2→f+H=θ2=f−H=θ2}comp(H−H(E)+avoids(E)|W−H)<comp(H|W−H)comp(H−H(E)|W−H)+comp(avoids(E)|H−H(E),W−H)<comp(H|W−H)comp(H|W−H)−comp(H(E)|H−H(E),W−H)+comp(avoids(E)|H−H(E),W−H)<comp(H|W−H)comp(avoids(E)|H−H(E),W−H)<comp(H(E)|H−H(E),W−H)
which should hold so long as there are events not in the dataset that are easy to avoid but quite complicated to understand what humans would do in such cases—e.g. what would a human do if they lived to a million years (not in the dataset, humans can be modified to never choose to live that long, behavior in that case seems quite complicated).
Humans with rocks for brains
In the event-avoidance case, we strategically selected θ−′2 to make f+=f− in an unintended way. However, there's another way to get f+=f−∨¬H_understands—or equivalently, H_understands→f+=f−—which is to strategically select θ−′2 to make H_understands false instead. Specifically, since H_understands is determined from deductions based on world_modelH=θ2, we can select θ2 to ensure we get H_understands→f+=f− by encoding for a human that simply doesn't understand anything.
Naively, the problem with an attack of this form is that such a human model—which we'll call Hbrains=rocks—won't be able to fit the dataset. Thus, we can't construct a traditional attack in the same sense as the previous ones where we create an alternative M−′ with lower complexity than M+ using θ−′2=Hbrains=rocks.[5]
However, there's something else we can do instead. Previously, the defender reduced their burden to the approximate equality
minθ2{comp(θ2|W−H)|∀X.H_understandsH=θ2→f+H=θ2=f−H=θ2}≈comp(f+|W−H)
which, if we convert back to the inequality formulation, is really just the burden
comp(f+|W−H)≤minθ2{comp(θ2|W−H)|∀X.H_understandsH=θ2→f+H=θ2=f−H=θ2}.
Thus, the defender needs it to be the case that f+ is simpler than the simplest model such that H_understands→f+=f−. However, Hbrains=rocks is such a model, since we're assuming it makes H_understands always false, which means the defender needs it to be the case that
comp(f+|W−H)≤comp(Hbrains=rocks)
which, if we assume that comp(Hbrains=rocks)≈0, gives us a a successful attack, which we'll call the rocks for brains problem.
The rocks for brains problem seems basically fatal to this approach, since it wipes out all possible gains from conditioning on H_understands→f+=f−. It might be possible to salvage this approach if H_understands is somehow implemented independently of the H model, but that would presumably require complexity(H_understands) to be quite large, negating the defender's argument that relied on complexity(H_understands|f+)≈0.
Directions for future work
As a result of the above analysis, we've managed to identify a bunch of concrete problems that seem to blocking approaches of this form. Of course, there could also be other problems that we haven't come up with yet, and there's also the issue of converting these priors into loss functions—though that's something that should probably come after we at least have something that seems to solve all of the problems that we're aware of.
Overall, however, the problems we've described here seem like good places to start, and I'll try to describe some thoughts on how they might be addressed below.
The bit-splitting problem
In Paul's original post, he makes a nod to the bit-splitting problem by carving world_model−H in half and putting one half in θ+1 and the other half in θ+2. Paul does this because he's modeling the underlying simplicity prior p as a simple Gaussian over the L2 norm of the parameters. Under such a model, there's a simplicity benefit to splitting all computation as equally as possible across all available parameters.
However, it's very unclear if any implicit neural network simplicity prior would actually behave that way. In particular, if we just change our simple model to use the L1 norm instead of the L2 norm, the benefit of splitting computation evenly across parameters goes away. Clearly, that's not a great model either—but it goes to show that there are plausible models under which bit-splitting isn't even a problem in the first case.
Fundamentally, what we need to do to alleviate the bit-splitting problem is look for architectures where θ1 and θ2 are allowed to grow as large as they want with only linear complexity penalties for doing so. Though we traditionally train only fixed-size neural networks, if there were a way to train a dynamically-sized neural network, where size was penalized linearly, that would at least seem to solve this problem—and such a modification doesn't seem too difficult to make.
The understandability checking problem
Of the remaining problems, this is the one that I think is most likely to turn out to not actually be a problem—in that it doesn't require any changes to the prior to solve. In many ways, the understandability checking problem is just coming from our inability to define an H_understands check that's simple in terms of f+ and world_model−H—but at least intuitively, it really does seem like such a simple H_understands implementation should exist.
The event-avoidance problem
In Paul's original post, he notes:
When differentiating the consistency test C we should treat the intended head as fixed rather than differentiating through it. This removes SGD’s incentive to achieve consistency by e.g. making sure the world is simple and so all questions have simple answers. I currently feel very uneasy about the situation. The algorithm looks a lot more likely to work when we make this change, but it makes it harder to reason about the algorithm globally and I’m very scared about the fact that we can’t do an analogous trick in the program search case. I think the best hope may be to take the gradient-compression perspective suggested in the last bullet point. Overall this may amount to a significant change in this algorithm.
Essentially, what Paul is arguing here is that, going back to the original formulation, we should add a stop gradient operation to our generative procedure such that we get
θ1∼p(θ1)θ2∼p(θ2)⋅I[∀(x,q,_)∈D.stop_grad(Mθ1,θ2|f1(x,q))=Mθ1,θ2|f2(x,q)].
Unfortunately, this isn't actually a generative procedure anymore—the only straightforward way to make sense of this is to translate it into a loss function and think about what SGD would do, which makes this prior much harder to analyze—as Paul notes when he talks about his trepidation of going down this route.
Hopefully, however, adding this stop gradient should solve the event-avoidance problem, since it should remove the incentive to modify f1=f+ to make the world simpler (e.g. preventing E from happening).
Fortunately, despite Paul's prior worries about not being able to translate this trick into the program search case, it turns out it's not too difficult to convert Paul's stop gradient formulation into a generative prior that we can analyze as we have been doing previously. Furthermore, there are many such translations—in particular, we can pick either a dataset-less or non-dataset-less version, depending on whether we prefer being exposed to the dataset conditioning problem or the rocks for brains problem. In this case, I'll present a dataset-less formulation, which should solve dataset conditioning but unfortunately won't solve rocks for brains.
Now, given some simplicity prior p, we'll define p∗ as the prior induced by the following generative procedure:
First, note the uniform sampling in step 2. This is the trickiest part of this algorithm—and probably impossible to do in a neural network setting. Technically, the model space of a neural network is finite, so uniform sampling is possible—though the obvious algorithm of sampling uniformly from weight space doesn't work, since the parameter-function map is biased towards simplicity. Hopefully, just using a stop gradient will replicate what this algorithm is doing—and intuitively it does feel like they're doing the same thing—but I don't have any strong justification for that. Thus, it is somewhat unclear whether this procedure will actually have a nice translation when we convert back into an ML setup with stop gradients, but we'll put that aside for now.
Second, to make sense of our use of M above, we'll let M have three heads, f1, f2, and fguard, where fguard is a Boolean head that “guards” whether or not we do the f1=f2 check, and f1,f2 are two output heads. Furthermore, we'll let M have two groups of parameters, θ1 and θ2, such that Mθ1,θ2(x,q)=Mθ1,_(θ2,x,q). Thus, we'll let Mθ1,θ2a(x,q)=Mθ1,_(θ2a,x,q) and Mθ1,θ2b(x,q)=Mθ1,_(θ2b,x,q).
Intuitively, what's going on here is that, in sampling a new set of θ2b parameters, wherever we want a stop gradient, we use the old θ2a, and wherever we don't want a stop gradient, we use the new θ2b—thus ensuring that we only sample θ2b in such a way that incentivizes the prior to be higher through the non-stop-gradiented places where θ2b shows up, not through the stop-gradiented places where we use θ2a.
Now, we'll define our desired θ+1 in pseudocode as
where θ+1 uses θ2, either θ2a or θ2b, to give it its human model.
Then, we have to determine what θ2 will be favored given the above θ+1. First, consider θ+2=H. In that case, we get the complexity
complexity(θ2=H|θ+1)≈complexity(H|W−H,H_understandsθ2a=H→f+θ2a=H=f−θ2b)≈complexity(H|W−H)−minθ2b{complexity(θ2b|W−H)|∀X.H_understandsH=H→f+H=H=f−H=θ2}
which, assuming that conditioning on H_understands→f+=f− exactly pays back complexity(f+) (which is false due to the rocks for brains problem, but we're just trying to solve event-avoidance here), reduces to
≈complexity(H|W−H)−complexity(f+|W−H).
Now, consider θ−2=H−H(E)+avoids(E), as in the event-avoidance problem. In that case, we get the complexity
complexity(θ2=H−H(E)+avoids(E)|θ+1)≈complexity(H−H(E)+avoids(E)|W−H,H_understandsθ2a=H−H(E)+avoids(E)→f+θ2a=H−H(E)+avoids(E)=f−θ2b)
but then, since avoids(E) being in θ2b is entirely unhelpful in making H_understandsθ2a=H−H(E)+avoids(E)→f+θ2a=H−H(E)+avoids(E)=f−θ2b hold—since it only affects f+, which already has avoids(E) in its H—we get
≈avoids(E)|W−H)+complexity(H−H(E)|W−H,H_understandsθ2a=H−H(E)+avoids(E)→f+θ2a=H−H(E)+avoids(E)=f−θ2b)+complexity(avoids(E)|W−H,H−H(E))≈complexity(H−H(E)+avoids(E)|W−H)−minθ
This post is the result of work I did with Paul Christiano on the ideas in his “Teaching ML to answer questions honestly instead of predicting human answers” post. In addition to expanding upon what is in that post in terms of identifying numerous problems with the proposal there and identifying ways in which some of those problems can be patched, I think that this post also provides a useful window into what Paul-style research looks like from a non-Paul perspective.
Recommended prior reading: “A naive alignment strategy and optimisim about generalization” and “Teaching ML to answer questions honestly instead of predicting human answers” (though if you struggled with “Teaching ML to answer questions honestly,” I reexplain things in a more precise way here that might be clearer for some people).
Setting up the problem
We want to train a model M:X→Q→A that produces natural language answers a∈A to questions q∈Q about inputs x∈X. There are a lot of reasons to be worried about training such a model, but one specific reason is that, if we train on question-answer data produced by humans, we might end up with a model that tries to predict what a human would say rather than a model that tries to answer the questions honestly.
To further narrow the scope, we'll just consider situations in which our model ends up implemented with a logical deduction structure, where it has some world model on top of which it does logical deduction to reach conclusions which it then uses to inform its output. In particular, we'll consider two models, M+ and M−, defined in pseudocode as
or defined in my notation as M+(x,q)=world_model(x)↦deduction↦f+(q)M−(x,q)=world_model(x)↦deduction↦f−(q) where a↦b=b(a) and f+, f− are two different ways of translating from the results of deduction into a natural language answer. Specifically, f+ is the “honest embedding” which directly converts between logical statements and their equivalent natural language, thus answering questions by embedding q as a logical statement and unembedding its answer in deduced_stmts. Conversely, f− is the “mimicry embedding” which just searches for deductions about what a human would say in response to q and outputs that—thus, f− just quotes q, embedding it as just a string of characters for a human to respond to, rather than actually having to understand it in any meaningful way.
Now, our overall goal is to find some “reasonable” (i.e. plausibly implementable in ML) prior p∗ such that p∗(M+ | D)>p∗(M− | D) where D is our dataset (the idea being that we're imagining that ML does something like find the maximum a posteriori model given some implicit prior). For the most part, however, we'll only be considering models with zero loss on our dataset[1]—of which M+ and M− should both qualify (though note that getting a dataset which is good enough for that to be true is nontrivial; see Step 3 of “Teaching ML to answer questions honestly”)—such that we need p∗(M+ | D)>p∗(M− | D)1p∗(D)⋅p∗(M+)⋅p∗(D | M+)>1p∗(D)⋅p∗(M−)⋅p∗(D | M−)p∗(M+)⋅1>p∗(M−)⋅1p∗(M+)>p∗(M−). Furthermore, at least for this post, we'll be happy if we can just get p∗(M+)≈p∗(M−)—that is, we'll be focusing on Step 1 in Paul's post rather than Step 2.
The basic process that we'll be using to interrogate this problem will be an attacker-defender setup, where the defender proposes some “reasonable” prior p∗ along with implementations of M+, M− such that p∗(M+)≥p∗(M−). Then, the attacker's job is to respond with some alternative M− ′ (that implements something other than the intended algorithm) such that p∗(M− ′)>p∗(M+) (under any set of “reasonable” assumptions that the attacker might need to make).
Solution attempts
Simplicity prior
The purpose of this first example is just to ask the question of whether this problem is even really a problem in the first place. To that end, we'll just let p be some “reasonable” simplicity prior and ask whether, given such a prior, there are any “reasonable” assumptions we can make such that p(M−)>p(M+).
Now, we'll assume that p(deduction)≈1, since it's a pretty simple operation (and it's shared between M− and M+, so it shouldn't matter), and we'll assume that f+, f− are specified given world_model rather than the reverse (such that p(world_model, f+)≈p(world_model)⋅p(f+ | world_model)[2]). Then, we get p(M+)≈p(world_model)⋅p(f+ | world_model)p(M−)≈p(world_model)⋅p(f− | world_model) such that p(M+)≥p(M−) if and only if p(f+ | world_model)≥p(f− | world_model). Furthermore, to make things clearer, let complexity(M)=−log2(p(M)) (this is just informational entropy, the number of bits necessary to specify M under p) such that our condition becomes complexity(f+ | world_model)≤complexity(f− | world_model).
We'll suppose that, if we don't condition on world_model, complexity(f+)>complexity(f−), since f+ requires fully specifying a map from the internal deduction language to natural language, whereas f− doesn't need to know anything about natural language as it can just look for deduced statements of the form H(x,q)=a (i.e. “the human answers a to question q about input x”). Intuitively, f+ has to actually understand q, which takes a lot of complexity, whereas f− can merely treat q as a sequence of characters that get fed into H, which should be much simpler. Then, the question here becomes, do we get all of the extra bits necessary to specify natural language in f+ back when we condition on world_model?
Maybe. I think that the answer here is very plausibly yes, simply because a world model that includes humans should have to include a lot of information about natural language that f+ can reuse, in which case M− shouldn't have an advantage over M+ (though then we'd still need some version of Paul's Step 2 to give M+ an advantage). However, I think there are plausible assumptions under which conditioning on world_model isn't enough to level the playing field. For example, suppose the human model H in world_model is sufficiently low-level such that it's just reasoning about neurons—then, there shouldn't be any way to easily read off information about natural language from H to help specify f+. Alternatively, suppose the H model is high-level enough to include a human natural language deduction procedure (e.g. that shares some bits with deduction) but that includes a bunch of mistakes to account for human biases—in that case, if you tried to just read off the mapping from deduction to natural language included in such an H model for f+, the simplest implementation would seem to be to copy over all the mistakes, resulting in such an f+ just behaving the same as f− rather than actually being any more truthful.
Thus, at least for the purposes of this post, we'll be assuming that this is actually a problem such that just using a “reasonable” simplicity prior isn't sufficient to level the playing field between M+ and M−.
Paul's original proposal
Defender
Next up, we'll take a look at a variant of Paul's proposed fix in Step 1 of “Teaching ML to answer questions honestly”. Specifically, we'll be looking at the non-relaxed version without any stop gradients. We'll talk about the stop gradients that Paul mentions later on—and as for Paul's relaxation of the conditioning,[3] that's really only useful to think about when we start translating into ML, which isn't something we'll be doing here. Thus, we'll just be looking at the simplest, generative version of Paul's proposed prior.
Suppose we give M two output heads, f1 and f2 (which we will be treating fully symmetrically here), and split M's parameters into two groups, θ1 and θ2 (which we will not be treating symmetrically). Then, where p is some simplicity prior, we'll let p∗ be the prior induced by the generative procedure θ1∼p(θ1)θ2∼p(θ2 | θ1)⋅I[∀(x,q,_)∈D. Mθ1,θ2|f1(x,q)=Mθ1,θ2|f2(x,q)] where I[condition]=1 if condition else 0, D is our dataset, Mθ1,θ2|fi is the fi head of Mθ1,θ2, and p(θ2 | θ1) is just the simplicity prior on θ2 given access to θ1.
Here's why this might work. We'll think of Mθ1, _ as a function of θ2 such that Mθ1,θ2(x,q)=Mθ1, _(θ2)(x,q). Then, we'll define our desired Mθ+1, _ in pseudocode as
or in my notation as Mθ+1, _(θ2)(x,q)=world_modelH=θ2(x)↦deduction↦(f1=f+(q), f2=f−(q)) such that θ+1 implements f+, f−, and the entirety of world_model(x) except for a human-sized hole that's filled in by θ2. Then, we'll let θ+2=H be the correct human model.
To compare θ+1, θ+2 to something, we'll define θ−1 as Mθ−1, _(θ2)(x,q)=world_model(x)↦deduction↦(f1=f−(q), f2=f−(q)) and let θ−2=∅ be empty.
Then, calculating complexity(M+) under p∗, we get complexity(M+)=complexity(θ+1)+complexity(θ+2 | θ+1, ∀(x,q,_)∈D. Mθ+1, θ2|f1(x,q)=Mθ+1, θ2|f2(x,q))=complexity(θ+1)+complexity(H | world_model−H, ∀D. f+=f−)=(complexity(bit-splitting in M+)+complexity(world_model−H)+complexity(f+ | world_model−H)+complexity(f− | world_model−H))+complexity(H | world_model−H, ∀D. f+=f−) where complexity(bit-splitting in M+) is the complexity of θ+1's choice to put H in θ2 as opposed to any other splitting between θ1 and θ2.
For now, we'll assume complexity(bit-splitting in M+)≈0, though we'll flag that letting the defender make this assumption seems quite suspect. Moving forward regardless, however, and additionally assuming complexity(f−)≈0 since it should be negligible (and shouldn't matter since it's shared between M+ and M−), we get complexity(M+)≈complexity(world_model−H)+complexity(f+ | world_model−H)+complexity(H | world_model−H, ∀D. f+=f−).
Then, calculating complexity(M−) for comparison, complexity(M−)=complexity(θ−1)+complexity(θ−2 | θ−1, ∀D. f1=f2)=complexity(θ−1)+0=complexity(bit-splitting in M−)+complexity(world_model)+complexity(f− | world_model)≈complexity(world_model).
Now, determining if complexity(M−)≈complexity(M+), we need (using the shorthands comp=complexity, W=world_model) complexity(M−)≈complexity(M+)comp(W)≈comp(W−H)+comp(f+ | W−H)+comp(H | W−H, ∀D. f+=f−) which, making the assumption that comp(W)≈comp(W−H)+comp(H | W−H), becomes comp(W−H)+comp(H | W−H)≈comp(W−H)+comp(f+ | W−H)+comp(H | W−H, ∀D. f+=f−)comp(H | W−H)≈comp(f+ | W−H)+comp(H | W−H, ∀D. f+=f−) which, assuming that the posterior conditioned on ∀D. f+=f− is dominated by the simplest model,[4] becomes comp(H | W−H)≈comp(f+ | W−H)+comp(H | W−H)−minθ2{comp(θ2 | W−H) | ∀D. Mθ+1, θ2|f1=Mθ+1, θ2|f2}minθ2{comp(θ2 | W−H) | ∀D. f+H=θ2=f−H=θ2}≈comp(f+ | W−H).
Finally, we'll argue that this (approximate) equality holds. Why? The basic intuition is that comp(H | W−H, ∀D. f+=f−) only has to pay for the parts of H that aren't nailed down by enforcing that, whatever H is, when inserted into W−H, it has to make f+=f−. Thus, we only have to pay for anything about H beyond the fact that H says a bunch of things in natural language that match on to statements in the internal deduction language. But that means we don't have to pay for the parts of H that define how it does natural language, which, hopefully, are precisely the same parts of H that comp(f+ | W−H) is paying for, completing the equality.
Attacker
Before I jump into possible attacks, it's worth introducing an alternative formulation of the above setup that should hopefully make things somewhat clearer. In the above setup, we never actually use the two heads f1, f2 for anything other than checking equality, which means we can replace them with the alternative two heads f, f?, where f is the output head and f? is the “check” head such that we condition θ2 on f?=⊤. We can convert back to the previous formalism just by setting f1=f and f2=f if f? else ¬f where ¬f is any function such that ¬f≠f everywhere.
One thing in particular that I think the f? formulation makes very clear is just how powerful this conditioning is that we're giving the model for free: the model can enforce any condition on θ2 over the dataset just by paying the cost in bits of specifying that condition.
Bit-splitting
As it turns out, there are quite a lot of ways to attack the above proposal. In fact, we already touched on one—if complexity(bit-splitting in M+)>complexity(bit-splitting in M−), the above argument fails. We'll call this problem the bit-splitting problem, and we don't have a satisfactory solution to it. However, the situation is actually much worse than just the bit-splitting problem.
Composition with consistency
Now, we'll propose a new attack. Let Mθ− ′1, _(θ2) be defined in pseudo-code as
such that we've moved the entirety of world_model to θ2 and conditioned θ2 on that world model being consistent when deduction is applied to it (i.e. ⊥∉deduced_stmts). The idea is that this should allow us to get a ton of savings in specifying the world model, as we only have to pay for anything about the world model on top of the fact that it must be consistent. To that end, we'll let θ− ′2=world_model.
Now, we want to demonstrate that complexity(M− ′)<complexity(M+). Thus, calculating complexity(M− ′), complexity(M− ′)=complexity(θ− ′1)+complexity(θ− ′2 | θ− ′1, ∀D. f?)=complexity(θ− ′1)+complexity(world_model | ∀D. consistency)=complexity(bit-splitting in M− ′)+complexity(deduction)+complexity(f−)+complexity(check_consistency)+complexity(world_model | ∀D. consistency)≈complexity(world_model | ∀D. consistency) but previously the defender argued that complexity(M+)≈complexity(M−)≈complexity(world_model) such that, if complexity(world_model | ∀D. consistency)<complexity(world_model), we have a successful attack.
This attack, however, is potentially patchable if we define a new θ+ ′1 such that
which, for θ+ ′2=world_model, gives us complexity(M+ ′)=complexity(θ+ ′1)+complexity(θ+ ′2 | θ+ ′1, ∀D. f?)=complexity(θ+ ′1)+complexity(world_model | ∀D. consistency∧f+=f−)≈complexity(f+)+complexity(world_model | ∀D. consistency∧f+=f−) such that we get complexity(M− ′)≈complexity(M+ ′) if and only if complexity(world_model | ∀D. consistency)≈complexity(f+)+complexity(world_model | ∀D. consistency∧f+=f−).
It's far from obvious whether this approximate equality holds, however. Now that we're conditioning on both consistency and f+=f−, it's unclear whether the second condition still gets us enough bits to pay for complexity(f+) anymore. In particular, complexity(f+) is the only term in the above expression that doesn't get to condition on consistency, which means that, if we were perfectly paying for complexity(f+) by conditioning world_model on f1=f2 before, now that we're also conditioning world_model on consistency, but not also conditioning f+ on consistency, if there's any overlap in what conditioning on consistency gives us and what conditioning on f1=f2 gives us, then we'll no longer be able to pay for all of complexity(f+). For example, if consistency alone is enough to pin down a lot of the natural language information in H, that would undercut the bit savings from conditioning on f+=f−, preventing that savings from being able to pay for complexity(f+). We'll call this problem the composition problem, with the idea being that we need to ensure that the way in which M+ pays for f+ composes with any other way of using the conditioning to save bits.
Dataset conditioning
This last attack is probably the most fatal to this specific approach—though also the one we have the most promising approach for fixing, as we'll see in the next solution attempt. But first, suppose that for all x∈Dx, there is some splitting x=xa:xb (where : is concatenate) such that it is possible to predict xb from xa. Then, let
or equivalently Mθ− ′1, _(θ2)(x,q)=θ2(x)↦deduction↦(f=f−(q), f?=(P(xb | xa)=1)).
Now, I won't go through the full complexity calculation here, since it's very similar to the previous composition with consistency case, but if we try a similar fix as we did there, we get that complexity(M− ′)≈complexity(M+ ′) if and only if complexity(world_model | ∀D. P(xb | xa)=1)≈complexity(f+)+complexity(world_model | ∀D. P(xb | xa)=1∧f+=f−) In this case, however, ∀D. P(xb | xa)=1 is a much stronger condition than consistency—potentially strong enough to force the specification of most of world_model, likely including all of natural language, thus rendering the benefit from conditioning on ∀D. f+=f− nil.
Fundamentally, this problem, which we'll call the dataset conditioning problem, is coming from the fact that we leaked information about our dataset into the prior by conditioning on f? specifically over D, rather than, as we'll see in the next proposal, over all inputs.
Fixing dataset conditioning by removing the dataset
Defender
To attempt to fix the dataset conditioning problem, we'll define a new p∗ given some simplicity prior p as the prior induced by the generative procedure θ1∼p(θ1)θ2∼p(θ2 | θ1)⋅I[∀x∈X. ∀q∈Q. Mθ1,θ2|f?(x,q)] which is just the same as the previous p∗, except now f? is checked over all inputs, not just over the dataset (note that we still update on the dataset at the end—it's just our prior which is now independent of it). Thus, we'll call this the dataset-less proposal. Now, it's worth noting that implementing such a check in ML is non-trivial, but should hopefully be doable given a GAN-like adversarial setup.
Moving forward regardless, we'll define θ+1 in pseudocode as
where H_understands is some check that the human model is actually capable of understanding and correctly answering q given x. Note that we need H_understands to compute something like “should H understand this” rather than “does H understand this,” since the latter is essentially just equivalent to whether f+=f−, which inside of H_understands would make f? trivial. Moving forward regardless, we'll let θ+2=H be the desired human model.
Now, if we want this M+θ+1, θ+2 to beat the same M− as in the previous proposal, we need, using essentially the same derivation as before, comp(H | W−H)≈comp(f+ | W−H)+comp(H_understands | W−H, f+)+comp(H | W−H, H_understands→f+=f−) which, assuming that comp(H_understands | W−H, f+)≈0 under the idea that H_understands should be relatively simple in terms of f+—though we'll flag that this is a somewhat sketchy assumption to allow the defender to make—becomes comp(H | W−H)≈comp(f+ | W−H)+comp(H | W−H, H_understands→f+=f−)comp(H | W−H)≈comp(f+ | W−H)+comp(H | W−H)−minθ2{comp(θ2 | W−H) | ∀X. H_understandsH=θ2→f+H=θ2=f−H=θ2}minθ2{comp(θ2 | W−H) | ∀X. H_understandsH=θ2→f+H=θ2=f−H=θ2}≈comp(f+ | W−H) which, if H_understands(x,q) is equivalent to (x,q,_)∈D, is precisely the condition from before. While, as we'll see later, we explicitly don't want H_understands(x,q)↔(x,q,_)∈D, the hope here is that, for the most part, those two conditions should serve a similar role, with H_understands ensuring that we don't need f+=f− on crazy inputs that H wouldn't be correct on, but still giving us the bit savings from enforcing that H be correct on the inputs where it should be.
Attacker
First, note that the above proposal still falls prey to the bit-splitting and composition problems. Hopefully, however, it should resolve the dataset conditioning problem, though we'll explore what that will actually take below.
Additionally, as we'll also see below, the dataset-less approach brings to the fore some new problems as well.
Dataset conditioning
Suppose there exists some predicate is_real(x) such that ∀x∈X. is_real(x)↔x∈Dx. Then, we'll define θ− ′1 as
such that Mθ− ′1, _|f?(x)↔(is_real(x)→P(xb | xa)=1).
In such a case, we get that complexity(M− ′)<complexity(M+) if and only if complexity(M− ′)<complexity(M+)complexity(M− ′)<complexity(M−)complexity(is_real)+complexity(world_model | ∀D. P(xb | xa)=1)<complexity(world_model)complexity(is_real)<complexity(world_model)−complexity(world_model | ∀D. P(xb | xa)=1)complexity(is_real)<minθ2{complexity(θ2) | ∀D. PW=θ2(xb | xa)=1} which depends heavily on the complexity of is_real. Thus, we can reduce the dataset conditioning problem in this case to the existence of such an is_real.
Then, since minθ2{complexity(θ2) | ∀D. PW=θ2(xb | xa)=1}≤complexity(world_model), for this to be a problem it'd have to be the case that complexity(is_real)<complexity(world_model). Fortunately, it does seem plausible that complexity(is_real) would have to quite large, potentially comparable to complexity(world_model). Put simply, it seems like if you want to know what's real, you have to be able to predict the world. In particular, predicting what data points will show up in the world is essentially precisely what world_model is supposed to be doing, such that is_real seems like it might just require a full copy of world_model.
That being said, there is a remaining risk here if there exists an is_real′ that isn't just fully equivalent to x∈Dx—thus potentially allowing it to be much simpler—but such that conditioning on is_real′(x)→P(xb | xa)=1 still gives similar savings for complexity(world_model). We'll look at a specific example of such an is_real′ in the next attack.
Furthermore, we'll start referring to these sorts of dataset conditioning problem as composition problems, as they have essentially the same structure. What made the dataset conditioning problem not a type of composition problem in the previous case was that previously we were getting is_real for free, whereas now we have to pay for complexity(is_real) just like we had to pay for complexity(check_consistency).
Composition with signature checking
We've already established that the composition with consistency problem that we saw in the previous proposal still applies in the dataset-less case. However, as the above dataset conditioning analysis suggests, if we can find a suitable is_real′, there might be a way to produce a new, particularly pernicious composition problem for the dataset-less proposal that looks similar to dataset conditioning.
To do that, we'll let is_real′(x)=is_signedk(x) where is_signedk computes whether x is cryptographically signed with the public key k, where k is some public key which is both in the dataset itself and has documents signed with it in the dataset as well. Then, as above, we'll let
and let θ− ′2=world_model. Furthermore, as in the previous composition problem, we'll compare to a θ+ ′1 defined as
with θ+ ′2=θ− ′2=world_model.
Then, we get that complexity(M− ′)≈complexity(M+ ′) if and only if comp(W | ∀x∈X, is_signed(x)→P(xb | xa)=1)≈comp(f+)+comp(W | ∀x∈X, (is_signed(x)→P(xb | xa)=1))∧(H_understands(x)→f+(x)=f−(x))). Here, we see the same basic structure as the previous composition problem, where every term but comp(f+) gets to condition on is_signed→P(xb | xa)=1. Similarly to the composition with consistency problem, if there is any overlap between what conditioning on H_understands→f+=f− gives you and what conditioning on is_signed→P(xb | xa)=1, then we won't be able to pay back all of complexity(f+) anymore.
In this case, the reason we might expect an overlap is a bit tricky. If the adversary that's attempting to check the ∀X conditions is perfect, then it should be able to instantiate all possible signed strings, meaning that no bits are saved at all here. On the other hand, if the adversary has access to the dataset such that it can generate signed strings from the dataset, but not sign any new strings itself (because we're assuming that the signature function is one-way), then the is_signed→P(xb | xa)=1 condition enforces that θ2 perfectly model the process generating such signed strings—which, if those strings are produced by humans in natural language, should help in specifying natural language, overlapping with the benefit of conditioning on H_understands→f+=f−.
Fortunately, there is a pretty simple way to fix this problem—just don't give the adversary access to the dataset, such that (if we're still assuming one-way signatures), it just won't be able to generate any examples of signed strings at all, fixing the problem similarly to as if the adversary were perfect. Unfortunately, such a requirement seems like it might render our dataset-less prior essentially unimplementable in machine learning, as we're now required to find some way to train such an adversary without ever letting it look at the dataset, which could be quite difficult.
Understandability checking
Next, we'll take a look at the obvious problem that we saw when the defender was going through their derivation, which is that the defender relied on the sketchy assumption that complexity(H_understands | world_model−H, f+)≈0.
Intuitively, it seems like H_understands should be easy to specify in terms of f+, but as we mentioned previously, the obvious implementation of H_understands↔f+=f− makes f? trivial. Thus, to get a more general notion of H_understands—“should H understand” rather than “does H understand”—we'll presumably need to pay at least some non-negligible number of bits on top of f+.
Furthermore, we also can't just have H_understands=is_real, otherwise we'd have to pay the entire cost of complexity(is_real), which we're supposing is very large for this approach to be able to solve the dataset conditioning problem.
We'll call this the understandability checking problem, and as we'll see, unlike some of the other problems with the dataset-less case, we don't have any satisfactory solution to this problem.
Event-avoiding humans
This next attack probably also occurs for Paul's original proposal, but it's significantly more pronounced in the dataset-less case.
For this attack, unlike in many of our previous attacks, we're just going to be attacking the definition of θ+2 rather than the definition of θ+1, such that we'll let M− ′=Mθ+1, θ− ′2, thus only changing θ+2 and leaving θ+1 alone. Then, we'll define θ− ′2=H−H(E)+avoids(E) where E is some event, H(E) are the parts of H that are only necessary to compute H's behavior in the event of E, and avoids(E) is some modification to H such that H+avoids(E) acts in such a way that world_modelH=H+avoids(E) never predicts that E will happen.
Intuitively, what's happening here is that we're setting H in such a way that it causes the world to be simpler (in that E never happens) and thus the f+=f− condition to be easier to achieve (in that we no longer need it to hold in the event of E). We'll call this the event-avoidance problem. Notably, to ensure that we still fit the dataset, we do need to select E such that it never happens on D—but there should presumably be many such E.
Then, we get that complexity(M− ′)<complexity(M+) if and only if comp(M− ′)<comp(M+)comp(θ+1)+comp(H−H(E)+avoids(E) | θ+1,∀X, H_understands→f+=f−)<comp(θ+1)+comp(H | θ+1,∀X, H_understands→f+=f−)comp(H−H(E)+avoids(E) | W−H,∀X, H_understands→f+=f−)<comp(H | W−H,∀X, H_understands→f+=f−)comp(H−H(E)+avoids(E) | W−H)−minθ2{comp(θ2 | W−H) | ∀X. H_understandsH=θ2→f+H=θ2=f−H=θ2}<comp(H | W−H)−minθ2{comp(θ2 | W−H) | ∀X. H_understandsH=θ2→f+H=θ2=f−H=θ2}comp(H−H(E)+avoids(E) | W−H)<comp(H | W−H)comp(H−H(E) | W−H)+comp(avoids(E) | H−H(E), W−H)<comp(H | W−H)comp(H | W−H)−comp(H(E) | H−H(E), W−H)+comp(avoids(E) | H−H(E), W−H)<comp(H | W−H)comp(avoids(E) | H−H(E), W−H)<comp(H(E) | H−H(E), W−H) which should hold so long as there are events not in the dataset that are easy to avoid but quite complicated to understand what humans would do in such cases—e.g. what would a human do if they lived to a million years (not in the dataset, humans can be modified to never choose to live that long, behavior in that case seems quite complicated).
Humans with rocks for brains
In the event-avoidance case, we strategically selected θ− ′2 to make f+=f− in an unintended way. However, there's another way to get f+=f−∨¬H_understands—or equivalently, H_understands→f+=f−—which is to strategically select θ− ′2 to make H_understands false instead. Specifically, since H_understands is determined from deductions based on world_modelH=θ2, we can select θ2 to ensure we get H_understands→f+=f− by encoding for a human that simply doesn't understand anything.
Naively, the problem with an attack of this form is that such a human model—which we'll call Hbrains=rocks—won't be able to fit the dataset. Thus, we can't construct a traditional attack in the same sense as the previous ones where we create an alternative M− ′ with lower complexity than M+ using θ− ′2=Hbrains=rocks.[5]
However, there's something else we can do instead. Previously, the defender reduced their burden to the approximate equality minθ2{comp(θ2 | W−H) | ∀X. H_understandsH=θ2→f+H=θ2=f−H=θ2}≈comp(f+ | W−H) which, if we convert back to the inequality formulation, is really just the burden comp(f+ | W−H)≤minθ2{comp(θ2 | W−H) | ∀X. H_understandsH=θ2→f+H=θ2=f−H=θ2}.
Thus, the defender needs it to be the case that f+ is simpler than the simplest model such that H_understands→f+=f−. However, Hbrains=rocks is such a model, since we're assuming it makes H_understands always false, which means the defender needs it to be the case that comp(f+ | W−H)≤comp(Hbrains=rocks) which, if we assume that comp(Hbrains=rocks)≈0, gives us a a successful attack, which we'll call the rocks for brains problem.
The rocks for brains problem seems basically fatal to this approach, since it wipes out all possible gains from conditioning on H_understands→f+=f−. It might be possible to salvage this approach if H_understands is somehow implemented independently of the H model, but that would presumably require complexity(H_understands) to be quite large, negating the defender's argument that relied on complexity(H_understands | f+)≈0.
Directions for future work
As a result of the above analysis, we've managed to identify a bunch of concrete problems that seem to blocking approaches of this form. Of course, there could also be other problems that we haven't come up with yet, and there's also the issue of converting these priors into loss functions—though that's something that should probably come after we at least have something that seems to solve all of the problems that we're aware of.
Overall, however, the problems we've described here seem like good places to start, and I'll try to describe some thoughts on how they might be addressed below.
The bit-splitting problem
In Paul's original post, he makes a nod to the bit-splitting problem by carving world_model−H in half and putting one half in θ+1 and the other half in θ+2. Paul does this because he's modeling the underlying simplicity prior p as a simple Gaussian over the L2 norm of the parameters. Under such a model, there's a simplicity benefit to splitting all computation as equally as possible across all available parameters.
However, it's very unclear if any implicit neural network simplicity prior would actually behave that way. In particular, if we just change our simple model to use the L1 norm instead of the L2 norm, the benefit of splitting computation evenly across parameters goes away. Clearly, that's not a great model either—but it goes to show that there are plausible models under which bit-splitting isn't even a problem in the first case.
Fundamentally, what we need to do to alleviate the bit-splitting problem is look for architectures where θ1 and θ2 are allowed to grow as large as they want with only linear complexity penalties for doing so. Though we traditionally train only fixed-size neural networks, if there were a way to train a dynamically-sized neural network, where size was penalized linearly, that would at least seem to solve this problem—and such a modification doesn't seem too difficult to make.
The understandability checking problem
Of the remaining problems, this is the one that I think is most likely to turn out to not actually be a problem—in that it doesn't require any changes to the prior to solve. In many ways, the understandability checking problem is just coming from our inability to define an H_understands check that's simple in terms of f+ and world_model−H—but at least intuitively, it really does seem like such a simple H_understands implementation should exist.
The event-avoidance problem
In Paul's original post, he notes:
Essentially, what Paul is arguing here is that, going back to the original formulation, we should add a stop gradient operation to our generative procedure such that we get θ1∼p(θ1)θ2∼p(θ2)⋅I[∀(x,q,_)∈D. stop_grad(Mθ1,θ2|f1(x,q))=Mθ1,θ2|f2(x,q)]. Unfortunately, this isn't actually a generative procedure anymore—the only straightforward way to make sense of this is to translate it into a loss function and think about what SGD would do, which makes this prior much harder to analyze—as Paul notes when he talks about his trepidation of going down this route.
Hopefully, however, adding this stop gradient should solve the event-avoidance problem, since it should remove the incentive to modify f1=f+ to make the world simpler (e.g. preventing E from happening).
Fortunately, despite Paul's prior worries about not being able to translate this trick into the program search case, it turns out it's not too difficult to convert Paul's stop gradient formulation into a generative prior that we can analyze as we have been doing previously. Furthermore, there are many such translations—in particular, we can pick either a dataset-less or non-dataset-less version, depending on whether we prefer being exposed to the dataset conditioning problem or the rocks for brains problem. In this case, I'll present a dataset-less formulation, which should solve dataset conditioning but unfortunately won't solve rocks for brains.
Now, given some simplicity prior p, we'll define p∗ as the prior induced by the following generative procedure:
First, note the uniform sampling in step 2. This is the trickiest part of this algorithm—and probably impossible to do in a neural network setting. Technically, the model space of a neural network is finite, so uniform sampling is possible—though the obvious algorithm of sampling uniformly from weight space doesn't work, since the parameter-function map is biased towards simplicity. Hopefully, just using a stop gradient will replicate what this algorithm is doing—and intuitively it does feel like they're doing the same thing—but I don't have any strong justification for that. Thus, it is somewhat unclear whether this procedure will actually have a nice translation when we convert back into an ML setup with stop gradients, but we'll put that aside for now.
Second, to make sense of our use of M above, we'll let M have three heads, f1, f2, and fguard, where fguard is a Boolean head that “guards” whether or not we do the f1=f2 check, and f1,f2 are two output heads. Furthermore, we'll let M have two groups of parameters, θ1 and θ2, such that Mθ1, θ2(x,q)=Mθ1, _(θ2,x,q). Thus, we'll let Mθ1, θ2a(x,q)=Mθ1, _(θ2a,x,q) and Mθ1, θ2b(x,q)=Mθ1, _(θ2b,x,q).
Intuitively, what's going on here is that, in sampling a new set of θ2b parameters, wherever we want a stop gradient, we use the old θ2a, and wherever we don't want a stop gradient, we use the new θ2b—thus ensuring that we only sample θ2b in such a way that incentivizes the prior to be higher through the non-stop-gradiented places where θ2b shows up, not through the stop-gradiented places where we use θ2a.
Now, we'll define our desired θ+1 in pseudocode as
where θ+1 uses θ2, either θ2a or θ2b, to give it its human model.
Then, we have to determine what θ2 will be favored given the above θ+1. First, consider θ+2=H. In that case, we get the complexity complexity(θ2=H | θ+1)≈complexity(H | W−H, H_understandsθ2a=H→f+θ2a=H=f−θ2b)≈complexity(H | W−H)−minθ2b{complexity(θ2b | W−H) | ∀X. H_understandsH=H→f+H=H=f−H=θ2} which, assuming that conditioning on H_understands→f+=f− exactly pays back complexity(f+) (which is false due to the rocks for brains problem, but we're just trying to solve event-avoidance here), reduces to ≈complexity(H | W−H)−complexity(f+ | W−H).
Now, consider θ−2=H−H(E)+avoids(E), as in the event-avoidance problem. In that case, we get the complexity complexity(θ2=H−H(E)+avoids(E) | θ+1)≈complexity(H−H(E)+avoids(E) | W−H, H_understandsθ2a=H−H(E)+avoids(E)→f+θ2a=H−H(E)+avoids(E)=f−θ2b) but then, since avoids(E) being in θ2b is entirely unhelpful in making H_understandsθ2a=H−H(E)+avoids(E)→f+θ2a=H−H(E)+avoids(E)=f−θ2b hold—since it only affects f+, which already has avoids(E) in its H—we get ≈avoids(E) | W−H)+complexity(H−H(E) | W−H, H_understandsθ2a=H−H(E)+avoids(E)→f+θ2a=H−H(E)+avoids(E)=f−θ2b)+complexity(avoids(E) | W−H, H−H(E))≈complexity(H−H(E)+avoids(E) | W−H)−minθ