key worth kv caching mistral transformers xformers

Ever questioned why the time to first token in LLMs is excessive however subsequent tokens are tremendous quick?

On this submit, I dive into the main points of KV-Caching utilized in Mistral, a subject I initially discovered fairly daunting. Nonetheless, as I delved deeper, it turned an interesting topic, particularly when it defined why the time to first token (TTFT) in these language fashions is usually excessive — a sample I seen throughout numerous API calls 🙂.

I’ll cowl:

  1. What precisely is KV-Caching?
  2. The idea of the rolling cache buffer
  3. The prefill and decode levels
  4. Formulating consideration masks with the assistance of the xFormers library

Think about our enter token sequence as x1, x2, x3xt, and we’re figuring out the output at time step t. To seek out the eye output (at every transformer layer), we’d like the dot product of the present token’s question vector with the important thing vectors of the present and previous tokens. After normalizing by way of softmax, these develop into the eye weights over the worth vectors. Listed here are two key observations:

  1. Single Token Decoding: Decoding occurs one token at a time. We’re solely within the self-attention output for the present token, focusing solely on its question vector, not question vectors of different tokens.
  2. Precomputed Keys and Values: We want the dot product with the keys of previous tokens, which have been already computed when calculating the self-attention output of the token at time step t−1. The identical goes for the worth vectors.

The size of the important thing portions are as follows:

  1. Token Embedding Vectors: dim
  2. Dimension of Question, Key, Worth Heads: head_dim
  3. Variety of Question Heads: n_heads
  4. Variety of Key and Worth Heads: n_kv_heads
  5. Variety of Transformer Layers: n_layers

(Be aware: Mistral makes use of grouped question consideration the place for every token, 4 of its question vectors attend to the identical key-value pair. With n_heads=32, we’ve n_kv_heads=32/4=8)

Within the unoptimized implementation:

Assuming a single transformer layer, at every time step, we calculate the question for the present token, and the important thing and worth vectors for each the present and previous tokens. This course of entails three matrix multiplications.

a. Question Calculation (Q):

b. Key Calculation (K):

c. Worth Calculation (V):

As soon as we’ve the question, key and worth vectors we are able to then proceed to compute the eye output utilizing —

Within the optimized implementation:

Nonetheless, as talked about in level 2, the keys and values of tokens as much as time step t−1 would have already been computed when figuring out the output at time step t−1. This implies we are able to keep away from redundant computations by storing the keys and values of tokens as much as time step t−1.

Be aware: Mistral makes use of a sliding window consideration mechanism, so we solely attend to a particular variety of earlier tokens. Extra particulars on this can be coated later.

What this implies is that in decoding, we compute the important thing and worth vectors just for the present token and never for the earlier ones. So, operations (b) and (c) above are carried out for only one token as a substitute of t tokens. Particularly:

Key Calculation (K):

Worth Calculation (V):

FLOPS Saved

At each step of decoding, we save 2*(t-1)*n_kv_heads*dim² FLOPS. For a sequence of size T, this interprets to financial savings of 2*(T*(T-1)/2)*n_kv_heads*dim²FLOPS.

Contemplating we’ve assumed a single transformer layer, and realizing that Mistral makes use of 32 transformer layers, the financial savings are multiplied by 32. That is vital!

For a typical sequence size of 10,000 tokens, with n_kv_heads=8 and dim=4096, we get 4.294e+17 FLOPS (10000*10000*8*4096*4096*32)

An Nvidia A100 GPU has roughly 312e+12 FLOPS, that means we might save round 23 minutes in producing this sequence of 10,000 tokens!

Be aware: This can be a simplified calculation to present an concept of the advantages, that are certainly substantial. Precise enhancements will rely on varied components equivalent to most possible cache measurement, GPU reminiscence, parallelization with a number of GPUs, and many others.

Now that we perceive the KV cache, I’ll talk about how we leverage it throughout output era!

First, let’s set up some terminology utilized by Mistral:

  1. Sliding Window Consideration (SWA): Mistral makes use of SWA, that means every token attends to itself and the earlier W−1 tokens, the place W is the window measurement.
  2. KV Cache Dimension: We set our KV Cache to measurement W. This implies we are able to retailer W key vectors and W worth vectors within the cache. This ensures we’ve the required context to compute the self-attention output for the subsequent token.
  3. Chunk Dimension: We course of person enter immediate sequences additionally W tokens at a time (extra on this within the subsequent part on Prefill). This chunk measurement limits GPU reminiscence utilization. Self-attention requires K, Q, and V to be on the GPU, and these develop with the enter measurement, making it impractical to course of all the enter sequence in a single batch.

Be aware:

Every transformer layer in Mistral has its personal separate KV Cache.

