Sho Yaida


Sorted by New

Wiki Contributions


Thank you for the discussion!

Let us start by stressing that, of course, the maximal-update parametrization is definitely an intriguing recent development, and it would be very interesting to find tools to be able to understand the strongly-coupled regime in which it resides.

Now, it seems like there are two different issues tangled in this discussion: (i) is one parameterization "better" than another in practice?; and (ii) is our effective theory analysis useful in practically interesting regimes?

  1. The first item is perhaps more an empirical question, whose answer will likely emerge in coming years. But, even if maximal-update parametrization turns out to be universally better for every task, its strongly-coupled nature makes it very difficult to analyze, which perhaps makes it more problematic from a safety/interpretability perspective.
  2. For the second item, we hope we will address concerns in the details of our reply below. 

We'd like to also emphasize that, even if you are against NTK parameterization in practice and don't think it's relevant at all -- a position we don't hold, but maybe one might -- perhaps it's still worth pointing out that our work provides a simple solvable model of representation learning from which we might learn some general principles that may be applicable to safety and interpretability.

With those said, let us respond to your comments point by point.

Aren't Standard Parametrisation and other parametrisations with a kernel limit commonly used mostly in cases where you're far away from reaching the depth-to-width≈0 limit, so expansions like the one derived for the NTK parametrisation aren't very predictive anymore, unless you calculate infeasibly many terms in the expensive perturbative series?
As far as I'm aware, when you're training really big models where the limit behaviour matters, you use parametrisations that don't get you too close to a kernel limit in the regime you're dealing with. Am I mistaken about that?

We aren't sure if that's accurate: empirically, as nicely described in Jennifer's 8-page summary (in Sec. 1.5), many practical networks -- from a simple MLP to the not-very-simple GPT-3 -- seem to perform well in a regime where the depth-to-width aspect ratio is small (like 0.01 or at most 0.1). So, the leading-order perturbative description would be fairly accurate for describing these practically-useful networks.

Moreover, one of the takeaways from "effective theory" descriptions is that we understand the truncation error: in particular, the errors from truncation will be of order (depth-to-width aspect ratio)^2. So this means we can estimate what we would miss by truncating the series and learn that sometimes -- if not most of the time -- we really don't have to compute these extra terms.

As for NTK being more predictable and therefore safer, it was my impression that it's more predictive the closer you are to the kernel limit, that is, the further away you are from doing the kind of representational learning AI Safety researchers like me are worried about. As I leave that limit behind, I've got to take into account ever higher order terms in your expansion, as I understand it. To me, that seems like the system is just getting more predictive in proportion to how much I'm crippling its learning capabilities.

It is true that decreasing the depth-to-width aspect ratio reduces the representation-learning capability of the network and -- to the extent that representation learning is useful for the task -- doing so would degrade the performance. But (i) let us reiterate that, as alluded to above, empirically networks seem to operate well in the perturbative regime where the aspect ratio is small and (ii) the converse is not true (i.e., it is not beneficial to keep increasing the aspect ratio indefinitely), as we illustrate in responding to the following point.

Yes, of course NTK parametrisation and other parametrisations with a kernel limit can still learn features at finite width, I never doubted that. But it generally seems like adding more parameters means your system should work better, not worse, and if it's not doing that, it seems like the default assumption should be that you're screwing up.

Actually, that last point is not always the case. One of the results from our book is that while increasing the depth-to-width ratio leads to more representation learning, it also leads to more fluctuations in gradients from random seed to random seed. Thus, the deeper your network is for fixed width, the harder it is to train, in the sense that different realizations will not only behave differently, but also will likely not be critical (i.e., it will not be on what is sometimes referred to as the "edge of chaos" and it will suffer from exploding/vanishing gradients). And this last observation is true for both the NTK parametrization and maximal-update parametrization, so by your logic, we would be screwing up no matter which parametrization we use. :)

As it turns out, this tradeoff between the benefit of representation learning and the cost of seed-to-seed fluctuations leads to the concept of the optimal aspect ratio where networks should perform the best. Empirical results indirectly indicate that this optimal aspect ratio may be in the perturbative regime; in the Appendix of our book, we also did a calculation using tools from information theory that gives evidence that the optimal depth-to-width ratio is in the perturbative regime.

If it was the case that there's no parametrisation in which you can avoid converging to a trivial limit as you heap on more parameters onto the width of an MLP, that would be one thing, and I think it'd mean we'd have learned something fundamental and significant about MLP architectures. But if it's only a certain class of parametrisations, and other parametrisations seem to deal with you piling on more parameters just fine, both in theory and in practice, my conclusion would be that what you're seeing is just a result of choosing a parametrisation that doesn't handle your limit gracefully. Specifically, as I understood it, Standard parametrisation for example just doesn't let enough gradient reach the layers before the final layer if the network gets too large. As the network keeps getting wider, those layers are increasingly starved of updates until they just stop doing anything altogether, resulting in you training what's basically a one layer network in disguise. So you get a kernel limit.

We don't think this is the case. Both NTK and maximal-update parametrizations can avoid converging to kernel limits and can allow features to evolve: for the NTK parametrization, we need to keep increasing the depth in proportion to the width; for the maximal-update parametrization, we need to keep the depth fixed while increasing the width. 


Sho and Dan