Right now in your code, you only calculate reconstruction error gradients for the very last step.
if random.random() > delta:
loss = loss + (probs * err).sum(dim=1).mean()
break
Pragmatically, it is more efficient to calculate reconstruction error gradients at every step and just weight by the probability of being the final image:
loss = loss + (1 - delta) * (probs * err).sum(dim=1).mean()
if random.random() > delta:
break
Although not mentioned in Yang's paper, we can instead select images proportional to ...
This gives the loss If we want an infinite-depth model, we can choose to sometimes halt, but usually sample another image with probability (for 'discount factor'). Also, as the depth increases, the images should become more similar to each other, so should increase exponentially to compensate. Empirically, I found as to give decent results.
I think you should choose so that the sample variance over the batch between the closest choice and the target. This is because a good model should match both the mean and the variance of the ground truth. The ground truth is that, when you encode an image, you choose the that has the least reconstruction error. The probabilities can be interpreted as conditional probabilities that you chose the right for the encoding, where each has a Gaussian prior for being the "right" encoding with mean and variance . The variance of the prior for the that is actually chosen should match the variance it sees in the real world. Hence, my recommendation for .
(You should weight the MSE loss by as well.)
Motivation for this post: Discrete Distribution Networks (Lei Yang, ICLR 2025).
Generative models aim to reproduce a real-world distribution given many training examples. One way to do this is to train the neural network so that the least amount of information is needed to reconstruct the training examples. For example, in a generative text model using next-token prediction, the only information needed is, "which token is next?". If the model outputs a distribution , while the correct distribution is , the number of bits needed to identify the correct next token is
the cross-entropy loss. While it is possible to train an autoregressive model for images, scanning pixel-by-pixel and line-by-line, next-token prediction is inherently flawed. Sometimes, the previous tokens do rely on the future, which means the model must predict not just the next token, but all the ones after it as well, while only being trained on the next token. Text happens to be written mostly one-dimensionally and has a relatively small size, so reinforcement learning can compensate for these flaws. However, images are much larger and intrinsically two-dimensional, so another approach is needed.
It cannot be directly outputting probabilities for every possible image. Even a simple black-and-white MNIST image has possibilities. Besides, the image on the screen is only an approximation to the image captured in the real world, so ignoring quantum effects, images should be continuous. The most common approach to modeling continuous distributions is to train a reversible model that maps it to another continuous distribution that is already known. The original image can be recovered by pointing to its mapped value, as well as the reverse path:
This technique is known as normalizing flows, as usually a normal distribution is chosen for the known distribution. The second term can be a little hard to compute, so diffusion models approximate it by using a stochastic PDE for the mapping. When is a solution to an ordinary differential equation,
then
Switching to a stochastic PDE
and tracking the difference , the mean-squared error approximately satisfies
which is close to Hutchinson's estimator, but weighted a little strange.
Flow models are pretty good, but the continuity assumption creates its own problems. Some features of images or videos, such as the number of fingers or the number of dogs, are discrete. While a flow model can push most of its outputs towards correct, discrete, values, sometimes it will have to interpolate between them, generating 4.5 fingers or 2.5 dogs. This motivates a need for discrete distribution networks.
In Lei Yang's work of the same name, this is achieved with a fractal. A model is trained to ouput slightly different images to the one it is fed in. Each iteration, the output closest to the target image is chosen to be fed back into the model for more, hopefully similar, images. An initially blank input should slowly become the target. To train the model, the chosen image in each iteration is updated towards the target. Since there are a finite number of images at each level, they will specialize into different parts of the target distribution. If they perfectly divvy it up, a sample image can be generated by randomly choosing output images, which, after enough iterations, is unlikely to have ever been specifically trained.
Unfortunately, many outputs, even at the top level, end up "dead" in training. They are never seen, so never updated, and very far from the target distribution. The issue is with always picking the most similar image; sometimes it is alright to pick a similar image, even if it is not the most similar. Although not mentioned in Yang's paper, we can instead select images proportional to
and increase over time. The information needed to construct a target image at a given iteration is some error correction bits, as well as the path taken. When generating, we will sample uniformly, which requires
bits to describe. This gives the loss
If we want an infinite-depth model, we can choose to sometimes halt, but usually sample another image with probability (for 'discount factor'). Also, as the depth increases, the images should become more similar to each other, so should increase exponentially to compensate. Empirically, I found as to give decent results.