How PyTorch NestedTensors, FlashAttention2, and xFormers can Increase Efficiency and Cut back AI Prices
As generative AI (genAI) fashions develop in each recognition and scale, so do the computational calls for and prices related to their coaching and deployment. Optimizing these fashions is essential for enhancing their runtime efficiency and lowering their operational bills. On the coronary heart of recent genAI programs is the Transformer structure and its consideration mechanism, which is notably compute-intensive.
In a earlier submit, we demonstrated how utilizing optimized consideration kernels can considerably speed up the efficiency of Transformer fashions. On this submit, we proceed our exploration by addressing the problem of variable-length enter sequences — an inherent property of real-world knowledge, together with paperwork, code, time-series, and extra.
The Problem of Batching Variable-Size Enter
In a typical deep studying workload, particular person samples are grouped into batches earlier than being copied to the GPU and fed to the AI mannequin. Batching improves computational effectivity and sometimes aids mannequin convergence throughout coaching. Often, batching includes stacking the entire pattern tensors alongside a brand new dimension — the batch dimension. Nonetheless, torch.stack requires that every one tensors to have the identical form, which isn’t the case with variable-length sequences.
Padding and its Inefficiencies
The normal method to deal with this problem is to pad the enter sequences to a set size after which carry out stacking. This answer requires acceptable masking inside the mannequin in order that the output is just not affected by the irrelevant tensor parts. Within the case of consideration layers, a padding masks signifies which tokens are padding and shouldn’t be attended to (e.g., see PyTorch MultiheadAttention). Nonetheless, padding can waste appreciable GPU sources, rising prices and slowing growth. That is very true for large-scale AI fashions.
Don’t Pad, Concatenate
One method to keep away from padding is to concatenate sequences alongside an current dimension as a substitute of stacking them alongside a brand new dimension. Opposite to torch.stack, torch.cat permits inputs of various shapes. The output of concatenation is single sequence whose size equals the sum of the lengths of the person sequences. For this answer to work, our single sequence would should be supplemented by an consideration masks that might make sure that every token solely attends to different tokens in the identical unique sequence, in a course of generally known as doc masking. Denoting the sum of the lengths of the entire particular person by N and adopting ”large O” notation, the scale of this masks would should be O(N²), as would the compute complexity of a normal consideration layer, making this answer extremely inefficient.
Consideration Layer Optimization
The answer to this downside comes within the type of specialised consideration layers. Opposite to the usual consideration layer that performs the complete set of O(N²) consideration scores solely to masks out the irrelevant ones, these optimized consideration kernels are designed to calculate solely the scores that matter. On this submit we’ll discover a number of options, every with their very own distinct traits. These embrace:
Integration into Present HuggingFace Fashions
For groups working with pre-trained fashions, transitioning to those optimizations may appear difficult. We’ll reveal how HuggingFace’s APIs simplify this course of, enabling builders to combine these strategies with minimal code adjustments and energy.
Disclaimers
- Please don’t interpret our use of any platforms, libraries, or optimization strategies as an endorsement for his or her use. The most effective choices for you’ll rely significantly on the specifics of your individual use-case.
- A few of the APIs mentioned listed here are in prototype or beta phases and should change sooner or later.
- The code examples supplied are for demonstrative functions solely. We make no claims concerning their accuracy, optimality, or robustness.
Particular due to Yitzhak Levi and Peleg Nahaliel for his or her contributions to this submit.
To facilitate our dialogue we’ll outline a easy generative mannequin (partially impressed by the GPT mannequin outlined right here). For a extra complete information on constructing language fashions, please see one of many many glorious tutorials accessible on-line (e.g., right here).
Transformer Block
We start by setting up a primary Transformer block, particularly designed to facilitate experimentation with totally different consideration mechanisms and optimizations. Whereas our block performs the identical computation as normal Transformer blocks, we make slight modifications to the same old alternative of operators so as to assist the opportunity of PyTorch NestedTensor inputs (as described right here).
# normal imports
import time, functools# torch imports
import torch
from torch.utils.knowledge import Dataset, DataLoader
import torch.nn as nn
# Outline Transformer settings
BATCH_SIZE = 32
NUM_HEADS = 16
HEAD_DIM = 64
DIM = NUM_HEADS * HEAD_DIM
DEPTH = 24
NUM_TOKENS = 1024
MAX_SEQ_LEN = 1024
PAD_ID = 0
DEVICE = 'cuda'
class MyAttentionBlock(nn.Module):
def __init__(
self,
attn_fn,
dim,
num_heads,
format=None,
**kwargs
):
tremendous().__init__()
self.attn_fn = attn_fn
self.num_heads = num_heads
self.dim = dim
self.head_dim = dim // num_heads
self.norm1 = nn.LayerNorm(dim, bias=False)
self.norm2 = nn.LayerNorm(dim, bias=False)
self.qkv = nn.Linear(dim, dim * 3)
self.proj = nn.Linear(dim, dim)
# mlp layers
self.fc1 = nn.Linear(dim, dim * 4)
self.act = nn.GELU()
self.fc2 = nn.Linear(dim * 4, dim)
self.permute = functools.partial(torch.transpose, dim0=1, dim1=2)
if format == 'bshd':
self.permute = nn.Id()
def mlp(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.fc2(x)
return x
def reshape_and_permute(self,x, batch_size):
x = x.view(batch_size, -1, self.num_heads, self.head_dim)
return self.permute(x)
def ahead(self, x_in, attn_mask=None):
batch_size = x_in.measurement(0)
x = self.norm1(x_in)
qkv = self.qkv(x)
# slightly than first reformatting after which splitting the enter
# state, we first break up after which reformat q, okay, v so as to
# assist PyTorch Nested Tensors
q, okay, v = qkv.chunk(3, -1)
q = self.reshape_and_permute(q, batch_size)
okay = self.reshape_and_permute(okay, batch_size)
v = self.reshape_and_permute(v, batch_size)
# name the attn_fn with the enter attn_mask
x = self.attn_fn(q, okay, v, attn_mask=attn_mask)
# reformat output
x = self.permute(x).reshape(batch_size, -1, self.dim)
x = self.proj(x)
x = x + x_in
x = x + self.mlp(self.norm2(x))
return x
Transformer Decoder Mannequin
Constructing on our programmable Transformer block, we assemble a typical Transformer decoder mannequin.
class MyDecoder(nn.Module):
def __init__(
self,
block_fn,
num_tokens,
dim,
num_heads,
num_layers,
max_seq_len,
pad_idx=None
):
tremendous().__init__()
self.num_heads = num_heads
self.pad_idx = pad_idx
self.embedding = nn.Embedding(num_tokens, dim, padding_idx=pad_idx)
self.positional_embedding = nn.Embedding(max_seq_len, dim)
self.blocks = nn.ModuleList([
block_fn(
dim=dim,
num_heads=num_heads
)
for _ in range(num_layers)])
self.output = nn.Linear(dim, num_tokens)def embed_tokens(self, input_ids, position_ids=None):
x = self.embedding(input_ids)
if position_ids is None:
position_ids = torch.arange(input_ids.form[1],
system=x.system)
x = x + self.positional_embedding(position_ids)
return x
def ahead(self, input_ids, position_ids=None, attn_mask=None):
# Embed tokens and add positional encoding
x = self.embed_tokens(input_ids, position_ids)
if self.pad_idx is just not None:
assert attn_mask is None
# create a padding masks - we assume boolean masking
attn_mask = (input_ids != self.pad_idx)
attn_mask = attn_mask.view(BATCH_SIZE, 1, 1, -1)
.increase(-1, self.num_heads, -1, -1)
for b in self.blocks:
x = b(x, attn_mask)
logits = self.output(x)
return logits
Variable Size Sequence Enter
Subsequent, we create a dataset containing sequences of variable lengths, the place every sequence is made up of randomly generated tokens. For simplicity, we (arbitrarily) choose a set distribution for the sequence lengths. In real-world eventualities, the distribution of sequence lengths sometimes displays the character of the info, such because the size of paperwork or audio segments. Be aware, that the distribution of lengths instantly impacts the computational inefficiencies attributable to padding.
# Use random knowledge
class FakeDataset(Dataset):
def __len__(self):
return 1000000def __getitem__(self, index):
size = torch.randint(1, MAX_SEQ_LEN, (1,))
sequence = torch.randint(1, NUM_TOKENS, (size + 1,))
enter = sequence[:-1]
goal = sequence[1:]
return enter, goal
def pad_sequence(sequence, size, pad_val):
return torch.nn.useful.pad(
sequence,
(0, size - sequence.form[0]),
worth=pad_val
)
def collate_with_padding(batch):
padded_inputs = []
padded_targets = []
for b in batch:
padded_inputs.append(pad_sequence(b[0], MAX_SEQ_LEN, PAD_ID))
padded_targets.append(pad_sequence(b[1], MAX_SEQ_LEN, PAD_ID))
padded_inputs = torch.stack(padded_inputs, dim=0)
padded_targets = torch.stack(padded_targets, dim=0)
return {
'inputs': padded_inputs,
'targets': padded_targets
}
def data_to_device(knowledge, system):
if isinstance(knowledge, dict):
return {
key: data_to_device(val,system)
for key, val in knowledge.gadgets()
}
elif isinstance(knowledge, (checklist, tuple)):
return kind(knowledge)(
data_to_device(val, system) for val in knowledge
)
elif isinstance(knowledge, torch.Tensor):
return knowledge.to(system=system, non_blocking=True)
else:
return knowledge.to(system=system)
Coaching/Analysis Loop
Lastly, we implement a foremost perform that performs coaching/analysis on enter sequences of various size.
def foremost(
block_fn,
data_collate_fn=collate_with_padding,
pad_idx=None,
practice=True,
compile=False
):
torch.random.manual_seed(0)
system = torch.system(DEVICE)
torch.set_float32_matmul_precision("excessive")# Create dataset and dataloader
data_set = FakeDataset()
data_loader = DataLoader(
data_set,
batch_size=BATCH_SIZE,
collate_fn=data_collate_fn,
num_workers=12,
pin_memory=True,
drop_last=True
)
mannequin = MyDecoder(
block_fn=block_fn,
num_tokens=NUM_TOKENS,
dim=DIM,
num_heads=NUM_HEADS,
num_layers=DEPTH,
max_seq_len=MAX_SEQ_LEN,
pad_idx=pad_idx
).to(system)
if compile:
mannequin = torch.compile(mannequin)
# Outline loss and optimizer
criterion = torch.nn.CrossEntropyLoss(ignore_index=PAD_ID)
optimizer = torch.optim.SGD(mannequin.parameters())
def train_step(mannequin, inputs, targets,
position_ids=None, attn_mask=None):
with torch.amp.autocast(DEVICE, dtype=torch.bfloat16):
outputs = mannequin(inputs, position_ids, attn_mask)
outputs = outputs.view(-1, NUM_TOKENS)
targets = targets.flatten()
loss = criterion(outputs, targets)
optimizer.zero_grad(set_to_none=True)
loss.backward()
optimizer.step()
@torch.no_grad()
def eval_step(mannequin, inputs, targets,
position_ids=None, attn_mask=None):
with torch.amp.autocast(DEVICE, dtype=torch.bfloat16):
outputs = mannequin(inputs, position_ids, attn_mask)
if outputs.is_nested:
outputs = outputs.knowledge._values
targets = targets.knowledge._values
else:
outputs = outputs.view(-1, NUM_TOKENS)
targets = targets.flatten()
loss = criterion(outputs, targets)
return loss
if practice:
mannequin.practice()
step_fn = train_step
else:
mannequin.eval()
step_fn = eval_step
t0 = time.perf_counter()
summ = 0
depend = 0
for step, knowledge in enumerate(data_loader):
# Copy knowledge to GPU
knowledge = data_to_device(knowledge, system=system)
step_fn(mannequin, knowledge['inputs'], knowledge['targets'],
position_ids=knowledge.get('indices'),
attn_mask=knowledge.get('attn_mask'))
# Seize step time
batch_time = time.perf_counter() - t0
if step > 20: # Skip first steps
summ += batch_time
depend += 1
t0 = time.perf_counter()
if step >= 100:
break
print(f'common step time: {summ / depend}')
PyTorch SDPA with Padding
For our baseline experiments, we configure our Transformer block to make the most of PyTorch’s SDPA mechanism. In our experiments, we run each coaching and analysis, each with and with out torch.compile. These have been run on an NVIDIA H100 with CUDA 12.4 and PyTorch 2.5.1
from torch.nn.useful import scaled_dot_product_attention as sdpa
block_fn = functools.partial(MyAttentionBlock, attn_fn=sdpa)
causal_block_fn = functools.partial(
MyAttentionBlock,
attn_fn=functools.partial(sdpa, is_causal=True)
)for mode in ['eval', 'train']:
for compile in [False, True]:
block_func = causal_block_fn
if mode == 'practice' else block_fn
print(f'{mode} with {collate}, '
f'{"compiled" if compile else "uncompiled"}')
foremost(block_fn=block_func,
pad_idx=PAD_ID,
practice=mode=='practice',
compile=compile)
Efficiency Outcomes:
- Analysis: 132 milliseconds (ms) with out torch.compile, 130 ms with torch.compile
- Coaching: 342 ms with out torch.compile, 299 ms with torch.compile
On this part, we’ll discover a number of optimization strategies for dealing with variable-length enter sequences in Transformer fashions.
Padding Optimization
Our first optimization relates to not the eye kernel however to our padding mechanism. Fairly than padding the sequences in every batch to a continuing size, we pad to the size of the longest sequence within the batch. The next block of code consists of our revised collation perform and up to date experiments.
def collate_pad_to_longest(batch):
padded_inputs = []
padded_targets = []
max_length = max([b[0].form[0] for b in batch])
for b in batch:
padded_inputs.append(pad_sequence(b[0], max_length, PAD_ID))
padded_targets.append(pad_sequence(b[1], max_length, PAD_ID))
padded_inputs = torch.stack(padded_inputs, dim=0)
padded_targets = torch.stack(padded_targets, dim=0)
return {
'inputs': padded_inputs,
'targets': padded_targets
}for mode in ['eval', 'train']:
for compile in [False, True]:
block_func = causal_block_fn
if mode == 'practice' else block_fn
print(f'{mode} with {collate}, '
f'{"compiled" if compile else "uncompiled"}')
foremost(block_fn=block_func,
data_collate_fn=collate_pad_to_longest,
pad_idx=PAD_ID,
practice=mode=='practice',
compile=compile)
Padding to the longest sequence in every batch leads to a slight efficiency acceleration:
- Analysis: 129 ms with out torch.compile, 116 ms with torch.compile
- Coaching: 337 ms with out torch.compile, 294 ms with torch.compile
SDPA with PyTorch NestedTensors
Subsequent, we benefit from the built-in assist for PyTorch NestedTensors in SDPA in analysis mode. At present a prototype characteristic, PyTorch NestedTensors permits for grouping collectively tensors of various size. These are generally known as jagged or ragged tensors. Within the code block under, we outline a collation perform for grouping our sequences into NestedTensors. We additionally outline an indices entry in order that we are able to correctly calculate the positional embeddings.
PyTorch NestedTensors are supported by a restricted variety of PyTorch ops. Working round these limitations can require some creativity. For instance, addition between NestedTensors is simply supported after they share exactly the identical “jagged” form. Within the code under we use a workaround to make sure that the indices entry shares the identical form because the mannequin inputs.
def nested_tensor_collate(batch):
inputs = torch.nested.as_nested_tensor([b[0] for b in batch],
format=torch.jagged)
targets = torch.nested.as_nested_tensor([b[1] for b in batch],
format=torch.jagged)
indices = torch.concat([torch.arange(b[0].form[0]) for b in batch])# workaround for making a NestedTensor with equivalent "jagged" form
xx = torch.empty_like(inputs)
xx.knowledge._values[:] = indices
return {
'inputs': inputs,
'targets': targets,
'indices': xx
}
for compile in [False, True]:
print(f'eval with nested tensors, '
f'{"compiled" if compile else "uncompiled"}')
foremost(
block_fn=block_fn,
data_collate_fn=nested_tensor_collate,
practice=False,
compile=compile
)
Though, with torch.compile, the NestedTensor optimization leads to a step time of 131 ms, much like our baseline end result, in compiled mode the step time drops to 42 ms for a formidable ~3x enchancment.
FlashAttention2
In our earlier submit we demonstrated the usage of FlashAttention and its affect on the efficiency of a transformer mannequin. On this submit we reveal the usage of flash_attn_varlen_func from flash-attn (2.7.0), an API designed to be used with variable-sized inputs. To make use of this perform, we concatenate the entire sequences within the batch right into a single sequence. We additionally create a cu_seqlens tensor that factors to the indices inside the concatenated tensor the place every of the person sequences begin. The code block under contains our collation perform adopted by analysis and coaching experiments. Be aware, that flash_attn_varlen_func doesn’t assist torch.compile (on the time of this writing).
def collate_concat(batch):
inputs = torch.concat([b[0] for b in batch]).unsqueeze(0)
targets = torch.concat([b[1] for b in batch]).unsqueeze(0)
indices = torch.concat([torch.arange(b[0].form[0]) for b in batch])
seqlens = torch.tensor([b[0].form[0] for b in batch])
seqlens = torch.cumsum(seqlens, dim=0, dtype=torch.int32)
cu_seqlens = torch.nn.useful.pad(seqlens, (1, 0))return {
'inputs': inputs,
'targets': targets,
'indices': indices,
'attn_mask': cu_seqlens
}
from flash_attn import flash_attn_varlen_func
fa_varlen = lambda q, okay, v, attn_mask: flash_attn_varlen_func(
q.squeeze(0),
okay.squeeze(0),
v.squeeze(0),
cu_seqlens_q=attn_mask,
cu_seqlens_k=attn_mask,
max_seqlen_q=MAX_SEQ_LEN,
max_seqlen_k=MAX_SEQ_LEN
).unsqueeze(0)
fa_varlen_causal = lambda q, okay, v, attn_mask: flash_attn_varlen_func(
q.squeeze(0),
okay.squeeze(0),
v.squeeze(0),
cu_seqlens_q=attn_mask,
cu_seqlens_k=attn_mask,
max_seqlen_q=MAX_SEQ_LEN,
max_seqlen_k=MAX_SEQ_LEN,
causal=True
).unsqueeze(0)
block_fn = functools.partial(MyAttentionBlock,
attn_fn=fa_varlen,
format='bshd')
causal_block_fn = functools.partial(MyAttentionBlock,
attn_fn=fa_varlen_causal,
format='bshd')
print('flash-attn eval')
foremost(
block_fn=block_fn,
data_collate_fn=collate_concat,
practice=False
)
print('flash-attn practice')
foremost(
block_fn=causal_block_fn,
data_collate_fn=collate_concat,
practice=True,
)
The affect of this optimization is dramatic, 51 ms for analysis and 160 ms for coaching, amounting to 2.6x and a pair of.1x efficiency boosts in comparison with our baseline experiment.
XFormers Reminiscence Environment friendly Consideration
In our earlier submit we demonstrated the usage of the memory_efficient_attention operator from xFormers (0.0.28). Right here we reveal the usage of BlockDiagonalMask, particularly designed for enter sequences of arbitrary size. The required collation perform seems within the code block under adopted by the analysis and coaching experiments. Be aware, that torch.compile failed in coaching mode.
from xformers.ops import fmha
from xformers.ops import memory_efficient_attention as meadef collate_xformer(batch):
inputs = torch.concat([b[0] for b in batch]).unsqueeze(0)
targets = torch.concat([b[1] for b in batch]).unsqueeze(0)
indices = torch.concat([torch.arange(b[0].form[0]) for b in batch])
seqlens = [b[0].form[0] for b in batch]
batch_sizes = [1 for b in batch]
block_diag = fmha.BlockDiagonalMask.from_seqlens(seqlens, system='cpu')
block_diag._batch_sizes = batch_sizes
return {
'inputs': inputs,
'targets': targets,
'indices': indices,
'attn_mask': block_diag
}
mea_eval = lambda q, okay, v, attn_mask: mea(
q,okay,v, attn_bias=attn_mask)
mea_train = lambda q, okay, v, attn_mask: mea(
q,okay,v, attn_bias=attn_mask.make_causal())
block_fn = functools.partial(MyAttentionBlock,
attn_fn=mea_eval,
format='bshd')
causal_block_fn = functools.partial(MyAttentionBlock,
attn_fn=mea_train,
format='bshd')
print(f'xFormer Consideration ')
for compile in [False, True]:
print(f'eval with xFormer Consideration, '
f'{"compiled" if compile else "uncompiled"}')
foremost(block_fn=block_fn,
practice=False,
data_collate_fn=collate_xformer,
compile=compile)
print(f'practice with xFormer Consideration')
foremost(block_fn=causal_block_fn,
practice=True,
data_collate_fn=collate_xformer)
The resultant step time have been 50 ms and 159 ms for analysis and coaching with out torch.compile. Analysis with torch.compile resulted in a step time of 42 ms.
Outcomes
The desk under summarizes the outcomes of our optimization strategies.
The most effective performer for our toy mannequin is xFormer’s memory_efficient_attention which delivered a ~3x efficiency for analysis and ~2x efficiency for coaching. We warning towards deriving any conclusions from these outcomes because the efficiency affect of various consideration capabilities can differ considerably relying on the particular mannequin and use case.
The instruments and strategies described above are straightforward to implement when making a mannequin from scratch. Nonetheless, today it’s not unusual for ML builders to undertake current (pretrained) fashions and finetune them for his or her use case. Whereas the optimizations we have now described might be built-in with out altering the set of mannequin weights and with out altering the mannequin habits, it’s not fully clear what one of the simplest ways to do that is. In a really perfect world, our ML framework would permit us to program the usage of an consideration mechanism that’s optimized for variable-length inputs. On this part we reveal the way to optimize HuggingFace fashions for variable-length inputs.
A Toy HuggingFace Mannequin – GPT2LMHeadModel
To facilitate the dialogue, we create a toy instance by which we practice a HuggingFace GPT2LMHead mannequin on variable-length sequences. This requires adapting our random dataset and data-padding collation perform in line with HuggingFace’s enter specs.
from transformers import GPT2Config, GPT2LMHeadModel# Use random knowledge
class HuggingFaceFakeDataset(Dataset):
def __len__(self):
return 1000000
def __getitem__(self, index):
size = torch.randint(1, MAX_SEQ_LEN, (1,))
input_ids = torch.randint(1, NUM_TOKENS, (size,))
labels = input_ids.clone()
labels[0] = PAD_ID # ignore first token
return {
'input_ids': input_ids,
'labels': labels
}
return input_ids, labels
def hf_collate_with_padding(batch):
padded_inputs = []
padded_labels = []
for b in batch:
input_ids = b['input_ids']
labels = b['labels']
padded_inputs.append(pad_sequence(input_ids, MAX_SEQ_LEN, PAD_ID))
padded_labels.append(pad_sequence(labels, MAX_SEQ_LEN, PAD_ID))
padded_inputs = torch.stack(padded_inputs, dim=0)
padded_labels = torch.stack(padded_labels, dim=0)
return {
'input_ids': padded_inputs,
'labels': padded_labels,
'attention_mask': (padded_inputs != PAD_ID)
}
Coaching Operate
Our coaching perform instantiates a GPT2LMHeadModel based mostly on the requested GPT2Config and proceeds to coach it on our variable-length sequences.
def hf_main(
config,
collate_fn=hf_collate_with_padding,
compile=False
):
torch.random.manual_seed(0)
system = torch.system(DEVICE)
torch.set_float32_matmul_precision("excessive")# Create dataset and dataloader
data_set = HuggingFaceFakeDataset()
data_loader = DataLoader(
data_set,
batch_size=BATCH_SIZE,
collate_fn=collate_fn,
num_workers=12 if DEVICE == "CUDA" else 0,
pin_memory=True,
drop_last=True
)
mannequin = GPT2LMHeadModel(config).to(system)
if compile:
mannequin = torch.compile(mannequin)
# Outline loss and optimizer
criterion = torch.nn.CrossEntropyLoss(ignore_index=PAD_ID)
optimizer = torch.optim.SGD(mannequin.parameters())
mannequin.practice()
t0 = time.perf_counter()
summ = 0
depend = 0
for step, knowledge in enumerate(data_loader):
# Copy knowledge to GPU
knowledge = data_to_device(knowledge, system=system)
input_ids = knowledge['input_ids']
labels = knowledge['labels']
position_ids = knowledge.get('position_ids')
attn_mask = knowledge.get('attention_mask')
with torch.amp.autocast(DEVICE, dtype=torch.bfloat16):
outputs = mannequin(input_ids=input_ids,
position_ids=position_ids,
attention_mask=attn_mask)
logits = outputs.logits[..., :-1, :].contiguous()
labels = labels[..., 1:].contiguous()
loss = criterion(logits.view(-1, NUM_TOKENS), labels.flatten())
optimizer.zero_grad(set_to_none=True)
loss.backward()
optimizer.step()
# Seize step time
batch_time = time.perf_counter() - t0
if step > 20: # Skip first steps
summ += batch_time
depend += 1
t0 = time.perf_counter()
if step >= 100:
break
print(f'common step time: {summ / depend}')
SDPA with Padding
Within the callback under we name our coaching perform with the default sequence-padding collator.
config = GPT2Config(
n_layer=DEPTH,
n_embd=DIM,
n_head=NUM_HEADS,
vocab_size=NUM_TOKENS,
)for compile in [False, True]:
print(f"HF GPT2 practice with SDPA, compile={compile}")
hf_main(config=config, compile=compile)
The resultant step occasions are 815 ms with out torch.compile and 440 ms with torch.compile.
FlashAttention2
We now benefit from HuggingFace’s built-in assist for FlashAttention2, by setting the attn_implementation parameter to “flash_attention_2”. Behind the scenes, HuggingFace will unpad the padded knowledge enter after which go them to the optimized flash_attn_varlen_func perform we noticed above:
flash_config = GPT2Config(
n_layer=DEPTH,
n_embd=DIM,
n_head=NUM_HEADS,
vocab_size=NUM_TOKENS,
attn_implementation='flash_attention_2'
)print(f"HF GPT2 practice with flash")
hf_main(config=flash_config)
The resultant time step is 620 ms, amounting to a 30% increase (in uncompiled mode) with only a easy flick of a swap.
FlashAttention2 with Unpadded Enter
In fact, padding the sequences within the collation perform solely to have them unpadded, hardly appears smart. In a current replace to HuggingFace, assist was added for passing in concatenated (unpadded) sequences to a choose variety of fashions. Sadly, (as of the time of this writing) our GPT2 mannequin didn’t make the reduce. Nonetheless, including assist requires simply 5 small line additions adjustments to modeling_gpt2.py so as to propagate the sequence position_ids to the flash-attention kernel. The total patch seems within the block under:
@@ -370,0 +371 @@
+ position_ids = None
@@ -444,0 +446 @@
+ position_ids=position_ids
@@ -611,0 +614 @@
+ position_ids=None
@@ -621,0 +625 @@
+ position_ids=position_ids
@@ -1140,0 +1145 @@
+ position_ids=position_ids
We outline a collate perform that concatenates our sequences and practice our hugging face mannequin on unpadded sequences. (Additionally see the built-in DataCollatorWithFlattening utility.)
def collate_flatten(batch):
input_ids = torch.concat([b['input_ids'] for b in batch]).unsqueeze(0)
labels = torch.concat([b['labels'] for b in batch]).unsqueeze(0)
position_ids = [torch.arange(b['input_ids'].form[0]) for b in batch]
position_ids = torch.concat(position_ids)return {
'input_ids': input_ids,
'labels': labels,
'position_ids': position_ids
}
print(f"HF GPT2 practice with flash, no padding")
hf_main(config=flash_config, collate_fn=collate_flatten)
The ensuing step time is 323 ms, 90% quicker than working flash-attention on the padded enter.
Outcomes
The outcomes of our HuggingFace experiments are summarized under.
With little effort, we have been in a position to increase our runtime efficiency by 2.5x when in comparison with the uncompiled baseline experiment, and by 36% when in comparison with the compiled model.
On this part, we demonstrated how the HuggingFace APIs permit us to leverage the optimized kernels in FlashAttention2, considerably boosting the coaching efficiency of current fashions on sequences of various size.
As AI fashions proceed to develop in each recognition and complexity, optimizing their efficiency has change into important for lowering runtime and prices. That is very true for compute-intensive elements like consideration layers. On this submit, we have now continued our exploration of consideration layer optimization, and demonstrated new instruments and strategies for enhancing Transformer mannequin efficiency. For extra insights on AI mannequin optimization, remember to try the first submit on this collection in addition to our many different posts on this matter.