PERFECTLY NORMAL

CALLUM MCDOUGALL

This image was created using a variant of my thread art algorithm - read more here.

Einops-inspired indexing

You can see the accompanying Colab here.

You can install the eindex library with:

pip install git+https://github.com/callummcdougall/eindex.git


There are a few libraries which have made functions like this, but here is my own interpretation of what indexing would look like with einops-style notation.

The idea is for the pattern string you pass to the eindex function to be as close as possible to how you'd think about defining each element of the output tensor. For example, suppose you have a logprobs tensor of shape (batch_size, seq_len, d_vocab), and a tensor of correct next tokens with shape (batch_size, seq_len), and you want to get the logprobs on the correct tokens. You would think about this as follows:

output[batch, seq] = logprobs[batch, seq, labels[batch, seq]]

In my library, this is implemented simply by writing the text inside the right hand square brackets expression as the pattern string:

output = eindex(logprobs, labels, "batch seq [batch seq]")

Setup

%pip install git+https://github.com/callummcdougall/eindex.git

from eindex import eindex
import torch

Examples

Indexing logprobs

Here is the example given above (along with a few other ways you could get the same result, and showing that they're equivalent):

BATCH_SIZE = 32
SEQ_LEN = 5
D_VOCAB = 100

logprobs = torch.randn(BATCH_SIZE, SEQ_LEN, D_VOCAB).log_softmax(-1)
labels = torch.randint(0, D_VOCAB, (BATCH_SIZE, SEQ_LEN))

# (1) Using eindex
output_1 = eindex(logprobs, labels, "batch seq [batch seq]")

# (2) Normal PyTorch, using `gather`
output_2 = logprobs.gather(2, labels.unsqueeze(-1)).squeeze(-1)

# (3) Normal PyTorch, not using `gather` (this is like what `eindex` does under the hood)
output_3 = logprobs[torch.arange(BATCH_SIZE).unsqueeze(-1), torch.arange(SEQ_LEN), labels]

# Check they're all the same
assert torch.allclose(output_1, output_2)
assert torch.allclose(output_1, output_3)

Multiple index dimensions

Suppose that your output vocab shape was 2D rather than 1D (weird I know but bear with me), and your labels tensor had shape (batch_size, seq_len, 2) (i.e. each slice corresponded to a different dimension of the output vocab). You want to index the following:

output[batch, seq, d1, d2] = logprobs[batch, seq, labels[batch, seq, 0], labels[batch, seq, 1]]

Again, this is implemented just like it's written:

D_VOCAB_1 = 100
D_VOCAB_2 = 50

logprobs = torch.randn(BATCH_SIZE, SEQ_LEN, D_VOCAB_1, D_VOCAB_2).log_softmax(-1)
labels = torch.stack([
    torch.randint(0, D_VOCAB_1, (BATCH_SIZE, SEQ_LEN)), 
    torch.randint(0, D_VOCAB_2, (BATCH_SIZE, SEQ_LEN))
], dim=-1)

# (1) Using eindex
output_1 = eindex(logprobs, labels, "batch seq [batch seq 0] [batch seq 1]")

# (2) Normal PyTorch, using `gather` (apparently GPT4 couldn't come up with anything less janky)
combined_index = labels[..., 0] * D_VOCAB_2 + labels[..., 1]
logprobs_flattened = logprobs.view(BATCH_SIZE, SEQ_LEN, D_VOCAB_1 * D_VOCAB_2)
output_2 = logprobs_flattened.gather(2, combined_index.unsqueeze(-1)).squeeze(-1)

# (3) Normal PyTorch, not using `gather`
output_3 = logprobs[torch.arange(BATCH_SIZE)[:, None], torch.arange(SEQ_LEN)[None, :], labels[:, :, 0], labels[:, :, 1]]

# Check they're all the same
assert torch.allclose(output_1, output_2)
assert torch.allclose(output_1, output_3)

If you had 2 different labels (rather than 2 different dimensions of the same label), this is also supported. We want to index the tensor as:

output[batch, seq, d1, d2] = logprobs[batch, seq, labels_1[batch, seq], labels_2[batch, seq]]
and this is implemented as:
logprobs = torch.randn(BATCH_SIZE, SEQ_LEN, D_VOCAB_1, D_VOCAB_2).log_softmax(-1)
labels_1 = torch.randint(0, D_VOCAB_1, (BATCH_SIZE, SEQ_LEN))
labels_2 = torch.randint(0, D_VOCAB_2, (BATCH_SIZE, SEQ_LEN))

# (1) Using eindex
output_1 = eindex(logprobs, labels_1, labels_2, "batch seq [batch seq] [batch seq]")

# (2) Normal PyTorch, using `gather`
combined_index = labels_1 * D_VOCAB_2 + labels_2
logprobs_flattened = logprobs.view(BATCH_SIZE, SEQ_LEN, D_VOCAB_1 * D_VOCAB_2)
output_2 = logprobs_flattened.gather(2, combined_index.unsqueeze(-1)).squeeze(-1)

# (3) Normal PyTorch, not using `gather`
output_3 = logprobs[torch.arange(BATCH_SIZE)[:, None], torch.arange(SEQ_LEN)[None, :], labels_1, labels_2]

# Check they're all the same
assert torch.allclose(output_1, output_2)
assert torch.allclose(output_1, output_3)

Note - when using multiple tensors, the square brackets are assumed to refer to the index tensors in the order they appear.


Offsetting dimensions

Let's go back to our logprobs and labels example earlier. Assume our labels are tokens in an autoregressive transformer. Usually, we'd have logprobs and tokens in a form such that they'd need to be offset by one, i.e. we want the tensor:

output[batch, seq] = logprobs[batch, seq, tokens[batch, seq+1]]

which has shape (batch_size, seq_len-1).

Using the tools so far, we could implement this by just slicing logprobs and tokens before doing the eindexing operation shown in the very first example:

output = eindex(logprobs[:, :-1], tokens[:, 1:], "batch seq [batch seq]")

However, there's also a way to perform this slicing within the indexing function itself:

output = eindex(logprobs, tokens, "batch seq [batch seq+1]")

This functionality is definitely more on the "optional" side, because I can imagine most users might prefer to do the slicing themselves. However, it seemed an intuitive extension to offer so I thought I'd include it!

Note - you shouldn't have a space around the + sign!

logprobs = torch.randn(BATCH_SIZE, SEQ_LEN, D_VOCAB).log_softmax(-1)
tokens = torch.randint(0, D_VOCAB, (BATCH_SIZE, SEQ_LEN))

# (1) Using eindex directly
output_1 = eindex(logprobs, tokens, "batch seq [batch seq+1]")

# (1) Using eindex plus slicing first
output_2 = eindex(logprobs[:, :-1], tokens[:, 1:], "batch seq [batch seq]")

# Check they're the same
assert output_1.shape == (BATCH_SIZE, SEQ_LEN - 1)
assert torch.allclose(output_1, output_2)

Rearranged dimensions

By default, eindex assumes that the order of dimensions in the final shape is the same as the order that each named dimension first appears in the string. But you can override this using " -> ..." syntax, just like how it works in einops.rearrange (in fact that's exactly what gets called under the hood when you do this).

logprobs = torch.randn(BATCH_SIZE, SEQ_LEN, D_VOCAB).log_softmax(-1)
labels = torch.randint(0, D_VOCAB, (BATCH_SIZE, SEQ_LEN))

# (1) Using eindex, without " -> "
output_1 = eindex(logprobs, labels, "batch seq [batch seq]")

# (1) Using eindex, without " -> " (same operation, but more explicit)
output_2 = eindex(logprobs, labels, "batch seq [batch seq] -> batch seq")

# (2) Using eindex, with " -> " (different operation, because the output is transposed)
output_3 = eindex(logprobs, labels, "batch seq [batch seq] -> seq batch")

# Check they're all the same
assert torch.allclose(output_1, output_2)
assert torch.allclose(output_1, output_3.T)

Error checking

I've tried to make the errors in this library as informative as possible. For example, if you use the same named dimension twice in the pattern string but it corresponds to different length dimensions, the exact mistake will be printed out. Here are 2 examples:

BATCH_SIZE = 32
BATCH_SIZE_INCORRECT = 33

logprobs = torch.randn(BATCH_SIZE, SEQ_LEN, D_VOCAB).log_softmax(-1)
labels = torch.randint(0, D_VOCAB, (BATCH_SIZE_INCORRECT, SEQ_LEN))

try:
    output = eindex(logprobs, labels, "batch seq [batch seq]")
except AssertionError as e:
    print(e)
Incompatible sizes for 'batch' dimension.
Based on your inputs, the inferred dimension sizes are 'batch=32 seq=5 [batch=33 seq=5]'.
logprobs = torch.randn(BATCH_SIZE, SEQ_LEN, D_VOCAB).log_softmax(-1)
labels = torch.randint(0, D_VOCAB, (BATCH_SIZE, SEQ_LEN))

try:
    output = eindex(logprobs, labels, "batch [batch seq]")
except AssertionError as e:
    print(e)
Invalid indices.
Number of terms in your string pattern = 2
Number of terms in your array to index into = 3
These should match.

If you come across any other errors which seem common enough that they should have more readable error messages, please let me know! This library is very small and easy to maintain (it's mostly just one function), so I'm pretty likely to be able to implement most small changes if I think they'd improve the library.