Einops-inspired indexing
You can see the accompanying Colab here.
You can install the eindex
library with:
pip install git+https://github.com/callummcdougall/eindex.git
There are a few libraries which have made functions like this, but here is my own interpretation of what indexing would look like with einops-style notation.
The idea is for the pattern string you pass to the eindex
function to be as close as possible to how you'd think about defining each element of the output tensor. For example, suppose you have a logprobs tensor of shape (batch_size, seq_len, d_vocab)
, and a tensor of correct next tokens with shape (batch_size, seq_len)
, and you want to get the logprobs on the correct tokens. You would think about this as follows:
output[batch, seq] = logprobs[batch, seq, labels[batch, seq]]
In my library, this is implemented simply by writing the text inside the right hand square brackets expression as the pattern string:
output = eindex(logprobs, labels, "batch seq [batch seq]")
Setup
%pip install git+https://github.com/callummcdougall/eindex.git
from eindex import eindex
import torch
Examples
Indexing logprobs
Here is the example given above (along with a few other ways you could get the same result, and showing that they're equivalent):
BATCH_SIZE = 32
SEQ_LEN = 5
D_VOCAB = 100
logprobs = torch.randn(BATCH_SIZE, SEQ_LEN, D_VOCAB).log_softmax(-1)
labels = torch.randint(0, D_VOCAB, (BATCH_SIZE, SEQ_LEN))
# (1) Using eindex
output_1 = eindex(logprobs, labels, "batch seq [batch seq]")
# (2) Normal PyTorch, using `gather`
output_2 = logprobs.gather(2, labels.unsqueeze(-1)).squeeze(-1)
# (3) Normal PyTorch, not using `gather` (this is like what `eindex` does under the hood)
output_3 = logprobs[torch.arange(BATCH_SIZE).unsqueeze(-1), torch.arange(SEQ_LEN), labels]
# Check they're all the same
assert torch.allclose(output_1, output_2)
assert torch.allclose(output_1, output_3)
Multiple index dimensions
Suppose that your output vocab shape was 2D rather than 1D (weird I know but bear with me), and your labels tensor had shape (batch_size, seq_len, 2)
(i.e. each slice corresponded to a different dimension of the output vocab). You want to index the following:
output[batch, seq, d1, d2] = logprobs[batch, seq, labels[batch, seq, 0], labels[batch, seq, 1]]
Again, this is implemented just like it's written:
D_VOCAB_1 = 100
D_VOCAB_2 = 50
logprobs = torch.randn(BATCH_SIZE, SEQ_LEN, D_VOCAB_1, D_VOCAB_2).log_softmax(-1)
labels = torch.stack([
torch.randint(0, D_VOCAB_1, (BATCH_SIZE, SEQ_LEN)),
torch.randint(0, D_VOCAB_2, (BATCH_SIZE, SEQ_LEN))
], dim=-1)
# (1) Using eindex
output_1 = eindex(logprobs, labels, "batch seq [batch seq 0] [batch seq 1]")
# (2) Normal PyTorch, using `gather` (apparently GPT4 couldn't come up with anything less janky)
combined_index = labels[..., 0] * D_VOCAB_2 + labels[..., 1]
logprobs_flattened = logprobs.view(BATCH_SIZE, SEQ_LEN, D_VOCAB_1 * D_VOCAB_2)
output_2 = logprobs_flattened.gather(2, combined_index.unsqueeze(-1)).squeeze(-1)
# (3) Normal PyTorch, not using `gather`
output_3 = logprobs[torch.arange(BATCH_SIZE)[:, None], torch.arange(SEQ_LEN)[None, :], labels[:, :, 0], labels[:, :, 1]]
# Check they're all the same
assert torch.allclose(output_1, output_2)
assert torch.allclose(output_1, output_3)
If you had 2 different labels (rather than 2 different dimensions of the same label), this is also supported. We want to index the tensor as:
output[batch, seq, d1, d2] = logprobs[batch, seq, labels_1[batch, seq], labels_2[batch, seq]]and this is implemented as:
logprobs = torch.randn(BATCH_SIZE, SEQ_LEN, D_VOCAB_1, D_VOCAB_2).log_softmax(-1)
labels_1 = torch.randint(0, D_VOCAB_1, (BATCH_SIZE, SEQ_LEN))
labels_2 = torch.randint(0, D_VOCAB_2, (BATCH_SIZE, SEQ_LEN))
# (1) Using eindex
output_1 = eindex(logprobs, labels_1, labels_2, "batch seq [batch seq] [batch seq]")
# (2) Normal PyTorch, using `gather`
combined_index = labels_1 * D_VOCAB_2 + labels_2
logprobs_flattened = logprobs.view(BATCH_SIZE, SEQ_LEN, D_VOCAB_1 * D_VOCAB_2)
output_2 = logprobs_flattened.gather(2, combined_index.unsqueeze(-1)).squeeze(-1)
# (3) Normal PyTorch, not using `gather`
output_3 = logprobs[torch.arange(BATCH_SIZE)[:, None], torch.arange(SEQ_LEN)[None, :], labels_1, labels_2]
# Check they're all the same
assert torch.allclose(output_1, output_2)
assert torch.allclose(output_1, output_3)
Note - when using multiple tensors, the square brackets are assumed to refer to the index tensors in the order they appear.
Offsetting dimensions
Let's go back to our logprobs and labels example earlier. Assume our labels are tokens in an autoregressive transformer. Usually, we'd have logprobs
and tokens
in a form such that they'd need to be offset by one, i.e. we want the tensor:
output[batch, seq] = logprobs[batch, seq, tokens[batch, seq+1]]
which has shape (batch_size, seq_len-1)
.
Using the tools so far, we could implement this by just slicing logprobs
and tokens
before doing the eindexing operation shown in the very first example:
output = eindex(logprobs[:, :-1], tokens[:, 1:], "batch seq [batch seq]")
However, there's also a way to perform this slicing within the indexing function itself:
output = eindex(logprobs, tokens, "batch seq [batch seq+1]")
This functionality is definitely more on the "optional" side, because I can imagine most users might prefer to do the slicing themselves. However, it seemed an intuitive extension to offer so I thought I'd include it!
Note - you shouldn't have a space around the +
sign!
logprobs = torch.randn(BATCH_SIZE, SEQ_LEN, D_VOCAB).log_softmax(-1)
tokens = torch.randint(0, D_VOCAB, (BATCH_SIZE, SEQ_LEN))
# (1) Using eindex directly
output_1 = eindex(logprobs, tokens, "batch seq [batch seq+1]")
# (1) Using eindex plus slicing first
output_2 = eindex(logprobs[:, :-1], tokens[:, 1:], "batch seq [batch seq]")
# Check they're the same
assert output_1.shape == (BATCH_SIZE, SEQ_LEN - 1)
assert torch.allclose(output_1, output_2)
Rearranged dimensions
By default, eindex
assumes that the order of dimensions in the final shape is the same as the order that each named dimension first appears in the string. But you can override this using " -> ..."
syntax, just like how it works in einops.rearrange
(in fact that's exactly what gets called under the hood when you do this).
logprobs = torch.randn(BATCH_SIZE, SEQ_LEN, D_VOCAB).log_softmax(-1)
labels = torch.randint(0, D_VOCAB, (BATCH_SIZE, SEQ_LEN))
# (1) Using eindex, without " -> "
output_1 = eindex(logprobs, labels, "batch seq [batch seq]")
# (1) Using eindex, without " -> " (same operation, but more explicit)
output_2 = eindex(logprobs, labels, "batch seq [batch seq] -> batch seq")
# (2) Using eindex, with " -> " (different operation, because the output is transposed)
output_3 = eindex(logprobs, labels, "batch seq [batch seq] -> seq batch")
# Check they're all the same
assert torch.allclose(output_1, output_2)
assert torch.allclose(output_1, output_3.T)
Error checking
I've tried to make the errors in this library as informative as possible. For example, if you use the same named dimension twice in the pattern string but it corresponds to different length dimensions, the exact mistake will be printed out. Here are 2 examples:
BATCH_SIZE = 32
BATCH_SIZE_INCORRECT = 33
logprobs = torch.randn(BATCH_SIZE, SEQ_LEN, D_VOCAB).log_softmax(-1)
labels = torch.randint(0, D_VOCAB, (BATCH_SIZE_INCORRECT, SEQ_LEN))
try:
output = eindex(logprobs, labels, "batch seq [batch seq]")
except AssertionError as e:
print(e)
Incompatible sizes for 'batch' dimension. Based on your inputs, the inferred dimension sizes are 'batch=32 seq=5 [batch=33 seq=5]'.
logprobs = torch.randn(BATCH_SIZE, SEQ_LEN, D_VOCAB).log_softmax(-1)
labels = torch.randint(0, D_VOCAB, (BATCH_SIZE, SEQ_LEN))
try:
output = eindex(logprobs, labels, "batch [batch seq]")
except AssertionError as e:
print(e)
Invalid indices. Number of terms in your string pattern = 2 Number of terms in your array to index into = 3 These should match.
If you come across any other errors which seem common enough that they should have more readable error messages, please let me know! This library is very small and easy to maintain (it's mostly just one function), so I'm pretty likely to be able to implement most small changes if I think they'd improve the library.