PERFECTLY NORMAL

CALLUM MCDOUGALL

This image was created using a variant of my thread art algorithm - read more here.

Monthly Algorithmic Challenge (September 2023): Sum of Two Numbers




Problem

This post is the sixth in the sequence of monthly mechanistic interpretability challenges. I designed them in the spirit of Stephen Casper's challenges, but with the more specific aim of working well in the context of the rest of the ARENA material, and helping people put into practice all the things they've learned so far. You can access the full ARENA materials from the Streamlit page. The purpose of the page you're currently reading is mainly to provide an easy-to-access place to read the solutions and understand the nature of these problems; I'd recommend using the Streamlit page instead if you're actually working through the problems. You can also find the solutions Colab here. I'll assume everyone has read the first problem, and has some overall context on the sequence.

Task & Dataset

The problem for this month is interpreting a model which has been trained to classify a sequence according to the Caeser cipher shift value which was used to encode it.

The sequences have been generated by taking English sentences containing only lowercase letters & punctuation, and choosing a random value X between 0 and 25 to rotate the letters (e.g. if the value was 3, then a becomes d, b becomes e, and so on, finishing with z becoming c). The model was trained using cross entropy loss to compute the shift value X for the text it's been fed, at every sequence position (so for a single sequence, the correct value will be the same at every sequence position, but since the model has bidirectional attention, it will find it easier to compute the value of X at later sequence positions).

There are 3 different modes to the problem, to give you some more options! Each mode corresponds to a different dataset, but the same task & same model architecture.

Easy mode

In easy mode, the data was generated by:

  • Choosing the 100 most frequent 3-letter words in the English Language (as approximated from a text file containing the book "Hitchhiker's Guide To The Galaxy")
  • Choosing words from this len-100 list, with probabilities proportional to their frequency in the book
  • Separating these words with spaces

The model uses single-character tokenization. The vocabulary size is 27: each lowercase letter, plus whitespace.

Medium mode

This is identical to easy, the only difference is that the words are drawn from this len-100 list uniformly, rather than according to their true frequencies. Can you see why this is harder?

Hard mode

In hard mode, the data was generated from random slices of OpenWebText (i.e. natural language text from the internet). It was processed by converting all uppercase characters to lowercase, then removing all characters except for the 26 lowercase letters plus the ten characters "\n .,:;?!'" (i.e. newline, space, and 8 common punctuation characters).

Model

The model is attention only, with 2 layers, and 2 attention heads per layer. It uses causal attention. It has layernorm at the end of the model. It was trained with an Adam optimizer with weight decay of 0.001, and a linearly decaying learning rate.





Solutions

First, let's do some setup:
seq_len = 32
dataset = CodeBreakingDataset(mode="easy", seq_len=seq_len, size=1000, word_list_size=100, path="hitchhikers.txt").to(device)

logits, cache = model.run_with_cache(dataset.toks)

logprobs = logits.log_softmax(-1) # [batch seq_len vocab_out]
probs = logprobs.softmax(-1) # [batch seq_len vocab_out]

# We want to index like `logprobs_correct[batch, seq] = logprobs[batch, seq, labels[batch]]`
logprobs_correct = eindex(logprobs, dataset.labels, "batch seq [batch]")
probs_correct = eindex(probs, dataset.labels, "batch seq [batch]")

print(f"Average cross entropy loss: {-logprobs_correct.mean().item():.3f}")
print(f"Mean probability on correct label: {probs_correct.mean():.3f}")
print(f"Median probability on correct label: {probs_correct.median():.3f}")
print(f"Min probability on correct label: {probs_correct.min():.3f}")

Average cross entropy loss: 0.121
Mean probability on correct label: 0.946
Median probability on correct label: 0.998
Min probability on correct label: 0.000

