| name | transformer-lens-interpretability |
| description | Provides guidance for mechanistic interpretability research using TransformerLens to inspect and manipulate transformer internals via HookPoints and activation caching. Use when reverse-engineering model algorithms, studying attention patterns, or performing activation patching experiments. |
| version | 1.0.0 |
| author | Orchestra Research |
| license | MIT |
| tags | ["Mechanistic Interpretability","TransformerLens","Activation Patching","Circuit Analysis"] |
| dependencies | ["transformer-lens>=2.0.0","torch>=2.0.0"] |
TransformerLens: Mechanistic Interpretability for Transformers
TransformerLens is the de facto standard library for mechanistic interpretability research on GPT-style language models. Created by Neel Nanda and maintained by Bryce Meyer, it provides clean interfaces to inspect and manipulate model internals via HookPoints on every activation.
GitHub: TransformerLensOrg/TransformerLens (2,900+ stars)
When to Use TransformerLens
Use TransformerLens when you need to:
- Reverse-engineer algorithms learned during training
- Perform activation patching / causal tracing experiments
- Study attention patterns and information flow
- Analyze circuits (e.g., induction heads, IOI circuit)
- Cache and inspect intermediate activations
- Apply direct logit attribution
Consider alternatives when:
- You need to work with non-transformer architectures → Use nnsight or pyvene
- You want to train/analyze Sparse Autoencoders → Use SAELens
- You need remote execution on massive models → Use nnsight with NDIF
- You want higher-level causal intervention abstractions → Use pyvene
Installation
pip install transformer-lens
For development version:
pip install git+https://github.com/TransformerLensOrg/TransformerLens
Core Concepts
HookedTransformer
The main class that wraps transformer models with HookPoints on every activation:
from transformer_lens import HookedTransformer
model = HookedTransformer.from_pretrained("gpt2-small")
import os
os.environ["HF_TOKEN"] = "your_token"
model = HookedTransformer.from_pretrained("meta-llama/Llama-2-7b-hf")
Supported Models (50+)
| Family | Models |
|---|
| GPT-2 | gpt2, gpt2-medium, gpt2-large, gpt2-xl |
| LLaMA | llama-7b, llama-13b, llama-2-7b, llama-2-13b |
| EleutherAI | pythia-70m to pythia-12b, gpt-neo, gpt-j-6b |
| Mistral | mistral-7b, mixtral-8x7b |
| Others | phi, qwen, opt, gemma |
Activation Caching
Run the model and cache all intermediate activations:
tokens = model.to_tokens("The Eiffel Tower is in")
logits, cache = model.run_with_cache(tokens)
residual = cache["resid_post", 5]
attn_pattern = cache["pattern", 3]
mlp_out = cache["mlp_out", 7]
logits, cache = model.run_with_cache(
tokens,
names_filter=lambda name: "resid_post" in name
)
ActivationCache Keys
| Key Pattern | Shape | Description |
|---|
resid_pre, layer | [batch, pos, d_model] | Residual before attention |
resid_mid, layer | [batch, pos, d_model] | Residual after attention |
resid_post, layer | [batch, pos, d_model] | Residual after MLP |
attn_out, layer | [batch, pos, d_model] | Attention output |
mlp_out, layer | [batch, pos, d_model] | MLP output |
pattern, layer | [batch, head, q_pos, k_pos] | Attention pattern (post-softmax) |
q, layer | [batch, pos, head, d_head] | Query vectors |
k, layer | [batch, pos, head, d_head] | Key vectors |
v, layer | [batch, pos, head, d_head] | Value vectors |
Workflow 1: Activation Patching (Causal Tracing)
Identify which activations causally affect model output by patching clean activations into corrupted runs.
Step-by-Step
from transformer_lens import HookedTransformer, patching
import torch
model = HookedTransformer.from_pretrained("gpt2-small")
clean_prompt = "The Eiffel Tower is in the city of"
corrupted_prompt = "The Colosseum is in the city of"
clean_tokens = model.to_tokens(clean_prompt)
corrupted_tokens = model.to_tokens(corrupted_prompt)
_, clean_cache = model.run_with_cache(clean_tokens)
paris_token = model.to_single_token(" Paris")
rome_token = model.to_single_token(" Rome")
def metric(logits):
return logits[0, -1, paris_token] - logits[0, -1, rome_token]
results = torch.zeros(model.cfg.n_layers, clean_tokens.shape[1])
for layer in range(model.cfg.n_layers):
for pos in range(clean_tokens.shape[1]):
def patch_hook(activation, hook):
activation[0, pos] = clean_cache[hook.name][0, pos]
return activation
patched_logits = model.run_with_hooks(
corrupted_tokens,
fwd_hooks=[(f"blocks.{layer}.hook_resid_post", patch_hook)]
)
results[layer, pos] = metric(patched_logits)
Checklist
Workflow 2: Circuit Analysis (Indirect Object Identification)
Replicate the IOI circuit discovery from "Interpretability in the Wild".
Step-by-Step
from transformer_lens import HookedTransformer
import torch
model = HookedTransformer.from_pretrained("gpt2-small")
prompt = "When John and Mary went to the store, Mary gave a bottle to"
tokens = model.to_tokens(prompt)
logits, cache = model.run_with_cache(tokens)
john_token = model.to_single_token(" John")
mary_token = model.to_single_token(" Mary")
logit_diff = logits[0, -1, john_token] - logits[0, -1, mary_token]
print(f"Logit difference: {logit_diff.item():.3f}")
def get_head_contribution(layer, head):
head_out = cache["z", layer][0, :, head, :]
W_O = model.W_O[layer, head]
W_U = model.W_U
contribution = head_out[-1] @ W_O @ W_U
return contribution[john_token] - contribution[mary_token]
head_contributions = torch.zeros(model.cfg.n_layers, model.cfg.n_heads)
for layer in range(model.cfg.n_layers):
for head in range(model.cfg.n_heads):
head_contributions[layer, head] = get_head_contribution(layer, head)
Checklist
Workflow 3: Induction Head Detection
Find induction heads that implement [A][B]...[A] → [B] pattern.
from transformer_lens import HookedTransformer
import torch
model = HookedTransformer.from_pretrained("gpt2-small")
repeated_tokens = torch.tensor([[1000, 2000, 1000]])
_, cache = model.run_with_cache(repeated_tokens)
induction_scores = torch.zeros(model.cfg.n_layers, model.cfg.n_heads)
for layer in range(model.cfg.n_layers):
pattern = cache["pattern", layer][0]
induction_scores[layer] = pattern[:, 2, 1]
top_heads = torch.topk(induction_scores.flatten(), k=5)
Common Issues & Solutions
Issue: Hooks persist after debugging
model.run_with_hooks(tokens, fwd_hooks=[...])
model.run_with_hooks(tokens, fwd_hooks=[...])
model.reset_hooks()
model.run_with_hooks(tokens, fwd_hooks=[...])
Issue: Tokenization gotchas
model.to_tokens("Tim")
model.to_tokens("Neel")
tokens = model.to_tokens("Neel", prepend_bos=False)
print(model.to_str_tokens(tokens))
Issue: LayerNorm ignored in analysis
pre_activation = residual @ model.W_in[layer]
ln_scale = model.blocks[layer].ln2.w
ln_out = model.blocks[layer].ln2(residual)
pre_activation = ln_out @ model.W_in[layer]
Issue: Memory explosion with large models
logits, cache = model.run_with_cache(
tokens,
names_filter=lambda n: "resid_post" in n or "pattern" in n,
device="cpu"
)
Key Classes Reference
| Class | Purpose |
|---|
HookedTransformer | Main model wrapper with hooks |
ActivationCache | Dictionary-like cache of activations |
HookedTransformerConfig | Model configuration |
FactoredMatrix | Efficient factored matrix operations |
Integration with SAELens
TransformerLens integrates with SAELens for Sparse Autoencoder analysis:
from transformer_lens import HookedTransformer
from sae_lens import SAE
model = HookedTransformer.from_pretrained("gpt2-small")
sae = SAE.from_pretrained("gpt2-small-res-jb", "blocks.8.hook_resid_pre")
tokens = model.to_tokens("Hello world")
_, cache = model.run_with_cache(tokens)
sae_acts = sae.encode(cache["resid_pre", 8])
Reference Documentation
For detailed API documentation, tutorials, and advanced usage, see the references/ folder:
External Resources
Tutorials
Papers
Official Documentation
Version Notes
- v2.0: Removed HookedSAE (moved to SAELens)
- v3.0 (alpha): TransformerBridge for loading any nn.Module