At first, it may appear (it did to me!) that calculating and caching solely the keys and values of the final W-1 tokens within the enter sequence could be enough to generate the primary output token. Nonetheless, that’s not the case! It is because Mistral has multiple transformer layer. To compute the output from the second layer of our subsequent token, we’d like the output of the final W−1 tokens within the first layer, which in flip depends upon the final (2W−1) enter tokens (much like receptive discipline in CNNs!)

Mistral makes use of a window measurement of W = 4096 tokens.

The enter to those fashions normally begins with user-provided tokens (the well-known person immediate 😊), adopted by the era of output tokens. The stage the place we populate the KV-cache with the keys and values from the person immediate, so we are able to use them when producing output tokens, is known as the prefill stage. That is the important thing purpose why the time to first token (TTFT) is usually excessive.

To grasp the workings of the prefill stage, let’s stroll by way of an instance:

Think about we’ve 3 sequences in our inference batch with person immediate token lengths of 4, 1, and three respectively. Suppose we’ve a window measurement W=3, and we wish to generate the subsequent 5 tokens for every sequence.

Given:

  1. seqlens = [4,1,3]
  2. sliding_window_size = cache_size = 3
  3. chunk_size = 2 (for illustration functions, ideally this is able to even be = W = 3 as talked about earlier than)

Within the prefill stage, since we have already got all of the enter tokens, we are able to course of them in parallel. With a chunk_size of two we require two iterations as defined under.

We now have a bit measurement of two, so we’ll course of the primary 2 tokens from every sequence. This implies the sequence lengths into account for this step are [2,1,2].

To batch the three sequences, one method is to pad the shorter sequences to match the longest sequence. Nonetheless, if the sequences fluctuate vastly in size, padding ends in lots of wasted reminiscence. Therefore, this method is usually not used.

The popular method is to concatenate all of the sequences within the batch right into a single bigger sequence. We’ll create an applicable consideration masks in order that tokens attend solely to these throughout the similar sequence.

This means our enter form is: [2+1+2,dim] = [5,dim]

We compute our Q, K, and V vectors for this enter by multiplying with matrices Wq, Wk, and Wv. Assuming the variety of heads = 1 for simplicity, the outputs could have the next shapes:

a. Q: [5, head_dim]

b. K: [5, head_dim]

c. V: [5, head_dim]

Subsequent, we add rotary positional encodings to our Q and K vectors.

With these preparations, we’re able to calculate the self-attention output!

Step 1: Retrieve from KV-Cache and Compute Consideration

Since that is the primary chunk, we take a look at the KV-cache and discover it empty — no vectors saved there. This implies there are not any earlier tokens to take care of, solely the present token itself. Consequently, the variety of key-value vectors (kv_seqlen) matches the variety of question vectors (q_seqlen) in every sequence.

To deal with this, we create our masks utilizing the BlockDiagonalCausalMask from the xFormers library like so:

masks = BlockDiagonalCausalMask.from_seqlens(q_seqlen = [2,1,2], kv_seqlen=[2,1,2]).make_local_attention(window_size=3)

The eye masks may be visualized utilizing

masks.materialize(form=(5,5)).exp()
# The 'form' argument is obtained as follows: the primary dimension is the overall variety of question vectors and the second dimension is the overall variety of key/worth vectors

and the output is

[[1., 0., 0., 0., 0.],
[1., 1., 0., 0., 0.],
[0., 0., 1., 0., 0.],
[0., 0., 0., 1., 0.],
[0., 0., 0., 1., 1.]]

Let’s perceive how we obtained this masks and why it is sensible. Concentrate on q_seqlen = [2,1,2] and kv_seqlen=[2,1,2].

Picture by creator

The primary sequence has 2 question vectors and a couple of key-value (kv) vectors. The eye masks for this sequence is the 2×2 matrix within the prime left:

[[1,0],
[1,1]]

The second aspect within the first row is 0 as a result of it is a causal masks, and we don’t want the primary token to take care of the second token (sooner or later).

The second sequence has simply 1 question and 1 kv vector, represented by the middle 1×1 matrix. The third sequence, much like the primary, has an an identical 2×2 matrix within the backside proper.

Discover that the eye masks for the sequences are logically concatenated alongside the diagonal.

Setting the window measurement to three in our masks creation ensures that we solely take into account as much as 3 tokens for consideration per sequence.

This masks is utilized to the output of the matrix product of Q and K.T. Thus, dot merchandise of queries and keys from totally different sequences are nullified by the 0s within the mixed consideration matrix, preserving causality.

Be aware: Underneath the hood, xFormers doesn’t calculate these dot merchandise in any respect that may be nullified by the 0s by the eye masks

