PERFECTLY NORMAL

CALLUM MCDOUGALL

This image was created using a variant of my thread art algorithm - read more here.

Monthly Algorithmic Challenge (November 2023): Cumulative Sum




Problem

This post is the fifth in the sequence of monthly mechanistic interpretability challenges. I designed them in the spirit of Stephen Casper's challenges, but with the more specific aim of working well in the context of the rest of the ARENA material, and helping people put into practice all the things they've learned so far. You can access the full ARENA materials from the Streamlit page. The purpose of the page you're currently reading is mainly to provide an easy-to-access place to read the solutions and understand the nature of these problems; I'd recommend using the Streamlit page instead if you're actually working through the problems. You can also find the solutions Colab here. I'll assume everyone has read the first problem, and has some overall context on the sequence.

Task & Dataset

The problem for this month is interpreting a model which has been trained to classify the cumulative sum of sequences. The model was fed input in the form of a sequence of digits between -max_value and max_value inclusive, and was tasked with classifying the cumulative sum of the sequence at each digit - correct labels are 0 for a negative cumulative sum, 1 for a cumulative sum of zero, and 2 for a positive cumulative sum.

Here is an example of what this dataset looks like:

dataset = CumsumDataset(size=1, seq_len=6, max_value=3, seed=40)

print(dataset[0]) # same as (dataset.toks[0], dataset.labels[0])

print(", ".join(dataset.str_toks[0])) # inputs to the model

print(", ".join(dataset.str_labels[0])) # whether the cumsum of inputs is strictly positive

(tensor([ 0, 1, -3, -3, -2, 3]), tensor([1, 2, 0, 0, 0, 0]))
+0, +1, -3, -3, -2, +3
zero, pos, neg, neg, neg, neg

Model

Our model was trained by minimising cross-entropy loss between its predictions and the true labels, at all sequence positions.training_model.ipynb to see how it was trained.

The model is is a 1-layer transformer with one attention head, and an MLP layer. I chose this problem to be about the simplest possible which would require an MLP to get very high performance on, due to the nonlinearity of the classification problem.





Solutions

First, let's do some setup (all functions can be found in the Colab / ARENA repo):
dataset = CumsumDataset(size=1000, max_value=5, seq_len=20, seed=42).to(device)
fix_dataset(dataset)
fix_model(model)

logits, cache = model.run_with_cache(dataset.toks)

clean_logprobs = logits.log_softmax(-1) # [batch seq_len vocab_out]
clean_probs = clean_logprobs.softmax(-1)

clean_logprobs_correct = eindex(clean_logprobs, dataset.labels, "batch seq [batch seq]")
clean_probs_correct = eindex(clean_probs, dataset.labels, "batch seq [batch seq]")

print(f"Average cross entropy loss: {-clean_logprobs_correct.mean().item():.3f}")
print(f"Mean probability on correct label: {clean_probs_correct.mean():.3f}")
print(f"Median probability on correct label: {clean_probs_correct.median():.3f}")
print(f"Min probability on correct label: {clean_probs_correct.min():.3f}")

Average cross entropy loss: 0.077
Mean probability on correct label: 0.936
Median probability on correct label: 0.999
Min probability on correct label: 0.551

And some simple (kinda hacky) visualisation:
def show(dataset: CumsumDataset, batch_idx: int):

    logits = model(dataset.toks[batch_idx].unsqueeze(0)).squeeze() # [seq_len vocab_out]
    probs = logits.softmax(dim=-1) # [seq_len vocab_out]

    fig = imshow(
        probs.T,
        y=dataset.vocab_out,
        x=[f"{s}<br><sub>({j})</sub>" for j, s in enumerate(dataset.str_toks[batch_idx])],
        labels={"x": "Token", "y": "Vocab"},
        xaxis_tickangle=0,
        title=f"Sample model probabilities:<br>{', '.join(dataset.str_toks[batch_idx])}",
        text=[
            ["〇" if (s == target) else "" for target in dataset.str_labels[batch_idx]]
            for s in dataset.vocab_out
        ],
        width=750,
        height=350,
        return_fig=True,
    )
    fig.show()

