Alfred Harwood and I were working through this as part of a Dovetail project and unfortunately I think we’ve found a mistake. The Taylor expansion in Step 2 has the 3rd order term . This term should disappear as goes to zero, but this is only true if stays constant. The transformation in Part 1 reduces (most terms of) and at the same rate, so decreases at the same rate as . So the 2nd order approximation isn’t valid.
For example, we could consider two binary random variables with probability distributions
and and and .
If , then as .
But consider the third order term for which is
This is a constant term which does not vanish as .
We found a counterexample to the whole theorem (which is what led to us finding this mistake), which has , and it can be found in this colab. There are some stronger counterexamples at the bottom as well. We used sympy because we were getting occasional floating point errors with numpy.
Sorry to bring bad news! We’re going to keep working on this over the next 7 weeks, so hopefully we’ll find a way to prove a looser bound. Please let us know if you find one before us!
I plan to spend today digging into this, and will leave updates under this comment as I check things.
(Update 0)
I'm starting by checking that there's actually a counterexample here. We also found some numerical counterexamples which were qualitatively similar (i.e. approximately-all of the weight was on one outcome), but thought it was just numerical error. Kudos for busting out the sympy and actually checking it.
Looking at the math on that third-order issue... note that the whole expansion is multiplied by . So even if , itself will still go to zero for small , so will go to zero. So it's not obviously a fatal flaw, though at the very least some more careful accounting would be needed at that step to make sure everything converges.
(Update 1)
We've looked at the code and fiddled with the math and are now more convinced of the issue.
The 2nd order approximation holds when ...Which our scaling-down construction does not provide. So, (among a host of other things,) we are now thinking about other ways to try wrangling the bound into a euclidean space or otherwise into some form that is similarly "easy" to work with.
(Thanks for finding this!)
(Update 2)
Taking the limit of the ratio of s (using summation rather than max) with while gives
Setting c very small and ramping up r indeed brakes the bound more and more severely. (Code changes from the collab you provided, below.)
Code changes / additions
Block 1:
a,b,c,d,r = sp.symbols("a b c d r")
variable_substitutions = { # The definitions of these variables
a: 0.25,
b: 1e-90,
c: 1e-91,
r: 20000000,
}
Block 2 (later on):
expr = (kl3/(kl1 + kl2)).subs(d, (1-3*c-(r+1)*b-2*a))
print("KL(X2->X1->L')/sum[KL(X1->X2->L),KL(X2->X1->L)]=",(kl3/(kl1 +kl2)).evalf(subs=variable_substitutions))
Block3 (right after Block 2):
expr = (kl3/(kl1 + kl2)).subs(d, (1-3*c-(r+1)*b-2*a)).subs(b, 10*c)
lim = sp.simplify(sp.limit(expr, c, 0))
print("Limit of KL(X2->X1->L')/sum[KL(X1->X2->L),KL(X2->X1->L)] as c->0+ =", lim)
(Update 3)
We're now pursuing two main threads here.
One thread is to simplify the counterexamples into something more intuitively-understandable, mainly hopes of getting an intuitive sense for whatever phenomenon is going on with the counterexamples. Then we'd build new theory specifically around that phenomenon.
The other thread is to go back to first principles and think about entirely different operationalizations of the things we're trying to do here, e.g. not using diagram 's as our core tool for approximation. The main hope there is that maybe isn't really the right error metric for latents, but then we need to figure out a principled story which fully determines some other error metric.
Either way, we're now >80% that this is a fundamental and fatal flaw for a pretty big chunk of our theory.
(Update 4)
We have now started referring to "Jeremy et Al" when discussing the findings at top-of-thread, and find this amusing.
As of this morning, our current thread is an adjustment to the error measure. Thinking it through from first principles, it makes intuitive sense to marginalize out latents inside a , i.e. rather than (where is typically some factorization of ). Conceptually, that would mean always grounding out errors in terms of predictions on observables, not on mind-internal latent constructs. We're now checking whether that new error gives us the properties we want in order to make the error measure useful (and in the process, we're noticing what properties we want in order for the error measure to be useful, and making those more explicit than we had before).
(Update 5)
A conjecture we are working on which we expect to be generally useful beyond possibly rescuing the stoch->det proof that used to rely on the work in this post:
Chainability (Conjecture):
with , ,
Define , and .
Then,
Here is a collab with a quick numerical test that suggests the bound holds (and that n=1, in this case).
(Note: The above as written is just one step of chaining, and ultimately we are hoping to show it holds for arbitrarily many steps, accumulating an associated number of epsilons as error.)
(Update 6)
Most general version of the chainability conjecture (for arbitrary graphs) has now been falsified numerically by David, but the version specific to the DAGs we need (i.e. the redundancy conditions, or one redundancy and the mediation condition) still looks good.
Most likely proof structure would use this lemma:
Lemma
Let be nonexpansive maps under distance metric . (Nonexpansive maps are the non-strict version of contraction maps.)
By the nonexpansive map property, . And by the triangle inequality for the distance metric, . Put those two together, and we get
(Note: this is a quick-and-dirty comment so I didn't draw a nice picture, but this lemma is easiest to understand by drawing the picture with the four points and distances between them.)
I think that lemma basically captures my intuitive mental picture for how the chainability conjecture "should" work, for the classes of DAGs on which it works at all. Each DAG would correspond to one of the functions . where takes in a distribution and returns the distribution factored over the DAG , i.e.
In order to apply the lemma to get our desired theorem, we then need to find a distance metric which:
The first two of those are pretty easy to satisfy for the redundancy condition DAGs: those two DAG operators are convex combinations, so good ol' Euclidean distance on the distributions should work fine. Making it match at is trickier, still working that out.
(Update 7)
After some back and forth last night with an LLM[1], we now have a proof of "chainability" for the redundancy diagrams in particular. (And have some hope that this will be most of what we need to rescue the stochastic->deterministic nat lat proof.)
Let P be a distribution over , , and .
Define:
Where you can think of Q as 'forcing' P into factorizing per one redundancy pattern: , S as forcing the other pattern: , and R as forcing one after the other: first , and then .
The theorem states,
,
Or in words: The error (in from ) accrued by applying both factorizations to P, is bounded by the the sum of the errors accrued by applying each of the factorizations to P, separately.
The proof proceeds in 3 steps.
Pf.
Let
Let
By the log-sum inequality:
as desired.
Pf.
Combining steps 1 and 2,
which completes the proof.
Notes:
In the second to last line of step 2, the expectation over is allowed because there are no free 's in the expression. Then, this aggregates into an expectation over as .
We are hopeful that this, thought different than the invalidated result in the top level post, will be an important step to rescuing the stochastic natural latent => deterministic natural latent result.
A (small) positive update for me on their usefulness to my workflow!
Additional note which might be relevant later: we can also get proof step 1 in a somewhat more general way, which establishes that the function is a nonexpansive map under . We'll write that proof down later if we need it.
our current thread is an adjustment to the error measure.
We're not sure that this is necessary. I quite like the current form of the errors. I've spent much of the past week searching for counterexamples to the ∃ deterministic latent theorem and I haven't found anything yet (although it's partially a manual search). My current approach takes a P(X_1,X_2) distribution, finds a minimal stochastic NL, then finds a minimal deterministic NL. The deterministic error has always been within a factor of 2 of the stochastic error. So currently we're expecting the theorem can be rescued.
rather than
That seems like a cool idea for the mediation condition, but Isn't it trivial for the redundancy conditions?
That seems like a cool idea for the mediation condition, but Isn't it trivial for the redundancy conditions?
Indeed, that specific form doesn't work for the redundancy conditions. We've been fiddling with it.
Would this still give us guarantees on the conditional distribution ?
E.g. Mediation:
is really about the expected error conditional on individual values of , & it seems like there are distributions with high mediation error but low error when the latent is marginalized inside , which could be load-bearing when the agents cast out predictions on observables after updating on
Oh nice, we tried to wrangle that counterexample into a simple expression but didn't get there. So that rules out a looser bound under these assumptions, that's good to know.
but thought it was just numerical error
I was totally convinced it was a numerical error. I spent a full day trying to trace it in my numpy code before I started to reconsider. At that point we'd worked through the proof carefully and felt confident of every step. But we needed to work out what was going on because we wanted empirical support for a tighter bound before we tried to improve the proof.
Do you have sympy code for the example noted at the bottom of the collab that claims a ratio of > 9.77 including the mediation ? I tried with the parameters you mention and am getting a ratio of ~3.4 (which is still a violation of previous expectations, tbc.)
That'll be the difference between max and sum in the denominator. If you use sum it's 3.39.
Here's one we worked out last night, where the ratio goes to infinity.
By the way, there seems to be an issue where sympy silently drops precision under some circumstances. Definitely a bug. A couple of times it's caused non-trivial errors in my KLs. It's pretty rare, but I don't know any way to completely avoid it. Thinking of switching to a different library.
Part 1 feels like magic. I don't understand it at an intuitive level and so I'm kinda suspicious of it. It seems like such a powerful technique for working with KL divergences. I'll spend some more time playing around with it. Everything else makes sense to me.
My question is how did you come up with this technique? Was "small KL inequalities can be equivalent to larger KL inequalities" a background fact that you knew beforehand? Or did you start by wanting to find a way to make the Hellinger distances work?
It sure does feel like a powerful technique! We haven't explored much how to generalize it yet, though.
At the time, we were thinking about the optimization problem "max (the one error) subject to (constraint on other errors)", and what the curve looks like which gives the max value as a function of the constraint errors. One (of many) angles I tried was to consider ways of transforming a latent, which would move it from one point in the feasible set to another point in the feasible set. And once I asked that question, basically the first thing I tried was the transformation in the proof which just scales down all the errors.
At that point we had already done the Hellinger distances thing (also among many other things), on the general principle of "try it in the second order regime before trying to prove globally", so it was just a matter of connecting the pieces together.