PERFECTLY NORMAL

CALLUM MCDOUGALL

This image was created using a variant of my thread art algorithm - read more here.

Monthly Algorithmic Challenge (August 2023): First Unique Character




Problem

This post is the second in the sequence of monthly mechanistic interpretability challenges. I designed them in the spirit of Stephen Casper's challenges, but with the more specific aim of working well in the context of the rest of the ARENA material, and helping people put into practice all the things they've learned so far. You can access the full ARENA materials from the Streamlit page. The purpose of the page you're currently reading is mainly to provide an easy-to-access place to read the solutions and understand the nature of these problems; I'd recommend using the Streamlit page instead if you're actually working through the problems. You can also find the solutions Colab here. I'll assume everyone has read the first problem, and has some overall context on the sequence.

Task & Dataset

The algorithmic task is as follows: the model is presented with a sequence of characters, and for each character it has to correctly identify the first character in the sequence (up to and including the current character) which is unique up to that point.

The null character "?" has two purposes:

  • In the input, it's used as the start character (because it's often helpful for interp to have a constant start character, to act as a "rest position").
  • In the output, it's also used as the start character, and to represent the classification "no unique character exists".

Here is an example of what this dataset looks like:

dataset = UniqueCharDataset(size=2, vocab=list("abc"), seq_len=6, seed=42)

for seq, first_unique_char_seq in zip(dataset.str_toks, dataset.str_tok_labels):
    print(f"Seq = {''.join(seq)}, Target = {''.join(first_unique_char_seq)}")

Seq = ?acbba, Target = ?aaaac
Seq = ?cbcbc, Target = ?ccb??

Explanation:

  1. In the first sequence, "a" is unique in the prefix substring "acbb", but it repeats at the 5th sequence position, meaning the final target character is "c" (which appears second in the sequence).
  2. In the second sequence, "c" is unique in the prefix substring "cb", then it repeats so "b" is the new first unique token, and for the last 2 positions there are no unique characters (since both "b" and "c" have been repeated) so the correct classification is "?" (the "null character").

Model

Our model was trained by minimising cross-entropy loss between its predictions and the true labels, at every sequence position simultaneously (including the zeroth sequence position, which is trivial because the input and target are both always "?"). You can inspect the notebook training_model.ipynb to see how it was trained. I used the version of the model which achieved highest accuracy over 50 epochs (accuracy ~99%).

The model is is a 2-layer transformer with 2 attention heads, and causal attention. It includes layernorm, but no MLP layers.


Solutions

First, let's do some setup:
dataset = UniqueCharDataset(size=1000, vocab=list("abcdefghij"), seq_len=20, seed=42)

# Inspect the dataset in a simple way
for seq, first_unique_char_seq in zip(dataset.str_toks[:5], dataset.str_tok_labels[:5]):
    print(f"Seq = {''.join(seq)}, Target = {''.join(first_unique_char_seq)}")

# Get some activations
logits, cache = model.run_with_cache(dataset.toks)
logprobs = logits.log_softmax(-1) # [batch seq_len vocab_out]
probs = logprobs.softmax(-1)

batch_size, seq_len = dataset.toks.shape
logprobs_correct = eindex(logprobs, dataset.labels, "batch seq [batch seq]")
probs_correct = eindex(probs, dataset.labels, "batch seq [batch seq]")

avg_cross_entropy_loss = -logprobs_correct.mean().item()
avg_correct_prob = probs_correct.mean().item()
min_correct_prob = probs_correct.min().item()

print(f"\nAverage cross entropy loss: {avg_cross_entropy_loss:.3f}")
print(f"Average probability on correct label: {avg_correct_prob:.3f}")
print(f"Min probability on correct label: {min_correct_prob:.3f}")

Average cross entropy loss: 0.017
Average probability on correct label: 0.988
Min probability on correct label: 0.001

