Hessian analysis with JAX: a platform-agnostic, high-performance approach
In mechanistic interpretability research, we often want to analyze the Hessian of the loss function (for example, by computing its eigenspectrum). Ideally, we would want our Hessian analysis code to work seamlessly for all models irrespective of architecture or training platform (e.g. PyTorch, tensorflow, flax, etc.). However, this is hard...