(Work done as part of SERI MATS Summer 2023 cohort under the supervision of @Lee Sharkey . A blog post containing audio features that you can listen to can be found here.)
TL;DR - Mechanistic Interpretability has mainly focused on language and image models, but there's a growing need for interpretability in multimodal models that can handle text, images, audio, and video. Thus far, there have been minimal efforts directed toward interpreting audio models, let alone multimodal ones. To the best of my knowledge, this work presents the first attempt to do interpretability on a multimodal audio-text model. I show that acoustic features inside OpenAI's Whisper model are human interpretable and formulate a way of listening to them. I then go on to present some macroscopic properties of the model, specifically showing that encoder attention is highly localized and the decoder alone acts as a weak LM.
Up to this point, the main focus in mechanistic interpretability has centred around language and image models. GPT-4, which currently inputs both text and images, is paving the way for the development of fully multimodal models capable of handling images, text, audio, and video. A robust mechanistic interpretability toolbox should allow us to understand all parts of a model. However, when it comes to audio models, let alone multimodal ones, there is a notable lack of mechanistic interpretability research. This raises concerns, because it suggests that there might parts of multimodal models that we cannot understand. Specifically, an inability to interpret the input representations that are fed into the more cognitive parts of these models (which theoretically could perform dangerous computations) presents a problem. If we cannot understand the inputs, it is unlikely that we can understand the potentially dangerous bits.
This post is structured into 3 main claims that I make about the model:
For context: Whisper is a speech-to-text model. It has an encoder-decoder transformer architecture as shown below. We used Whisper tiny which is only 39M parameters but remarkably good at transcription! The input to the encoder is a 30s chunk of audio (shorter chunks can be padded) and the output from the decoder is the transcript, predicted autoregressively. It is trained only on labelled speech to text pairs.
By finding maximally activating dataset examples (from a dataset of 10,000 2s audio clips) for MLP neurons/directions in the residual stream we are able to detect acoustic features corresponding to specific phonemes. By amplifying the audio around the sequence position where the feature is maximally active, you can clearly hear these phonemes, as demonstrated by the audio clips below.
It turns out that neurons in the MLP layers of the encoder are highly interpretable. The table below shows the phonetic sound that each neuron activates on for the first 50 neurons in block.2.mlp.1. You can also listen to some of these audio features here.
The residual stream is not in a privileged basis so we would not expect the features it learns to be neuron aligned. We can however train sparse autoencoders on the residual stream activations and find maximally activating dataset examples for these learnt features. We also find these to be highly interpretable and often correspond to phonemes. Example audio clips for these learnt can also be found here.
The presence of polysemantic neurons in both language and image models is widely acknowledged, suggesting the possibility of their existence in acoustic models as well. By listening to dataset examples at different ranges of neuron activation we were able to uncover these polysemantic acoustic neurons. Initially, these neurons appeared to respond to a single phoneme when you only listen to the max activating dataset examples. However, listening to examples at varying levels of activation reveals polysemantic behaviour. Presented in the following plots are the sounds that neuron 1 and neuron 3 in blocks.2.mlp.1 activate on at different ranges of activation. Again, example audio clips can be found in the blog post.
Interestingly, the encoder attention patterns are highly temporally localized. This contrasts with standard LLMs which often attends to source tokens based on semantic content rather than distance to the destination token.
We propagate the attention scores Rt down the layers of the encoder as in Generic Attention-model Explainability for Interpreting Bi-Modal and Encoder-Decoder Transformers. This roughly equates to,
At is the attention pattern in layer t and ¯At is the attention pattern weighted by gradient contribution. This produces the striking pattern below; up to the point where the audio ends, the attention pattern is very localized. When the speech ends (at frame ~500 in the following plot), all future positions attend back to the end of the speech.
Given how localized the attention pattern appears to be, we investigate what happens if we constrain it so that every audio embedding can only attend to the k nearest tokens on either side. Eg if k=2 we would we apply the following mask to the attention scores before the softmax:
Here are the transcripts that emerge from a short audio clips from Hot Ones as we limit the attention window for various values of k. We observe that even when our attention window is reduced to k=75 (normally k=750), the model continues to generate reasonably accurate transcripts, indicating that information is being encoded in a localized manner.
Recall that Whisper is an encoder-decoder transformer; the decoder cross-attends to the output of the final layer of the encoder. Given the apparent localization of the embeddings in this final layer, we postulate that we could remove words from the transcript by 'chopping' out their corresponding embeddings. Concretely we let,
final_layer_output[start_index:stop_index] = final_layer_output_for_padded_input[start_index:stop_index]
Consider the following example in which we substitute the initial 50 audio embeddings with padded equivalents (e.g., start_index=0, stop_index=50). These 50 embeddings represent (50/1500)*30s=1s of audio. Our observation reveals that the transcript resulting from this replacement omits the initial two words. The fact that we can do this, suggests that for each word in the transcript, the decoder is cross-attending to a small window of audio embeddings and using a limited amount of context from the rest of the audio embeddings.
`hot ones. The show where celebrities answer hot questions while feeding even hotter wings.`
Substitute embedding between (start_index=0, stop_index=50):
`The show where celebrities answer hot questions while feeding even hotter wings.`
We can also do this in the middle of the sequence. Here we let (start_index=150, stop_index=175) which corresponds to 3-3.5s in the audio and observe that the transcript omits the words `hot questions`:
Substitute embeddings between (start_index=150, stop_index=175):
`hot ones. The show where celebrities while feeding even hotter wings.`
Whisper is trained exclusively on supervised speech-to-text data; the decoder is not pre-trained on text. In spite of this, the model still acquires rudimentary language modeling capabilities. While this outcome isn't unexpected, the subsequent experiments that validate this phenomenon are quite interesting/amusing in themselves.
If we just use 'padding' frames as the input of the encoder and 'prompt' the decoder we can recover bigram statistics. For example, at the start of transcription, the decoder is normally prompted with:<|startoftranscript|><|en|><|transcribe|>
Instead we set the 'prompt' to be:<|startoftranscript|><|en|><|transcribe|> <our_prompt_token>
This is analogous to telling the model that the first word in the transcription is <our_prompt_token>.
Below we plot the top 20 most likely next tokens and their corresponding logit for a variety of prompts. We can see that when the model has no acoustic information it relys on learnt bigrams.
Bigram statistics are often learnt by the token embedding layer in transformer language models. Additionally in LLMs, we observe semantically similar words clustered in embedding space. This phenomenon also holds for Whisper, but additionally we discover that words with similar sounds also exhibit proximity in the embedding space. To illustrate this, we choose specific words and then create a plot of the 20 nearest tokens based on their cosine similarity.‘rug’ is close in embedding space to lug, mug and tug. This is not very surprising of a speech-to-text model; if you think you hear the word ‘rug’, it is quite likely that the word was in fact lug or mug.
Finally, we collected maximally activating dataset examples (using the same dataset of 10,000 2s audio clips) for the neuron basis of decoder blocks.0.mlp.1. We find that they often activate on semantically similar concepts, suggesting that a) the model is already activing on the word level by the first MLP layer and b) it has aquired rudimentary language modelling capabilities like a weak LLM. Below we show the transcripts for the maximally activating dataset examples for some neurons in decoder.blocks.0.mlp.1.
Neuron 10 - Food related
Neuron 12 - Numbers (particularly *th)
Neuron 14 - Verbs related to moving things
To the best of our knowledge, this work presents the first attempt to do interpretability on a multimodal audio-text model. We have demonstrated that acoustic features are human interpretable and formulated a way of listening to them. Additionally, we have also presented some macroscopic properties of Whisper’s encoder and decoder. Our findings reveal that the audio encoder’s attention is highly localized, in contrast to the semantically aware attention patterns observed in Large Language Models. Furthermore, despite being exclusively trained on a supervised speech-to-text task, the decoder has acquired basic language modelling capabilities. This is a first step in developing universal interpretability techniques that can be used to detect dangerous/deceptive computation in multimodal models. This work is however by no means comprehensive. A notable limitation is that we simply used dataset examples to demonstrate acoustic features (rather than using an optimization based method like DeepDream) potentially biasing features towards the dataset. Future work would include getting an optimization based feature visualization method working in the audio domain, in addition to looking more closely into how the acoustic features in the encoder are mapped to linguistic ones in the decoder.
This post was very helpful to me, thank you. But probably not for the reasons you intended (it made me more technically amibitous in my quest to solve podcast diarization). That said, I do have some questions. 1) Are there any "glitch phonemes" analogous to glitch tokens e.g. SolidGoldMagikarp? 2) I don't undestand this plot. Is it saying that the model's attention is sometimes highly nonlocalized? Part of my confusion is that I don't know what "Souce" and "Destination" means in this case.
3) How does the model handle multi-speaker audio? 3a) Are there features which shift iff the speaker changes? Text which is highly likely given that feature (e.g. "-" to reprsent someone is being cut-off)?3b) What happens if multiple people are talking at once? 4) What happens when the model is listening to e.g. nature sounds, music, laughter etc.?5) Unrelated, but how hard would it be to stitch together a whisper and e.g. Llama to make a little multi-modal model?
Interesting work! Cool to see mech interp done in such different modalities.
Did you look at neurons in other layers in the encoder? I'm curious if there are more semantic or meaningful audio features. I don't know how many layers the encoder in Whisper tiny has.
Re localisation of attention, did you compute statistics per head of how far away it attends? That seems a natural way to get more info on this - I'd predict that most but not all encoder heads are highly localised (just like language models!). The fact that k=75 starts to mess up performance demonstrates that such heads must exist, IMO. And it'd be cool to investigate what kind of attentional features exist - what are the induction heads of audio encoders?!
Re other layers in the encoder: There are only 4 layers in Whisper tiny, couldn't find any 'listenable' features in the earlier layers 0,1 so I'm guessing they activate more on frequency patterns than human recognisable sounds. Simple linear probes trained on layers 2 and 3 suggest they learn language features (eg is_french) and is_speech. Haven't looked into it any more than that though.
Re localisation of attention - 'I'd predict that most but not all encoder heads are highly localised' - this looks true when you look at the attn patterns per head. As you said most heads (4/6) in each layer are highly localised - you can mask them up to k=10. But there are 1 or 2 heads in each layer that are not so localized and are responsible for the degradation seen when you mask them.
Thanks for the post Ellena!
I was wondering if the finding "words are clustered by vocal and semantic similarity" also exists in traditional LLMs? I don't remember seeing that, so could it mean that this modularity could also make interpretability easier?
It seems logical: we have more structure on the data, so better way to cluster the text, but I'm curious of your opinion.
I wouldn't expect an LLM to do this. An LLM wants to predict the most likely next word, so is going to assign high probabilities to semantically similar words (hence why they are clustered in embedding space). Whisper is trying to do speech-to-text, so as well as needing to know about semantic similarity of words it also needs to know about words that sound the same. Eg if it thinks it heard 'rug', it is pretty likely that the person speaking actually said 'mug' hence these words are clustered. Does that make sense?
Just curious if there are notebooks or code for sharing so we can rerun the above analysis on the shared samples and other samples. Thanks.
Working on that one - the code is not in a shareable state yet but I will link a notebook here once it is!
At a glance, I couldn't find any significant capability externality, but I think that all interpretability work should, as a standard, have a paragraph explaining why the authors won't think their work will be used to improve AI systems in an unsafe manner.
Whisper seems sufficiently far from the systems pushing the capability frontier (GPT-4 and co) that I really don't feel concerned about that here