[Not very confident, but just saying my current view.]
I'm pretty skeptical about integrated gradients.
As far as why, I don't think we should care about the derivative at the baseline (zero or the mean).
As far as the axioms, I think I get off the train on "Completeness" which doesn't seem like a property we need/want.
I think you just need to eat that there isn't any sensible way to do something reasonable that gets Completeness.
The same applies with attribution in general (e.g. in decision making).
The same applies with attribution in general (e.g. in decision making).
As in, you're also skeptical of traditional Shapley values in discrete coalition games?
"Completeness" strikes me as a desirable property for attributions to be properly normalized. If attributions aren't bounded in some way, it doesn't seem to me like they're really 'attributions'.
Very open to counterarguments here, though. I'm not particularly confident here either. There's a reason this post isn't titled 'Integrated Gradients are the correct attribution method'.
Integrated gradients is a computationally efficient attribution method (compared to activation patching / ablations) grounded in a series of axioms.
Maybe I'm confused, but isn't integrated gradients strictly slower than an ablation to a baseline?
If you want to get attributions between all pairs of basis elements/features in two layers, attributions based on the effect of a marginal ablation will take you forward passes, where is the number of features in a layer. Integrated gradients will take backward passes, and if you're willing to write custom code that exploits the specific form of the layer transition, it can take less than that.
If you're averaging over a data set, IG is also amendable to additional cost reduction through stochastic source techniques.
Maybe I'm confused, but isn't integrated gradients strictly slower than an ablation to a baseline?
For a single interaction yes (1 forward pass vs integral with n_alpha integration steps, each requiring a backward pass).
For many interactions (e.g. all connections between two layers) IGs can be faster:
(This is assuming you do path patching rather than "edge patching", which you should in this scenario.)
Sam Marks makes a similar point in Sparse Feature Circuits, near equations (2), (3), and (4).
We now have a method for how to do attributions on single data points. But when we're searching for circuits, we're probably looking for variables that have strong attributions between each other on average, measured over many data points.
Maybe?
One thing I've been thinking a lot recently is that building tools to interpret networks on individual datapoints might be more relevant than attributing over a dataset. This applies if the goal is to make statistical generalizations since a richer structure on an individual datapoint gives you more to generalize with, but it also applies if the goal is the inverse, to go from general patterns to particulars, since this would provide a richer method for debugging, noticing exceptions, etc..
And basically the trouble a lot of work that attempts to generalize ends up with is that some phenomena are very particular to specific cases, so one risks losing a lot of information by only focusing on the generalizable findings.
Either way, cool work, seems like we've thought about similar lines but you've put in more work.
The issue with single datapoints, at least in the context we used this for, which was building interaction graphs for the LIB papers, is that the answer to 'what directions in the layer were relevant for computing the output?' is always trivially just 'the direction the activation vector was pointing in.'
This then leads to every activation vector becoming its own 'feature', which is clearly nonsense. To understand generalisation, we need to see how the network is re-using a small common set of directions to compute outputs for many different inputs. Which means looking at a dataset of multiple activations.
And basically the trouble a lot of work that attempts to generalize ends up with is that some phenomena are very particular to specific cases, so one risks losing a lot of information by only focusing on the generalizable findings.
The application we were interested in here was getting some well founded measure of how 'strongly' two features interact. Not a description of what the interaction is doing computationally. Just some way to tell whether it's 'strong' or 'weak'. We wanted this so we could find modules in the network.
Averaging over data loses us information about what the interaction is doing, but it doesn't necessarily lose us information about interaction 'strength', since that's a scalar quantity. We just need to set our threshold for connection relevance sensitive enough that making a sizeable difference on a very small handful of training datapoints still qualifies.