To Masks or To not Masks: The Impact of Immediate Tokens on Instruction Tuning | by David Vaughn | Sep, 2024

These plots recommend that when a dataset’s Rg distribution covers a number of orders of magnitude or has non-negligible illustration in each the Rg>1 and Rg<1 areas (similar to within the case with OpenOrca and different datasets with R̅g>1) the distribution can change into extremely skewed. In consequence, the arithmetic imply could also be disproportionately influenced by bigger values, doubtlessly misrepresenting the distribution’s central tendency. In such instances, computing the imply in log-space (then optionally reworking it again to the unique scale) may present a extra significant abstract statistic. In different phrases, it might make sense to make use of the geometric imply:

The RACE Studying Comprehension Dataset

Primarily based on the above R̅g desk, I made a decision the RACE ReAding Comprehension Dataset from Examinations (R̅g=0.01) can be a superb candidate for investigation. A number of alternative QA appeared like a great test-bed for exploring the results of prompt-masking, for the reason that immediate is of course very lengthy relative to the completion. No matter immediate size, the completion is at all times 1 character lengthy, particularly A, B, C or D (if you happen to ignore particular tokens, delimiters, and so on). My hunch was that if there are any results from modulating immediate token weights, they will surely be noticeable right here.

As acknowledged within the dataset card:

RACE is a large-scale studying comprehension dataset with greater than 28,000 passages and practically 100,000 questions. The dataset is collected from English examinations in China, that are designed for center faculty and highschool college students. The dataset may be served because the coaching and take a look at units for machine comprehension.

The QA schema is straightforward: the immediate presents a query, presumably some context (the article area), after which lists 4 choices. The completion (reply) is at all times considered one of: A, B, C, D. This dataset viewer hosted on HuggingFace permits shopping the complete set, however right here’s a small instance:

RACE instance (screenshot from https://huggingface.co/datasets/ehovy/race/viewer/all/practice)

Earlier than we soar into the complete implementation of prompt-loss-weight, and take a look at it out on the RACE knowledge, we’d like a fundamental understanding of loss and the place it comes from. Merely put, loss is a measure of how properly our mannequin (LLM) “matches” (explains, predicts) our knowledge. Throughout fine-tuning (and in addition pre-training), we “transfer” the mannequin nearer to the information by tweaking the community weights in such a approach that decreases the loss. The chain rule (of calculus) offers us a exact algorithm for computing these tweaks, given the loss perform and the community structure.

The commonest loss perform in LLM fine-tuning is known as Cross Entropy Loss (CEL). Because of this, most discussions of CEL are framed across the definition of cross-entropy, which comes from info idea. Whereas it’s true that “cross-entropy” is true there within the title, a extra intuitive understanding may be achieved when approaching CEL by way of the lens of most chance estimation (MLE). I’ll attempt to clarify it from each angles.

We have now already established that LLMs are wired for subsequent token prediction. What this implies is that the LLM is principally only a mathematical perform that takes as enter a sequence of tokens, and outputs a conditional chance distribution for the subsequent token over the complete token vocabulary V. In different phrases, it outputs a vector of chance values of dimension |V| that sums to 1. (in set notation |S| denotes the variety of components, or cardinality, of a set S)

Let’s take a small toy instance for example how this works. Think about that our coaching knowledge accommodates the 4-token sequence: The fowl flew away. Given the primary 3 tokens (The fowl flew), an LLM may output the next vector of chances for each doable 4ᵗʰ token — for the sake of simplicity, we’ll think about that the 5 candidate tokens listed (in magenta) are the one prospects (i.e. |V|=5). The perform p() represents the conditional chances output by the LLM (discover they sum to 1):

(picture by the writer)

When coaching (or fine-tuning) an LLM on a token sequence, we step by way of the sequence token-by-token and examine the next-token-distribution generated by the LLM to the precise subsequent token within the sequence, and from there we calculate the CEL for that token.

Discover right here that the precise 4ᵗʰ token within the sequence (away) does not have the very best chance within the desk. Throughout coaching, we want to tweak the weights barely in order to extend the chance of away, whereas lowering the others. The key is having the precise loss perform… it permits us to compute precisely how a lot to tweak every weight, for every token.

As soon as the loss is computed for every token, the ultimate loss is computed because the common per-token-loss over all tokens. However first we should set up the formulation for this per-token-loss.

Data Concept Interpretation

Persevering with the toy drawback, to compute CEL for the 4ᵗʰ token place, we examine the precise 4ᵗʰ token to the generated distribution p() over all 5 doable 4ᵗʰ tokens. Actually, we deal with the precise 4ᵗʰ token as a distribution q() in its personal proper (albeit a degenerate one) that has a price of 1 for the token showing within the knowledge –away– and a price of 0 for all different doable 4ᵗʰ tokens (that is typically known as one-hot encoding).

(picture by the writer)

The explanation we contort the coaching knowledge into this unusual one-hot encoded chance illustration q() is so we are able to apply the formulation for cross-entropy, which is a measure of the divergence between two discrete chance distributions (BTW, not symmetric w.r.t. q,p):

the place x indexes over all doable states (i.e. 5 tokens). This works out to:

So principally CEL is simply utilizing the q vector to pick from the p vector the only worth similar to the token that really seems within the knowledge –away– (i.e. multiplying it by 1), and throwing away all different values (i.e. multiplying by 0). So we’re indexing over all doable states (tokens) solely to pick one and ignore the remainder.

MLE Interpretation

When fine-tuning an LLM, we search the LLM weights θ that maximize the chance of the coaching knowledge given these weights, typically known as the chance of the weights ℒ(θ) = ℙ(D|θ). And so we require an expression for this amount. Fortunately, there’s a straightforward approach to compute this from subsequent token chances, which the LLM already offers us.

Beginning with the different chain rule (of chance), we decompose the joint chance of a token sequence S right into a product of conditional chances:

Chain Rule (chance)

This decomposition establishes the connection between next-token-prediction and the joint chance of the complete token sequence — the joint chance is simply the product of all of the conditionals.

Utilizing i to index over the tokens of a token sequence S = (t₁,t₂,t₃,…, tᵢ ,…), we’ll use the next shorthand to indicate the conditional chance output by an LLM for the iᵗʰ token in a sequence, given the LLM weights θ and the earlier i-1 tokens:

It ought to be emphasised that pᵢ is not a vector right here (i.e. a distribution over all doable subsequent tokens) however represents solely the chance computed for the precise iᵗʰ token, i.e. the yellow highlighted row within the above instance.

If we take the logarithm of the joint chance of a sequence, a product turns into a sum (since log is monotonic, this doesn’t have an effect on optimization):

Now we are able to join the ultimate sum-of-logs expression (proper right here☝)️ to the formulation for Common Cross Entropy Loss L over a token sequence:

which is the causal language mannequin goal perform. Typically the “Common” is dropped from the title, and it’s simply known as “Cross Entropy Loss,” but it surely’s good to do not forget that CEL is technically computed on the token degree, after which averaged throughout tokens. From this last expression it ought to hopefully be clear that minimizing the CEL is equal to maximizing the chance of the token sequence, which is what MLE seeks.

One comfort ensuing from the type of this expression is that it is vitally straightforward to change if we wish to compute the loss over any subset of the tokens. Recall that we might typically be involved in discovering the LLM weights θ that maximize the chance of the completion given the immediate:

We might simply regulate the loss for this state of affairs by merely averaging solely over the completion tokens. If we use “𝕀c” to denote the set of all completion token indices, then we are able to specific completion loss as:

For the reason that loss for every token is already conditioned on all earlier tokens within the sequence, which means the immediate is mechanically accounted for within the conditional, even when we common over completion tokens solely.

Now that now we have established CEL as an common of per-token losses over a token sequence, we are able to outline the weighted common model of CEL:

Relying how we set the weights wᵢ, we are able to use this formulation to outline a number of losses. For instance, if we set all weights wᵢ =1 then we get better the usual, full sequence CEL from earlier than. Nonetheless, if we set wᵢ =1 just for completion tokens, and wᵢ = 0 for immediate tokens, then we get completion loss. And likewise, immediate loss is outlined by setting wᵢ =1 solely over immediate tokens, and wᵢ = 0 in any other case.

Since we not often (if ever) wish to down-weight the completion tokens, we repair the completion token weights at wᵢ =1, however for the immediate tokens we are able to outline a steady worth on the [0:1] interval known as prompt_loss_weight. This fashion we are able to tune how a lot to weight the immediate tokens throughout coaching, from wᵢ = 0 (completion loss) all the best way to wᵢ =1 (normal full sequence loss). Or, we might even use wᵢ =0.1 to offer the immediate tokens a small however non-zero weight.

Loss Implementation

Let’s have a look beneath the hood at how loss is often computed within the HuggingFace transformers package deal. Since we’ll be fine-tuning the Llama-2–7b-chat-hf mannequin in our experiments, we’ll take a look at LlamaForCausalLM, particularly on the ahead go, the place loss is computed throughout coaching.

Recall that loss is a approach of evaluating every precise token to the LLM’s prediction for that token (given the previous precise tokens) — and so the loss perform wants entry to those two knowledge buildings. On this case, loss is fed two tensors: logitsand labels. The labels tensor holds the precise tokens (token ids to be actual). Thelogits tensor holds the anticipated next-token-probabilities, previous to softmax normalization (which forces them to sum to 1 — it seems that it’s extra environment friendly to depart these values of their uncooked, pre-normalized kind).

The logits tensor is 3D, with form [B,N,|V|], the place B is batch dimension, N is sequence size (in tokens), and |V| is token vocabulary dimension. The 2D labels tensor simply accommodates the token sequence itself, so it has form [B,N]. Right here is the important thing part of code the place CEL is often computed:

# Shift-by-1 in order that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()

# Flatten the tensors
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)

# Allow mannequin parallelism
shift_labels = shift_labels.to(shift_logits.gadget)

# Compute loss
loss_fct = CrossEntropyLoss()
loss = loss_fct(shift_logits, shift_labels)

For every place i alongside the 2nd dimension of logits, this tensor accommodates chances for predicting the subsequent token (token i+1) given all of the previous tokens up by way of the iᵗʰ token. These chances should be in comparison with the precise i+1ˢᵗ token in labels. That is why the shift-by-1 occurs within the first a number of traces — to carry these two values into alignment for every token.