The BlockDiagonalCausalMask in xFormers begins filling 1s from the top-left of every block, which is precisely what we’d like for our first prefill.

Step 2: Cache Replace

Subsequent, we replace the cache with the computed keys and values. Our cache measurement is initialized to W×batch_size=W×3 that’s one for every sequence and one every for key and values. This can be a rolling cache that means tokens within the first sequence will dissipate cache positions [0, 1, 2, 0, 1, 2 …], tokens within the second sequence will dissipate cache positions [3, 4, 5, 3, 4, 5 …] and tokens within the third sequence will dissipate cache positions [6, 7, 8, 6, 7, 8 …].

So, our KV-Cache after the primary iteration (on processing 2, 1 and a couple of variety of tokens from every sequence) seems like this:

Picture by creator

We now transfer on to the remaining a part of our sequences. The remaining tokens to course of for every sequence are [2, 0, 1]. In Mistral code, this stage is known as the ‘subsequent prefill’ stage.

Step 1: Retrieve from KV-Cache and Compute Consideration

As in iteration 1, we first take a look at the KV-cache however now we discover entries in them. We retrieve the entries and carry out and an unroll/unrotate step on them to revive the right sequence order. Why can we do that?

Keep in mind, it is a rolling cache. If we had processed, say, 5 tokens, the queries and values for the 4th and fifth tokens would occupy the primary two cache positions, adopted by these of the third token. After unrolling, we might have the queries and values of the third, 4th, and fifth tokens in that order. Nonetheless, on this case, since we haven’t processed greater than 3 tokens, the present cache order matches the token order.

Be aware: The explanation we have to unrotate is that throughout the prefill stage, we course of a number of tokens per sequence and we have to establish which queries ought to attend to which keys within the sequence. In distinction, throughout the decode stage (described within the following part), we course of just one token of a sequence at a time. In that case, unrotation isn’t needed as a result of this single token will attend to all parts within the cache.

At present, the variety of question vectors for every sequence is [2, 0, 1]. The variety of key vectors is calculated because the variety of question vectors plus the variety of legitimate entries within the cache:

kv_seqlen = [2+2, 0+1, 1+2] = [4, 1, 3]

We create the masks utilizing the make_local_attention_from_bottomright() methodology of the BlockDiagonalMask class from xFormers:

BlockDiagonalMask.from_seqlens(
q_seqlen=[2,0,1],
kv_seqlen=[4,1,3],
).make_local_attention_from_bottomright(window_size=3)

This masks seems like:

Picture by creator

Just like the logic defined in Iteration 1, we’ve three matrices concatenated diagonally, the place the rows signify the variety of queries and the columns signify the variety of keys in every sequence.

Right here, we have to use make_local_attention_from_bottomright() as a substitute of make_local_attention(), as we wish to begin from the underside proper in every block.

Step 2: Cache Replace

We retailer the computed keys and values into the cache much like iteration 1 in a rolling trend. Our up to date cache then seems like this:

Picture by creator

After the prefill stage, we transfer on to the decode stage, the place we start producing our output tokens one by one.

Not like the prefill stage, the place Step 1 entails studying cache entries and computing consideration and Step 2 entails updating the cache with the brand new entries, within the decode stage we reverse these steps. First, we replace the cache with the brand new entries, after which we learn all of the entries (together with those we simply added) to compute self-attention.

This method works neatly as a result of decoding occurs one token at a time, and we all know all entries within the cache are inside our context window (of measurement W) and wanted for self-attention.

Step 1: Cache Replace

We compute the important thing and worth vectors for the present enter token and add them to the cache. The brand new tokens are #4, #1 and #3 for the three sequences. The up to date cache seems like this:

Picture by creator

Step 2: Retrieve from KV-Cache and Compute Consideration

We now proceed to compute self-attention and the related masks!

  1. We now have one question for every sequence within the batch, so
    q_seqlen= [1, 1, 1].
  2. The variety of keys is the variety of legitimate entries within the cache, given by kv_seqlen = [3, 2, 3].

Within the Mistral codebase, for simplicity, they repair the eye masks form to (W×batch_size, W×batch_size) = (9,9)

We create our consideration masks once more with xFormers like so:

BlockDiagonalCausalWithOffsetPaddedKeysMask.from_seqlens(
q_seqlen=[1,1,1],
kv_padding=3,
kv_seqlen=[3,2,3]
)

This masks seems like:

Picture by creator

We now have 3 blocks of 1×3 matrices concatenated diagonally. Since we mounted our consideration masks to 9×9 for simplicity, our preliminary consideration rating matrix (earlier than making use of the masks) considers dot merchandise between all queries within the cache (legitimate or not) with all keys. That is evident, for instance, in sequence 2 above, the place we place a 0 within the third entry of the block to invalidate that entry.