Context: I like to approach various ML phenomena from the perspective of spatial transformations. Compared to the classic engineering approach, this approach provides such a sudden "bird's eye view" of what's happening in models that, after developing it, I felt like I had a double-descent by myself. Inspired by this realization, I decided to write this post for those familiar with the basics but not deeply familiar with ML spaces in general.
Most key phenomena in ML can be understood as properties and dynamics of representations in spaces, rather than as algorithmic or statistical artifacts.
This perspective allows us to see common causes for seemingly unrelated effects. To get you started, I'll list some phenomena that "Space View" captures well:
First and foremost, it is an empirical understanding that literally every element you interact with during model training is a space of some sort. Training weights? Here is your weight space; go find the optimal configuration within it using backpropagation. Using data? Here you go: the data space, where the structure directly influences the descent. Not using data? No big deal, interacting with an environment counts too, while the feature space and the space of useful features imply specific computational costs and model characteristics required to capture them. You are optimizing a function located within a functional space (realizable), through which you can descend via boosting, and at the output, you get logits. Descending along these allows you to obtain radically different responses to safe/unsafe prompts with only small changes in the representation space.
These spaces are everywhere, you couldn't hide from them even if you tried.
If we examine the interconnections and transformations between different spaces, we can construct a diagram like this:
To keep this post from turning into a lecture, I will omit the description of each individual space. For the purpose of this introduction, we are primarily interested in the space of realizable functions of the model. As is well known, gradient boosting performs a functional descent, so that from the space of functions the model can implement, at each step we select and add those that locally minimize our loss in the best possible way[1].
Unfortunately, in reality, reaching the global optimum is hindered by a multitude of hidden factors. For instance, the fact that a model can theoretically realize a function does not guarantee that it will consistently achieve it given sufficient resources (hereinafter: data, time, model size). Architectural features, as well as optimizer properties, introduce errors — bias and implicit bias, respectively — that fundamentally prevent achieving the goal.
Looking at the picture from another angle, another fundamental constraint comes from the data. Since the data represents just a single specific point in the probability space (unless defined otherwise), judging the space based on it will inevitably involve error. As is known, error can be decomposed into bias + variance + noise, and in this specific case, we are dealing with noise.
However, regarding variance, things are not so obvious.
Variance characterizes the algorithm's sensitivity to changes in data or how well our learning algorithm can find the best model within the class.
In our case, it feels like we aren't using the word "space" enough here.
Let's take another look at the functional space. The implicit bias of gradient descent lies in the fact that when it is impossible to further lower the loss (i.e., reaching a flat minimum), it constructs a solution possessing minimal curvature or weight norm. In cases where the model is underparameterized, or the target function is not within the reachable space, the descent will form a highly curved function to satisfy all training examples. That is, a solution possessing high variance.
The most illustrative example of this curve smoothing is the phenomenon of under- and over-parameterization.
With a shortage of parameters but an abundance of other resources, the model will overfit, resulting in the pattern described above.
When there are just enough parameters to learn strictly one solution, the curve will still be excessively curved, but it will reach the true minimum.
As we increase the parameters further, we encounter a situation where the model has already reached the plane of min(bias), but now gains the opportunity to reduce variance as well by decreasing curvature. Whole neighborhoods of min(bias) solutions become accessible on this plane, where small changes in data allow the loss to remain stable (the very definition of variance!). Moreover, since these neighborhoods grow larger as parameters increase (as does the area of admissible solutions), it becomes significantly easier for the model to find them during the optimizer’s exploration [2].
From the perspective of function space, by the way, minimum norm also implies that boosting attempts to construct the straightest possible path to the min(loss) plane effectively, i.e. a perpendicular to the solution plane.
This is exactly what I mean by “Space view.” The concept of error decomposition simply becomes an aggregate of consequences arising from the interaction of various spaces.
Briefly, papers on deep/multiple descent explain that:
a) The onset of double descent also depends on training time and data volume.
b) Double Descent is a special case of multiple descent, the onset of which depends on the heterogeneity of feature learning, when there is a significant difference in the difficulty of learning them for generalization.
What interests us here is that the phenomenon described above is defined by the feature space. Useful features, unlike data (which the model receives continuously), are learned gradually, persisting in the representation space only if they aid the descent right now. If different representations require different amounts of compute to be acquired due to model characteristics, we see the representation space undergo a transition multiple times, leading to error reduction each time.
In other words, if we target specific model representations, we can significantly increase the error for a specific data cluster or manifold. Essentially, this is exactly what modern unlearning methods do by utilizing activations and logits in their loss functions. Moreover, strictly speaking, for a certain spectrum of tasks, we can estimate how long it will take to re-learn the target functional and what resources it will require.
Training a model on a task with fundamentally different data means guaranteeing slow convergence and a low probability of reaching double descent. For example, transformers on raw time series, though I’ve dabbled in this myself and there are works where people do the same. If there isn't enough data in the domain, or if we assume that a potential adversary/competitor lacks significant compute power and time, then our unlearning is effective.
Unfortunately, regarding alignment, simple manipulation of activations won't work, because even such simple tasks as "be honest" often require the same representations as contradictory actions. The RMU method proposed an alternative: sending activations on unsafe prompts in random directions. Thus, the model doesn't just correct the output at the logits while keeping representations intact, it actually forgets the corresponding features (finding a way to stop thinking about the bad stuff at any cost, just like us).
However, a problem lurks here too: the network still preserves the target function in the function space, and it likely shares the same nature as the desired behavior. That is, it is extremely unlikely (I’d bet 0.99) that one can pick a model size or training time such that we pass the first descent for one function but not the other; as a result, subsequent fine-tuning would make restoring the target functional a matter of significantly fewer resources than we would like.
In this context, methods that impede fine-tuning offer more room for research, since the functions are too similar to be separated solely by controlling resources into different spaces: reachable and effectively reachable.
We haven't talked much about the parameter space itself, yet there is plenty of interesting stuff happening there too! It is, after all, responsible for the notorious variance too.
As you know, linear spaces are invariant to affine transformations in the sense that they preserve all information after them. Consequently, if a model is non-linear and possesses an implicit bias like ReLU, it will take longer to adapt to OOD (Out-of-Distribution) data. Moreover, this isn't due to the volume of computation, but because there is no simple action that allows for the painless modification of the majority of representations without breaking the fragile structure of parameters.
Representations in non-linear models turn out to be localized in such fragile groups of neurons, which is quite amusing, since in the representation space, they tend to position themselves as orthogonally as possible to each other. This tendency works for uncorrelated features and allows them to influence each other less during training. As a side effect, this further increases the model's capacity due to biases (interference arises between features, I'll omit the details for brevity).
Thus, from the perspective of the "Space view," in exchange for some variance and instability on OOD data when transitioning from a linear model to a non-linear one, we gain the ability to realize a different class of functions and additional parameter capacity. And yet, it seemed like such a trifle. I’d bet (0.7) that this can be leveraged in real-world examples somehow.
I find it extremely fascinating how looking at the interaction of spaces allows us to unify a multitude of seemingly loosely connected concepts. I hope the reader found this interesting. Feel free to write to me in the comments/DM to share other properties of interest to you that are conveniently represented through the "Space view."
When training neural networks, descent in parameter space approximates functional descent, so the overall picture remains the same. I decided that describing the degeneracy of the weight-to-function mapping was unnecessary, since norm minimization applies to any descent, and for now, we can do without singular theory.
From the perspective of Singular Learning Theory, this is explained by the model's effective dimensionality; i.e., how many parameters actually influence the solution in the neighborhood of the loss minimum, but the core essence remains the same.