If you (understandably) think "this post is too long, ain't nobody got time for that!" I suggest scrolling to ~the middle of the post and at least check out the Mandelbrot Training Animation(s). I think they're pretty neat.
After going through ~half of the ARENA curriculum, I couldn't stop myself from taking a few days off and just exploring what happens in a small neural network during gradient descent. I'm summarizing that little adventure in this not-so-little post. Hard to say how far my findings and newly won intuitions generalize beyond this particular case, but at the very least it led to some nice-looking pictures.
I trained a small neural network (<100 neuron MLPs) to learn complex number multiplication. Technically, I didn't train one, but a whole bunch of them. Here are some of the questions I wanted to answer:
As this was rather explorative and not super systematic, much of what I did was rather improvised. For instance, I didn't do thorough hyperparameter exploration in any sense, but just manually tried out different numbers (for e.g. batch size, training set size and learning rate) and ended up with a configuration that seemed to perform well enough. As the function I'm training on is quite simple and the networks are small, most training runs finished within a few seconds and the entire project fit well within <$10 worth of cloud compute.
Throughout this post, I'll typically start with a question, giving readers a chance to reflect on their expectations to test their intuition, before summarizing what I found.
Caveat: Pretty much all I did here was vibe-coded using the Google Colab Gemini integration (and occasionally help from Claude whenever Gemini got stuck). So I can't entirely guarantee the code doesn't contain any stupid mistakes and perhaps very occasionally doesn't quite do what I think it does. That being said, most of what I found seems pretty plausible to me and/or I continued investigating up to a point where things clearly worked out the way they should. The exception to this is probably the Loss Landscape section towards the end, which stands on shakier ground. If you want to check the code, or run any of the experiments yourself, please PM me and I'll gladly share the colab notebook.
What is complex number multiplication?
I can imagine many lesswrong readers are broadly familiar with the concept, but for those who are not, here's a very basic introduction to complex numbers. Basically, while real numbers move on a 1-dimensional "ray" of numbers, complex numbers can be interpreted as points in a 2D space (the complex plane). So, a complex number x consists of two components (or coordinates), called its real and its imaginary part, which can be written as rex + imx i, where i is a constant with the property that i² = -1. Here, rex can be interpreted as "pointing in the direction of real numbers", while i (and thereby also imx i) would be orthogonal to that, "pointing in the direction of imaginary numbers". Defining complex numbers like this, addition of two complex numbers is trivial - you can just add real parts and imaginary parts separately (like vector addition in R²). Multiplication is also pretty straightforward when doing it in closed form algebraically: for two complex numbers x and y, you get
(rex + imx i) × (rey + imy i) = rex rey + rex imy i + imx rey i + imx imy i²
But since i² is defined as -1, we can turn this into:
= rex rey + rex imy i + imx rey i - imx imy
= (rex rey - imx imy) + (rex imy + imx rey)i
And so we end up with a new complex number, where its real part (rex rey - imx imy) and imaginary part (rex imy + imx rey) can be straightforwardly calculated with simple addition and multiplication.
One neat "application" of complex numbers is that they allow computing (and thereby rendering visually) some interesting "fractals" like the Mandelbrot set - which we'll do later on.
The neural network(s) that I've trained are not able to express complex number multiplication fully, as they cannot multiply activations but only add them[1]. Hence, what the networks are going to learn, will always just be an approximation of this function rather than the real math behind it.
As I wanted to train a network on complex number multiplication, I needed 4 inputs (real and imaginary part of two complex numbers) and 2 outputs (real and imaginary part of the result). I varied the number and width of hidden layers throughout my experiments.
Some other implementation and training details:
At first, I trained models with a few different hidden layer setups to see how well they perform:
Feel free to come up with some hypotheses on how these will perform, either relative to each other or even what MSE loss to expect. I trained all of these for 150 epochs each (learning rate of 0.01, 4096 * 0.8 training samples).
Unsurprisingly, larger models performed better than smaller ones. SiLU helped a lot. In all cases, train loss was lower than test loss, but they were reasonably close, to the networks generalized well enough (but then again, with a function as smooth as complex number multiplication, it would be surprising if there was no generalization).
Here are the train and test losses I obtained after 150 epochs for each of the models:
Model | Hidden Layers | Train Loss | Test Loss |
0 | [] | 6.2571 | 7.6424 |
1 | [10] ReLU | 3.1443 | 3.7585 |
2 | [10, 10] ReLU | 1.6828 | 2.2726 |
3 | [20, 30, 20] ReLU | 0.1662 | 0.3403 |
4 | [20, 30, 20] SiLU | 0.0225 | 0.1108 |
5 | [20, 30, 20] SiLU (1000 epochs) | 0.0036 | 0.0262 |
I won't show all the details, but here's the exemplary loss curve of model 3 ([20, 30, 20] hidden layers using ReLU):
Eyeballing it, it appears that this particular model went through ~3 phases during training, with the first ~6 epochs going through a very steep drop, then dropping more slowly but ~linearly until epoch ~30-40, and then just very slowly getting closer towards 0. I didn't look more deeply into these phases though. This was also not a general pattern for the other models (although the very broad pattern of steep drop at first that abruptly turns flat holds for all of them).
I likely could have trained any of these models for much longer to reduce the loss further. E.g. the last model, after its 1000 epochs, did not plateau yet. I also could have used more training samples (or dynamically generated training data instead of fixed samples) to improve generalization and reduce the gap between train and test loss.
Taking the case of 2 hidden layers, I was wondering how the neuron count in these layers would affect the loss. So I made a grid of [2…12] x [2…12] and trained (for only 30 epochs each) one model for each combination (the first hidden layer having n neurons, and the second having m neurons, with both n and m ranging from 2 to 12), then plotting the test loss each model achieved.
What's your expectation of what we might find?
I thought that, surely, more is better. I wasn't sure though if what I'd find would be roughly symmetrical or not. If you have, say, 4 neurons in one layer and 10 in the other, would it be better to have the smaller layer first, or the larger one? No idea. I know that CNNs have somewhat of a "christmas tree" shape, starting out wide and then getting smaller and smaller layers, so maybe that is a general pattern, that starting with wide layers and narrowing them down over time makes sense? Perhaps.
In my noisy results, it's clear that indeed more is better, and there's a small effect of "first layer being larger" indeed being the better choice. But I wouldn't want to generalize from this tiny experiment.
After experimenting with some ways to visualize what these functions learned, I ended up doing the following: I created an interactive widget where you can set re + im part of one of the input numbers. It would then render a 2D function, mapping re + im parts of the second input number to a color that represents the output number (using angle + distance for colorization). While this is a bit hard to interpret concretely, it does allow for quick visual comparison of the real function vs what the networks learned. I then also rendered a heatmap to show the diff between the learned and real function.
Here we see the performance of Model 0 (the one without any hidden layers), for some arbitrary first complex number:
Comparing left and middle image, we can see that there is something useful going on in this model - the ordering of colors is clearly similar, with blue at the top, red/purple on the left, green on the right, black in the middle. But it's also very easy to see the difference.
The heatmap on the right shows the diff between the two images on the left.
How would you expect things to develop from here, when looking at the larger models?
Unsurprisingly, model 1 ([10]) looks a bit more accurate:
(Also note that the heatmap on the right, indicating the loss, keeps adjusting its scale to the model - so the heatmaps will have similar color schemes independent of the model's overall loss. The purpose of the heatmap is rather to see the "shape" of the difference between the learned function and reality.)
Model 2 ([10, 10]):
Model 3 ([20, 30, 20]):
Here, it starts becoming easy to mistake the predicted image in the middle with the real thing on the left (perhaps depending on your display contrast). Based on this eyeballed benchmark, the model truly comes really close. But looking at the heatmap on the right, we can see that there are still many "triangle shapes" visible - it appears that the model just stitches together piece-wise approximations of the function, distributed somewhat arbitrarily in space.
We can also see that the "resolution" of these piece-wise approximations is larger than in the earlier functions, which makes sense, as it has way more neurons (and weights) to approximate the function.
Model 4 ([20, 30, 20], SiLU):
Now, the diff heatmap is smooth at last - no obvious stitching of pieces, even though I can imagine that the same thing is still going on in principle, the smoother activation function just makes it harder to see in this rendering. I would still assume that different weights are "responsible" for different parts of space. I didn't look into this though, so maybe I'm wrong.
For completeness, here's model 5 (same model as above, but trained for 1000 instead of 150 epochs):
One of the amazing applications (or, perhaps, "applications") of complex numbers is rendering fractals like the Mandelbrot or Julia set.
So, naturally, I thought: what if I render them, not based on proper complex number multiplication, but using the networks I trained?
What's your expectation? Will it work? Will Mandelbrot be recognizable? If so, for which of the models I trained above?
For reference, here's the actual Mandelbrot, when rendered correctly (with some arbitrary color gradient based on iterations it took to detect divergence):
Rendering it based on Model 0 (no hidden layers), we obtain… something. It's not very good:
It's neither Mandelbrot-like, nor even centered around where it's supposed to be, nor notably "fractal". Zooming in on the slightly more interesting part:
Yeah, not much there. I'll spare you Model 1 ([10]) which is similarly underwhelming.
Model 2 ([10, 10]) certainly has something going on:
But zooming into the interesting part, it's closer to repetitive noise than anything fractal-like:
Model 3 ([20, 30, 20]) is actually the first one that has some distant resemblance of Mandelbrot:
But it's still extremely far off. I found this surprising, as the visual function comparison in the last section looked very promising for this model.
Now, what about Model 4 ([20, 30, 20] with SiLU)? Based on loss, we saw that using SiLU instead of ReLU made a modest difference, reducing test loss roughly to a third for otherwise same training conditions. But will this make a big difference for what the rendered Mandelbrot looks like?
Have a look:
Well, it basically nails it!
And if we zoom in on the middle, the "neck" of the figure:
We can see that even zoomed in, a lot of the detail that one would expect is indeed there.
For comparison, here's what the above viewport should look like:
So we can definitely see some differences, e.g. the one based on Model 4 is much less symmetric - but the basic shapes are all there, just slightly morphed in space.
It's interesting how Model 3 and Model 4 had test loss values that were not all that far apart (0.34 vs 0.11, respectively), yet their Mandelbrot renderings are worlds apart. Why is that?
I don't know for sure, but my loose assumption would be something like this:
The difference between model 3 and 4 is that 3 uses ReLU and hence has all these flat, linear patches, whereas model 4 uses SiLU and is much smoother. The way fractals like Mandelbrot are generated is through an iteration of hundreds of steps where we keep multiplying (in the case of Mandelbrot squaring) complex numbers. Multiplication of complex numbers can be interpreted as scaling + rotation. I suppose that in the model 3 case, where much of the space is flat, this iteration just behaves predictably "linear" within these isolated patches of space and doesn't perform the subtle rotational dance that gives fractals their fractal-ness. So, even though the loss of model 3 and model 4 is not that far apart, model 3 may just lack the "smoothness of space" property that is required for fractals to fractal?
Finally, I've spent some of my compute to visualize the learning process of model 3 and 4 to see what the learned Mandelbrot looks like through the epochs.
Here's the model 3 architecture (for this animation I trained it a bit longer than for the images above):
It really looks as if there's a Mandelbrot inside there that's struggling to push space into its proper shape. Maybe if I'd have let this run for a few hours, we would have gotten somewhere, at least in this zoomed-out view. (Based on my explanation above, I would expect that we'd basically never get a proper fractal that we can zoom into and find "infinite" detail)
It reminds me a bit of a bird struggling to hatch from its egg - and of Ilya Sutskever's claim that "the models just want to learn".
And here we have model 4:
It's interesting how it finds a Mandelbrot-like shape extremely quickly, even while its loss is still much higher than that of the fully trained model 3, and afterwards it's basically just fine-tuning the spatial proportions.
The post carries "gradient descent" in its title, yet so far there was relatively little of that. So let's leave the fancy visualizations behind and look at some numbers and dimensions.
Firstly, I was interested in the question to what degree the random network initialization determines the ultimate performance of the network. Does an initialized network's loss determine its performance after n epochs? Is it worth it creating multiple initializations and checking their loss to achieve better results? Or perhaps, one should start with a bunch of random initializations, then trains for m epochs, and continue with the best one from there?
For this, I settled on a simple model architecture (similar to model 2 above: two hidden layers with [10, 10] neurons) and then created 30 models with random initialization.
So, let's see what their starting test loss distribution looked like:
The initial test loss of the 30 models I set up ranged from roughly 29 to 38 - so some potential for different performance, but not orders of magnitude of difference.
After 10 epochs:
Ranging from roughly 6 to 18.
And after 100:
Ranging from about 1.2 to 3.8.
So there is some difference in performance between the 30 models, but it doesn't seem that huge and they all keep moving within some relatively narrow band of loss values. I'm not sure if they will all converge to the same lower bound if you train them long enough, but that seems certainly possible based on the above data.
However, based on these histograms, we don't know which models end up with the lowest loss. So here's a visualization of the entire process with all 30 models:
It's hard to make out the details here, but one can certainly see that the order of the models changes a lot. The 3 that started out worst soon caught up (although one of them ends up performing worst again later). And the one that started out best did not stay in first place for long.
Another way to look at this data is to check the correlation of the loss of different models after n vs m epochs. E.g., does the loss after 5 epochs predict the loss after 100 epochs to a strong enough degree that it might be worth training multiple models a bit to then pick the best one for a longer training run?
Probably not - while there is a decent correlation of 0.556, it's probably small enough that training multiple models a bit will not give you much of an edge vs just taking any model and throwing all your compute at that one. Again, I would be careful to generalize much beyond my toy setup here as with basically all findings, but still interesting to see.
Next, I wanted to learn more about the loss landscape. What's its geometry like? Is it rough? Smooth? Regular? Chaotic? Fractal-esque? Will I find anything unexpected there?
What I ended up doing is this: At different points in time during training, I'd look along the current gradient and scale the learning rate from, say, -10 to 10, and check what loss I'd end up with for the learning rate scaled with these factors. This would then give me a slice of the loss landscape as a R → R graph. What kind of structure do you think this will yield?
For this, I trained a [10, 10, 10] model (don't ask me why I chose a different one from the earlier experiments - but I don't think it matters that much) for 45 epochs (45 being arbitrary, I could have equally chosen 5 or 3,000).
And this is the "gradient loss slice" I obtained:
So, again, what do we see here? The x axis shows different scaling factors for the gradient of the model at that point in time. The y axis then shows the relative loss of the network, when the current gradient times the learning rate times the given factor is applied to the model weights (relative to the loss of the model without a change). So, 3 things should be pretty certain and are indeed the case here:
Now, what I was surprised by is that this loss landscape is just a parabola (it's not exactly one - it's not perfectly symmetrical; but it clearly is very close to one). My initial expectation was to see some way more chaotic landscape. You know, similar to, say, a slice of a mountain on Earth, or something. Or somewhat like a random walk where we only locally know that right in front of us it should descend, but otherwise behaves erratically. But, no, it's extremely regular and boring. Is this just due to the function we're approximating? After all, complex number multiplication itself is also pretty regular, smooth and boring. So maybe that's just the reason?
I went ahead and came up with some cursed and entirely useless function (as a replacement for complex number multiplication, just to see what its loss landscape will look like), optimized purely for the purpose of leading to more interesting geometry:
def interesting_function(a, b, c, d):
v1 = np.sin(np.pow(np.abs(a),np.abs(b)))
v2 = np.fmod(d, (b - a))
v3 = np.min([c, d]) / np.max([c, d])
v4 = np.fmod(a * b, v3)
return np.mod(np.abs(v1 * c + (1 - c) * v2), 3), np.mod(np.abs(v3 * v4), 3)
It doesn't make much sense - I just semi-randomly threw together a bunch of mathematical operators to get some high-frequency irregular behavior out of it.
Visualizing it, I can't say I'm not happy with how the function turned out:
(The model here clearly struggles to mirror that function, but we can still see it made a bit of progress during training)
So, will this much less regular function also have a much less regular loss landscape? Or will we still just get something parabola-like?
It turns out that even for this function, my gradient slice looks rather boring:
Perhaps it's just the direction of the gradient that's boring and regular? Maybe random other directions yield more interesting geometry?
Well, I went ahead and plotted a 2D function, with two (more or less[2]) random directions in weight space, and looked at how a linear combination of them both would affect the loss. And I basically got a bowl:
(This now shows absolute, not relative loss, hence the low point does not have a value much below 0)
To be honest, when I went into this experiment, I expected something more in this direction (but with more dimensions):
I assumed the loss landscape would be quite erratic and complex, and beyond the local knowledge about "gradient going down here" anything could happen. But it appears that, at least for the two functions I looked at, the loss landscape is indeed very regular and "boring". If someone has an intuitive understanding for why that should be the case, I'd be very interested to hear it!
Now, I should mention, I'm much less confident in the results of this section than the others, because there are a few things that don't quite add up for me:
So, ummm… yeah. But, on the other hand: asking Claude Opus 4.1 in a very non-leading way what type of geometry it would expect from the above experiment, it did predict: something that looks like a parabola. So, apparently, this result is not actually surprising, and my expectation of some interesting geometry hiding in this high-dimensional loss landscape was just incorrect.
I was wondering about one other thing - how do parameters actually move through the high-dimensional parameter space? I could think of at least two options here. As I trained a model (now again on complex number multiplication) and the loss more and more converged towards its lower bound, I could either observe:
Here's what I found:
These are all weights and biases of a model with hidden layers of sizes [6, 8, 6] that ends up with 154 parameters (4*6 + 6*8 + 8*6 + 6*2 = 132 weights, and 6+8+6+2 = 22 biases).
So, does this look more like option 1 or option 2? Well, hard to say, but perhaps a bit closer to option 1. While most parameters have "settled down" after the first few hundred epochs, clearly some of them are still on a mission, though. I was also a bit surprised how smoothly many of them are moving. The above was with a learning rate of 0.01. Increasing it to 0.03, I got the following:
It still seems to be the case that many parameters are systematically drifting into one direction, albeit now in a more noisy fashion. It gives me "just update all the way bro!" vibes - but perhaps this is a bit of a coordination problem: multiple parameters have to move more or less in unison, and none of them can reasonably update further even though the longer term trajectory is clear. Hence, increasing the learning rate doesn't appear to help much with arriving at the resulting configuration that much sooner and it just leads to a bunch of back-and-forth now.
Maybe a suitable metaphor would be: when you're walking through the dark, you'll take super small steps, even when you know the direction of your destination, as you need to be careful to not trip and fall over. These networks here similarly seem to often move in relatively consistent directions, but need to do so carefully and slowly, as so many parameter need to be coordinated for this to work out.
Returning to a learning rate of 0.01, I then went for a longer training run of 3000 epochs, this time not plotting every single parameter, but instead a few different metrics:
And we get this mess (starting at epoch 200, so we're skipping much of the initial chaos):
(Note: the step distance was scaled up by 10x as otherwise it was hard to visually tell it apart from 0)
Some observations:
If the learning rate is indeed somewhat too large, let's have a look at the same chart but for learning rate 0.001. Note that the out-of-view epochs (the first 200) still use learning rate 0.01, to get to the same visible starting point (and to get the test loss to roughly 1 so the chart looks nice with all 4 lines sharing the same y-axis).
Indeed, the test loss is much less volatile here - and even reaches a lower loss than before (~0.2 after 3000 epochs, compared to ~0.23 with the previous higher learning rate). Additionally it appears that the dot product is now (for the observed 3000 epochs) consistently positive, suggesting that parameters are moving much more regularly/predictably in slowly-changing directions rather than back-and-forth. I would suppose that the dot product eventually drops below 0 here as well. Alright, let's test that as well, and then we call it a day.
Here we go, 12k epochs instead of 3k:
Interestingly, while the dot product now does generally drop below 0 after around 5000 epochs, it occasionally jumps back up above 0 every now and then for a hundred epochs at a time or so. So it appears that interesting things are still going on to some degree even after 10k+ epochs.
What should we take away from all of this? It's a bit hard to say, given that it's unclear to what degree my findings here generalize beyond the simple toy model I've been investigating.
But if I had to summarize some loosely acquired beliefs that I took away from this:
That all being said, I'm sure there are many things I got wrong in this post, from parts of the implementation to my interpretations of the results. Maybe even some of the questions I've been asking are misguided. So, if anyone actually took the time to read (parts of) this post, please feel very free to correct any misconceptions and errors you encounter. I'd also be happy about further interesting questions getting raised that I failed to think of which would be worth exploring.
I hope this exploration was interesting enough and maybe even contained some novel insights. Thanks for reading!