show(dataset, 1)

Summary of how the model works

The single attention head implements uniform attention to all previous tokens in the sequence. The OV matrix is essentially one-dimensional: it projects each token with value $s$ onto $s \boldsymbol{u}$, where $\boldsymbol{u}$ is some vector in the residual stream learned by the model. The component of the residual stream in this direction then represents the cumulative mean (note, the cumulative mean rather than the cumulative sum, because attention is finite - for example, we expect the component to be the same after the sequences (1, 1, 2) and (1, 1, 2, 1, 1, 2) because net attention to each different token value will be the same).

The model's "positive cumsum prediction direction" aligns closely with $\boldsymbol{u}$, and vice-versa for the "negative cumsum prediction direction" - this allows the model to already get >50% accuracy before the MLP even comes into play. But without the MLP, the model has a hard time dealing with sequences that have cummeans close to zero: it usually defaults to predicting zero. The job of the MLP layer is to detect when the cummean is positive and boost the positive prediction + suppress negative prediction (neurons #0, #1, #3 and #4), or vice-versa when the cummean is negative (neurons #2 and #7). This sharp nonlinear behaviour is what allows the model to correctly classify sequences even when the cummean is close to zero.

First pass: searching for a "sum direction"

First, let's look at attention patterns. Noting my expectations beforehand:

  • I expect the model to be writing the total value of numbers to some subspace, i.e. there's a vector $\boldsymbol{u}$ s.t. the vector moved from source token $s$ to destination token $s$ is $s \boldsymbol{u}$. If we aggregate over all source tokens, this means that we'll be storing the sum information in token $d$ - then I expect the neurons can process this information in a nonlinear way to get the required output.
  • When I plot the raw attention patterns, they might look uniform, but I at least expect the info-weighted attention patterns to show this pattern (of being proportional to the size of the source token).
cv.attention.from_cache(
    cache = cache,
    tokens = dataset.str_toks,
    batch_idx = list(range(10)),
    attention_type = "info-weighted",
    radioitems = True,
    batch_labels = ["<code>" + ", ".join(s) + "</code>" for s in dataset.str_toks],
)

Eyeballing info-weighted attention patterns, this looks like it holds up. Interestingly, standard attention patterns are almost perfectly uniform*, so this suggests that all the interesting behaviour comes from the OV matrix. The next step will be to examine the OV matrix, and see if I can find evidence of this $\boldsymbol{u}$ direction. Specifically, when I perform SVD on the OV matrix, I expect to find only one direction that matters, and the values (-5, -4, ..., +5) will be spread along this direction in a linear way.

*I realised after writing this that the attention patterns are uniform because the QK matrices are identically zero! This means all logits will be zero, so probabilities will be the same.

I've taken code from one of my previous monthly problems to plot the SVD, although I've adapted it so that the left-hand plot takes into account magnitudes, not just directions (that way I can more easily see whether there is indeed just one singular direction which matters).

def plot_svd_single(tensor, title=None):

    U_matrix, S_matrix, V_matrix = t.svd(tensor)

    singular_directions = V_matrix[:, :2]
    # This line of code is changed: we scale with the magnitude of the singular direction
    singular_directions_scaled = utils.to_numpy(singular_directions * S_matrix[:2])
    df = pd.DataFrame(singular_directions_scaled, columns=['Dir 1', 'Dir 2'])
    df['Labels'] = dataset.vocab

    fig = make_subplots(rows=1, cols=2, subplot_titles=["First two singular directions", "Singular values"])
    fig.add_trace(go.Scatter(x=df['Dir 1'], y=df['Dir 2'], mode='markers+text', text=df['Labels']), row=1, col=1)
    fig.update_traces(textposition='top center', marker_size=5)
    fig.add_trace(go.Bar(y=utils.to_numpy(S_matrix)), row=1, col=2)
    fig.update_layout(height=400, width=750, showlegend=False, title_text=title, template="simple_white")
    # Make sure the axes scales are the same (found this code from stackoverflow)
    fig.update_yaxes(scaleanchor="x", scaleratio=1)
    fig.show()