And a visualisation of its probability output for a single sequence (note how it generally gets better at predicting with a larger context window):
def show(model: HookedTransformer, dataset: CodeBreakingDataset, batch_idx: int):

    logits = model(dataset.toks[batch_idx].unsqueeze(0)).squeeze() # [seq_len vocab_out]
    probs = logits.softmax(dim=-1) # [seq_len vocab_out]

    imshow(
        probs.T,
        y=dataset.vocab_out,
        x=[f"{s}
({j})" for j, s in enumerate(dataset.str_toks[batch_idx])], labels={"x": "Token", "y": "Vocab"}, xaxis_tickangle=0, title=f"Sample model probabilities:
{''.join(dataset.str_toks[batch_idx])} ({''.join(dataset.str_toks_raw[batch_idx])})", text=[ ["〇" if (s == dataset.str_labels[batch_idx]) else "" for _ in range(seq_len)] for s in dataset.vocab_out ], width=750, height=600, ) show(model, dataset, batch_idx=0)

Summary of how the model works

Easy and medium mode work in a pretty similar way. At the zeroth position, the model just predicts based on the frequency of first letters. For instance, in the easy dataset t has a frequency of about 29% across all words' first letters, so the direct path will assign 29% probability to whatever the rotation would have been if t was the original character.

At positions later than zero, some of the performance is explained by aggregation of independent evidence of sequence positions. For example, in easy mode, t being the first token is evidence for a rotation of zero, and h being the second token is also evidence for a rotation of zero, but constructive interference between these two pieces of evidence means that combining them gives a stronger signal than either of them alone (the correct probability when combining these two bits of evidence independently & additively is 74%, much higher than the independent probabilities which are both 29%). However, the model's confidence in a rotatino of zero in this particular case is 93%, suggesting it's doing something more sophisticated.

That more sophisticated algorithm turns out to be some kind of bigram / trigram frequency matching, which operates across multiple layers. Heads in layer 0 attend within a single word, storing "bigram information" in the second character and "bigram/trigram information" in the third character of each word. Heads in layer 1 attend back to tokens which contain important bigram information (mostly the second and third characters of the first word). Head 1.0 usually attends to the 3rd characters in a word, and 1.1 usually attends to the 2nd character in a word (it's necessary for a head to learn to use bigram information even though this is weaker than trigram information, because otherwise the model can't improve performance for sequence position 2). The attention patterns seem to form using K-composition with the positional embeddings; heads in layer 1 attend to tokens which themselves attended to the 1st or 2nd characters in the first word, in layer 0. The heads avoid attending to the 1st character in the first word because this position has a very large layernorm scale which nullifies the K-composition signal.

I've not gone into detail for the hard mode model, because this is the final problem in this sequence and I want to leave some threads open! However, you can read some guidance at the end of this document.

Easy mode

I'll start with the easy mode model. Most of the logic we'll do in the medium mode will be similar to the easy mode, so this section contains most of the heavy lifting.

Notation

  • I'll use character arithmetic conventions, i.e. b - a = 1 and a - b = 25 as well as things like the + 18 = lzw
    • Where I need to use different variables (i.e. not representing those actual letters), I'll use capital letters, e.g. X doesn't represent the letter x
  • I'll denote the letters as [0A, 0B, 0C, 0_, 1A, 1B, ...].
    • I'll also use 0* to denote any of [0A, 0B, OC, 0_] and *A to be any of [0A, 1A, 2A, ...].
  • I'll say two strings S1, S2 are rotationally equivalent if S1 + X = S2 for some integer X
    • e.g. ab and bc are rotationally equiv with X=1, same for one and bar with X=13

To help, we'll create a simple Char class to work with letters:

class Char:
    '''
    Class to handle single characters, with methods for getting character index, rotating characters, and
    returning character diff.
    '''
    def __init__(self, char):
        self.char = char
        self.idx = string.ascii_lowercase.index(char)
    
    def __sub__(self, other: "Char") -> int:
        return (self.idx - other.idx) % 26
    
    def __repr__(self):
        return f"Char({self.char})"

    def add(self, idx) -> "Char":
        new_idx = (self.idx + idx) % 26
        return Char(string.ascii_lowercase[new_idx])
    
assert Char("a").add(5).char == "f"
assert Char("f") - Char("a") == 5

TOKEN_SYMBOLS = [f"{x}{y}" for x in range(10) for y in "ABC_"][:seq_len]

Position 0

At this position, all the model can do is bigram frequencies. In other words, if the first letter is X, and the frequency of a given letter Y (over all possible first letters of the decoded training sequences) is f(Y), then the model will predict X-Y with probability f(Y).

Let's first check that this is actually happening.

# Create a dataset of single-length sequences for any possible starting character
toks = t.arange(26).long().unsqueeze(1)

# Get the probabilities
probs = model(toks).softmax(-1).squeeze()

# Plot them
imshow(
    probs,
    size=(550, 550),
    title="Logprobs for each character in the sequence",
    y=list(string.ascii_lowercase),
    labels=dict(x="Predicted rotation", y="First character", color="Logprob"),
)

Conclusion

Yes, it is happening. Each row of predictions looks like the same probability distribution, just shifted. For example, we can see that if the first character is t then the highest prediction (31%) is on a rotation of zero, because t is the most common first letter. This is borne out in the plot below:

# Get total count for first letter frequencies, within the dataset we sampled from during training
letter_frequencies = {seq_pos: t.zeros(26).to(device) for seq_pos in range(3)}
for word, freq in dataset.words_freq:
    for seq_pos in range(3):
        letter_frequencies[seq_pos][Char(word[seq_pos]).idx] += freq

# Get the mean probability implied for that character being the first token, by our model
first_letter_probs = [
    sum(probs[(i + j) % 26, j].item() for j in range(26)) / 26
    for i in range(26)
]

# Plot them
scatter(
    x = letter_frequencies[0].tolist(), y = first_letter_probs,
    text = list(string.ascii_lowercase), textposition = 'top left',
    title = "First letter frequencies and mean probabilities",
    labels = dict(x="Frequency of first character", y="Model's implied prediction for first character"),
    template="ggplot2", size = (500, 550),
)

Let's also make a table of these frequencies, because it'll be helpful later on:

table = Table("Letter", "First", "Second", "Third", title="Character frequencies in words (%)")

for letter in string.ascii_lowercase:
    table.add_row(letter, *[
        f"{100 * letter_frequencies[seq_pos][Char(letter).idx].item():>5.1f}"
        for seq_pos in range(3)
    ])

rprint(table)

Position 1

Okay, now things are getting interesting! At this sequence position the model has access to two tokens, meaning in theory it can not only use the evidence from those tokens individually but also look at the relationship between them.

Does it look at the relationship between them, or does it just sum the independent evidence from them? To investigate this, I'll plot 2 things for comparison:

  1. The model's prediction for the correct token, for the first 3 tokens, over all possible 100 words we could get as our first word (I'll average the predictions over all possible rotations for each word).
  2. An idealized version of the model's predictions, which we'd get if we just averaged the evidence from the first and second tokens.

What do I mean by "averaging evidence from the first and second tokens"? Well, the plot in the previous cell showed us that the first letter of the sequence can give us a distribution over rotations which corresponds to the frequency of first letters across all words in the dataset. So we can just do the same thing for the second letter, and then average these two distributions.

For example, suppose the first two letters were th. Getting t should concentrate your probability mass around rotation 0, and maybe the same will be true for h given the frequency of the word the. Furthermore, constructive interference could in theory make it so all other possible rotations have much lower probability, since rotations which were plausible after seeing the first letter might be less plausible when looking at the second letter.

This is a nice theory, now let's test it!

# (1) Compute the actual probabilities for each word (averaged over all possible 26 rotations of each word)

# Get list of words (this will be useful later)
word_list_tokens = t.tensor([
    [Char(x).idx for x in seq]
    for seq in dataset.word_list
]).long()

# Get a list of 100 * 26 words, i.e. every rotation of each 100 words
word_list_tokens_all_rotations = t.tensor([
    [Char(x).add(y).idx for x in seq]
    for seq in dataset.word_list
    for y in range(26)
]).long()

# Get the tensor of correct rotations, for indexing. This will have shape (100*26,)
correct_rotations = t.tensor([
    y
    for seq in dataset.word_list
    for y in range(26)
]).long()

# Get probabilities, which will have shape (100*26, seq=3, vocab_out=26)
logits, cache_all_rotations = model.run_with_cache(word_list_tokens_all_rotations)
probs = logits.softmax(-1)

# Index into probs with the correct rotations, then average over the 26 rotations for each word
correct_probs = eindex(probs, correct_rotations, "words seq [words]")
correct_probs_per_word = einops.reduce(correct_probs, "(words rotations) seq -> words seq", "mean", rotations=26)
# (2) Compute the probabilities for each word, if we averaged the "direct evidence" for sequence positions 0 & 1

# Use letter frequencies to get a dictionary of logprobs (so we can do things like average vectors together). Note, some
# of these values will be neginf, but that's fine because we'll convert them into probability space before plotting.
letter_logprobs = {seq_pos: t.log(x) for seq_pos, x in letter_frequencies.items()}

# Define a helper function to shift a row tensor by a given amount
def shift_row(x: Tensor, shift: int) -> Tensor:
    return t.cat([x[-shift:], x[:-shift]])

# Use this helper function to construct a vector of logprobs for each token in our `word_list_tokens`. For instance, if
# we see token `c` at position 1, we take the vector `letter_logprobs[1]` and shift it by 2 to get the frequencies of
# each different rotation value, conditional on `c` being the observed token at position 1.
direct_logprobs_per_token = t.stack([
    t.stack(([
        shift_row(letter_logprobs[seq_pos], -word_list_tokens[batch_idx, seq_pos].item())
        for seq_pos in range(3)
    ]))
    for batch_idx in range(100)
]) # shape (100, seq=3, vocab_out=26)

# Get average logprobs for each of those 100*3 tokens, up to that point in the sequence
direct_logprobs_per_token = t.cat([
    direct_logprobs_per_token[:, :seq_pos].sum(dim=1, keepdim=True)
    for seq_pos in range(1, 4)
], dim=1) # shape (100, seq=3, vocab_out=26)

# Convert to probabilities, and get the correct probs (i.e. rotation=0)
direct_probs = direct_logprobs_per_token.softmax(-1)[..., 0]
imshow(
    t.stack([correct_probs_per_word.T, direct_probs.T]),
    facet_col = 0,
    facet_col_wrap = 1,
    facet_labels = ["Model probabilities", "Idealized direct evidence sum probabilities"],
    facet_row_spacing = 0.2,
    x = dataset.word_list,
    xaxis_tickangle = 60,
    size = (300, 2000),
    title = "Model probabilities for each word, compared to probabilities from aggregating token direct evidence",
)

Conclusion

This "evidence aggregating" seems to explain a lot of the model's performance, although the model is quite a bit better than we'd expect from this simplified model. For example: looking at just t gives you 29% evidence for a rotation of 0, and looking at h individually & combining this evidence additively boosts this to 74%, but the model's true prediction is 93%.

What kind of algorithm does better than the basic additive combination of evidence which we proposed? The model must be using the evidence provided by these two tokens in some nonlinear way. In the extreme case, the model could literally have learned every possible bigram of letters and maps it to a distribution over rotations for that bigram. For example, the letters th could belong to a 0-rotated the (28.02%), or a 19-rotated man (0.69%) or may (0.21%), or a 10-rotated dry (0.04%). This implies that if the model sees th then it should be approximately 28.02 / (28.02 + 0.69 + 0.21 + 0.04) = 96.75% confident that the rotation is 0, which is close to (but higher than) the model's actual confidence of 93%. We might guess that the model is doing some weaker version of this bigram matching, constrained by the fact that there are a really large number of possible bigrams (nearly 100*26=2600).

DLA

To do some more analysis, let's start thinking about which components are responsible for the model's performance at positions 1 and beyond. First we'll stare at some attention probabilities and look at the DLA.

# Get residual stream decomposition (with mean subtracted)
resid = t.stack([
    cache["embed"],
    cache["result", 0][:, :, 0],
    cache["result", 0][:, :, 1],
    cache["result", 1][:, :, 0],
    cache["result", 1][:, :, 1],
])
resid = resid - resid.mean(dim=1, keepdim=True)

# Compute direct logit attribution per component, then index to get the logits on the correct token
dla = (resid / cache["scale"][:, :].unsqueeze(0)) @ model.W_U
dla_correct = eindex(dla, dataset.labels, "comp batch seq [batch]") # shape = [components=5, batch=1000, seqpos=32]

# Get data for first sequence position, and (flattened) data for all sequence positions
DATA_SEQPOS_1 = dla_correct[:, :, 1]
DATA_SEQPOS_ALL = einops.rearrange(dla_correct, "comp batch seq -> comp (batch seq)")

# Create a subplot, and fill in the histograms
titles = ["Direct", "0.0", "0.1", "1.0", "1.1"]
subplot_titles=[f"{title}, seq_pos={seq_pos}" for title in titles for seq_pos in ["1", "all"]]
fig = make_subplots(rows=5, cols=2, subplot_titles=subplot_titles, vertical_spacing=0.06)
fig.update_annotations(font_size=14).update_xaxes(range=[-15, 15])
for row, title in enumerate(titles):
    kwargs = dict(marker_color=px.colors.qualitative.Plotly[row], opacity=0.6, showlegend=title=="embeddings")
    for col, (seq_pos, data) in enumerate([("1", DATA_SEQPOS_1[row]), ("all", DATA_SEQPOS_ALL[row])]):
        fig.add_trace(go.Histogram(x=utils.to_numpy(data), **kwargs), row=row+1, col=col+1)
        fig.add_vline(x=data.mean(), line_width=1.5, row=row+1, col=col+1, annotation_text=f" mean={data.mean():.2f}")
fig.update_layout(width=900, height=900, title="Histograms of component DLA")
fig.show()
n = 15

cv.attention.from_cache(
    cache = cache,
    tokens = dataset.str_toks,
    batch_idx = list(range(n)),
    radioitems = True,
    return_mode = "view",
    batch_labels = ["".join(s) + "  ====  " + "".join(s2) for s, s2 in zip(dataset.str_toks[:n], dataset.str_toks_raw[:n])],
    mode = "small",
)

Conclusion

The DLA plots make it clear that in most cases heads 1.0 and 1.1 do all the heavy lifting, although for sequence position 1 head 1.1 is a lot more important.

As for the attention patterns, we see that:

  • Heads 0.0 and 0.1 are mostly looking "within words", i.e. *B attending to *A and *C attending to (*B, *C).
  • Head 1.0 is attending from tokens back to *C tokens, mostly 0C.
  • Head 1.1 is attending from tokens back to *B tokens, mostly 0B.

From this, a possible theory emerges:

  • After layer 0, the *B tokens store information about (*A, *B), and the *C tokens store information about (*A, *B, *C).
  • The heads in layer 1 do something vaguely bigram-y, in other words:
    • Head 1.0 looks back at the (*A, *B, *C) information present in *C tokens (mostly 0C). It converts this information into a distribution over rotations.
    • Head 1.1 looks back at the (*A, *B) information present in *B tokens (mostly 0B). It converts this information into a distribution over rotations.

This algorithm would be sufficient to achieve performance at the level we've observed in this model. It would also explain a few things, e.g. why head 1.1 matters a lot more at sequence position 1 - the (*A, *B) information is the only type of information we have access to at this position.

Looking closer at the attention patterns, it seems like the heads in layer 1 are actually smart enough to attend more to the more common words, because they carry more definitive information. For example, if the first word is uncommon but the second word is the then more attention will go to the second word than the first.

Two experiments to do next:

  1. Layer 1 - OV circuit. Take all possible values of the second and third positions in the residual stream after the first layer (there are only 100*26 possible pairs of values), map these values through the OV matrices of heads 1.1 and 1.0 respectively, and see if the output boosts the correct rotation as I expect.
  2. Layer 1 - QK circuit. See whether they preferentially attend to early sequence positions, and to the evidence from more common words (both seem true looking at the attention patterns).

Assuming these investigations both turn out positive, I'll have a look at the medium and hard mode models. My guess is that they'll be pretty similar (since the discussion so far has been in terms of common bigram and trigram patterns, nothing that shouldn't generalize nicely to these different datasets).

Layer 1 - OV circuit

resid_post = cache_all_rotations["resid_post", 0]

layer0_output = einops.einsum(
    resid_post / resid_post.std(dim=-1, keepdim=True),
    model.W_V[1] @ model.W_O[1],
    "batch seq d_model, head d_model d_model_out -> batch head seq d_model_out",
)
layer0_output_logits = (layer0_output / cache_all_rotations["scale"].mean()) @ model.W_U

layer0_output_correct_logits = eindex(layer0_output_logits, correct_rotations, "batch head seq [batch]")

layer0_output_correct_logits_by_word = einops.reduce(layer0_output_correct_logits, "(word rot) head seq -> head word seq", "mean", rot=26)

imshow(
    layer0_output_correct_logits_by_word.transpose(-1, -2),
    facet_col = 0,
    facet_col_wrap = 1,
    facet_labels = ["1.0", "1.1"],
    facet_row_spacing = 0.2,
    x = dataset.word_list,
    y = ["0A", "0B", "0C"],
    xaxis_tickangle = 60,
    size = (320, 2000),
    title = "Logits for correct rotations, via OV circuit in layer 1",
)

This does support our hypothesis: head 1.0 is very good at boosting the correct rotation when it acts on the (0A, 0B, 0C) information stored in 0C (i.e. the 3rd row of the 1st plot is distinctly positive), and head 1.1 is very good at boosting the correct rotation when it acts on the (0A, 0B) information stored in 0B (i.e. the 2nd row of the 2nd plot is distinctly positive). A few other observations:

  • 1.1 also has a positive effect when attending to 0C most of the time; this makes sense given we saw in the attention patterns that 1.1 also attends to 0C a bit.
  • The positive effect is stronger for the more common words, which makes sense given these come up more often during training & are attended to more (we saw this from the attention patterns).
  • Heads 1.0 and 1.1 seem to have some kind of offset effect: in a word where one head fails to classify it correctly, the other head will often pick up the slack. This is similar to patterns we've seen before in this monthly series.

Hunt for linear structure - the

After this, I tried to take a deep dive into the linear structure of the model, specifically trying to find any kind of nice structure present in how the model classifiest the 26 possible rotations of the after the second sequence position (which it does very well, with minimum accuracy 86% and max 93%). These investigations were pretty unsuccessful, mainly I think because the model seems to deal with the 26 possible rotations in non-symmetric ways. For example, the amount that token 0B self-attends in head 1.0 varies a lot over these 26 rotations, from approximately 1 to approximately 0. Sadly, I ended up deciding that I probably wasn't going to find structure here without a lot more work. Partly, I think this is because the model can basically get away with treating bigrams as a lookup table - there are only 100 possible words, and 2600 possible bigrams (further more, 50% of the probability mass is concentrated on the first 4 words, and 90% on the first 33 words). So the model doesn't need to learn some kind of efficient rotationally symmetric structure. I'd be interested in whether a longer vocabulary list or a higher weight decay forces the model to learn more efficient structure, but I'm not going to investigate this here.

Layer 1 - QK circuit

First we'll plot the positional QK circuit:

scale_0 = cache["scale", 0, "ln1"].mean(0)[:, 0]
scale_1 = cache["scale", 1, "ln1"].mean(0)[:, 0]

# Get the keys, from the positional embeddings
W_pos_keys = (model.W_pos / scale_1) @ model.W_K[1] + model.b_K[1].unsqueeze(1)

# Get the queries, which are just the mean value of the residual stream pre-layer 1, mapped through query matrices
W_pos_queries = cache["resid_pre", 1].mean(0) @ model.W_Q[1] + model.b_Q[1].unsqueeze(1)

# Get attention scores from the queries and keys
attn_scores = (W_pos_queries @ W_pos_keys.transpose(-1, -2)) / (model.cfg.d_head ** 0.5)
attn_scores_masked = t.where(t.tril(t.ones_like(attn_scores)).bool(), attn_scores, -float("inf"))

imshow(
    attn_scores_masked.softmax(-1),
    size = (650, 1200),
    facet_col = 0,
    facet_labels = ["1.0", "1.1"],
    x = TOKEN_SYMBOLS,
    y = TOKEN_SYMBOLS,
    labels = {"y": "Dest", "x": "Src"},
)

This wasn't what I was expecting - no strong pattern showing higher attention paid to earlier tokens. From this, I'm guessing that most of what determines the attention patterns in 1.0 and 1.1 is the output from the layer-1 OV circuit.

A second possible theory: there's K-composition going on. Specifically, heads in layer 1 are attending back to tokens which themselves attended to tokens in *B or *C positions in layer 0.

# Get the keys, from the positional embeddings mapped through the OV circuit of layer 0
W_OV = model.W_V @ model.W_O
W_pos_post_00 = (model.W_pos / scale_0) @ W_OV[0].sum(0)
W_pos_keys = (W_pos_post_00 / scale_1[1:].mean()) @ model.W_K[1] + model.b_K[1].unsqueeze(1)

# Get the queries, which are just the mean value of the residual stream pre-layer 1, mapped through query matrices
W_pos_queries = cache["resid_pre", 1].mean(0) @ model.W_Q[1] + model.b_Q[1].unsqueeze(1)

# Get attention scores from the queries and keys
attn_scores = (W_pos_queries @ W_pos_keys.transpose(-1, -2)) / (model.cfg.d_head ** 0.5)
attn_scores_masked = t.where(t.tril(t.ones_like(attn_scores)).bool(), attn_scores, -float("inf"))

imshow(
    attn_scores_masked.softmax(-1),
    size = (650, 1200),
    facet_col = 0,
    facet_labels = ["1.0", "1.1"],
    x = TOKEN_SYMBOLS,
    y = TOKEN_SYMBOLS,
    labels = {"y": "Dest", "x": "Src"},
)

This looks a lot clearer now! We can see from this plot that:

  • Head 1.0 likes to attend to tokens which themselves attended to 0A in layer 0
  • Head 1.1 likes to attend to tokens which themselves attended to 0B in layer 0

Although this explains the bias towards early positions, it's not sufficient for explaining how 1.0 attends to 0C and 1.1 to (0B, 0C). At best, this plot would explain why 1.0 attends to (0A, 0B, 0C) and 1.1 to (0B, 0C).

I think the reason 1.0 doesn't attend to 0A is that the pre-layer 1 attention layernorm scale factors for the 0A token are very large (see plot below). So even if the attention from 0A -> 0A is larger than the attention from 0B -> 0A, this effect is offset by the scale factor. I think a similar thing is happening at all *A positions (you can see the scale in the plot below spikes at all *A positions).

As for why 1.0 attends to 0C rather than 0B, here I'm less sure. I'll leave it as an open question for now.

line(
    cache["scale", 1, "ln1"].mean(0)[:, 0, 0],
    title = "Average pre-layer 1 layer norm scale, across sequence positions",
    labels = {"x": "Sequence position", "y": "Scale"},
    size = (400, 700),
)

I'm not going to investigate the second part of my QK circuit hypothesis (about whether more common words boost attention to themselves). This seems empirically true, but probably messy to verify.

Medium mode

I'll start by zipping straight through the key bits of analysis I did on the easy model, but for the medium model. I'll comment on any differences between the two models.

filename = section_dir / "caesar_cipher_model_medium.pt"

model = create_model(
    d_vocab=27,
    seq_len=32,
    seed=42,
    d_model=48,
    d_head=24,
    n_layers=2,
    n_heads=2,
    d_mlp=None,
    normalization_type="LN",
    device=device,
)
state_dict = model.center_writing_weights(t.load(filename))
state_dict = model.center_unembed(state_dict)
state_dict = model.fold_layer_norm(state_dict)
state_dict = model.fold_value_biases(state_dict)
model.load_state_dict(state_dict, strict=False);

Basic statistics

Slightly lower mean probability (91% vs 94.6%) - this is to be expected since in medium difficulty we don't have a positively skewed distribution of word frequencies, so after we see a combination of two letters it's harder to conclude with high probability what the next word is. However, that doesn't mean the model can't learn the same algorithm as it did in easy mode - let's see if that's the case!

seq_len = 32
dataset = CodeBreakingDataset(mode="medium", seq_len=seq_len, size=1000, word_list_size=100, path="hitchhikers.txt").to(device)

logits, cache = model.run_with_cache(dataset.toks)

logprobs = logits.log_softmax(-1) # [batch seq_len vocab_out]
probs = logprobs.softmax(-1) # [batch seq_len vocab_out]

# We want to index like `logprobs_correct[batch, seq] = logprobs[batch, seq, labels[batch]]`
logprobs_correct = eindex(logprobs, dataset.labels, "batch seq [batch]")
probs_correct = eindex(probs, dataset.labels, "batch seq [batch]")

print(f"Average cross entropy loss: {-logprobs_correct.mean().item():.3f}")
print(f"Mean probability on correct label: {probs_correct.mean():.3f}")
print(f"Median probability on correct label: {probs_correct.median():.3f}")
print(f"Min probability on correct label: {probs_correct.min():.3f}")
Average cross entropy loss: 0.200
Mean probability on correct label: 0.910
Median probability on correct label: 0.997
Min probability on correct label: 0.001

Attention patterns

Seems not too dissimilar. The layer 0 heads still attend within each word, and the layer 1 heads still attend back to the *B and *C tokens. 1.0 attends primarily to *C like last time, and 1.0 attends to both *B and *C (although a bit more of *C than *B). Possibly this is because the uniformity of the word distribution means that it's harder to deduce rotation just by looking at (*A, *B) information. As an example - recall in easy mode we said the following:

*The letters th could belong to a 0-rotated the (28.02%), or a 19-rotated man (0.69%) or may (0.21%), or a 10-rotated dry (0.04%). This implies that if the model sees th then it should be approximately 28.02 / (28.02 + 0.69 + 0.21 + 0.04) = 96.75% confident that the rotation is 0.

But in medium mode, our probability isn't 96.75% in this case, it's 25%, since each of these 4 words is equally likely. It's only when factoring in *C information that we can boost the probability of the correct rotation.

Another difference here is that the heads don't have as strong a bias towards earlier sequence positions. Possibly this is because our easy-mode model could get away with having very non-uniform attention; if the first word is very common and offer strong evidence then the model doesn't need to attend to any word other than that one.

n = 10

cv.attention.from_cache(
    cache = cache,
    tokens = dataset.str_toks,
    batch_idx = list(range(n)),
    radioitems = True,
    return_mode = "view",
    batch_labels = ["".join(s) + "  ====  " + "".join(s2) for s, s2 in zip(dataset.str_toks[:n], dataset.str_toks_raw[:n])],
    mode = "small",
)

DLA

DLA tells a similar story as it did in easy mode:

  • The direct path, and heads in layer 0, have very little effect. Most of the effect comes from heads in layer 1.
  • Specifically for sequence position 1, head 1.1 is more important than head 1.0 (because, as we saw from the attention patterns above, it also deals with *B rather than only *C).

One notable difference is that the model is much worse at predicting the correct rotation at sequence position 1. This is what we'd expect, again because the uniformity of the distribution makes it harder to confidently predict the word just from the first 2 letters.

resid = t.stack([
    cache["embed"],
    cache["result", 0][:, :, 0],
    cache["result", 0][:, :, 1],
    cache["result", 1][:, :, 0],
    cache["result", 1][:, :, 1],
])
resid = resid - resid.mean(dim=1, keepdim=True)

dla = (resid / cache["scale"][:, :].unsqueeze(0)) @ model.W_U
dla_correct = eindex(dla, dataset.labels, "comp batch seq [batch]") # shape = [components=5, batch=1000, seqpos=32]

titles = ["Direct", "0.0", "0.1", "1.0", "1.1"]
subplot_titles=[f"{title}, seq_pos={seq_pos}" for title in titles for seq_pos in ["1", "all"]]

DATA_SEQPOS_1 = dla_correct[:, :, 1]
DATA_SEQPOS_ALL = einops.rearrange(dla_correct, "comp batch seq -> comp (batch seq)")

fig = make_subplots(rows=5, cols=2, subplot_titles=subplot_titles, vertical_spacing=0.06)
fig.update_annotations(font_size=14).update_xaxes(range=[-15, 15])
for row, title in enumerate(titles):
    kwargs = dict(marker_color=px.colors.qualitative.Plotly[row], opacity=0.6, showlegend=title=="embeddings")
    for col, (seq_pos, data) in enumerate([("1", DATA_SEQPOS_1[row]), ("all", DATA_SEQPOS_ALL[row])]):
        fig.add_trace(go.Histogram(x=utils.to_numpy(data), **kwargs), row=row+1, col=col+1)
        fig.add_vline(x=data.mean(), line_width=1.5, row=row+1, col=col+1, annotation_text=f" mean={data.mean():.2f}")

fig.update_layout(width=900, height=900, title="Histograms of component DLA")
fig.show()

Layer-1 OV circuit

Same pattern observed here as in easy mode. Two main differences:

  • 1.1 seems better at 0C than it is at 0B, but this is consistent with our observation that in medium mode this head attends more to *C tokens than to *B tokens.
  • The patterns seem more uniform, rather than starting strong then tailing off like they did in easy mode. This is because the easy mode model will have put extra effort into learning these patterns for the most frequent words, whereas here all 100 words have the same frequency.
logits, cache_all_rotations = model.run_with_cache(word_list_tokens_all_rotations)

resid_post = cache_all_rotations["resid_post", 0]

layer0_output = einops.einsum(
    resid_post / resid_post.std(dim=-1, keepdim=True),
    model.W_V[1] @ model.W_O[1],
    "batch seq d_model, head d_model d_model_out -> batch head seq d_model_out",
)
layer0_output_logits = (layer0_output / cache_all_rotations["scale"].mean()) @ model.W_U

layer0_output_correct_logits = eindex(layer0_output_logits, correct_rotations, "batch head seq [batch]")

layer0_output_correct_logits_by_word = einops.reduce(layer0_output_correct_logits, "(word rot) head seq -> head word seq", "mean", rot=26)

imshow(
    layer0_output_correct_logits_by_word.transpose(-1, -2),
    facet_col = 0,
    facet_col_wrap = 1,
    facet_labels = ["1.0", "1.1"],
    facet_row_spacing = 0.2,
    x = dataset.word_list,
    y = ["0A", "0B", "0C"],
    xaxis_tickangle = 60,
    size = (320, 2000),
    title = "Logits for correct rotations, via OV circuit in layer 1",
)

Layer-1 QK circuit

To cap off our analysis, we can see a similar pattern in the QK circuit. Head 1.0 likes to attend to tokens which themselves attended to 0A in layer 0, and head 1.1 likes to attend to tokens which themselves attended to 0B in layer 0. The reason head 1.0 doesn't attend to 0A is again because of the large layernorm scale, which nullifies the boost caused by high self-attention from 0A to 0A.

W_OV = model.W_V @ model.W_O

W_pos_post_00 = (model.W_pos / scale_0) @ W_OV[0].sum(0)
W_pos_keys = (W_pos_post_00 / scale_1[1:].mean()) @ model.W_K[1] + model.b_K[1].unsqueeze(1)

W_pos_queries = cache["resid_pre", 1].mean(0) @ model.W_Q[1] + model.b_Q[1].unsqueeze(1)

attn_scores = (W_pos_queries @ W_pos_keys.transpose(-1, -2)) / (model.cfg.d_head ** 0.5)
attn_scores_masked = t.where(t.tril(t.ones_like(attn_scores)).bool(), attn_scores, -float("inf"))

imshow(
    attn_scores_masked.softmax(-1),
    size = (650, 1200),
    facet_col = 0,
    facet_labels = ["1.0", "1.1"],
    x = TOKEN_SYMBOLS,
    y = TOKEN_SYMBOLS,
    labels = {"y": "Dest", "x": "Src"},
)

line(
    cache["scale", 1, "ln1"].mean(0)[:, 0, 0],
    title = "Average pre-layer 1 layer norm scale, across sequence positions",
    labels = {"x": "Sequence position", "y": "Scale"},
    size = (400, 700),
)