Check out the paper here: Byte Latent Transformer: Patches Scale Better Than Tokens
Press enter or click to view image in full size
Introduction
It’s no secret Transformers are the workhorse behind LLMs and many other GenAI models. What is probably less known is that LLMs don’t operate directly on raw text input. Instead, when you type a question into your favorite LLM user interface and hit enter that text goes through a transformation process called “tokenization”. Essentially, each language model has it’s own vocabulary that it was trained on. That vocabulary consists of a set of “tokens” which, depending up on which tokenization algorithm was used, are generally small words and word fragments. Transformers are great at modeling dependencies across sequential chunks of information so tokens work really well and they’ve been been standard practice since pretty much ever.
If Transformers work well with sequential data, this begs the question: Why not just feed raw bytes into the Transformer and skip tokenization? The biggest reason for this is Compute Efficiency. The Attention mechanism in Transformers notoriously scales quadratically according to the sequence length (n²). Sending in a stream of bytes is going to be much longer and use substantially more FLOPS than a tokenized sequence of text. Given the size of modern LLMs and the cost of the hardware they run on, being smart about how we chunk our sequences is super important.
Enter the Byte Latent Transformer (BLT), from the team at MetaAI, which completely changes this calculus. Remember that bytes themselves aren’t the problem, it’s the length of the sequence we send into the Transformer. What the BLT does is introduce a new strategy for encoding raw bytes that efficiently groups them into “patches”. Some initial overhead is spent on this patching process but the resulting sequence of patches is what is used for the vast majority of the model’s computation and it is much shorter than the byte sequence. In fact, as demonstrated in the BLT paper, this patching process results in shorter sequences on average than tokenization. As if being more compute efficient wasn’t enough, the researchers also found improvements in performance when comparing models with tokenizer LLM architectures (i.e. Llama 3) to BLT. Finally, BLT also has the advantage that it’s vocabulary is basically the space of all possible bytes (0–255) making the architecture much more flexible in terms of input.
Patching Strategies
The BLT paper introduces 3 patching strategies for converting a sequence of bytes into a sequence of patches.
- Strided Patching (K-Bytes) — This is the simplest strategy which uses a fixed patch size, k. Basically just group every k bytes into a patch.
- Space Patching — Every time a space-like byte is encountered start a new patch. For text this would roughly work out to patches being individual words.
- Entropy Patching — Use the “entropy” associated with the next byte in the sequence to determine patch boundaries. The idea is to group areas of the byte sequence together that we think are common or very likely to occur together in our dataset. When we encounter a byte position with a lot of uncertainty around it’s value, we start a new patch.
I won’t go deeper into Strided and Space Patching as these are very straightforward. Additionally, those strategies underperform both Entropy Patching and tokenization methods. Instead, I’ll focus on exploring Entropy Patching.
Now let’s look at how we quantify entropy for each byte so that we know how to create byte patches. First, the following expression represents the probability that byte i is equal to the value k given all of the previous bytes in the sequence.
Press enter or click to view image in full size
If we are confident that the next byte is equal to k then this expression will be close to 1. Alternatively if we are confident that the next byte is not equal to k then the expression will be close to 0 and if we are uncertain then it will be somewhere in between. I’ll discuss how we compute this expression a bit later. Next we use this expression of confidence to calculate the entropy, H(x), associated with byte i:
Press enter or click to view image in full size
If we plot the function f(x)=-x ln(x) we can see that this results in low entropy when we are confident about the value of the next byte (x is close to 0 or 1) and larger values in the middle where we are uncertain about it’s value.
Press enter or click to view image in full size
Now, supposing we plot the value of H(i) for a sequence of bytes we would find something like the chart below. This is taken from the BLT paper and shows the entropy of a sequence of characters where vertical lines show the start of new patches.
Press enter or click to view image in full size
The paper describes two methods for selecting patch boundaries using the next-byte entropy. In both cases a threshold value, t, is set as a hyper-parameter and new patches are started any time our entropy exceeds that threshold.
- Global Threshold: H(i) > t
- Relative Threshold: H(i) — H(i-1) > t
Now that we know how to calculate next-byte entropy and use this value to determine patch boundaries we have one big unanswered question that I kind of glossed over: How do we calculate the probability associated with the value of the next byte? After all, we need this to compute entropy but we never talked about where this comes from. The short answer is: We train a model. Yes, you heard that right! We train a small language model to predict the next byte in our sequence of bytes and we use this model both during training and inference to assist in our entropy calculation and patching. This model is not part of our primary model and it is only used for entropy patching. Note that the BLT paper does not provide specifics on the model size or architecture but presumably a small, simple Transformer trained on some subset of the full dataset would likely be sufficient.
BLT Architecture
Now that we’ve discussed how to convert a sequence of bytes into a sequence of patches let’s walk through step by step how these patches are encoded and used. Let’s start by looking at the components of the BLT architecture:
- Latent Global Transformer — This is the largest model in our architecture. In traditional tokenizer based architectures this would be the only model. In BLT this model accepts Latent Patch Representations generated by the Local Encoder as input and outputs another sequence of patch representations that the Local Decoder turns into a sequence of bytes.
- Local Encoder — This models translates the patches we created from our byte stream into an enriched set of latent representations that can be passed to the Latent Global Transformer.
- Local Decoder — This model takes the sequence of patches generated by the Latent Global Transformer and transforms them into a sequence of bytes representing the final output.
Full BLT Walkthrough
Step 1: Translate our byte sequence, b, into a sequence of embeddings, e.
Here we will swap out our raw byte values for embedding vectors. This is similar to the process in traditional Transformer models where we embed our tokens. The length of these embeddings will match the dimensionality of the Local Encoder.
Press enter or click to view image in full size
Step 2: Generate byte-grams, g, for all byte positions.
Byte-grams are similar to n-grams in NLP where you group tokens that are close to each other, except here we operate on the raw byte stream. In the next step we will use these byte grams to further enrich our embeddings, e, with more information about the company each byte is keeping. We can choose a single gram size or we can use a set of gram sizes. To keep things simple here we’ll use a single gram size, k.
Press enter or click to view image in full size
Step 3: Append byte-gram embeddings to our embedding sequence, e.
Here we will look up an embedding for each byte-gram. To keep the size of our embedding table to something manageable we use a Hash function to map each byte-gram to an integer value corresponding to a table index. This will enrich each byte embedding with information about what other bytes are nearby.
Press enter or click to view image in full size
Step 4: Use our patching strategy to determine patch boundaries and apply these boundaries to our embedding sequence.
Here, we are assuming a constant patch size, s, as a simplification. In practice patch sizes will vary when using Space Patching and Entropy Patching.
Press enter or click to view image in full size
Step 5: Use the Local Encoder to transform each patch into a latent patch representation.
Here the Local Encoder operates on a single patch at a time (ignoring batches) by passing in the set of all byte embeddings contained within that patch. The encoder has two outputs:
- Latent Patch Representation, P, is used directly by the global transformer for processing.
- Latent Byte Representations, B, are used later in the decoding process.
Press enter or click to view image in full size
Step 6: Pass the Latent Patch Representations, P, to the Global Latent Transformer for primary processing
This is where the primary processing takes places. In a traditional Transformer architecture the global transformer would be the only model. You can imagine this as similar to passing a set of text embeddings into a Transformer except that we’ve gone through a different set of pre-processing steps to translate from bytes into rich embedding representations.
Press enter or click to view image in full size
Step 7: Use the Local Decoder to translate the Global Outputs into the final output byte stream, O
The Local Decoder has two inputs:
- The output from the Global Latent Transformer, o
- The Latent Byte Representations, B, from the Local Encoder
Press enter or click to view image in full size
Local Encoder Architecture
Now that we’ve discussed the flow of patches through the BLT architecture let’s dive into the Local Encoder and take a look. The encoder architecture is composed of a series of layers, each of which makes use of a special Cross Attention mechanism designed for processing patches and a standard decoder Transformer Layer. Each layer has the same inputs and outputs: Patch Representation, P and Byte Representations B. Let’s start with the input, one of our patches with the enriched byte embeddings, and compute the output.
The first step in the Local Encoder is to compute the Patch Representation for layer 0 using the input patch. To do this we first pool all of the byte embeddings for the patch using Max Pooling. This flattens our byte embedding vectors into a single vector and has the effect of picking out the largest features from each byte in the patch. After pooling our byte embeddings they are multiplied by a projection matrix which reshapes them into the same dimensionality as the Latent Global Transformer. This is important since that’s the next step once they are encoded. For the first layer, the Byte Representations are set to the byte embedding vectors from our patch, p.
Press enter or click to view image in full size
Then in each subsequent layer the Local Encoder uses a special formulation of Cross Attention and a Transformer Layer to compute the outputs. Note the formulation of Cross Attention below where the calculation of the Patch Representation for the current layer uses the previous layer’s Patch Representation for the Queries and the Byte Representations from the previous layer for the Keys and Values.
Press enter or click to view image in full size
This process is repeated through all the layers of the Local Encoder and the Patch Representation and Byte Representations from the last layers are used as the Local Encoder output for the given input patch.
Local Decoder Architecture
Now let’s take a look at how we decode the final output of our network by digging into the Local Decoder, which has two inputs: 1) o, the output from the Global Latent Transformer and 2) B, the output Latent Byte Representations from the Local Encoder. The decoder has a similar architecture to the Local Encoder in that each layer makes use of Cross Attention and a Transformer Layer to compute the outputs, however with several important differences. First, the Keys and Values of the Cross Attention mechanism are fixed for all decoder layers. Note that this greatly improves performance during decoding. Additionally, Cross Attention and the Transformer Layer are computed sequentially versus parallel in the encoder. In the following formulation you’ll see that the output of Cross Attention is passed into the Transformer Layer to compute the layer output.
Press enter or click to view image in full size
This process is repeated through all layers of the Local Decoder and the output, O, from the last layer is our output byte stream.
Conclusion
The Byte Latent Transformer is a significant departure from the tokenization process used in the traditional Transformer architecture. While it offers benefits in terms of compute efficiency it also introduces architectural complexity. Instead of one model we now have a Local Encoder, Latent Global Transformer, and a Local Decoder. Additionally, if you intend to use Entropy Patching, which is the preferred patching strategy, you will also need to train a separate model for computing next-byte likelihoods. Time will tell if this approach catches on but I imagine that tooling and best practices will reduce the overhead associated with this architecture. The opportunity to improve performance while also getting a large reduction in compute cost seems too good to pass up! 🥪