W_OV = model.W_V[0, 0] @ model.W_O[0, 0] # [d_model, d_model]
W_OV_full = model.W_E @ W_OV # [d_vocab, d_model]

fig = plot_svd_single(W_OV_full.T, title="SVD of WEWOV")
Conclusion - yep, pretty cut and dry. We can take the existence of $u$ as given, and this code allows us to define it:
# Define the u vector as a constant (we'll be using it later on too)
U_matrix, S_matrix, V_matrix = t.svd(W_OV_full.T)
U = U_matrix[:, 0]

Finding $u$ in resid mid

Before I move onto analysing neurons, I'll do one last thing: take this $v$-direction, and make some plots of all the projections along $u$ at each different sequence position in the model (grouped by what the actual cumulative sum is at that point). Hopefully, I'll find that there's clean separation between these values, which is presumably how the neurons are able to extract the cumsum information from the model.

# Get resid mid, and project it along the u-direction
resid_mid = cache["resid_mid", 0] # [batch seq d_model]
resid_mid_proj = einops.einsum(resid_mid, U, "batch seq d_model, d_model -> batch seq")

# Get the actual cumulative sums, and rescale
cumsums = (dataset.toks - dataset.max_value).cumsum(dim=-1) # [batch seq]

# Plotly code (generated from giving GPT4 a spec, then tweaking the code)
def create_violin_plot(floats, ints):
    # Create a DataFrame from the inputs
    df = pd.DataFrame({'Values': floats, 'Categories': ints})
    # Create the violin plot
    fig = px.box(
        df, y='Values', color='Categories', 
        color_discrete_sequence=px.colors.sequential.Agsunset,
        labels={'Categories': 'Cumsum', 'Values': 'Value'},
        category_orders={'Categories': list(range(cumsums.min(), cumsums.max()+1))},
        width=1000, height=600,
        title="Projections of resid_mid along u-direction, grouped by cumsum",
    )
    fig.show()

create_violin_plot(floats=resid_mid_proj.flatten().tolist(), ints=cumsums.flatten().tolist())

Not terrible, but nowhere near a clean separation.

After making this plot, I realised the problem - the model has finite attention to spread uniformly over tokens. For example, I expect the projection along the $\boldsymbol{u}$-direction after the sequence (1, 2) to be of the same size as the projection after (1, 2, 1, 2), despite the latter being twice as long of a sequence (in both case, the same vector will have been added to the destination position). So a more informative plot would be a scatter plot, where the y-axis is the projection value and the x-axis is cumulative sum divided by current sequence position - or to put it another way, cumulative mean.

cummeans = cumsums / t.arange(1, cumsums.shape[1]+1).to(device)

fig = px.scatter(
    x=cummeans.flatten().tolist(),
    y=resid_mid_proj.flatten().tolist(), 
    labels={'x':'Projection in u-direction', 'y':'Cumulative mean'},
    height=500, width=700,
    title="Projection of resid_mid in u-direction, against cumulative mean",
    template="ggplot2",
)
fig.show()
Looking great - this direction clearly does store the cumulative mean, with zero exceptions. For convenience I'll flip the sign of the $u$ vector, so that it represents the positive direction rather than negative.
U_matrix, S_matrix, V_matrix = t.svd(W_OV_full.T)
U = -U_matrix[:, 0]

