Yesterday I attended a talk by Franziska Boenisch on training data memorization in diffusion models. The short version: diffusion models memorize a small percentage of their training data verbatim, reproducible with specific prompts regardless of noise seed. Privacy concern, etc.
I was mostly interested in the adversarial case - say Snapchat trains on their data and open-sources a model. Could you retrieve those rare but real memorized samples? Some literature suggests massive random prompt searches: generate a ton of images, check using some metric.
I find this incredibly unsatisfying. Surely there's a more algorithmic way?
This post documents my 1-day attempt at exactly that. Spoiler: it sort of works, with caveats.
One thing mentioned during the talk was that some neurons exhibit unusual behavior when faced with prompts that elicit memorized data - they spike absurdly high.
Formally, calculate the mean μk and standard deviation σk for neuron k over a held-out dataset. The z-score is:
zk=ak−μkσk
Franziska's team uses this to build a mitigation. But as an adversary, maybe we can use it for the opposite purpose:
Imagine some starting prompt, maybe "A man smiling," and get the corresponding embedding. Freeze the model weights and instead activate gradients for the text embedding. Then do a normal forward pass and calculate the z-scores. Define the loss function:
L=−Vark(zk)
or anything else along those lines capturing this spiking behavior. And then just do normal gradient descent for a few steps.
Sounds simple enough, right? Surely it will work first tr...
Uh...
Okay I definitely don't see a man smiling and it shouldn't be too surprising that there are a lot of nonsensical inputs which would result in spiking z-scores as well. So maybe just regularize against it, adding L2 distance to the original embedding and rather explore the close landscape than whatever this is. Now it must wor...
Well.
At least it looks somewhat realistic? What exactly "it" is, has still not come to me, even as I'm writing this. Maybe a higher regularization coefficient, same seed?
I'm not sure if this qualifies as progress.
Step back
Okay. Let's take a step back. Maybe my ambitions of directly finding a memorized training sample with a random prompt were simply too big. Luckily there's a dataset already documenting memorized images and the prompt that elicits them. For example:
When given the prompt "<i>I Am Chris Farley</i> Documentary Releases First Trailer", this image always generates - no matter what seed we choose for the noise. That's quite remarkable.
Now what if we alter the prompt slightly - how about just "I Am Chris Farley"?
...and I thought my previous generations were bad. Roughly the same character but that's where the similarities end. Memorization is incredibly robust against differing noise, not so much for differing prompts. That's a shame since it means random search has to test an insane number of prompts.
So what if we use this technique on it? We take this much shorter prompt as initialization prompt and see where it takes us:
Oh wow.
I expected at least 2 walls of wood with one smiling at me, but this is actually spot on. And not just this seed:
It might not be obvious, but those are actually 2 different images - once again, we have the robustness against noise.
That's already great progress!
Rough Speedup Estimate (to random search)
Very rough Fermi estimate: if this works ~50% of the time and P("Documentary Releases First Trailer" | "I Am Chris Farley") < 1% for the random sampling, that's over a ~30-40x speedup accounting for gradient step overhead. Take with appropriate salt.
Breaking the Mitigation?
Remember when I said "Franziska and her team use this to actually come up with a mitigation"? Well it would be pretty cool if this method could even break that mitigation (which random search normally can't).
I won't dive into too much details, but essentially she and her team figure out a very small set of neurons (often only 5 or so) such that pruning them results in disastrous consequences for a memorization example while hardly impacting general capabilities. You can read the details here: 2406.02366.
With this mitigation turned on, generating the memorized samples even with the full prompt becomes almost impossible:
Top: mitigation off. Bottom: mitigation on. Five different seeds.
I'd give this maybe 1 out of 5 (seeds 2 and 5 vaguely gesture in the right direction).
When applying this gradient method with mitigation still on:
Better - seeds 2 and 5 are basically there, seed 1 is on the right track. Maybe 2.5 out of 5.
The Weird Case
What about partial prompt AND mitigation on?
Row 1: baseline. Row 2: mitigation on. Row 3: gradient method on. Row 4: both on.
It utterly nails it? Now I'm confused.
Looking at the results, their mitigation has a threshold hyperparameter and it decided the truncated prompt doesn't trigger memorization concerns. This feels like a shortcoming - even without my method, 3/5 seeds are pretty close. With the gradient method, we get almost pixel-perfect reproduction.
You could say "just lower the threshold," but then you'd flag many non-memorized samples and degrade quality overall.
Closing thoughts
To be clear about what this is and isn't:
Tested on essentially one example in depth
The speedup estimate is a Fermi estimate, not a benchmark
Hyperparameter sensitive (though manageable)
There are so many directions you could explore from this - but as a 1 day project, I think this is a good ending point.
If there's interest, I might polish the code and publish it. It's fun to play with and even the GPU-poor should find themselves at home - I did all this on an A4000 which you can rent for 20 cents an hour.
Yesterday I attended a talk by Franziska Boenisch on training data memorization in diffusion models. The short version: diffusion models memorize a small percentage of their training data verbatim, reproducible with specific prompts regardless of noise seed. Privacy concern, etc.
I was mostly interested in the adversarial case - say Snapchat trains on their data and open-sources a model. Could you retrieve those rare but real memorized samples? Some literature suggests massive random prompt searches: generate a ton of images, check using some metric.
I find this incredibly unsatisfying. Surely there's a more algorithmic way?
This post documents my 1-day attempt at exactly that. Spoiler: it sort of works, with caveats.
One thing mentioned during the talk was that some neurons exhibit unusual behavior when faced with prompts that elicit memorized data - they spike absurdly high.
Formally, calculate the mean μk and standard deviation σk for neuron k over a held-out dataset. The z-score is:
zk=ak−μkσkFranziska's team uses this to build a mitigation. But as an adversary, maybe we can use it for the opposite purpose:
Imagine some starting prompt, maybe "A man smiling," and get the corresponding embedding. Freeze the model weights and instead activate gradients for the text embedding. Then do a normal forward pass and calculate the z-scores. Define the loss function:
L=−Vark(zk)or anything else along those lines capturing this spiking behavior. And then just do normal gradient descent for a few steps.
Sounds simple enough, right? Surely it will work first tr...
Uh...
Okay I definitely don't see a man smiling and it shouldn't be too surprising that there are a lot of nonsensical inputs which would result in spiking z-scores as well. So maybe just regularize against it, adding L2 distance to the original embedding and rather explore the close landscape than whatever this is. Now it must wor...
Well.
At least it looks somewhat realistic? What exactly "it" is, has still not come to me, even as I'm writing this. Maybe a higher regularization coefficient, same seed?
I'm not sure if this qualifies as progress.
Step back
Okay. Let's take a step back. Maybe my ambitions of directly finding a memorized training sample with a random prompt were simply too big. Luckily there's a dataset already documenting memorized images and the prompt that elicits them. For example:
When given the prompt "<i>I Am Chris Farley</i> Documentary Releases First Trailer", this image always generates - no matter what seed we choose for the noise. That's quite remarkable.
Now what if we alter the prompt slightly - how about just "I Am Chris Farley"?
...and I thought my previous generations were bad. Roughly the same character but that's where the similarities end. Memorization is incredibly robust against differing noise, not so much for differing prompts. That's a shame since it means random search has to test an insane number of prompts.
So what if we use this technique on it? We take this much shorter prompt as initialization prompt and see where it takes us:
Oh wow.
I expected at least 2 walls of wood with one smiling at me, but this is actually spot on. And not just this seed:
It might not be obvious, but those are actually 2 different images - once again, we have the robustness against noise.
That's already great progress!
Rough Speedup Estimate (to random search)
Very rough Fermi estimate: if this works ~50% of the time and P("Documentary Releases First Trailer" | "I Am Chris Farley") < 1% for the random sampling, that's over a ~30-40x speedup accounting for gradient step overhead. Take with appropriate salt.
Breaking the Mitigation?
Remember when I said "Franziska and her team use this to actually come up with a mitigation"? Well it would be pretty cool if this method could even break that mitigation (which random search normally can't).
I won't dive into too much details, but essentially she and her team figure out a very small set of neurons (often only 5 or so) such that pruning them results in disastrous consequences for a memorization example while hardly impacting general capabilities. You can read the details here: 2406.02366.
With this mitigation turned on, generating the memorized samples even with the full prompt becomes almost impossible:
I'd give this maybe 1 out of 5 (seeds 2 and 5 vaguely gesture in the right direction).
When applying this gradient method with mitigation still on:
Better - seeds 2 and 5 are basically there, seed 1 is on the right track. Maybe 2.5 out of 5.
The Weird Case
What about partial prompt AND mitigation on?
It utterly nails it? Now I'm confused.
Looking at the results, their mitigation has a threshold hyperparameter and it decided the truncated prompt doesn't trigger memorization concerns. This feels like a shortcoming - even without my method, 3/5 seeds are pretty close. With the gradient method, we get almost pixel-perfect reproduction.
You could say "just lower the threshold," but then you'd flag many non-memorized samples and degrade quality overall.
Closing thoughts
To be clear about what this is and isn't:
There are so many directions you could explore from this - but as a 1 day project, I think this is a good ending point.
If there's interest, I might polish the code and publish it. It's fun to play with and even the GPU-poor should find themselves at home - I did all this on an A4000 which you can rent for 20 cents an hour.