SAE Circuits - minalee-research/cs257-students GitHub Wiki
#Interpretability
Aditya Singh
Link to code: https://github.com/adsingh-64/SAE-Circuits
Abstract
Sparse autoencoder (SAE) circuits have emerged as a promising tool for understanding language model behavior. These are circuits where the components (called latents) correspond to human-interpretable features (ideally) and hence offer an improvement over traditional circuit analysis units such as neurons or attention heads, which can be difficult to interpret. The goal of this project is to better understand the basic science of SAE circuits -- how do different nodes in an SAE circuit relate to each other? In particular, there are three kinds of relationships in a circuit: latent-to-latent, latent-to-token, and latent-to-logit. We find that latent-to-latent and latent-to-token relationships are extremely sparse, while latent-to-logit relationships are substantially less sparse. Inspired by the lack of sparsity in latent-to-logit relationships, we then look at when SAE circuits fail using a 'Starts with 'E'' letter detection task. We define a failure case as an example prompt where the main causally important latent(s) for an SAE circuit do not include the top causally important latent for that particular example prompt. We find in these failure cases, latents performing similar causal behavior in the model are nearly orthogonal. This finding may help inform SAE circuit design, since it highlights how SAE latent computation has no clear geometric pattern.
What this project is about
Large language models are reaching human-level performance in many different areas, yet their inner working often remain opaque. The field of interpretability aims to better understand how large language models work. Prior interpretability work used units such as neurons or attention heads as the basic building blocks of reasoning about language model behavior. However, such units can be difficult to interpret, since they may respond to several different features in the data. On the other hand, sparse autoencoders (SAEs) have emerged as a technique to reconstruct a model's activations into a more interpretable "dictionary" of features. The SAE activations are called latents. The hope is then that each latent corresponds to exactly one human-interpretable feature (for example, a particular subject matter or grammatical structure). Over ~the past year, an attempt has been made to use SAE latents as the basic units of reasoning about language model behavior, since by construction they tend to be more interpretable. By an SAE circuit, we mean any arrangement of latents that seem to come together in a meaningful way to help the LLM perform a task. An SAE circuit is a (task-specific) subnetwork of the full model, where only the most causally important latents for a task are kept and all other nodes in the model are mean-ablated (set to their mean over some sample activation distribution). This creates a simpler and more interpretable model for a particular task.
The goal of this project is to use high-level circuit analysis (see Approach) to make general observations about the basic science of SAE circuits, so as to better inform SAE circuit design. There are two distinct analyses in this project. First, we look at the three fundamental types of relationships within an SAE circuit:
- Latent‐to‐latent: How latents in earlier layers influence latents in later layers.
- Latent‐to‐token: How input tokens affect latents.
- Latent‐to‐logit: How each latent ultimately affects the model’s predicted vocabulary logits.
We measure what percentage of relationships are causally important for each of the three relationships over 1000 examples, and find a sparsity rate of <2e-5 for latent to latent, 0.005 for latent-to-token, and 0.015 for latent-to-logit.
Following the observation that latent-to-logit relationships are less sparse, the second part of this project investigates the process of filtering a model down to an SAE circuit, and in particular looks at what happens when SAE circuits 'fail'. This is done via a 'starts with 'E'' letter detection task, where we determine the most causally important latent(s) for detecting the letter 'e' in a starting token. Then we define a failure case as a prompt where the most causally important latent is not one of the main latents (we call this latent the 'absorbing latent'), and look at the cosine similarity of the absorbing latent direction with the main latent directions. Over 5000 distinct 'starts with 'E'' letter detection prompts, we find ~20% of prompts are classified as failure cases. In these failure cases, we find that the absorbing latent and the main latents are highly orthogonal, as in only 6% of failure cases does the cosine similarity cross the threshold of 0.025.
The findings in this project should not be taken as general laws about SAE circuits as we tested only on limited data distributions and a single task, but we hope that the broad observations made in this project can better inform SAE circuit design.
No techniques used in this paper are novel. Attribution patching is credited to (Nanda 2023), while the idea of using attribution patching at scale on circuits to find circuits of SAE latents was done in (Marks et al., 2024).
Approach
All experiments in this project rely on measuring the causal effect of earlier nodes in the model on later nodes in the model. The main technique we use to accomplish this is attribution patching, which is an inexpensive method requiring only a single forward and single backward pass through the model to measure all possible relationships. Most importantly, attribution patching is the main technique used to actually filter models for SAE circuits, so it is an appropriate technique to use since our goal is to better understand the SAE circuit process.
Activation Patching and Attribution Patching
Activation Patching
Activation patching is an interpretability technique for identifying which model activations are most important for determining model behavior between two similar prompts that differ in a key detail.
Given:
- A clean prompt $x_{clean}$
- A patch prompt $x_{patch}$
- A model activation location $a$ to patch
We quantify the importance of activation $a$ on the pair of inputs $(x_{clean}, x_{patch})$ by:
$$\Delta m = m(x_{clean} | do(a = a_{patch})) - m(x_{clean})$$
Where $m$ is some metric (for example, the value of a logit) as described in sparse feature circuits research.
If $\Delta m$ is large, this provides causal evidence that the activation location $a$ matters significantly for the metric on $x_{clean}$ (more precisely, it is necessary).
A priori, it is not entirely clear which activation locations will matter the most. Hence, a direct approach is to iterate over all activations in the model (or more generally, the circuit one is analyzing), and patch in the activation from the patch run. However, the computational cost of this approach is high, as each data point for activation patching requires a separate forward pass.
Attribution Patching
One solution to the computational cost issue is attribution patching. Attribution patching (Nanda, 2023) is a technique that uses a first-order Taylor expansion to approximate activation patching.
Viewing the metric $m$ as a function of the activation location $a$, we approximate:
$$m(x_{clean} | do(a = a_{patch})) - m(x_{clean}) \approx (\nabla_{a = a_{clean}}m)(a_{patch} - a_{clean})$$
Zero Ablation
A special case of activation patching is zero ablation, where we intervene on a forward pass by simply replacing activations with a zero vector. The goal of this intervention is to simulate the effect of the activation not being active.
Often, a more appropriate ablation would be mean ablation (setting the activation to its mean over some set of training examples), since that better simulates the activation not responding to the input. However, in the case of Sparse Autoencoder (SAE) latents, their activations are extremely sparse by design (a latent having a non-zero activation more than 1% of the time would be considered a very high frequency latent), so setting a latent activation to zero is most principled for simulating the activation not responding to the input.
The attribution patching approximation for zero ablation is:
$$m(x_{clean}) - m(x_{clean} | do(a = 0)) \approx (\nabla_{a = a_{clean}}m)(a_{clean})$$
All attributions done in this project follow the above equation, where $x_{clean}$ is the input prompt, and $m$ and $a$ vary depending on what we are analyzing. Observe that all this requires is a single forward and backward pass on $x_{clean}$!
Experiments
In the first part of this project, we look at relationships in an SAE circuit by using residual stream SAEs for GPT-2 small. These SAEs reconstruct residual stream activation vectors as $$x = \sum_{i = 0}^{d_{sae}}f_id_i + x_0$$, where the $f_i$ are the latent activations, the $d_i$ are their corresponding directions in the residual stream, and $x_0$ is a bias. The dimension $d_{sae}$ is much larger than the dimension of the residual stream, so we can expect the reconstruction to be sparse in that the vast majority of the $f_i$ will be zero.
Latent-to-latent relationships
In measuring latent-to-latent relationships, we are trying to answer the following question: which upstream latents are causally important for a downstream latent? We use attribution patching where the metric $m$ is the activation of a layer $3$ SAE latent (downstream latent) and the gradient is being taken with respect to a layer $1$ SAE latent (downstream latent).
Latent-to-logit relationships
In measuring latent-to-logit relationships, we are trying to answer the following question: which latents are causally important for a particular logit? We use attribution patching where the metric $m$ is the activation of a logit, while the gradient is taken with respect to latents in a layer $9$ residual stream SAE.
Latent-to-token relationships
In measure latent-to-token relationships, we are trying to answer the following question: which tokens are causally important for a particular latent? We use attribution patching where the metric $m$ is the activation of the latent. However, it does not make sense to directly differentiate the latent activation with respect to the token embedding vector, since the components of the Jacobian vector will be meaningless. What does it mean that the latent activation has a large attribution with respect to the $i$-th component of a token's embedding vector? Hence, we instead multiply token embedding vectors by a scale vector $s$ of all ones, and then differentiate with respect to $s$.
As a first step, we create 3 plots for latent-to-latent, latent-to-token, and logit-to-latent relationships on the example prompt 'The Eiffel Tower is in Paris'. These plots visually suggest that logit-to-latent relationships stand out as less sparse. To more rigorously measure sparsity rates for each of the three relationships, we average the number of non-zero attribution pairs over the total number of attribution pairs over 1000 prompts. These prompts are taken from the set of prompts on which the GPT-2 small residual stream SAEs were trained from.
Starts with E Detection Circuit
From the perspective of the above findings, the idea of sparse feature circuits working is rather surprising, since the dictionary learning is optimized for SAEs to be sparse as activations on the input (sparse activations), rather than actions on the logits (sparse attributions). If a large portion of latents matter for a given logit (or function of the logits), just how sparse can SAE circuits be without missing important latents? The second part of this project is to investigate how sparse feature circuits might fail using a 'starts-with-'E'' letter detection task, which was used in the recent work A is for Absorption: Studying Feature Splitting and Absorption in Sparse Autoencoders (Chanin et al., 2024) to look at the phenomenon of feature absorption.
Inputs to the model are few-shot prompts on which the model should output the ‘_E’ token: ‘tartan has the first letter: T mirth has the first letter: M egg has the first letter:’
Our metric is the logit of the correct letter ‘_E’ minus the mean of the 25 incorrect letter logits. This allows us to search for latents that prompt the '_E' logit in particular, as opposed to just letters generally. We then compute attributions of the metric with respect to latent activations encoding the layer 5 (pre-) residual stream of the token that starts with ‘e’ (egg in the above example). We switch from the GPT-2 residual stream SAEs to Gemma 2B and the canonical 16k width GemmaScope SAE for the layer 5 residual stream, since this SAE is more performant. We compute attributions over 5000 different examples given in the few-shot format, each with a distinct starts-with-e token (case-insensitive). We then average attributions to find the main top latent(s) by attribution. In turn, we define a failure case to be when the top latent by attribution for a given prompt is not one of the main top latent(s) by attribution (we call this the absorbing latent). After collecting failure cases, we compute the cosine similarity of the absorbing latent and the main top latent(s), and compare the failure case cosine similarity distribution to the baseline cosine similarity distribution of all latents and the main top latent(s).
Experimental Details
SAEs
We use the sae_lens
library to access two sets of pretrained SAEs used in this project. First, we use GPT-2 small residual stream SAEs. The architecture of these SAEs is the standard SAE architecture, while the width is $d_{sae} = 24567 = 32\cdot 768 = 32\cdot d_{model}$. Second, we use the canonical Gemmascope 16k layer 5 residual stream SAE. Canonical refers to the sparsity level of the SAE. The architecture of this SAE is the state-of-the-art JumpReLU SAE architecture.
We use two separate datasets. For computing average sparsity rates in the first part of the project with the GPT-2 small SAEs, our 1000 prompts come from the training data for the GPT-2 small SAEs themselves. This data is provided within the sae_lens
library via the ActivationStore
object. For the 'starts-with-'E'' letter detection task, our few-shot prompting format follows (Chanin et al., 2024), but we manually create the prompts by searching over Gemma 2B's vocabulary for distinct starts-with-e tokens. There are 8832 such tokens in its vocabulary, and we randomly select 5000 of these to obtain the 5000 few-shot prompts used in the 'starts-with-'E'' letter detection experiment.
Attribution patching is done via the torch.func.jacrev
wrapper, which allows us to take custom Jacobians with respect to activations (not parameters!) in the model and SAEs. We utilize the fact that if either the upstream or downstream node has a zero activation, the derivative between them will be zero. Since latent activations are almost always zero, this means the Jacobians will be highly sparse and computing the entire Jacobian would be wasteful. Hence, we implement this with a SparseTensor class, which stores a sparse tensor as a triple (nonzero_activations, nonzero_indices, shape)
.
Results
First we show the 3 plots of latent-to-latent, latent-to-token, and latent-to-laten attributions plots on the example prompt 'The Eiffel Tower is in Paris.' The sparsity rates averaged over 1000 examples shown subsequently is the stronger piece of evidence, but these are included so the reader can see qualitatively what these attribution may look like.
Here is the latent-to-latent attributions plot:
Here is the latent-to-token gradients plot:
Here is the logit-to-latent attributions plot:
Preliminarily, we see that the attributions are less sparse than in the first two cases. This point is driven home if we (somewhat unprincipled) look just at the gradients rather than the latents:
Here is the logit-to-latent gradients plot:
Inspired by the initial visual findings, we quantitatively measure sparsity for each of the three relationships by computing the number of nonzero attribution pairs over the total number of pairs. We average this sparsity rate over 1000 examples taken from the activation store for the SAEs involved (activations the SAEs were trained on), and obtain the following average sparsity rates:
- Latent-to-latent: 1.7893e-05
- Token-to-latent: 0.0054
- Latent-to-logit: 0.0154
Thus, latent to logit relationships are clearly the least sparse by attribution. This is likely due to the SAE's cost function, which penalizes latents responding to tokens with an L_1 or L_0 loss on the latent activations vector, but no such penalty for latent attributions.
For the 'starts-with-'E'' letter detection task, here are the results averaging attributions over 5000 different examples given in the few shot format, each with a distinct starts-with-‘e’ token (case-insensitive). There are two clear top latents by attribution (plot of top 25):
It is reasonable to expect that an SFC-like approach that filters latents via an attribution threshold will filter out all latents but latent 16070 and latent 13484. These are the two main ‘starts with E’ latents by attribution. We then define a failure case to be when the top latent by attribution is neither of the two. We find approximately 20% of the examples are classified as failure cases. For example, the example with ‘Env’ is a failure case:
Next, we check if the top latent by attribution has a cosine similarity of at least 0.025 with either one of latent 16070 or latent 13484 (cosine sim computed using the decoder vectors for the latents). The threshold of 0.025 follows Chanin et al. As a baseline, nearly a quarter of latents have cosine similarity of at least 0.025 with either one of latent 16070 or latent 13484:
Yet, for the failure cases, we find that the cosine similarity exceeds the threshold of 0.025 in only 6% of cases: Moreover, the cosine similarity is typically between -0.025 and 0.025. We conclude that when the main latent(s) by attribution is replaced by a different latent that is the top latent by attribution in its place, there is minimal interference between the replacing latent and the main latent(s). We should not generalize this since the project was only done using high-level circuit analysis on the single ‘starts with E’ task, but the results indicate that SAE circuits may have entirely different, orthogonal computational pathways.
Further Directions
To extend this experiment, it would be interesting to see if we can describe the absorbing latent by attribution by some characteristic. For example, in Chanin et al., they find that on the starts-with-E letter detection task, there is a general latent that activates on tokens that starts with E, but conditionally it will not activate because a token-aligned latent like 'elephants' fires instead. Is it also possible that the absorbing feature by attribution is similarly ‘token-aligned’? This is likely harder since we are trying to describe a latent not by its role as a representation of the input, but its dual role as an action on the logits, so standard tools like looking at max activating examples are insufficient to classify the latent.
References
- Nanda, N. (2023). Attribution Patching. Retrieved from https://www.neelnanda.io/mechanistic-interpretability/attribution-patching
- Anthropic. (2024, March). March 2024 Monthly Update: Feature Heads. Retrieved from https://transformer-circuits.pub/2024/march-update/index.html#feature-heads
- Marks, S. et al. (2024). Sparse Feature Circuits: Discovering and Editing Interpretable Causal Graphs in Language Modles. arXiv preprint.
- Neuronpedia. [Dashboard for visualizing neural network features and their explanations].
- SAE Lens. (2023). HookedSAETransformer class for GPT-2 small. [Software library].
- Chanin, D et al (2024). A is for Absorption: Studying Feature Splitting and Absorption in Sparse Autoencoders
- ARENA (2024). Retrieved from https://arena-chapter1-transformer-interp.streamlit.app/[1.3.2]_Interpretability_with_SAEs. [code]