Note that this scatter plot shows some cumulative mean values which are very close to zero (e.g. when cumsum is 1 but the sequence position is large). These will be hard for the model to tell apart, and in fact we can see that from the very first example which was visualised above (the model finds it difficult to distinguish between pos/zero and neg/zero towards the end of the sequence). This would be less of an issue if the model had more attention heads, because having finite attention is the bottleneck here. However, in the limit for long sequences this problem would always exist. 2-layer models might also be able to do a bit better, by splitting the cumulative sum calculation into multiple steps. Can you see how this might work? How long are the sequences you can train a 1-layer / 2-layer model on respectively?

Analysing neurons

Firstly, let's see what happens if we ablate all neurons. The scatter plot above shows that the model could probably do a pretty good job classifying neg/pos without help from the MLP! If the model predicts zero by default, but predicts positive strongly with high $u$-projection, and negative strongly with high $v$-projection, then it would probably do okay.

# Hook function to ablate a neuron (zeroing its activation)
def hook_fn_ablate_neuron(post: Float[Tensor, "batch seq d_mlp"], hook: HookPoint, neuron_idx: int):
    post[:, :, neuron_idx] = 0

ablated_logits: Tensor = model.run_with_hooks(
    dataset.toks,
    fwd_hooks = [(utils.get_act_name("post", 0), partial(hook_fn_ablate_neuron, neuron_idx=list(range(model.cfg.d_mlp))))]
)

ablated_logprobs = ablated_logits.log_softmax(-1) # [batch seq_len vocab_out]
ablated_probs = ablated_logprobs.softmax(-1)

ablated_logprobs_correct = eindex(ablated_logprobs, dataset.labels, "batch seq [batch seq]")
ablated_probs_correct = eindex(ablated_probs, dataset.labels, "batch seq [batch seq]")

print(f"Average cross entropy loss: {-ablated_logprobs_correct.mean().item():.3f}")
print(f"Mean probability on correct label: {ablated_probs_correct.mean():.3f}")
print(f"Median probability on correct label: {ablated_probs_correct.median():.3f}")
print(f"Min probability on correct label: {ablated_probs_correct.min():.3f}")
Average cross entropy loss: 0.536
Mean probability on correct label: 0.637
Median probability on correct label: 0.614
Min probability on correct label: 0.159

Okay, so the model still does pretty decent. Our suspicion that the $u$-direction is being used for positive/negative classifications is confirmed by showing that the unembedding direction for "positive" has high negative cosine similarity with $u$, and vice-versa for the unembedding direction for "negative". Unsurprisingly, it's also the case that the positive and negative unembedding directions are negatively aligned.
pos_u_sim = t.cosine_similarity(U, model.W_U[:, 2], dim=-1)
neg_u_sim = t.cosine_similarity(U, model.W_U[:, 0], dim=-1)
pos_neg_sim = t.cosine_similarity(model.W_U[:, 0], model.W_U[:, 2], dim=-1)

print(f"Cos sim of pos unembedding and u: {pos_u_sim:+.3f}")
print(f"Cos sim of neg unembedding and u: {neg_u_sim:+.3f}")
print(f"Cos sim of pos and neg unembedding: {pos_neg_sim:+.3f}")
Cos sim of pos unembedding and u: +0.970
Cos sim of neg unembedding and u: -0.958
Cos sim of pos and neg unembedding: -0.949

Additionally, since I expect the neurons are implementing behaviour of the form "fire when the cumulative sum is positive & boost positive predictions" (or vice-versa for a negative cumulative sum), I would expect that the model's default behaviour in the absence of the MLP layer is to predict "zero", and that's primarily what's causing us to have high loss. I'll test this by plotting "model's assigned probabilities for zero when MLP is mean-ablated" against the cumulative mean, and compared it to the non-ablated case.

Note, I used probabilities rather than logprobs because I'm looking for qualitative behaviour not quantitative, and because probs have a peak at zero which is easier to visually interpret than logprobs. I also took a small sample from this dataset because plotting all of them was pretty intensive, given how many datapoints there are.

