Alfred Harwood and I were working through this as part of a Dovetail project and unfortunately I think we’ve found a mistake. The Taylor expansion in Step 2 has the 3rd order term . This term should disappear as goes to zero, but this is only true if stays constant. The transformation in Part 1 reduces (most terms of) and at the same rate, so decreases at the same rate as . So the 2nd order approximation isn’t valid.
For example, we could consider two binary random variables with probability distributions
and and and .
If , then as .
But consider the third order term for which is
This is a constant term which does not vanish as .
We found a counterexample to the whole theorem (which is what led to us finding this mistake), which has , and it can be found in this colab. There are some stronger counterexamples at the bottom as well. We used sympy because we were getting occasional floating point errors with numpy.
Sorry to bring bad news! We’re going to keep working on this over the next 7 weeks, so hopefully we’ll find a way to prove a looser bound. Please let us know if you find one before us!
I plan to spend today digging into this, and will leave updates under this comment as I check things.
(Update 0)
I'm starting by checking that there's actually a counterexample here. We also found some numerical counterexamples which were qualitatively similar (i.e. approximately-all of the weight was on one outcome), but thought it was just numerical error. Kudos for busting out the sympy and actually checking it.
Looking at the math on that third-order issue... note that the whole expansion is multiplied by . So even if , itself will still go to zero for small , so will go to zero. So it's not obviously a fatal flaw, though at the very least some more careful accounting would be needed at that step to make sure everything converges.
(Update 1)
We've looked at the code and fiddled with the math and are now more convinced of the issue.
The 2nd order approximation holds when ...Which our scaling-down construction does not provide. So, (among a host of other things,) we are now thinking about other ways to try wrangling the bound into a euclidean space or otherwise into some form that is similarly "easy" to work with.
(Thanks for finding this!)
(Update 2)
Taking the limit of the ratio of s (using summation rather than max) with while gives
Setting c very small and ramping up r indeed brakes the bound more and more severely. (Code changes from the collab you provided, below.)
Code changes / additions
Block 1:
a,b,c,d,r = sp.symbols("a b c d r")
variable_substitutions = { # The definitions of these variables
a: 0.25,
b: 1e-90,
c: 1e-91,
r: 20000000,
}
Block 2 (later on):
expr = (kl3/(kl1 + kl2)).subs(d, (1-3*c-(r+1)*b-2*a))
print("KL(X2->X1->L')/sum[KL(X1->X2->L),KL(X2->X1->L)]=",(kl3/(kl1 +kl2)).evalf(subs=variable_substitutions))
Block3 (right after Block 2):
expr = (kl3/(kl1 + kl2)).subs(d, (1-3*c-(r+1)*b-2*a)).subs(b, 10*c)
lim = sp.simplify(sp.limit(expr, c, 0))
print("Limit of KL(X2->X1->L')/sum[KL(X1->X2->L),KL(X2->X1->L)] as c->0+ =", lim)
(Update 3)
We're now pursuing two main threads here.
One thread is to simplify the counterexamples into something more intuitively-understandable, mainly hopes of getting an intuitive sense for whatever phenomenon is going on with the counterexamples. Then we'd build new theory specifically around that phenomenon.
The other thread is to go back to first principles and think about entirely different operationalizations of the things we're trying to do here, e.g. not using diagram 's as our core tool for approximation. The main hope there is that maybe isn't really the right error metric for latents, but then we need to figure out a principled story which fully determines some other error metric.
Either way, we're now >80% that this is a fundamental and fatal flaw for a pretty big chunk of our theory.
(Update 4)
We have now started referring to "Jeremy et Al" when discussing the findings at top-of-thread, and find this amusing.
As of this morning, our current thread is an adjustment to the error measure. Thinking it through from first principles, it makes intuitive sense to marginalize out latents inside a , i.e. rather than (where is typically some factorization of ). Conceptually, that would mean always grounding out errors in terms of predictions on observables, not on mind-internal latent constructs. We're now checking whether that new error gives us the properties we want in order to make the error measure useful (and in the process, we're noticing what properties we want in order for the error measure to be useful, and making those more explicit than we had before).
(Update 5)
A conjecture we are working on which we expect to be generally useful beyond possibly rescuing the stoch->det proof that used to rely on the work in this post:
Chainability (Conjecture):
with , ,
Define , and .
Then,
Here is a collab with a quick numerical test that suggests the bound holds (and that n=1, in this case).
(Note: The above as written is just one step of chaining, and ultimately we are hoping to show it holds for arbitrarily many steps, accumulating an associated number of epsilons as error.)
(Update 6)
Most general version of the chainability conjecture (for arbitrary graphs) has now been falsified numerically by David, but the version specific to the DAGs we need (i.e. the redundancy conditions, or one redundancy and the mediation condition) still looks good.
Most likely proof structure would use this lemma:
Lemma
Let be nonexpansive maps under distance metric . (Nonexpansive maps are the non-strict version of contraction maps.)
By the nonexpansive map property, . And by the triangle inequality for the distance metric, . Put those two together, and we get
(Note: this is a quick-and-dirty comment so I didn't draw a nice picture, but this lemma is easiest to understand by drawing the picture with the four points and distances between them.)
I think that lemma basically captures my intuitive mental picture for how the chainability conjecture "should" work, for the classes of DAGs on which it works at all. Each DAG would correspond to one of the functions . where takes in a distribution and returns the distribution factored over the DAG , i.e.
In order to apply the lemma to get our desired theorem, we then need to find a distance metric which:
The first two of those are pretty easy to satisfy for the redundancy condition DAGs: those two DAG operators are convex combinations, so good ol' Euclidean distance on the distributions should work fine. Making it match at is trickier, still working that out.
(Update 7)
After some back and forth last night with an LLM[1], we now have a proof of "chainability" for the redundancy diagrams in particular. (And have some hope that this will be most of what we need to rescue the stochastic->deterministic nat lat proof.)
Let P be a distribution over , , and .
Define:
Where you can think of Q as 'forcing' P into factorizing per one redundancy pattern: , S as forcing the other pattern: , and R as forcing one after the other: first , and then .
The theorem states,
,
Or in words: The error (in from ) accrued by applying both factorizations to P, is bounded by the the sum of the errors accrued by applying each of the factorizations to P, separately.
The proof proceeds in 3 steps.
Pf.
Let
Let
By the log-sum inequality:
as desired.
Pf.
Combining steps 1 and 2,
which completes the proof.
Notes:
In the second to last line of step 2, the expectation over is allowed because there are no free 's in the expression. Then, this aggregates into an expectation over as .
We are hopeful that this, thought different than the invalidated result in the top level post, will be an important step to rescuing the stochastic natural latent => deterministic natural latent result.
A (small) positive update for me on their usefulness to my workflow!
Additional note which might be relevant later: we can also get proof step 1 in a somewhat more general way, which establishes that the function is a nonexpansive map under . We'll write that proof down later if we need it.
our current thread is an adjustment to the error measure.
We're not sure that this is necessary. I quite like the current form of the errors. I've spent much of the past week searching for counterexamples to the ∃ deterministic latent theorem and I haven't found anything yet (although it's partially a manual search). My current approach takes a P(X_1,X_2) distribution, finds a minimal stochastic NL, then finds a minimal deterministic NL. The deterministic error has always been within a factor of 2 of the stochastic error. So currently we're expecting the theorem can be rescued.
rather than
That seems like a cool idea for the mediation condition, but Isn't it trivial for the redundancy conditions?
That seems like a cool idea for the mediation condition, but Isn't it trivial for the redundancy conditions?
Indeed, that specific form doesn't work for the redundancy conditions. We've been fiddling with it.
Would this still give us guarantees on the conditional distribution ?
E.g. Mediation:
is really about the expected error conditional on individual values of , & it seems like there are distributions with high mediation error but low error when the latent is marginalized inside , which could be load-bearing when the agents cast out predictions on observables after updating on
Oh nice, we tried to wrangle that counterexample into a simple expression but didn't get there. So that rules out a looser bound under these assumptions, that's good to know.
but thought it was just numerical error
I was totally convinced it was a numerical error. I spent a full day trying to trace it in my numpy code before I started to reconsider. At that point we'd worked through the proof carefully and felt confident of every step. But we needed to work out what was going on because we wanted empirical support for a tighter bound before we tried to improve the proof.
Do you have sympy code for the example noted at the bottom of the collab that claims a ratio of > 9.77 including the mediation ? I tried with the parameters you mention and am getting a ratio of ~3.4 (which is still a violation of previous expectations, tbc.)
That'll be the difference between max and sum in the denominator. If you use sum it's 3.39.
Here's one we worked out last night, where the ratio goes to infinity.
By the way, there seems to be an issue where sympy silently drops precision under some circumstances. Definitely a bug. A couple of times it's caused non-trivial errors in my KLs. It's pretty rare, but I don't know any way to completely avoid it. Thinking of switching to a different library.
Part 1 feels like magic. I don't understand it at an intuitive level and so I'm kinda suspicious of it. It seems like such a powerful technique for working with KL divergences. I'll spend some more time playing around with it. Everything else makes sense to me.
My question is how did you come up with this technique? Was "small KL inequalities can be equivalent to larger KL inequalities" a background fact that you knew beforehand? Or did you start by wanting to find a way to make the Hellinger distances work?
It sure does feel like a powerful technique! We haven't explored much how to generalize it yet, though.
At the time, we were thinking about the optimization problem "max (the one error) subject to (constraint on other errors)", and what the curve looks like which gives the max value as a function of the constraint errors. One (of many) angles I tried was to consider ways of transforming a latent, which would move it from one point in the feasible set to another point in the feasible set. And once I asked that question, basically the first thing I tried was the transformation in the proof which just scales down all the errors.
At that point we had already done the Hellinger distances thing (also among many other things), on the general principle of "try it in the second order regime before trying to prove globally", so it was just a matter of connecting the pieces together.
Suppose random variables and contain approximately the same information about a third random variable , i.e. both of the following diagrams are satisfied to within approximation :
We call a "redund" over , since conceptually, any information contains about must be redundantly represented in both and (to within approximation).
Here's an intuitive claim which is surprisingly tricky to prove: suppose we construct a new variable by sampling from , so the new joint distribution is
By construction, this "resampled" variable satisfies one of the two redundancy diagrams perfectly: . Intuitively, we might expect that approximately satisfies the other redundancy diagram as well; conceptually, (approximately) only contains redundant information about , so contains (approximately) the same information about as does, so the resampling operation should result in (approximately, in some sense) the same distribution we started with and therefore (approximately) the same properties.
In this post, we'll prove that claim and give a bound for the approximation error.
Specifically:
Theorem: Resampling (Approximately) Conserves (Approximate) Redundancy
Let random variables , satisfy the diagrams and to within , i.e.
Also, assume .
Construct by sampling from , so . Then is perfectly satisfied by construction, and is satisfied to within , i.e.
In diagrammatic form:
We will use the shorthand to mean . For instance, is shorthand for , which is equivalent to .
We will work with nats for mathematical convenience (i.e. all logarithms are natural logs).
The proof proceeds in three steps:
First we construct the new variable as a stochastic function of . Specifically, with probability , else is a constant , where is outside the support of (so when we see , we gain no information about ).
A little algebra confirms that 's errors are simply 's errors scaled down by :
Similarly, constructing just like (i.e. ) is equivalent to constructing as a stochastic function of where with probability , else is . So, by the same algebra as above,
The upshot: if there exists a distribution over variables for which
then there also exists a distribution satisfying the same inequality with all 's arbitrarily small[1]. Flipping that statement around: if there does not exist any distribution for which the 's are all arbitrarily small and the inequality is satisfied, then there does not exist any distribution for which the inequality is satisfied.
In other words: if we can show
in the regime where all the 's are arbitrarily small, then the same inequality is also established globally, proving our theorem. The rest of the proof will therefore show
in the regime where all the 's are arbitrarily small. In particular, we'll use a second order approximation for the 's.
Before we can use a second order approximation of the 's, we need to show that small implies that second order approximationis valid.
For that purpose, we use the Hellinger-KL inequality:
where is the squared Hellinger distance.[2]
Using standard logarithm inequalities, we can weaken the Hellinger-KL inequality to
So, as goes to 0, the Hellinger distance goes to 0, and therefore and are arbitrarily close together in standard Euclidean distance. Since is smooth (for strictly positive distributions, which we have assumed), we can therefore use a second order approximation (with respect to ) for our arbitrarily small 's.
Now for the second order expansion itself.
Our small quantity is . Then
To simplify that further, we can use the sum-to-1 constraints on the distributions: implies
so . That simplifies our second order approximation to
i.e. in the second order regime is twice the Hellinger distance.
Combining this with Step 1, we've now established that if we can prove our desired bound for Hellinger distances rather than , then the bound also applies globally for errors. So now, we can set aside the notoriously finicky KL divergences, and work with good ol' Euclidean geometry.
Writing everything out in the second order regime, our preconditions say
and we want to bound
That last expression has a Jensen vibe to it, so let's use Jensen's inequality.
We're going to use Jensen's inequality on the squared Hellinger distance, so we need to establish that squared Hellinger distance is convex as a function of the distributions .
Differentiating twice with respect to yields the Hessian
Note that one column is the other column multiplied by , so one of the eigenvalues is 0. The trace is positive, so the other eigenvalue is positive. Thus, the function is (non-strictly) convex.
Now, we'll use Jensen's inequality on
Specifically:
So, applying Jensen's, we get
With that, we have bounds on three (squared) Hellinger distances:
So, on average over :
So, on average, the Euclidean distance from end-to-end, between and , is at most .
That gives us the desired bound:
implying
Combined with our previous two sections, that establishes the desired upper bound of on .
Below is a plot of the maximal error achieved via numerical minimization of subject to a constraint on, searching over distributions of X and . Above, we proved that the ratio of those two quantities can be no higher than . As expected from the proof, it is visually clear that each point on the curve lies on a line between itself and the origin which itself always lies below the curve. Some, presumably, noise is present due to occasional failures of the optimizer to find the maximum error.
Zooming in on the steepest part of the curve and eyeballing the plot, it looks like the maximum ratio achieved is around 4 (.02/.005), implying an empirical upper bound of the resampled diagram of ~8:
Looking into the actual solutions found, the solutions with ratio of ~4 involve one of the two terms in the x-axis sum being much larger than the other (5-10x). Therefore we expect to be able, in principle, to get a tighter bound (~4, empirically, rather than the proven 9 or empirical 8.) The most likely place for improvement in the proof is to bound the Hellinger distance between and directly by , cutting one step out of the "path", and that would indeed reduce the bound from 9 to 4. We'll leave that for future work.
Interesting additional piece for future reference: If we include the mediation condition into the denominator, and so look for a bound in terms of a factor of the sum of all natural latent condition epsilons, we find that the empirical factor in question is 1 (roughly. Not sure what happened at ~0.4):
Note that, since by assumption, none of the 's are infinite. This is the only place where we need ; that assumption can probably be eliminated by considering the infinite case directly, but we're not going to do that here.
A quick aside: while it might look messy at first, the Hellinger distance is a particularly natural way to talk about Euclidean distances between probability distributions. In general, if one wants to view a distribution as a vector, is the most natural vector to consider, since the sum-to-1 constraint says is a unit vector under the standard Euclidean distance.