🌳 Superposition & SAEs in Mechanistic Interpretability: An Intro
Anthropic's Scaling Monosemanticity and Golden Gate Claude are some really exciting work in mechanistic interpretability, so I've been trying to better wrap my head around them. I wasn't satisfied by existing explanations of their work, so here's my attempt at something more digestible than the original papers.
I also have some (work-in-progress) code to show for it!
Features & Superposition
Anthropic's Transformer Circuits Thread has had a few previous works that try to interpret the hidden states of deep learning models. Ultimately, the goal is to find meaningful features in models' activations.
But what is a feature? We can go by multiple definitions, such as:
- Aspects of the input data that are meaningful to us humans
- Directions in the activation space of the neural network's layers
Suppose we have an activation vector $a \in \mathbb{R}^2$, and we're trying to represent two features (according to Definition 1 above)—say "airplane" and "water bottle." Here, we have 2 features and 2 dimensions, so we're able to assign orthogonal basis vectors to each feature.
I trained a toy 1-layer model, following this notebook. More details on that later.
For now, here are the actual results (code):
Note how the features that this toy model has learned don't exactly line up with the standard basis—there's no reason they have to! That's because the models I trained have a non-privileged basis. I can rotate all of the features by some arbitrary rotation matrix $R$, and the model will behave in exactly the same way as it does (provided I also rotate the inputs by $R$).
Source: Anthropic's Toy Models Paper
Now, let's scale things up. Suppose we're instead trying to use represent 3 or 5 features in our activation space $a$. All of a sudden, we can't give each feature its own orthogonal basis direction. We have to squeeze 3 or 5 features into $\mathbb{R}^2$ somehow—and the best way is to pick directions that have the least overlap/interference (i.e. minimize cosine similarity between any 2 features).
NB 1: There's a beautiful connection to bond angles in chemistry here!
NB 2: You'll sometimes read that these features form an overcomplete basis for the activation space.
This phenomenon is called superposition!
|
|
Actually, I missed one thing: The model won't always spread out the vectors like this. It only does this when every feature occurs somewhat infrequently (sparsely) in the training data.
- If the features aren't sparse (they occur frequently in the data), the model chooses to just represent only the most important features in the activation space. In this case, we don't get much superposition.
- If the features are sparse (they occur infrequently in the data), the model will represent the important features in the activation space. If we have enough important sparse features, the model will use superposition to represent all of them with minimum overlap/interference.
Source: Anthropic's Toy Models Paper
Here are some intuitive explanations of sparsity and importance:
Example 1: Say we have a model that predicts annual income. It's trained on features ${\text{attendedCollege}, \text{occupation}, \text{hasBrownHair}}$, but only 5% of the dataset examples have brown hair. Since hair color is sparse in this dataset, and it's (probably) not a good indicator of income, then the model will likely send the feature vector representing "hasBrownHair" to zero. Superposition (probably) won't happen.
Example 2: Say we have a model that predicts colorblindness. It's trained on features ${\text{age}, \text{citizenship}, \text{isMale}}$, but the dataset examples are only 5% male. The "male" feature is very sparse—it occurs 5% of the time—but when predicting colorblindness, it's very important. So the model will likely not send that feature vector to 0 in the activation space. If there are enough other sparse, important features like "isMale," the model will (probably) represent them all in superposition.
Of course, in deep neural nets, this all happens in very high-dimensional vector spaces. But the basic principles are the same.
Details on the Toy Model
The 1-layer model was a simple model that first tries to map its input $x \in \mathbb{R}^{1 \times n}$ to $\mathbb{R}^2$, then reconstruct it to get $\hat{x}$. Formally, with our weight matrix $W \in \mathbb{R}^{n \times 2}$ and bias $b \in \mathbb{R}^n$,
$$h = xW$$ $$\hat{x} = \text{ReLU}(h W^T + b)$$
We train using the MSE loss (to push the predictions closer to the original inputs).
Now we can extract the features from the toy 1-layer model (relatively) easily. Let's look at $W$. Spelled out, we have
$$ x = \begin{bmatrix} x_1 & \dots & x_n \end{bmatrix} \quad W = \begin{bmatrix} a_0(x_1) & a_1(x_1)\\ \vdots&\vdots \\ a_0(x_n) & a_1(x_n) \end{bmatrix} $$The features are very easy to find in this model: The $i$th row of $W$ corresponds to the $i$th feature, $x_i$, embedded in $\mathbb{R}^2$. Thus, the model represents each feature $x_i$ as the direction of its activation vector, $a(x_i)$. Our $n$ features are then just ${a_1, \dots, a_n}$. That's it!
Motivating Sparse Autoencoders: Sparse Dictionary Learning
With deeper models, this kind of interpretability isn't easy. We can interpret the weight matrix applied to the input, but what if we want to find meaningful features in deep hidden layer? Maybe we're trying to interpret the activations within hidden state of a transformer (the residual stream), or we're trying to interpret the activations of a deep MLP.
Sparse autoencoders (SAEs) are a popular approach, and Anthropic used them in the Scaling Monosemanticity paper to find interpretable features in Claude 3 Sonnet. Let's motivate them from first principles:
Say we want to interpret the activations of an intermediate layer of some deep model $f(x)$.
Our (substantiated) hope is that in each hidden layer, the model is trying to represent a ton of features that the model learned during training.
Thus, we're trying to analyze an intermediate activation $a$, which is a $d$-dimensional vector that's trying to represent a lot more than $d$ features (ex. "traffic cone", "studying in a library", "smokestack", "geologist", etc.). That means we're going to have superposition!
If we're trying to find interpretable features from $a$, then we're trying to find a set of directions in the activation space. Each direction in this activation space should correspond to something that makes sense to us humans.
The question is, how do we figure out a good set of vectors?
Since our activation space doesn't have enough dimensions to give each feature its own orthogonal basis vector, why not map $a$ to a "higher-dimensional space" that does?
We don't know exactly how many features $a$ is trying to represent, so we just want to map it to a high-enough dimensional space.
To get really in the weeds: My intuition is that we want to map $a$ to a high dimensional space, where it's more likely that each entry in the high-dimensional vector (i.e. each neuron) corresponds to a feature (...at least in a privileged basis, where the standard basis directions are likely more meaningful). We have to be careful about privileged/non-privileged bases when talking about "neurons corresponding to features." In our case, this mapping will be in a privileged basis.
The idea is that $a$ is some combination of interpretable features in this "higher-dimensional space," which has then been projected into the lower-dimensional activation space. (This map from activation space to the "higher-dimensional space" is the dictionary part of sparse dictionary learning.)
Since $a$ is representing features in superposition, that means the features occurred sparsely in the training dataset. (For any data example, only a few features were present.) Thus, $a$ is a sparse combination of these "higher-dimensional features."
So, we want an algorithm that maps $a$ to a sparse vector in this "higher-dimensional space."
Sparse autoencoders do exactly that!
Sparse Autoencoders (SAEs)
The structure of a basic autoencoder is pretty simple. It's just an MLP with one hidden layer. It maps the input activation $a$ to a high-dimensional hidden state $h$. Then, using $h$, it computes $\hat{a}$, a reconstruction of $a$.
The intuition is that $h$ will contain some useful information about $a$, since it's an intermediate layer in the computation.
In essence, the SAE computation looks like this:
$$h = a^TW_\text{enc}$$ $$\hat{a} = \text{ReLU}(hW_\text{dec})$$
$$a \in \mathbb{R}^d \quad W_\text{enc} \in \mathbb{R}^{d \times m} \quad W_\text{dec} \in \mathbb{R}^{m \times d}$$
In reality, Anthropic made a few modifications, but the idea is the same.
Since we're trying to reconstruct the original input (the activation $a$), we use an MSE loss. But since $a$ is a sparse combination of features in $h$, we also want to ensure that $h$ is a sparse vector—so we add an L1 penalty/regularization term (with hyperparameter $\lambda$). For a single training example, our SAE loss is
$$\mathcal{L}(a, \hat{a}) = ||a - \hat{a}||_2^2 + \lambda ||h||_1$$
To be super clear, the original model $f(x)$ was trained on the input dataset. The SAE is trained on $a$, the activations that were gathered from running the original model on a bunch of data.
There are a few more details (dead neuron resampling, etc.) that improve the utility of the SAEs—but this is the core.
Using SAEs to Interpret Model Activations
Now that we have features, we can interpret them!
For large language models, one way we can do this is to prompt them carefully and see what features in the SAE activate. If we input a bunch of prompts related to dolphins to an LLM, collect the activations, run them through the SAE, and see that neuron $i$ in the SAE hidden state consistently activates, then we can be reasonably confident that the $i$th neuron of the SAE correlates with dolphins. (We can't get causal relationships, but correlations are still very useful!)
Once we find that the $i$th neuron in the SAE that corresponds to a particular feature (ex. dolphins), we can create a "fake" hidden state with only the $i$th neuron active. If we pass that hidden state through the SAE decoder, we can get the direction in the model's activation space that corresponds to that feature! Like a dictionary, we can use the SAE to translate back and forth between interpretable features and directions in activation space!
Of course, there are multiple ways to actually use the SAE for interpretability—this is frontier research after all. This is just one possible way.
For further reading, I'll point to this section and this section of the Scaling Monosemanticity paper (which are relatively digestible). If you want to read even further into mechanistic interpretability, look into this paper and post on steering vectors, which are related!
References
- Toy Models of Superposition
- Towards Monosemanticity: Decomposing Language Models With Dictionary Learning
- Scaling Monosemanticity: Extracting Interpretable Features from Claude 3 Sonnet
- Sparse Autoencoders Find Highly Interpretable Features in Language Models
- MATS Colab Exercises