PERFECTLY NORMAL

CALLUM MCDOUGALL

This image was created using a variant of my thread art algorithm - read more here.

Monthly Algorithmic Challenge (October 2023): Sorted List




Problem

This post is the fourth in the sequence of monthly mechanistic interpretability challenges. I designed them in the spirit of Stephen Casper's challenges, but with the more specific aim of working well in the context of the rest of the ARENA material, and helping people put into practice all the things they've learned so far. You can access the full ARENA materials from the Streamlit page. The purpose of the page you're currently reading is mainly to provide an easy-to-access place to read the solutions and understand the nature of these problems; I'd recommend using the Streamlit page instead if you're actually working through the problems. You can also find the solutions Colab here. I'll assume everyone has read the first problem, and has some overall context on the sequence.

Task & Dataset

The problem for this month is interpreting a model which has been trained to sort a list. The model is fed sequences like:

[11, 2, 5, 0, 3, 9, SEP, 0, 2, 3, 5, 9, 11]

and has been trained to predict each element in the sorted list (in other words, the output at the SEP token should be a prediction of 0, the output at 0 should be a prediction of 2, etc).

Here is an example of what this dataset looks like:

dataset = SortedListDataset(size=1, list_len=5, max_value=10, seed=42)

print(dataset[0].tolist())
print(dataset.str_toks[0])

[9, 6, 2, 4, 5, 11, 2, 4, 5, 6, 9]
['9', '6', '2', '4', '5', 'SEP', '2', '4', '5', '6', '9']

Model

The model is attention-only, with 1 layer, and 2 attention heads per layer. It was trained with layernorm, weight decay, and an Adam optimizer with linearly decaying learning rate.





Solutions

First, let's do some setup:
logits, cache = model.run_with_cache(dataset.toks)
logits: Tensor = logits[:, dataset.list_len:-1, :]

targets = dataset.toks[:, dataset.list_len+1:]

logprobs = logits.log_softmax(-1) # [batch seq_len vocab_out]
probs = logprobs.softmax(-1)

batch_size, seq_len = dataset.toks.shape
logprobs_correct = eindex(logprobs, targets, "batch seq [batch seq]")
probs_correct = eindex(probs, targets, "batch seq [batch seq]")

avg_cross_entropy_loss = -logprobs_correct.mean().item()

print(f"Average cross entropy loss: {avg_cross_entropy_loss:.3f}")
print(f"Mean probability on correct label: {probs_correct.mean():.3f}")
print(f"Median probability on correct label: {probs_correct.median():.3f}")
print(f"Min probability on correct label: {probs_correct.min():.3f}")
Average cross entropy loss: 0.039
Mean probability on correct label: 0.966
Median probability on correct label: 0.981
Min probability on correct label: 0.001

And some simple (kinda hacky) visualisation:
def show(dataset: SortedListDataset, batch_idx: int):
    
    logits: Tensor = model(dataset.toks)[:, dataset.list_len:-1, :]
    logprobs = logits.log_softmax(-1) # [batch seq_len vocab_out]
    probs = logprobs.softmax(-1)

    str_targets = dataset.str_toks[batch_idx][dataset.list_len+1: dataset.seq_len]

    imshow(
        probs[batch_idx].T,
        y=dataset.vocab,
        x=[f"{dataset.str_toks[batch_idx][j]}<br><sub>({j})</sub>" for j in range(dataset.list_len+1, dataset.seq_len)],
        labels={"x": "Token", "y": "Vocab"},
        xaxis_tickangle=0,
        title=f"Sample model probabilities:<br>Unsorted = ({','.join(dataset.str_toks[batch_idx][:dataset.list_len])})",
        text=[
            ["〇" if (str_tok == target) else "" for target in str_targets]
            for str_tok in dataset.vocab
        ],
        width=400,
        height=1000,
    )

show(dataset, 0)

Summary of how the model works