And some simple (kinda hacky) visualisation:
def show(i):

    imshow(
        probs[i].T,
        y=dataset.vocab,
        x=[f"{dataset.str_toks[i][j]}<br><sub>({j})</sub>" for j in range(model.cfg.n_ctx)],
        labels={"x": "Token", "y": "Vocab"},
        xaxis_tickangle=0,
        title=f"Sample model probabilities (for batch idx = {i}), with correct classification highlighted",
        text=[
            ["〇" if str_tok == correct_str_tok else "" for correct_str_tok in dataset.str_tok_labels[i]]
            for str_tok in dataset.vocab
        ],
        width=900,
        height=450,
    )

show(0)

Some initial notes

I initially expected the same high-level story as July's model:

  • There are some layer-0 heads which are moving information into useful sequence positions, depending on whether the tokens at those sequence positions are the same as earlier tokens.
  • There are some layer-1 heads which are picking up on this information, and converting it into a useful classification.

Things I expect to see:

  • Attention patterns
    • There are layer-0 heads which act as duplicate token heads (abbrev. DTH); attending back to previous instances of themselves.
    • There are layer-1 heads for which each token attends back to the first unique token up to that point.
  • Full matrices
    • The full QK matrix of heads in layer 0 should be essentially the identity, if they're acting as duplicate token heads.
    • The OV circuit of heads in layer 1 should basically be a copying circuit, because when they attend to token T they're using that as a prediction.
      • The OV circuit could be V-composing with layer-0 heads, but this doesn't seem strictly necessary.
    • The full QK matrix of heads in layer 1 should be privileging earlier tokens (because if there's more than one unique token, the head will have to attend to the earlier one).
    • The full QK matrix of heads in layer 1 (with Q-composition from layer 0) should have a negative stripe, because they'll be avoiding tokens which are duplicates.

Other thoughts:

  • With multiple low-dimensional heads, it's possible their functionality is being split. We kinda saw this in the July problem (with posn 20 being mostly handled by head 1.1 and the other positions in the second half being handled by head 1.0 to varying degrees of success).

Attention patterns

I've visualised attention patterns below. I wrote a function to perform some HTML formatting (aided by ChatGPT) to make the sequences easier to interpret, by highlighting all the tokens which are the first unique character at some point in the sequence. Also, note that batch_labels was supplied not as a list of strings, but as a function mapping (batch index, str toks) to a string. Either are accepted by the cv.attention.from_cache function.

def format_sequence(str_toks: List[str], str_tok_labels: Tensor, code: bool = True) -> str:
    '''
    Given a sequence (as list of strings) and labels (as Int-type tensor), formats the sequence by
    highlighting all the tokens which are the first unique char at some point.
    
    We add an option to remove the code tag, because this doesn't work as a plotly title.
    '''
    seq = "<b><code>" + " ".join([
        f"<span style='color:red;'>{tok}</span>"
        if (tok in str_tok_labels) and (tok not in str_toks[:i]) else tok
        for i, tok in enumerate(str_toks)
    ]) + "</code></b>"
    if not(code):
        seq = seq.replace("<code>", "").replace("</code>", "")
    return seq


cv.attention.from_cache(
    cache = cache,
    tokens = dataset.str_toks,
    batch_idx = list(range(10)),
    attention_type = "info-weighted",
    radioitems = True,
    return_mode = "view",
    batch_labels = lambda batch_idx, str_tok_list: format_sequence(str_tok_list, dataset.str_tok_labels[batch_idx]),
    mode = "small",
)

Conclusions

Some of the evidence fits my model:

  • 0.1 is acting as a DTH, for all tokens except a
  • 0.2 seems to be acting as a DTH for a, patching this gap
  • I guessed some kind of split functionality might be happening in the layer-1 attention heads. The non-intersecting vertical lines in the layer-1 attention heads support this/
    • To clarify - the vertical lines suggest that each of these heads has a particular set of tokens which it cares about, and these sets are disjoint. Further examination shows that the sets seem to be [a, c] for 1.0, [d, e, f, j] for 1.1, and [b, g, h, i] for 1.2.
    • Notation: we'll refer to these sets as the "domain" of the layer 1 head. So [a, c] is the domain of 1.0, etc.

But a lot of the evidence doesn't fit:

  • Head 0.0 attends pretty uniforly to early tokens which aren't same as the destination token - this doesn't fit with my DTH model
    • Guess: 0.0 is V-composing with heads in layer 1, because the combined heuristics "token is early" and "token is not the same as destination token" all seem like they point towards this token being a correct classification.
  • Head 0.2 looks a bit like 0.0 in some respects - guess that it's doing some combination of DTH + whatever 0.0 is doing.
  • Heads 0.0 and 0.2 are also strongly self-attending at the first non-null token. This fits with the guess above, because the first non-null token must be unique at that point (and is the best default guess for the first unique token later on in the sequence).
  • Although layer 1 attention heads are splitting functionality in a way I thought might happen, they're not actually doing what I was expecting. I thought I'd see these heads attending to & boosting the first unique token, but they don't seem to be doing this.

From all this evidence, we form a new hypothesis:

Some layer 0 heads (0.1 and partly 0.2) are DTH; they're K-composing with layer 1 heads to cause those heads to attend to & suppress duplicate tokens. Some layer 0 heads (0.0 and partly 0.2) are "early unique" heads; they're attending more to early tokens and are V-composing with layer 1 heads to boost these tokens. Also, layer 1 heads are splitting functionality across tokens in the vocabulary: each head in layer 1 has a particular set of tokens, whose disjoint union is the whole vocabulary.

This hypothesis turns out to be pretty much correct, minus a few details.

OV circuits

The next thing I wanted to do was plot all OV circuits: both for the actual attention heads and the virtual attention heads. Before doing this, I wanted to make clear predictions about what I'd see based on the two different versions of my hypothesis:

  • Original hypothesis
    • The OV circuits of layer 1 heads will be negative copying circuits (because they're attending to & suppressing duplicated tokens).
    • The virtual OV circuits from composition of heads in layers 0 & 1 will either not be important, or be negative copying circuits too.
  • New hypothesis
    • The virtual OV circuits from 0.1 ➔ (head in layer 1) will be negative, specifically on that layer 1 head's domain. Same for the virtual circuit from 0.2 ➔ 1.0 for a (because a is in the domain of 1.0).
    • The other virtual OV circuits will be positive copying circuits.
    • The OV circuits for layer 0 heads will probably be positive copying circuits for 0.0 and (0.2, minus the a token). Not sure what the OV circuits for layer 1 heads will look like, since the heads in layer 1 have to boost and suppress tokens (rather than just suppressing tokens, as in my original hypothesis).
    • If the OV circuits for layer 0 are positive copying circuits, I'd weakly guess they'd be negative for layer 1.
scale_final = cache["scale"][:, :, 0][:, 1:].mean()
scale_0 = cache["scale", 0, "ln1"][:, 1:].mean()
scale_1 = cache["scale", 1, "ln1"][:, 1:].mean()
W_OV = model.W_V @ model.W_O
W_E = model.W_E
W_U = model.W_U

# ! Get direct path
W_E_OV_direct = (W_E / scale_final) @ W_U

# ! Get full OV matrix for path through just layer 0
W_E_OV_0 = (W_E / scale_0) @ W_OV[0]
W_OV_0_full = (W_E_OV_0 / scale_final) @ W_U # [head1 vocab_in vocab_out]

# ! Get full OV matrix for path through just layer 1
W_E_OV_1 = (W_E / scale_1) @ W_OV[1]
W_OV_1_full = (W_E_OV_1 / scale_final) @ W_U # [head1 vocab_in vocab_out]

# ! Get full OV matrix for path through heads in layer 0 and 1
W_E_OV_01 = einops.einsum(
    (W_E_OV_0 / scale_1), W_OV[1],
    "head0 vocab_in d_model_in, head1 d_model_in d_model_out -> head0 head1 vocab_in d_model_out",
)
W_OV_01_full = (W_E_OV_01 / scale_final) @ W_U # [head0 head1 vocab_in vocab_out]

# Stick 'em together
assert W_OV_01_full.shape == (3, 3, 11, 11)
assert W_OV_1_full.shape == (3, 11, 11)
assert W_OV_0_full.shape == (3, 11, 11)
assert W_E_OV_direct.shape == (11, 11)
W_OV_full_all = t.cat([
    t.cat([W_E_OV_direct[None, None], W_OV_0_full[:, None]]), # [head0 1 vocab_in vocab_out]
    t.cat([W_OV_1_full[None], W_OV_01_full]),  # [head0 head1 vocab_in vocab_out]
], dim=1) # [head0 head1 vocab_in vocab_out]
assert W_OV_full_all.shape == (4, 4, 11, 11)

# Visually check plots are in correct order
# W_OV_full_all[0, 1] += 100
# W_OV_full_all[1, 0] += 100

components_0 = ["W<sub>E</sub>"] + [f"0.{i}" for i in range(3)]
components_1 = ["W<sub>U</sub>"] + [f"1.{i}" for i in range(3)]

# Text added after creating this plot, to highlight the stand-out patterns
text = []
patterns = ["", "", "", "", "ac", "ac", "c", "ac", "defj", "dj", "defj", "efj", "bgh", "bgi", "bhi", "bgh"]
for i, pattern in enumerate(patterns):
    text.append([[pattern[pattern.index(i)] if (i==j and i in pattern) else "" for i in dataset.vocab] for j in dataset.vocab])

imshow(
    W_OV_full_all.transpose(0, 1).flatten(0, 1), # .softmax(dim=-1),
    facet_col = 0,
    facet_col_wrap = 4,
    facet_labels = [" ➔ ".join(list(dict.fromkeys(["W<sub>E</sub>", c0, c1, "W<sub>U</sub>"]))) for c1 in components_1 for c0 in components_0],
    title = f"Full virtual OV circuits",
    x = dataset.vocab,
    y = dataset.vocab,
    labels = {"x": "Source", "y": "Dest"},
    height = 1200,
    width = 1200,
    text = text,
)

Conclusions

These results basically fit with my new hypothesis, and I consider this plot and the conclusions drawn from it to be the central figure for explaining this model.

To review the ways in which this plot fits with my new hypothesis:

  • The paths from 0.1 ➔ (head in layer 1) are mostly negative on that head's domain.
  • The path from 0.2 ➔ 1.0 is negative at a.
  • Most of the other paths from (0.0 or 0.2) ➔ (head in layer 1) are strongly positive on that layer 1 head's domain (and weakly positive outside of that head's domain).
  • The OV circuits for heads 0.0 and 0.2 are positive copying circuits, albeit weakly so.
  • We can see that the direct paths via a head in layer 1 are generally negative (more so than I expected).

Additionally, there is strong evidence against my original hypothesis: several of the virtual OV circuits are unmistakably positive copying circuits.

The ways in which this plot doesn't fit with my new hypothesis:

  • There are two more negative paths from (0.0 ➔ 0.2) ➔ (head in layer 1) than I expected: both b and g have negative paths from 0.2 ➔ 1.2.
    • From looking at the rest of the graph, I'm guessing this is because the 0.1 ➔ 1.2 path doesn't do a very good job suppressing duplicated b or g, so another path has to step in and perform this suppression.

Some more notes on this visualisation:

  • This plot might make it seem like the virtual paths are way more important than the single attention head paths. This is partly true, but also can be slightly misleading - these virtual paths will have smaller magnitudes than the plot suggests, since the attention patterns are formed by multiplying together two different attention patterns (and as we saw from the info-weighted attention patterns above, a lot of attention goes to the null character at the start of the sequence, and the result vector from this is very small so unlikely to be used in composition).

QK circuits

I'm now going to plot some QK circuits. I expect to see the following:

  • Head (0.1 off token a) and (0.2 on tokens [a, b, g]) will have a positive stripe, for the (query-embedding) x (key-embedding) QK circuit.
  • Head 0.0, and (0.2 everywhere except [a, b, g]) will attend more to early tokens, i.e. they'll have a smooth gradient over source positions for both the (query-embedding) x (key-pos-embed) and (query-pos-embed) x (key-pos-embed) QK circuits.

To be safe, I also wanted to make a bar chart of the mean & std of the layernorm scale factors which I'm using in this computation, to make sure they aren't implementing any complicated logic (they seem not to be).

W_pos_labels = [str(i) for i in range(model.cfg.n_ctx)]

# Check layernorm scale factor mean & std dev, verify that std dev is small
scale = cache["scale", 0, "ln1"][:, :, 0, 0] # shape (batch, seq)
df = pd.DataFrame({
    "std": scale.std(0).cpu().numpy(),
    "mean": scale.mean(0).cpu().numpy(),
})
px.bar(
    df, 
    title="Mean & std of layernorm before first attn layer", 
    template="simple_white", width=600, height=400, barmode="group"
).show()

W_QK: Tensor = model.W_Q[0] @ model.W_K[0].transpose(-1, -2) / (model.cfg.d_head ** 0.5)

W_E_scaled = model.W_E / scale.mean()
W_pos_scaled = model.W_pos / scale.mean(dim=0).unsqueeze(-1)

W_Qemb_Kemb = W_E_scaled @ W_QK @ W_E_scaled.T
W_Qboth_Kpos = t.concat([W_E_scaled, W_pos_scaled]) @ W_QK @ W_pos_scaled.T
# Apply causal masking
W_Qboth_Kpos[:, -len(W_pos_labels):].masked_fill_(t.triu(t.ones_like(W_Qboth_Kpos[:, -len(W_pos_labels):]), diagonal=1).bool(), float("-inf"))

imshow(
    W_Qemb_Kemb,
    facet_col = 0,
    facet_labels = [f"0.{head}" for head in range(model.cfg.n_heads)],
    title = f"Query = W<sub>E</sub>, Key = W<sub>E</sub>",
    labels = {"x": "Source", "y": "Dest"},
    x = dataset.vocab,
    y = dataset.vocab,
    height = 400,
    width = 750,
)
imshow(
    W_Qboth_Kpos,
    facet_col = 0,
    facet_labels = [f"0.{head}" for head in range(model.cfg.n_heads)],
    title = f"Query = W<sub>E</sub> & W<sub>pos</sub>, Key = W<sub>pos</sub>",
    labels = {"x": "Source", "y": "Dest"},
    x = W_pos_labels,
    y = dataset.vocab + W_pos_labels,
    height = 620,
    width = 1070,
)

Conclusions

This pretty fits with both the two expectations I had in the previous section. The query-side positional embeddings actually seem to have a slight bias towards attending to later positions, but it looks like this is dominated by the effect from the query-side token embeddings (which show a stronger "attend to earlier positions" effect). Also, note that 0.1 has a bias against self-attention, which makes sense given its role as a DTH.

One other observation - heads 0.0 and 0.2 self-attending strongly at position 1 stands out here. This is a good indication that V-composition between these two heads & some heads in layer 1 is boosting tokens, because "boost logits for the very first non-null token in the sequence, mostly at the first position but also at the positions that come after" is a very easy and helpful heuristic to learn. In fact, we might speculate this was one of the first things the model learned (after the obvious "predict null character at the start of the sequence"). The algorithm proposed at the start (all heads in layer 0 acting as duplicate token heads, and heads in layer 1 attending to the first non-duplicated token) might actually have achieved better global loss properties. But if the model learned this heuristic early on, and a consequence of this heuristic is positive virtual copying circuits forming between (0.0, 0.2) and heads in layer 1, then it might have no choice but to implement the version of the algorithm we see here.

Exercise to the reader - can you find evidence for/against this claim? You can find all the details of training, including random seeds, in the notebook august23_unique_char/training_model.ipynb. Is this heuristic one of the first things the model learns? If you force the model not to learn this heuristic during training (e.g. by adding a permanent hook to make sure heads in layer 0 never self-attend), does the model learn a different algorithm more like the one proposed at the start?

Direct Logit Attribution

The last thing I'll do here (before moving onto some adversarial examples) is write a function to decompose the model's logit attribution by path through the model. Specifically, I can split it up into the same 16 paths as we saw in the "full virtual OV circuits" heatmap earlier. This will help to see whether the theories I've proposed about the model's behaviour are correct.

It wasn't obvious what kind of visualisation was appropriate here. If I focused on a single sequence, then there are 3 dimensions I care about: the destination position, the path through the model, and which token is being suppressed / boosted. I experimented with heatmaps using each of these three dimensions as a facet column, and ended up settling on using the latter of these as the facet column - that way it's easier to compare different paths, at different points in the sequence (and because there's usually only a handful of tokens I care about the logit attribution of, at any given point).

One last note - I subtracted the mean DLA for every path. This turned out to be important, because there are a few effects we need to control for (in particular, the direct effect of head 1.0 in c even when it's not in the sequence). This is why the first column is all zeros (this doesn't mean the head is unable to predict ? as the first character!). I've also allowed this mean to be returned as a tensor and used as input rather than a boolean, in case I want to subtract the mean for a much smaller dataset.

def dla_imshow(
    dataset: UniqueCharDataset,
    cache: ActivationCache,
    batch_idx: int,
    str_tok: Union[str, List[str]],
    subtract_mean: Union[bool, Tensor] = True,
):
    # ! Get DLA from the direct paths & paths through just heads in layer 0
    resid_decomposed = t.stack([
        cache["embed"] + cache["pos_embed"],
        *[cache["result", 0][:, :, head] for head in range(3)]
    ], dim=1)
    assert resid_decomposed.shape == (len(dataset), 4, 20, model.cfg.d_model), resid_decomposed.shape
    t.testing.assert_close(resid_decomposed.sum(1) + model.b_O[0], cache["resid_post", 0])

    dla = (resid_decomposed / cache["scale"].unsqueeze(1)) @ model.W_U
    assert dla.shape == (len(dataset), 4, 20, model.cfg.d_vocab), dla.shape

    # ! Get DLA from paths through layer 1
    resid_decomposed_post_W_OV = einops.einsum(
        (resid_decomposed / cache["scale", 0, "ln1"][:, None, :, 0]),
        model.W_V[1] @ model.W_O[1],
        "batch decomp seqK d_model, head d_model d_model_out -> batch decomp seqK head d_model_out"
    )
    resid_decomposed_post_attn = einops.einsum(
        resid_decomposed_post_W_OV,
        cache["pattern", 1],
        "batch decomp seqK head d_model, batch head seqQ seqK -> batch decomp seqQ head d_model"
    )
    new_dla = (resid_decomposed_post_attn / cache["scale"][:, None, :, None]) @ model.W_U
    dla = t.concat([
        dla,
        einops.rearrange(new_dla, "batch decomp seq head vocab -> batch (decomp head) seq vocab")
    ], dim=1)

    # ! Get DLA for batch_idx, subtract mean baseline, optionally return the mean
    dla_mean = dla.mean(0)
    if isinstance(subtract_mean, Tensor):
        dla = dla[batch_idx] - subtract_mean
    elif subtract_mean:
        dla = dla[batch_idx] - dla_mean
    else:
        dla = dla[batch_idx]

    # ! Plot everything
    if isinstance(str_tok, str):
        str_tok = [str_tok]
        kwargs = dict(
            title = f"Direct Logit Attribution by path, for token {str_tok[0]!r}",
            height = 550,
            width = 700,
        )
    else:
        assert len(str_tok) % 2 == 0, "Odd numbers mess up figure order for some reason"
        kwargs = dict(
            title = "Direct Logit Attribution by path",
            facet_col = -1,
            facet_labels = [f"DLA for token {s!r}" for s in str_tok],
            height = 100 + 450 * int(len(str_tok) / 2),
            width = 1250,
            facet_col_wrap = 2,
        )
    toks = [dataset.vocab.index(tok) for tok in str_tok]
    layer0 = [" "] + [f"0.{i} " for i in range(3)]
    layer1 = [f"1.{i} " for i in range(3)]
    imshow(
        dla[:, :, toks].squeeze(),
        x = [f"{s}<br><sub>({i})</sub>" for i, s in enumerate(dataset.str_toks[batch_idx])],
        y = layer0 + [f"{c0}➔ {c1}".lstrip(" ➔ ") for c0 in layer0 for c1 in layer1],
        # margin = dict.fromkeys("tblr", 40),
        aspect = "equal",
        text_auto = ".0f",
        **kwargs,
    )
    if isinstance(subtract_mean, bool) and subtract_mean:
        return dla_mean


print(f"Seq = {''.join(dataset.str_toks[0])}, Target = {''.join(dataset.str_tok_labels[0])}")

dla_mean = dla_imshow(
    dataset,
    cache,
    batch_idx = 0, 
    str_tok = ["c", "g"],
    subtract_mean = True,
)

Seq = ?chgegfaeadieaebcffh, Target = ?ccccccccccccccchhhd

Playing around with these plots for a while, I concluded that they pretty much fit my expectations. The paths which are doing boosting and suppression are almost always the ones I'd expect from the OV composition plot.

For example, take the plot above, which shows the attribution for [c, g] in the very first sequence. Consider the attribution for c:

  • At position 1, the paths (0.0, 0.2) ➔ 1.0 boost c, and (0.1, direct) ➔ 1.0 suppress it (the former is stronger, presumably because the first character in 0.0 and 0.2 strongly self-attends). This fits with our virtual OV circuits plot.
  • After position 1, the main positive attribution paths are (0.0, 0.2) ➔ 1.0, as expected.
  • Once c becomes duplicated for the first time, the negative attribution from (0.1, direct) ➔ 1.0 outweighs the positive attribution. This makes sense, because once c is duplicated head 0.1 and 1.0 will both attend more to the duplicated c (neither of which boosts the positive paths for c, since the duplicated c doesn't attend back to the first instance of c in 0.0 or 0.2).

Now consider the attribution for g:

  • The path (0.0 ➔ 1.2) boosts g at positions 3 & 4, as expected from the virtual OV circuits plot.
  • The paths (direct, 0.2 ➔ 1.2) kick in after the duplicated token to suppress g for the rest of the sequence (as well as 0.1 ➔ 1.2, but this one is weaker) - again, this is the path we expect to see.

Final summary

Some layer 0 heads (0.1 everywhere except a, and 0.2 on [a, b, g]) are duplicate token heads; they're composing with layer 1 heads to cause those heads to attend to & suppress duplicate tokens. This is done both with K-composition (heads in layer 1 attend more to duplicated tokens), and V-composition (the actual outputs of the DTHs are used as value input to heads in layer 1 to suppress duplicated tokens).

All other layer 0 head paths are involved in boosting, rather than suppression. They attend to early tokens, which are not the same as the current destination token. Their outputs are used as value input to heads in layer 1 to boost these tokens.

Layer 1 heads split their functionality across the vocabulary. 1.0 handles boosting / suppression for [a, c], 1.1 handles [d, e, f, j], and 1.2 handles [b, g, h, i]. These sets are disjoint, and their union is the whole vocabulary. This makes sense, because layer 1 attention is a finite resource, and it has to be used to suppress every duplicated token in the sequence (missing even one duplicated token could cause the model to make an incorrect classification).

Adversarial examples

To make things interesting, I'll try and find an example where the model thinks the answer is X rather than Y (as opposed to thinking that there is a unique solution when there isn't, or vice-versa).

Firstly, it seems like a good idea to overload single heads if we can (think of overloading chess pieces!). Head 1.0 manages ac (2 tokens), 1.1 manages defj (4 tokens), and 1.2 manages bghi (4 tokens). The latter heads have more responsibilities, so we should try overloading one of them.

Secondly, we don't want the correct token to be the one at the first non-null position - that would be too easy! Heads like 0.0 and 0.2 strongly self-attend at the first non-null position, and then heads in layer 1 attend to this position in order to boost those logits. We need to put a few duplicated tokens first in the sequence.

Thirdly, we should have 2 unique tokens right next to each other, in the hope that the model will think the second one is the correct answer rather than the first one. We saw a smooth gradient with the full QK circuits (when the key-side circuit was positional), so the differences between adjacent tokens should be minimal.

After searching for a bit, I found the example below. We're overloading head 1.2, by including duplicated tokens [g, b, i] before a non-duplicated h and non-duplicated a. Head 1.0 is able to boost a because it's not overloaded, but head 1.2 is unable to boost h because it's already attending to & suppressing the duplicated tokens [g, b, i]. There are a few more examples like this you can create if you play around with the exact order and identity of tokens.

class CustomDataset(UniqueCharDataset):
        
    def __init__(
        self,
        tokens: Union[Int[Tensor, "batch seq"], Callable],
        size: Optional[int] = None,
        vocab: List[str] = list("abcdefghij"), 
        seq_len: int = 20,
        seed: int = 42
    ):
        
        self.vocab = vocab + ["?"]
        self.null_tok = len(vocab)
        if isinstance(tokens, Tensor):
            self.size = tokens.shape[0]
        else:
            assert size is not None
            self.size = size
        t.manual_seed(seed)

        # Generate our sequences
        if isinstance(tokens, t.Tensor):
            self.toks = tokens
        else:
            self.toks = tokens(self.size, seq_len, self.null_tok)
        self.str_toks = [
            [self.vocab[tok] for tok in seq]
            for seq in self.toks
        ]

        # Generate our labels (i.e. the identity of the first non-repeating character in each sequence)
        self.labels = find_first_unique(self.toks[:, 1:], self.null_tok)
        self.labels = t.cat([
            t.full((self.size, 1), fill_value=self.null_tok),
            self.labels
        ], dim=1)
        self.str_tok_labels = [
            [self.vocab[tok] for tok in seq]
            for seq in self.labels
        ]


str_toks = "?ggbbiihaggbigbigbig"
toks = t.tensor([[dataset.vocab.index(tok) for tok in str_toks]])

advex_dataset = CustomDataset(tokens=toks)

advex_logits, advex_cache = model.run_with_cache(advex_dataset.toks)
advex_logprobs = advex_logits.squeeze().log_softmax(-1).T
advex_probs = advex_logits.squeeze().softmax(-1).T

print(f"Seq = {''.join(advex_dataset.str_toks[0])}, Target = {''.join(advex_dataset.str_tok_labels[0])}")

imshow(
    advex_probs,
    y=advex_dataset.vocab,
    x=[f"{s}<br><sub>({j})</sub>" for j, s in enumerate(str_toks)],
    labels={"x": "Position", "y": "Predicted token"},
    title="Probabilities for adversarial example",
    width=800,
    text=[
        ["〇" if str_tok == correct_str_tok else "" for correct_str_tok in advex_dataset.str_tok_labels[0]]
        for str_tok in advex_dataset.vocab
    ],
)

Seq = ?ggbbiihaggbigbigbig, Target = ?g?b?i?hhhhhhhhhhhhh

Verify that head 1.2 is attending strongly to the duplicated gbi tokens, less to to h (and to those after it):

cv.attention.from_cache(
    advex_cache,
    tokens = list(str_toks),
    attention_type = "standard",
)

Remaining questions / notes / things not discussed

Null character

I've not discussed how the model predicts the null character yet, because I didn't consider it a crucial part of the model's functionality. Some more investigation leads me to the following hypothesis:

  • The model predicts ? at the first position because every attention head has to self-attend here. Looking at DLA without subtracting the mean shows that almost every path contributes a small positive amount to ?.
  • In later positions, ? is predicted in much the same way other tokens are predicted: all duplicated tokens are strongly suppressed, and ? is boosted mainly via V-composition paths. This effect is weaker than for other tokens in the vocabulary - which it has to be, because if any non-duplicated token exists then it needs to dominate the boosting of ?.
  • Unlike the other characters in the vocabulary, ? doesn't have a dedicated head in layer 1.

Layer 1 QK circuits

There are some interesting patterns here, somewhat analogous to the patterns in the layer 0 QK circuits. I've not discussed them here though, because I don't consider them critical to understand how this model functions.

How can the model predict a token with high probability, without ever attending to it with high probability?

I'm including a short answer to this question, because it's something which confused me a lot when I started looking at this model.

Consider a sequence like ?aab...c as an example. How can the model correctly predict b at position c? The answer, in short - in heads 0.0 and 0.2, all the tokens between b and c will slightly attend to b. Then in head 1.2, c will attend to these intermediate tokens, and these virtual OV circuits will boost b. Also, the duplicate token head 0.1 makes sure a is very suppressed, so that b will be predicted with highest probability.