As expected, ablating the MLPs causes the model to have a strong bias towards predicting that the cumsum is zero (and it doesn't much harm the model's predictions when the answer actually is zero). Without the MLPs, the cumsum has to be really extreme for the model to be confident that the sum isn't zero (remember that the x-axis below is cumulative mean, not cumulative sum).

fig = make_subplots(cols=2, shared_yaxes=True, subplot_titles=["Clean", "MLP ablated"])

random_indices = t.randperm(cummeans.numel())[:2000]
x = cummeans.flatten()[random_indices].tolist()
y_clean = clean_probs[..., 1].flatten()[random_indices].tolist()
y_ablated = ablated_probs[..., 1].flatten()[random_indices].tolist()
cumsum_is_zero = ["#1F77B4" if v < 0.01 else "#FF7F0E" for v in cummeans.flatten()[random_indices].abs()]

fig.add_trace(go.Scatter(x=x, y=y_clean, mode="markers", marker=dict(color=cumsum_is_zero, opacity=0.5)), row=1, col=1)
fig.add_trace(go.Scatter(x=x, y=y_ablated, mode="markers", marker=dict(color=cumsum_is_zero, opacity=0.5)), row=1, col=2)
fig.update_layout(title="Model's P(zero sum), with / without MLPs (blue = cumsum is actually zero)", showlegend=False, height=500, width=1100)
fig.show()
One other interesting observation: despite being very close to opposite directions, both the pos and neg unembedding directions have small negative cosine similarity with the zero unembedding direction. This makes sense, because if a positive or a negative sum is detected, these are both reasons to push against a zero prediction.
pos_zero_sim = t.cosine_similarity(model.W_U[:, 2], model.W_U[:, 1], dim=-1)
neg_zero_sim = t.cosine_similarity(model.W_U[:, 0], model.W_U[:, 1], dim=-1)

print(f"Cos sim of pos and zero unembedding: {pos_zero_sim:+.3f}")
print(f"Cos sim of neg and zero unembedding: {neg_zero_sim:+.3f}")
Cos sim of pos and zero unembedding: -0.190
Cos sim of neg and zero unembedding: -0.128

Now, let's see which neurons are actually important by ablating them individually. We find that 5 and 6 are almost completely useless, 2 is by far the most useful, and the rest are somewhere in the middle.
# Get the average correct logprobs with no ablation (we calculated this earlier)
avg_loss_clean = -clean_logprobs_correct.mean().item()

# Hook function to ablate a neuron (zeroing its activation)
def hook_fn_ablate_neuron(post: Float[Tensor, "batch seq d_mlp"], hook: HookPoint, neuron_idx: int):
    post[:, :, neuron_idx] = 0

# Iterate through neurons, see how much loss is changed from ablating each of them
print("Increase in loss from ablating neuron...")
for neuron_idx in range(model.cfg.d_mlp):
    logits = model.run_with_hooks(
        dataset.toks,
        fwd_hooks=[(utils.get_act_name("post", 0), partial(hook_fn_ablate_neuron, neuron_idx=neuron_idx))]
    )
    logprobs = logits.log_softmax(-1) # [batch seq_len vocab_out]
    avg_loss = eindex(-logprobs, dataset.labels, "batch seq [batch seq]").mean().item()
    print(f"  {neuron_idx}: {avg_loss-avg_loss_clean:.4f}")
Increase in loss from ablating neuron...
0: 0.0185
1: 0.0163
2: 0.0951
3: 0.0200
4: 0.0311
5: 0.0002
6: 0.0000
7: 0.0465

This problem is pretty close to solved now - the only thing left is to figure out exactly how these neurons are helping the model. Specifically, I expect to find some of the neurons start activating on a threshold of when the $\boldsymbol{u}$-projection is positive/negative, and they'll be (boosting pos and suppressing the others) / (boosting neg and suppressing the others) respectively.