In the second half of the sequence, the attention heads perform the algorithm "attend back to (and copy) the first token which is larger than me". For example, in a sequence like:

[7, 5, 12, 3, SEP, 3, 5, 7, 12]

we would have the second 3 token attending back to the first 5 token (because it's the first one that's larger than itself), the second 5 attending back to 7, etc. The SEP token just attends to the smallest token.

Some more refinements to this basic idea:

  • The two attending heads split responsibilities across the vocabulary. Head 0.0 is the less important head; it deals with values in the range 28-37 (roughly). Head 0.1 deals with most other values.
  • Heads actually sometimes attend more to values like d+2, d+3 than to d+1 (when d is the destination token). So why aren't sequnces with [d, d+1, d+2] adversarial examples (i.e. making the model incorrectly predict d+2 after d)?
    • Answer - the OV circuit shows that when we attend to source token s, we also boost things slightly less thn s, and suppress things slightly more than s.
    • So imagine we have a sequence [d, d+1, d+2]:
      • Attention to d+1 will boost d+1 a lot, and suppress d+2 a bit.
      • Attention to d+2 will boost d+2 a lot, and boost d+1 a bit.
      • So even if d+2 gets slightly more attention, d+1 might end up getting slightly more boosting.
  • Sequences with large jumps are adversarial examples (because they're rare in the training data, which was randomly generated from choosing subsets without replacement).

Attention patterns

First, let's visualise attention like we usually do:

cv.attention.from_cache(
    cache = cache,
    tokens = dataset.str_toks,
    batch_idx = list(range(10)),
    radioitems = True,
    return_mode = "view",
    batch_labels = ["<code>" + " ".join(s) + "</code>" for s in dataset.str_toks],
    mode = "small",
)

Note, we only care about the attention patterns from the second half of the sequence back to earlier values (since it's a 1-layer model, and that's where we're taking predictions from).

Some observations:

  • SEP consistently attends to the smallest value.
  • Most of the time, token d will attend to the smallest token which is strictly larger than d, in at least one of the heads.
    • Seems like heads 0.0 and 0.1 split responsibility across the vocabulary: 0.1 deals with most values, 0.0 deals with a small range of values around ~30.
  • This strongly suggests that the heads are predicting whatever they pay attention to.
  • One slightly confusing result - sometimes token d will pay attention more to the value which is 2 positions higher than d in the sorted list, rather than 1 position higher (e.g. very first example: 4 attends more to 7 than to 5). This is particularly common in sequences with 3 numbers very close together.
    • Further investigation (not shown here) suggests that these are not adversarial examples, i.e. attending more to 7 than to 5 doesn't stop 5 from being predicted. At this point, I wasn't sure what the reason for this was.

Next steps - confirm anecdotal observations about OV and QK circuits (plus run some basic head ablation experiments).

Ablating heads

Testing whether head 0.1 matters more (this was my hypothesis, since it seems to cover more of the vocabulary than 0.0). Conclusion - yes.

def get_loss_from_ablating_head(layer: int, head: int):

    def hook_fn(activation: Float[Tensor, "batch seq nheads d"], hook: HookPoint):
        activation_mean: Float[Tensor, "d_model"] = cache[hook.name][:, :, head].mean(0)
        activation[:, :, head] = activation_mean
        return activation
        
    model.reset_hooks()
    logits_orig = model(dataset.toks)
    logprobs_orig = logits_orig.log_softmax(-1)[:, dataset.list_len:-1, :]
    logits_ablated = model.run_with_hooks(dataset.toks, fwd_hooks=[(utils.get_act_name("result", layer), hook_fn)])
    logprobs_ablated = logits_ablated.log_softmax(-1)[:, dataset.list_len:-1, :]

    targets = dataset.toks[:, dataset.list_len+1:]
    logprobs_orig_correct = eindex(logprobs_orig, targets, "batch seq [batch seq]")
    logprobs_ablated_correct = eindex(logprobs_ablated, targets, "batch seq [batch seq]")

    return (logprobs_orig_correct - logprobs_ablated_correct).mean().item()


print("Loss from mean ablating the output of...")
for layer in range(model.cfg.n_layers):
    for head in range(model.cfg.n_heads):
        print(f"  ...{layer}.{head} = {get_loss_from_ablating_head(layer, head):.3f}")
Loss from mean ablating the output of...
...0.0 = 0.920
...0.1 = 4.963

OV & QK circuits

We expect OV to be a copying circuit, and QK to be an "attend to anything bigger than self" circuit. SEP should attend to the smallest values.

W_OV = model.W_V[0] @ model.W_O[0] # [head d_model_in d_model_out]

W_QK = model.W_Q[0] @ model.W_K[0].transpose(-1, -2) # [head d_model_dest d_model_src]

W_OV_full = model.W_E @ W_OV @ model.W_U

W_QK_full = model.W_E @ W_QK @ model.W_E.T

imshow(
    W_OV_full,
    labels = {"x": "Prediction", "y": "Source token"},
    title = "W<sub>OV</sub> for layer 1 (shows that the heads are copying)",
    width = 900,
    height = 500,
    facet_col = 0,
    facet_labels = [f"W<sub>OV</sub> [0.{h0}]" for h0 in range(model.cfg.n_heads)]
)

imshow(
    W_QK_full,
    labels = {"x": "Input token", "y": "Output logit"},
    title = "W<sub>QK</sub> for layer 1 (shows that the heads are attending to next largest thing)",
    width = 900,
    height = 500,
    facet_col = 0,
    facet_labels = [f"W<sub>QK</sub> [0.{h0}]" for h0 in range(model.cfg.n_heads)]
)

Conclusion - this basically matches the previous hypotheses:

  • Strong diagonal pattern for the OV circuits shows that 0.1 is a copying head on most of the vocabulary (everything outside the values in the [28, 37] range), and 0.1 is a copying head on the other values.
  • Weak patchy diagonal pattern in QK circuit shows that most tokens attend more to ones which are slightly above them (and also that there are some cases where d attends more to d+2, d+3 etc than to d+1).

Visualising that last observation in more detail, for the case d=25:

def qk_bar(dest_posn: int):
    bar(
        [W_QK_full[0, dest_posn, :], W_QK_full[1, dest_posn, :]], # Head 1.1, attention from token dest_posn to others
        title = f"Attention scores for destination token {dest_posn}",
        width = 900,
        height = 400,
        template = "simple_white",
        barmode = "group",
        names = ["0.0", "0.1"],
        labels = {"variable": "Head", "index": "Source token", "value": "Attention score"},
    )

qk_bar(dest_posn=25)

The most attended to are actually 28 and 29! We'll address this later, but first let's also explain a slightly simpler but also confusing-seeming result from the heatmap above.

What's with the attention to zero?

One weird observation in the heatmap it's worth mentioning - some tokens with very high values (i.e. >35) attend a lot to very small tokens, e.g. zero.

qk_bar(dest_posn=40)

Why don't these tokens all attend to zero?

Answer - plotting the QK circuit with token embeddings on the query side and positional embeddings on the key side shows that tokens near the end of the sequence have a bias against attending to very small tokens. Since tokens near the end of the sequence are likely to be precisely these larger values (i.e. >35), it's reasonable to guess that this effect cancels out the previously observed bias towards small tokens.

POSN_LABELS = [str(i) for i in range(dataset.seq_len)]
POSN_LABELS[dataset.list_len] = "SEP"

W_Qpos_Kemb = model.W_pos @ W_QK @ model.W_E.T

imshow(
    W_Qpos_Kemb,
    labels = {"x": "Key token", "y": "Query position"},
    title = "W<sub>QK</sub> for layer 1 (shows that the heads are attending to next largest thing)",
    y = POSN_LABELS,
    width = 950,
    height = 350,
    facet_col = 0,
    facet_labels = [f"W<sub>QK</sub> [0.{h0}]" for h0 in range(model.cfg.n_heads)]
)

Advexes

This plot also reveals a lot of potential advexes - for example, SEP consistently attends to the smallest value up to around ~30, where this pattern falls off. So if your entire sequence was in the range [30, 50], it's very possible that the model would fail to correctly identify the smallest token. Can you exhibit an example of this?

Another possible advex: if there's a large gap between tokens x and y, then x might attend to itself rather than to y. I created a CustomSortedList dataclass to confirm this. I also wrote a function show_multiple which can show multiple different plots in a batch at once (this was helpful for quickly testing out advexes) - you can see this in the Setup code section.

class CustomSortedListDataset(SortedListDataset):

    def __init__(self, unsorted_lists: List[List[int]], max_value: int):
        '''
        Creates a dataset from the unsorted lists in unsorted_lists.
        '''
        self.size = len(unsorted_lists)
        self.list_len = len(unsorted_lists[0])
        self.seq_len = 2*self.list_len + 1
        self.max_value = max_value

        self.vocab = [str(i) for i in range(max_value+1)] + ["SEP"]

        sep_toks = t.full(size=(self.size, 1), fill_value=self.vocab.index("SEP"))
        unsorted_list = t.tensor(unsorted_lists)
        sorted_list = t.sort(unsorted_list, dim=-1).values
        self.toks = t.concat([unsorted_list, sep_toks, sorted_list], dim=-1)

        self.str_toks = [[self.vocab[i] for i in toks] for toks in self.toks.tolist()]

        
custom_dataset = CustomSortedListDataset(
    unsorted_lists = [
        [0] + list(range(40, 49)),
        [5] + list(range(30, 48, 2)),
    ],
    max_value=50,
)

custom_logits, custom_cache = model.run_with_cache(custom_dataset.toks)

cv.attention.from_cache(
    cache = custom_cache,
    tokens = custom_dataset.str_toks,
    radioitems = True,
    return_mode = "view",
    batch_labels = ["<code>" + " ".join(s) + "</code>" for s in custom_dataset.str_toks],
    mode = "small",
)

show_multiple(custom_dataset)

Conclusion - yes, we correctly tricked x into self-attending rather than attending to y in these cases. The predictions were a bit unexpected, but we can at least see that the model predicts x with non-negligible probability (i.e. showing it's incorrectly predicted the token it attends to), and doesn't predict y at all.

Solving the [d, d+1, d+2] mystery

At this point, I spent frankly too long trying to figure out how sequences of the form [d, d+1, d+2] weren't adversarial for this model. Before eventually finding the correct answer, the options I considered were:

  • Maybe the less important head does something valuable. For instance, if there's a token where 0.1 boosts d+1 and d+2, maybe head 0.0 suppresses d+2.
    • After all, it does seem from the OV circuit plot like head 0.0 is an anti-copying head at the tokens where 0.1 is a copying head.
    • (However, the same cannot be said for the tokens where 0.0 is a copying head, i.e. 0.1 doesn't seem like it's anti-copying here - which made me immediately suspicious of this explanation.)
  • The direct path W_E @ W_U is responsible for boosting tokens like d+1 much more than d+2.
    • This proves to kinda be true (see plot below), but if this was the main factor then you'd expect [d, d+1, d+2] sequences to become advexes once you remove the direct path. I ran an ablation experiment to test this, and it turned out not to be true.
imshow(
    model.W_E @ model.W_U,
    title = "DLA from direct path",
    labels = {"x": "Prediction", "y": "Input token"},
    height = 500,
    width = 600,
)

Finally, I found the actual explanation. As described earlier, attending to d+2 will actually slightly boost d+1, and attending to d+1 will slightly suppress d+2 (and the same holds true for slightly larger gaps between source tokens). So even if d+2 is getting a bit more attention, the net effect will be that d+1 gets boosted more than d+2.

To visualise this, here's a set of 5 examples. Each of them contains sequences with 3 values x < y < z close together, which I judged from the QK bar charts earlier would trick the model by having x attend to z as much as / more than y. For each of them, I measured the direct logit attribution to y and z respectively, coming from the source tokens y and z respectively.

I already knew that I would see:

  • DLA from y -> y large, positive
  • DLA from z -> z large, positive (often larger than y -> y)

And if this hypothesis was correct, then I expected to see:

  • DLA from y -> z weakly negative
  • DLA from z -> y weakly positive
  • The total DLA to y should be larger than the total DLA to z (summing over source tokens y and z)

This is exactly what we see:

custom_dataset = CustomSortedListDataset(
    unsorted_lists = [
        [0, 5, 14, 15, 17, 25, 30, 35, 40, 45],
        [0, 5, 10, 15, 20, 25, 26, 27, 40, 45],
        [0, 5, 10, 15, 20, 25, 30, 31, 32, 45],
        [0, 5, 10, 15, 20, 25, 30, 31, 34, 45],
    ],
    max_value=50,
)
custom_logits, custom_cache = model.run_with_cache(custom_dataset.toks)

# For each sequence, define which head I expect to be the important one, and define
# which tokens are acting as (x, y, z) in each case
head_list = [1, 1, 0, 0]
x_tokens = [14, 25, 30, 30]
y_tokens = [15, 26, 31, 31]
z_tokens = [17, 27, 32, 34]
src_tokens = t.tensor([y_tokens, z_tokens]).T

# Get the positions of (x, y, z) by indexing into the string tokens lists (need to be
# careful that I'm taking them from the correct half of the sequence)
L = custom_dataset.list_len
x_posns = t.tensor([L + toks[L:].index(str(dt)) for dt, toks in zip(x_tokens, custom_dataset.str_toks)])
y_posns = t.tensor([toks.index(str(st)) for st, toks in zip(y_tokens, custom_dataset.str_toks)])
z_posns = t.tensor([toks.index(str(st)) for st, toks in zip(z_tokens, custom_dataset.str_toks)])
src_posns = t.stack([y_posns, z_posns]).T

out = einops.einsum(
    custom_cache["v", 0], model.W_O[0],
    "batch seqK head d_head, head d_head d_model -> batch seqK head d_model",
)
# out = out.sum(2)
out = out[range(4), :, head_list] # [batch seqK d_model]

attn = custom_cache["pattern", 0][range(4), head_list] # [batch seqQ seqK]
result_pre_sum = einops.einsum(
    out, attn,
    "batch seqK d_model, batch seqQ seqK -> batch seqQ seqK d_model",
)

scale = custom_cache["scale"].unsqueeze(-1) # [batch seqQ 1 1]
dla = (result_pre_sum / scale) @ model.W_U # [batch seqQ seqK d_vocab]

# want tensor of shape (4, 2, 2), with dimensions (batch dim, src token = y/z, predicted token = y/z)
dla_from_yz_to_yz = dla[t.arange(4)[:, None, None], x_posns[:, None, None], src_posns[:, None, :], src_tokens[:, :, None]]

fig = imshow(
    dla_from_yz_to_yz,
    facet_col = 0,
    facet_labels = [
        f"Seq #{i}<br>(x, y, z) = ({x}, {y}, {z})"
        for i, (x, y, z) in enumerate(zip(x_list, y_list, z_list))
    ],
    title = "DLA for custom dataset with (x, y, z) close together",
    title_y = 0.95,
    labels = {"y": "Effect on prediction", "x": "Source token"},
    x = ["src = y", "src = z"],
    y = ["pred = y", "pred = z"],
    width = 900,
    height = 400,
    text_auto = ".2f",
)