6/27: GPUs and Transformer Interpretability Basics
I've been keeping a weekly work journal as I've been at the Recurse Center to help my organize notes as I've been learning. These are essentially snapshots of what I've been looking at. A fellow recurser pointed out to me that these might be worth putting out into the world on their own so this is an experiment in that.
This post is going to be more or less those raw notes, lightly edited.
This week, we have:
- Recurse goals check-in
- Paper reading notes on A Mathematical Framework for Transformer Circuits
- Paper reading notes on Transformer Feed-Forward Layers Are Key-Value Memories
- Notes on CS336 Lecture 5
- Notes on ARENA 1.1, Transformer from scratch
recurse goals check-in
I extended my RC batch to another 6 weeks. Judging by the first half, 6 weeks is a lot less time than you’d think. I’m writing this out to be a bit more picky about what I’m going to spend time on here. The main culprits are CS336 and ARENA, but I’d also like to have read through the original circuits and transformer circuits threads and feel like I can read recent research papers and generally understand what's going on.
Paper reading goals:
For CS336
There's basics (implementing), systems (gpus), scaling (laws), data (gathering), and RL.
I may try to gloss over the GPUs and data gathering portions, but I do want to do the RL part more deeply. I feel somewhat torn about the GPU assignment because I think it would really sharpen my intuition on GPUs but I also think it’s probably distracting and would take me quite a while. Assignment 3 seems pretty short and sweet. Assignment 4 on data cleanup I will probably skim, then focus on assignment 5.
For ARENA
There's Chapter 1 (transformer interpretability), Chapter 2 (RL), and Chapter 3 (LLM evals).
I feel more interested in chapters 1 and 2, and may skim the work on evals.
Much of this content seems to overlap with itself which seems in my favor, but more for reviewing the knowledge vs. time savings. I would also like to spend some time writing and/or working on a last project to cap off this work.
Can I actually get everything done?
Here’s a sample schedule to see if these goals are actually feasible:
Schedule:
- Week 7: half of ARENA chapter 1, cs336 lecture 6, reading through original circuits thread
- Week 8: second half of ARENA chapter 1, cs336 lecture 7 and 8
- Week 9: cs336 lectures 9, 10, 11, 12 and assignment 3 (scaling)
- Week 10: cs336 lectures 13, 14 (data), ARENA chapter 3 (evals), additional paper reading
- Week 11: cs336 lectures 15, 16, 17 (rl) and ARENA chapter 2 (rl)
- Week 12: cs336 lecture 18 and 19 (guest lectures), tbd final project
It’s possible that I should spend more time on RL and put that earlier instead, depending on how out of order the stanford class can be. This seems vaguely doable, but still pretty aggressive scheduling1.
A Mathematical Framework for Transformer Circuits
Basic framework of OV and QK circuits
Instead of looking at attention as a unit that scales out, this paper prefers to look at the transformer using the QK circuit and OV circuit, which it states are roughly independent.
The QK circuit is the one that makes up the attention scores (with softmax), i.e. what should the head attend to?
The OV circuit is the one that tells the attention head how much that attention should update the output logits. OV is made up of the values matrix and the output projection matrix that goes from d_head to d_model (i.e. dimension of the residual stream).
The destination token attends to the source token. The QK is a function of both, but OV is a function of only the source. In other words, the destination token is attending to the source by writing some information about the source to itself. Note that “tokens” here refer to their positions in the context window.
One-layer models
When you have a one layer model, values can reach the model output either via the residual connection or entering through one attention ahead.
Skip trigrams
Note that trigram is a bit odd, because it’s still a prediction based on the src and dest. Neel mentions it might be better to call this “skip bigram” instead.
One cool behavior we see is the model doing what looks like “normalizing” the tokenizer output where some tokens are the ~same in some sense ( Ralph
and Ralph
and RALPH
should be the ~same, conceptually but not grammatically).
Copying and skip trigram limits
A lot of the heads seem to be doing copying behavior like above, where the destination token copies the source token into its predicted output (like the “perfect” example above).
A limit of trigram behavior is that it incorrectly predicts outputs too.
This seems to happen because the destination token is only considered for the KV circuit, not the OV one. So it can tell the output “where” to look, but not “what” which causes that step to not always produce useful output.
The paper also does some analysis of eigenvectors/values with the intuition that:
Negative | Imaginary | Positive | |
---|---|---|---|
KV Circuit | Avoid same-attention | Other tokens | Prefer same-attention |
OV Circuit | Anti-Copying | Other tokens | Copying |
It’s unclear to me exactly how this was computed or how generally useful it is, but the input/output vectors are on the vocab space, not the context window.
Generalization to 2-layer models
The paper then talks about the fact that when you go into 2-layer models, you have 3 paths: the residual, the attention heads, and the “virtual attention heads” which are all of the combinations between the values that entered one head each for each of the layers.
There’s then some “term importance” math to decide if the virtual heads matter and empirically, they do some tricks to see which affect the output loss the most. The individual attention heads seem most important, followed by the residual stream, followed by the virtual attention heads.
Induction heads
Following skip trigrams, the paper posits that the 2-layer model spends its time composing the layers and forming “induction heads” which guess the next output by trying to look at previous examples within the context window. There are a few interesting parts to this.
In-context learning This is in some sense, “learning” within the context window because it infers possible outcomes for the text based on its surrounding words.
Learned feature from composition This behavior uses the previous layer to figure out what’s going on. It uses a previous layers’ head that does “previous token lookup”
Questions
- Neel mentions that combining embed/unembed layer doesn’t work, but that seems contrary to weight tying that many models seem to do?
- How do we think about why some parts of the model have a privileged basis or not?
Garcon
No notes, but interesting overview of a tool they made at Anthropic to help with inspecting models. This sounds pretty neat to work on.
Transformer Feed-Forward Layers Are Key-Value Memories
Rough idea is that they analyzed what context prefixes would trigger a given neuron in MLP layers in an LLM to get a sense for what each is trying to do. They took each neuron, found which inputs best activated that neuron, then tried to get humans to see if they did something interpretable in English.
Key-value decomposition
More concretely,
- Keys are the inputs that trigger a value
- Values are the probability output distribution
Key-value pairs correspond to what individual neurons are doing.
For instance, in a transformer the key might be “military bases in“ and the value distribution would be places where bases might be. For MNIST, keys are “inputs that have closed loop” might be values of 0, 6, 8, or 9.
Even more concretely,
- Keys can be thought of as the vector k, which is a slice of the neuron weights inputting to the hidden layer
- Values can be thought of as the vector v, which is a slice outputting from the hidden layer
With this, they saw that:
- neurons did tend to have interpretable behavior, i.e. the inputs that triggered the key could be logically grouped by a human (“these all end with the word scenario”)
- earlier layers tended to have simpler behaviors (“ends in this word”) and later ones had more complex behavior (groupable by semantic meaning)
Note that the simpler behaviors here also seem to kind of feel like the “skip trigrams” model that the Anthropic mathematical framework paper proposes, but that is for the attention heads not the MLP. Also, this paper doesn’t talk about in-context learning in the attention head, whereas the MLP is doing the opposite - word-specific learning. How do these interact? It somewhat feels like they are learning similar things but in different ways.
Values as output distributions
They also sort of discuss that later layers seem more focused on output than the earlier layers.
They also suggest that MLPs contribute small parts to the residual stream, which get combined into the answer. This is pretty agreeable with our current mental model of the residual stream which implies that vectors get sharpened while flowing through the network, with the resdidual stream used as a basis.
This paper seems to be the one that one of the NGW2 talks applying this to MNIST was based on.
Questions
- are there more transformers this can be applied to? they only use one random one?
- why is it ok to interpret the values this way? is it because you’d expect inputs that trigger these keys to roughly trigger these outputs anyways?
- what work comes out of this paper? the only one citation i found is rome
- what do the keys that had "no grouping" look like?
- 3.6 agreements seems pretty high - is it actually that many? what do the raw annotations look like?
- could you try to do this at scale by having a larger llm grade the results?
- why is this different than say, trying to interpret individual neurons in CNNs? This feels really similar to how the Circuits thread thinks about things.
Lecture 5 - GPUs
Modern GPUs tend to be memory-bound and not CPU-bound because CPU scaled has moved faster than memory.
Major tricks:
Reduce memory bandwidth usage
- Kernel fusion - doing more steps in serial within a run in the SM instead of separate steps that head back to HBM often
- Coalescing - let the hardware help you read memory in bursts
- Tiling - for matmuls, doing smaller pieces of the matmul to minimize loads from HBM.
Tradeoff more compute for less memory bandwidth
- Recomputation - instead of storing and fetching a large thing from memory, prefer recomputing it if it’s a stateless artifact, i.e. the attention scores in the backprop of attention
- Use lower precision FLOPs for less accuracy but faster
FlashAttention
Use all of the above tricks when calculating the attention head. In particular, reduce the cost of the sequence^2 attention scores by fusing the entire attention head into one kernel.
To do this, the matrices need to be tiled. Softmax is difficult because it requires computing across the entire row, but online softmax avoids this by doing a lclever math trick and doing a running calc of the denominator.
FlashAttention also prefers recomputing certain parts (attention scores) during the backprop instead of storing/loading from memory.
FlashAttention is different in that it is just as correct/precise as vanilla attention, but with less memory access for better arithmetic intensity (how much math per byte moved).
Arena 1.1
This was mostly review from assignment 1 of cs336 so I won’t have much here.
Beam search
Keep sampling the k best sequences so far and return a full output at the end that has the best loss. More likely to get a sequence with good loss overall, uncovers hidden sequences that have good loss.
KV-caching
Questions
- How does KV cache interact with positional embeddings?
- What happens to KV cache when the context window length overflows?