Since the results for projecting in the $\boldsymbol{u}$-direction have been so clean, I'm satisfied to just consider what a neuron's output is when the $\boldsymbol{u}$-component has different values. I spent a while brainstorming some possible visualisations, eventually settling on a line chart of the neuron's output: the x-axis is the $\boldsymbol{u}$-component, and the y-axis can represent the neuron's effect on the prediction of pos/neg/zero respectively. I chose to do this rather than anything working with the actual cache, because the results from the previous sections (i.e. the projections onto the $\boldsymbol{u}$-direction) are so clear that we don't really need to. Also, I decided to keep things 1D rather than 2D because a 2D plot includes some redundancy (initially I had a plot where the x-axis was "pos minus neg" direction and the y-axis was "zero" direction, but these are both proportional to neuron activation meaning all the data was on a boring line!).

# Get possible values of u-projection 
u_coeffs = t.linspace(-4, 4, 100).to(device)
u_proj = einops.einsum(U, u_coeffs, "d_model, n -> n d_model")

# Calculate what the neuron's activations are
act_pre = einops.einsum(u_proj, model.W_in[0], "n d_model, d_model d_mlp -> n d_mlp")
act_post = F.relu(act_pre + model.b_in[0])

# Calculate what the neuron's outputs are, projected onto each unembedding direction
neuron_output = einops.einsum(act_post, model.W_out[0], "n d_mlp, d_mlp d_model -> n d_model d_mlp")
neuron_output_projected = einops.einsum(neuron_output, model.W_U, "n d_model d_mlp, d_model d_vocab_out -> n d_mlp d_vocab_out")

# Create plotly chart
fig = make_subplots(rows=2, cols=4, shared_yaxes=True, subplot_titles=[f"Neuron {i}" for i in range(8)], vertical_spacing=0.1)
rows_and_cols = [(row, col) for row in range(1, 3) for col in range(1, 5) for vocab_out_idx in range(3)]

for neuron_idx in range(8):
    for vocab_out_idx, vocab_out in enumerate(dataset.vocab_out):
        row, col = rows_and_cols[neuron_idx * len(dataset.vocab_out) + vocab_out_idx]
        fig.add_trace(
            go.Scatter(
                x=u_coeffs.tolist(), 
                y=neuron_output_projected[:, neuron_idx, vocab_out_idx].tolist(),
                name=vocab_out if neuron_idx == 0 and vocab_out_idx < 3 else None,
                showlegend=True if neuron_idx == 0 and vocab_out_idx < 3 else False,
                line=dict(color=px.colors.qualitative.D3[vocab_out_idx])
            ),
            row=row, col=col,
        )

fig.update_layout(height=600, width=1000, title="Effect of neurons on model predictions, for a range of u-projection values")
ymax = neuron_output_projected.abs().max().item() + 1
fig.update_yaxes(range=[-ymax, ymax])
fig.show()

This basically confirms all of our previous observations. To summarize the important bits:

  • We can see why neurons #5 and #6 aren't important, because they don't respond to the $u$-direction at all.
  • Every other neuron either has the pattern of "activate on positive cummeans and boost the positive prediction & suppress the negative prediction" or vice-versa.
  • We can see why #2 is the most important, because it has the largest response of any neuron. Also, only 2 neurons (#2 and #7) deal with cases when the cummean is negative, but four neurons (#0, #1, #3, #4) deal with cases when it's positive - this also helps explain why ablating neuron 2 had such a large effect (about 4x the average effect of ablating other helpful neurons, despite the fact that its effect size is not 4x larger than those other neurons).
  • Note how, if you zoom in, you can see that all of the important neurons boost the zero direction when they fire.

From our understanding of the model, we can guess that the best candidate for adversarial examples will be long sequences which have cumulative sums very close to zero at the end (the only thing that matters is the cumsum, not the actual values). Unlike previous models, this isn't particularly surprising, i.e. we might have guessed this without the MI analysis - but it's nice that we understand why a bit better now!