Monthly Algorithmic Challenge (October 2023): Sorted List
Problem
This post is the fourth in the sequence of monthly mechanistic interpretability challenges. I designed them in the spirit of Stephen Casper's challenges, but with the more specific aim of working well in the context of the rest of the ARENA material, and helping people put into practice all the things they've learned so far. You can access the full ARENA materials from the Streamlit page. The purpose of the page you're currently reading is mainly to provide an easy-to-access place to read the solutions and understand the nature of these problems; I'd recommend using the Streamlit page instead if you're actually working through the problems. You can also find the solutions Colab here. I'll assume everyone has read the first problem, and has some overall context on the sequence.Task & Dataset
The problem for this month is interpreting a model which has been trained to sort a list. The model is fed sequences like:
[11, 2, 5, 0, 3, 9, SEP, 0, 2, 3, 5, 9, 11]
and has been trained to predict each element in the sorted list (in other words, the output at the SEP
token should be a prediction of 0
, the output at 0
should be a prediction of 2
, etc).
Here is an example of what this dataset looks like:
dataset = SortedListDataset(size=1, list_len=5, max_value=10, seed=42)
print(dataset[0].tolist())
print(dataset.str_toks[0])
['9', '6', '2', '4', '5', 'SEP', '2', '4', '5', '6', '9']
Model
The model is attention-only, with 1 layer, and 2 attention heads per layer. It was trained with layernorm, weight decay, and an Adam optimizer with linearly decaying learning rate.
Solutions
First, let's do some setup:logits, cache = model.run_with_cache(dataset.toks)
logits: Tensor = logits[:, dataset.list_len:-1, :]
targets = dataset.toks[:, dataset.list_len+1:]
logprobs = logits.log_softmax(-1) # [batch seq_len vocab_out]
probs = logprobs.softmax(-1)
batch_size, seq_len = dataset.toks.shape
logprobs_correct = eindex(logprobs, targets, "batch seq [batch seq]")
probs_correct = eindex(probs, targets, "batch seq [batch seq]")
avg_cross_entropy_loss = -logprobs_correct.mean().item()
print(f"Average cross entropy loss: {avg_cross_entropy_loss:.3f}")
print(f"Mean probability on correct label: {probs_correct.mean():.3f}")
print(f"Median probability on correct label: {probs_correct.median():.3f}")
print(f"Min probability on correct label: {probs_correct.min():.3f}")
Mean probability on correct label: 0.966
Median probability on correct label: 0.981
Min probability on correct label: 0.001
And some simple (kinda hacky) visualisation:
def show(dataset: SortedListDataset, batch_idx: int):
logits: Tensor = model(dataset.toks)[:, dataset.list_len:-1, :]
logprobs = logits.log_softmax(-1) # [batch seq_len vocab_out]
probs = logprobs.softmax(-1)
str_targets = dataset.str_toks[batch_idx][dataset.list_len+1: dataset.seq_len]
imshow(
probs[batch_idx].T,
y=dataset.vocab,
x=[f"{dataset.str_toks[batch_idx][j]}<br><sub>({j})</sub>" for j in range(dataset.list_len+1, dataset.seq_len)],
labels={"x": "Token", "y": "Vocab"},
xaxis_tickangle=0,
title=f"Sample model probabilities:<br>Unsorted = ({','.join(dataset.str_toks[batch_idx][:dataset.list_len])})",
text=[
["〇" if (str_tok == target) else "" for target in str_targets]
for str_tok in dataset.vocab
],
width=400,
height=1000,
)
show(dataset, 0)
Summary of how the model works
In the second half of the sequence, the attention heads perform the algorithm "attend back to (and copy) the first token which is larger than me". For example, in a sequence like:
[7, 5, 12, 3, SEP, 3, 5, 7, 12]
we would have the second 3 token attending back to the first 5 token (because it's the first one that's larger than itself), the second 5 attending back to 7, etc. The SEP token just attends to the smallest token.
Some more refinements to this basic idea:
- The two attending heads split responsibilities across the vocabulary. Head 0.0 is the less important head; it deals with values in the range 28-37 (roughly). Head 0.1 deals with most other values.
- Heads actually sometimes attend more to values like
d+2
,d+3
than tod+1
(whend
is the destination token). So why aren't sequnces with[d, d+1, d+2]
adversarial examples (i.e. making the model incorrectly predictd+2
afterd
)?- Answer - the OV circuit shows that when we attend to source token
s
, we also boost things slightly less thns
, and suppress things slightly more thans
. - So imagine we have a sequence
[d, d+1, d+2]
:- Attention to
d+1
will boostd+1
a lot, and suppressd+2
a bit. - Attention to
d+2
will boostd+2
a lot, and boostd+1
a bit. - So even if
d+2
gets slightly more attention,d+1
might end up getting slightly more boosting.
- Attention to
- Answer - the OV circuit shows that when we attend to source token
- Sequences with large jumps are adversarial examples (because they're rare in the training data, which was randomly generated from choosing subsets without replacement).
Attention patterns
First, let's visualise attention like we usually do:
cv.attention.from_cache(
cache = cache,
tokens = dataset.str_toks,
batch_idx = list(range(10)),
radioitems = True,
return_mode = "view",
batch_labels = ["<code>" + " ".join(s) + "</code>" for s in dataset.str_toks],
mode = "small",
)
Note, we only care about the attention patterns from the second half of the sequence back to earlier values (since it's a 1-layer model, and that's where we're taking predictions from).
Some observations:
- SEP consistently attends to the smallest value.
- Most of the time, token
d
will attend to the smallest token which is strictly larger thand
, in at least one of the heads.- Seems like heads 0.0 and 0.1 split responsibility across the vocabulary: 0.1 deals with most values, 0.0 deals with a small range of values around ~30.
- This strongly suggests that the heads are predicting whatever they pay attention to.
- One slightly confusing result - sometimes token
d
will pay attention more to the value which is 2 positions higher thand
in the sorted list, rather than 1 position higher (e.g. very first example:4
attends more to7
than to5
). This is particularly common in sequences with 3 numbers very close together.- Further investigation (not shown here) suggests that these are not adversarial examples, i.e. attending more to
7
than to5
doesn't stop5
from being predicted. At this point, I wasn't sure what the reason for this was.
- Further investigation (not shown here) suggests that these are not adversarial examples, i.e. attending more to
Next steps - confirm anecdotal observations about OV and QK circuits (plus run some basic head ablation experiments).
Ablating heads
Testing whether head 0.1 matters more (this was my hypothesis, since it seems to cover more of the vocabulary than 0.0). Conclusion - yes.
def get_loss_from_ablating_head(layer: int, head: int):
def hook_fn(activation: Float[Tensor, "batch seq nheads d"], hook: HookPoint):
activation_mean: Float[Tensor, "d_model"] = cache[hook.name][:, :, head].mean(0)
activation[:, :, head] = activation_mean
return activation
model.reset_hooks()
logits_orig = model(dataset.toks)
logprobs_orig = logits_orig.log_softmax(-1)[:, dataset.list_len:-1, :]
logits_ablated = model.run_with_hooks(dataset.toks, fwd_hooks=[(utils.get_act_name("result", layer), hook_fn)])
logprobs_ablated = logits_ablated.log_softmax(-1)[:, dataset.list_len:-1, :]
targets = dataset.toks[:, dataset.list_len+1:]
logprobs_orig_correct = eindex(logprobs_orig, targets, "batch seq [batch seq]")
logprobs_ablated_correct = eindex(logprobs_ablated, targets, "batch seq [batch seq]")
return (logprobs_orig_correct - logprobs_ablated_correct).mean().item()
print("Loss from mean ablating the output of...")
for layer in range(model.cfg.n_layers):
for head in range(model.cfg.n_heads):
print(f" ...{layer}.{head} = {get_loss_from_ablating_head(layer, head):.3f}")
...0.0 = 0.920
...0.1 = 4.963
OV & QK circuits
We expect OV to be a copying circuit, and QK to be an "attend to anything bigger than self" circuit. SEP
should attend to the smallest values.
W_OV = model.W_V[0] @ model.W_O[0] # [head d_model_in d_model_out]
W_QK = model.W_Q[0] @ model.W_K[0].transpose(-1, -2) # [head d_model_dest d_model_src]
W_OV_full = model.W_E @ W_OV @ model.W_U
W_QK_full = model.W_E @ W_QK @ model.W_E.T
imshow(
W_OV_full,
labels = {"x": "Prediction", "y": "Source token"},
title = "W<sub>OV</sub> for layer 1 (shows that the heads are copying)",
width = 900,
height = 500,
facet_col = 0,
facet_labels = [f"W<sub>OV</sub> [0.{h0}]" for h0 in range(model.cfg.n_heads)]
)
imshow(
W_QK_full,
labels = {"x": "Input token", "y": "Output logit"},
title = "W<sub>QK</sub> for layer 1 (shows that the heads are attending to next largest thing)",
width = 900,
height = 500,
facet_col = 0,
facet_labels = [f"W<sub>QK</sub> [0.{h0}]" for h0 in range(model.cfg.n_heads)]
)
Conclusion - this basically matches the previous hypotheses:
- Strong diagonal pattern for the OV circuits shows that 0.1 is a copying head on most of the vocabulary (everything outside the values in the [28, 37] range), and 0.1 is a copying head on the other values.
- Weak patchy diagonal pattern in QK circuit shows that most tokens attend more to ones which are slightly above them (and also that there are some cases where
d
attends more tod+2
,d+3
etc than tod+1
).
Visualising that last observation in more detail, for the case d=25
:
def qk_bar(dest_posn: int):
bar(
[W_QK_full[0, dest_posn, :], W_QK_full[1, dest_posn, :]], # Head 1.1, attention from token dest_posn to others
title = f"Attention scores for destination token {dest_posn}",
width = 900,
height = 400,
template = "simple_white",
barmode = "group",
names = ["0.0", "0.1"],
labels = {"variable": "Head", "index": "Source token", "value": "Attention score"},
)
qk_bar(dest_posn=25)
The most attended to are actually 28 and 29! We'll address this later, but first let's also explain a slightly simpler but also confusing-seeming result from the heatmap above.
What's with the attention to zero?
One weird observation in the heatmap it's worth mentioning - some tokens with very high values (i.e. >35) attend a lot to very small tokens, e.g. zero.
qk_bar(dest_posn=40)
Why don't these tokens all attend to zero?
Answer - plotting the QK circuit with token embeddings on the query side and positional embeddings on the key side shows that tokens near the end of the sequence have a bias against attending to very small tokens. Since tokens near the end of the sequence are likely to be precisely these larger values (i.e. >35), it's reasonable to guess that this effect cancels out the previously observed bias towards small tokens.
POSN_LABELS = [str(i) for i in range(dataset.seq_len)]
POSN_LABELS[dataset.list_len] = "SEP"
W_Qpos_Kemb = model.W_pos @ W_QK @ model.W_E.T
imshow(
W_Qpos_Kemb,
labels = {"x": "Key token", "y": "Query position"},
title = "W<sub>QK</sub> for layer 1 (shows that the heads are attending to next largest thing)",
y = POSN_LABELS,
width = 950,
height = 350,
facet_col = 0,
facet_labels = [f"W<sub>QK</sub> [0.{h0}]" for h0 in range(model.cfg.n_heads)]
)
Advexes
This plot also reveals a lot of potential advexes - for example, SEP
consistently attends to the smallest value up to around ~30, where this pattern falls off. So if your entire sequence was in the range [30, 50], it's very possible that the model would fail to correctly identify the smallest token. Can you exhibit an example of this?
Another possible advex: if there's a large gap between tokens x
and y
, then x
might attend to itself rather than to y
. I created a CustomSortedList
dataclass to confirm this. I also wrote a function show_multiple
which can show multiple different plots in a batch at once (this was helpful for quickly testing out advexes) - you can see this in the Setup code section.
class CustomSortedListDataset(SortedListDataset):
def __init__(self, unsorted_lists: List[List[int]], max_value: int):
'''
Creates a dataset from the unsorted lists in unsorted_lists.
'''
self.size = len(unsorted_lists)
self.list_len = len(unsorted_lists[0])
self.seq_len = 2*self.list_len + 1
self.max_value = max_value
self.vocab = [str(i) for i in range(max_value+1)] + ["SEP"]
sep_toks = t.full(size=(self.size, 1), fill_value=self.vocab.index("SEP"))
unsorted_list = t.tensor(unsorted_lists)
sorted_list = t.sort(unsorted_list, dim=-1).values
self.toks = t.concat([unsorted_list, sep_toks, sorted_list], dim=-1)
self.str_toks = [[self.vocab[i] for i in toks] for toks in self.toks.tolist()]
custom_dataset = CustomSortedListDataset(
unsorted_lists = [
[0] + list(range(40, 49)),
[5] + list(range(30, 48, 2)),
],
max_value=50,
)
custom_logits, custom_cache = model.run_with_cache(custom_dataset.toks)
cv.attention.from_cache(
cache = custom_cache,
tokens = custom_dataset.str_toks,
radioitems = True,
return_mode = "view",
batch_labels = ["<code>" + " ".join(s) + "</code>" for s in custom_dataset.str_toks],
mode = "small",
)
show_multiple(custom_dataset)
Conclusion - yes, we correctly tricked x
into self-attending rather than attending to y
in these cases. The predictions were a bit unexpected, but we can at least see that the model predicts x
with non-negligible probability (i.e. showing it's incorrectly predicted the token it attends to), and doesn't predict y
at all.
Solving the [d, d+1, d+2]
mystery
At this point, I spent frankly too long trying to figure out how sequences of the form [d, d+1, d+2]
weren't adversarial for this model. Before eventually finding the correct answer, the options I considered were:
- Maybe the less important head does something valuable. For instance, if there's a token where 0.1 boosts
d+1
andd+2
, maybe head 0.0 suppressesd+2
.- After all, it does seem from the OV circuit plot like head 0.0 is an anti-copying head at the tokens where 0.1 is a copying head.
- (However, the same cannot be said for the tokens where 0.0 is a copying head, i.e. 0.1 doesn't seem like it's anti-copying here - which made me immediately suspicious of this explanation.)
- The direct path
W_E @ W_U
is responsible for boosting tokens liked+1
much more thand+2
.- This proves to kinda be true (see plot below), but if this was the main factor then you'd expect
[d, d+1, d+2]
sequences to become advexes once you remove the direct path. I ran an ablation experiment to test this, and it turned out not to be true.
- This proves to kinda be true (see plot below), but if this was the main factor then you'd expect
imshow(
model.W_E @ model.W_U,
title = "DLA from direct path",
labels = {"x": "Prediction", "y": "Input token"},
height = 500,
width = 600,
)
Finally, I found the actual explanation. As described earlier, attending to d+2
will actually slightly boost d+1
, and attending to d+1
will slightly suppress d+2
(and the same holds true for slightly larger gaps between source tokens). So even if d+2
is getting a bit more attention, the net effect will be that d+1
gets boosted more than d+2
.
To visualise this, here's a set of 5 examples. Each of them contains sequences with 3 values x < y < z
close together, which I judged from the QK bar charts earlier would trick the model by having x
attend to z
as much as / more than y
. For each of them, I measured the direct logit attribution to y
and z
respectively, coming from the source tokens y
and z
respectively.
I already knew that I would see:
- DLA from
y -> y
large, positive - DLA from
z -> z
large, positive (often larger thany -> y
)
And if this hypothesis was correct, then I expected to see:
- DLA from
y -> z
weakly negative - DLA from
z -> y
weakly positive - The total DLA to
y
should be larger than the total DLA toz
(summing over source tokensy
andz
)
This is exactly what we see:
custom_dataset = CustomSortedListDataset(
unsorted_lists = [
[0, 5, 14, 15, 17, 25, 30, 35, 40, 45],
[0, 5, 10, 15, 20, 25, 26, 27, 40, 45],
[0, 5, 10, 15, 20, 25, 30, 31, 32, 45],
[0, 5, 10, 15, 20, 25, 30, 31, 34, 45],
],
max_value=50,
)
custom_logits, custom_cache = model.run_with_cache(custom_dataset.toks)
# For each sequence, define which head I expect to be the important one, and define
# which tokens are acting as (x, y, z) in each case
head_list = [1, 1, 0, 0]
x_tokens = [14, 25, 30, 30]
y_tokens = [15, 26, 31, 31]
z_tokens = [17, 27, 32, 34]
src_tokens = t.tensor([y_tokens, z_tokens]).T
# Get the positions of (x, y, z) by indexing into the string tokens lists (need to be
# careful that I'm taking them from the correct half of the sequence)
L = custom_dataset.list_len
x_posns = t.tensor([L + toks[L:].index(str(dt)) for dt, toks in zip(x_tokens, custom_dataset.str_toks)])
y_posns = t.tensor([toks.index(str(st)) for st, toks in zip(y_tokens, custom_dataset.str_toks)])
z_posns = t.tensor([toks.index(str(st)) for st, toks in zip(z_tokens, custom_dataset.str_toks)])
src_posns = t.stack([y_posns, z_posns]).T
out = einops.einsum(
custom_cache["v", 0], model.W_O[0],
"batch seqK head d_head, head d_head d_model -> batch seqK head d_model",
)
# out = out.sum(2)
out = out[range(4), :, head_list] # [batch seqK d_model]
attn = custom_cache["pattern", 0][range(4), head_list] # [batch seqQ seqK]
result_pre_sum = einops.einsum(
out, attn,
"batch seqK d_model, batch seqQ seqK -> batch seqQ seqK d_model",
)
scale = custom_cache["scale"].unsqueeze(-1) # [batch seqQ 1 1]
dla = (result_pre_sum / scale) @ model.W_U # [batch seqQ seqK d_vocab]
# want tensor of shape (4, 2, 2), with dimensions (batch dim, src token = y/z, predicted token = y/z)
dla_from_yz_to_yz = dla[t.arange(4)[:, None, None], x_posns[:, None, None], src_posns[:, None, :], src_tokens[:, :, None]]
fig = imshow(
dla_from_yz_to_yz,
facet_col = 0,
facet_labels = [
f"Seq #{i}<br>(x, y, z) = ({x}, {y}, {z})"
for i, (x, y, z) in enumerate(zip(x_list, y_list, z_list))
],
title = "DLA for custom dataset with (x, y, z) close together",
title_y = 0.95,
labels = {"y": "Effect on prediction", "x": "Source token"},
x = ["src = y", "src = z"],
y = ["pred = y", "pred = z"],
width = 900,
height = 400,
text_auto = ".2f",
)