To higher perceive MLA and likewise make this text self-contained, we’ll revisit a number of associated ideas on this part earlier than diving into the main points of MLA.
MHA in Decoder-only Transformers
Word that MLA is developed to speedup inference velocity in autoregressive textual content era, so the MHA we’re speaking about below this context is for decoder-only Transformer.
The determine under compares three Transformer architectures used for decoding, the place (a) reveals each the encoder and decoder proposed within the authentic “Consideration is All You Want” paper. Its decoder half is then simplified by [6], resulting in a decoder-only Transformer mannequin proven in (b), which is later utilized in many era fashions like GPT [8].
These days, LLMs are extra generally to decide on the construction proven in (c) for extra secure coaching, with normalization utilized on the enter fairly then output, and LayerNorm upgraded to RMS Norm. This may function the baseline structure we’ll focus on on this article.
Inside this context, MHA calculation largely follows the method in [6], as proven within the determine under:
Assume now we have n_h consideration heads, and the dimension for every consideration head is represented as d_h, in order that the concatenated dimension might be (h_n · d_h).
Given a mannequin with l layers, if we denote the enter for the t-th token in that layer as h_t with dimension d, we have to map the dimension of h_t from d to (h_n · d_h) utilizing the linear mapping matrices.
Extra formally, now we have (equations from [3]):
the place W^Q, W^Ok and W^V are the linear mapping matrices:
After such mapping, q_t, k_t and v_t might be cut up into n_h heads to calculate the scaled dot-product consideration:
the place W^O is one other projection matrix to map the dimension inversely from (h_n · d_h) to d:
Word that the method described by Eqn.(1) to (8) above is only for a single token. Throughout inference, we have to repeat this course of for every newly generated token, which entails numerous repeated calculation. This results in a way known as Key-Worth cache.
Key-Worth Cache
As advised by its identify, Key-Worth cache is a way designed to speedup the autoregressive course of by caching and reusing the earlier keys and values, fairly than re-computing them at every decoding step.
Word that KV cache is usually used solely in the course of the inference stage, since in coaching we nonetheless must course of your entire enter sequence in parallel.
KV cache is often carried out as a rolling buffer. At every decoding step, solely the brand new question Q is computed, whereas the Ok and V saved within the cache might be reused, in order that the eye might be computed utilizing the brand new Q and reused Ok, V. In the meantime, the brand new token’s Ok and V will even be appended to the cache for later use.
Nonetheless, the speedup achieved by KV cache comes at a price of reminiscence, since KV cache usually scales with batch measurement × sequence size × hidden measurement × variety of heads, resulting in a reminiscence bottleneck when now we have bigger batch measurement or longer sequences.
That additional results in two methods aiming at addressing this limitation: Multi-Question Consideration and Grouped-Question Consideration.
Multi-Question Consideration (MQA) vs Grouped-Question Consideration (GQA)
The determine under reveals the comparability between the unique MHA, Grouped-Question Consideration (GQA) [10] and Multi-Question Consideration (MQA) [9].
The fundamental thought of MQA is to share a single key and a single worth head throughout all question heads, which may considerably scale back reminiscence utilization however will even affect the accuracy of consideration.
GQA could be seen as an interpolating technique between MHA and MQA, the place a single pair of key and worth heads might be shared solely by a bunch of question heads, not all queries. However nonetheless it will result in inferior outcomes in comparison with MHA.
Within the later sections, we’ll see how MLA manages to hunt a steadiness between reminiscence effectivity and modeling accuracy.
RoPE (Rotary Positional Embeddings)
One final piece of background we have to point out is RoPE [11], which encodes positional data instantly into the eye mechanism by rotating the question and key vectors in multi-head consideration utilizing sinusoidal capabilities.
Extra particularly, RoPE applies a position-dependent rotation matrix to the question and key vectors at every token, and makes use of sine and cosine capabilities for its foundation however applies them in a singular option to obtain rotation.
To see what makes it position-dependent, take into account a toy embedding vector with solely 4 components, i.e., (x_1, x_2, x_3, x_4).
To use RoPE, we firstly group consecutive dimensions into pairs:
- (x_1, x_2) -> place 1
- (x_3, x_4) -> place 2
Then, we apply a rotation matrix to rotate every pair:
the place θ = θ(p) = p ⋅ θ_0, and θ_0 is a base frequency. In our 4-d toy instance, which means (x_1, x_2) might be rotated by θ_0, and (x_3, x_4) might be rotated by 2 ⋅ θ_0.
This is the reason we name this rotation matrix as position-dependent: at every place (or every pair), we’ll apply a distinct rotation matrix the place the rotation angle is set by place.
RoPE is extensively utilized in trendy LLMs resulting from its effectivity in encoding lengthy sequences, however as we will see from the above system, it’s position-sensitive to each Q and Ok, making it incompatible with MLA in some methods.