How paying “higher” consideration can drive ML price financial savings
Launched within the landmark 2017 paper “Consideration Is All You Want” (Vaswani et al., 2017), the Transformer structure is broadly thought to be one of the vital influential scientific breakthroughs of the previous decade. On the core of the Transformer is the eye mechanism, a novel strategy that permits AI fashions to grasp complicated buildings by specializing in completely different elements of enter sequences based mostly on the duty at hand. Initially demonstrated on the planet of pure language processing, the success of the Transformers structure has rapidly unfold to many different domains, together with speech recognition, scene understanding, reinforcement studying, protein construction prediction, and extra. Nonetheless, consideration layers are extremely resource-intensive, and as these layers develop into the usual throughout more and more massive fashions, the prices related to their coaching and deployment have surged. This has created an pressing want for methods that scale back the computational price of this core layer in order to extend the effectivity and scalability of Transformer-based AI fashions.
On this put up, we’ll discover a number of instruments for optimizing consideration in PyTorch. Our focus will likely be on strategies that keep the accuracy of the eye layer. These will embody PyTorch SDPA, FlashAttention, TransformerEngine Consideration, FlexAttention, and xFormer consideration. Different strategies that scale back the computational price through approximation of the eye calculation (e.g., DeepSpeed’s Sparse Consideration, Longformer, Linformer, and extra) won’t be thought of. Moreover, we won’t focus on normal optimization strategies that, whereas useful to consideration efficiency, usually are not particular to the eye computation itself (e.g., FP8 coaching, mannequin sharding, and extra).
Importantly, consideration optimization is an lively space of analysis with new strategies popping out on a reasonably common foundation. Our aim is to extend your consciousness of a number of the present options and offer you a basis for additional exploration and experimentation. The code we’ll share under is meant for demonstrative functions solely — we make no claims relating to its accuracy, optimality, or robustness. Please don’t interpret our point out of any platforms, libraries, or optimization strategies as an endorsement for his or her use. The most effective choices for you’ll rely enormously on the specifics of your personal use-case.
Many because of Yitzhak Levi for his contributions to this put up.
To facilitate our dialogue, we construct a Imaginative and prescient Transformer (ViT)-backed classification mannequin utilizing the favored timm Python bundle (model 0.9.7). We are going to use this mannequin for example the efficiency impression of assorted consideration kernels.
We begin by defining a simplified Transformer block that permits for programming the eye operate by passing it into its constructor. Since consideration implementations assume particular enter tensor codecs, we additionally embody an choice for controlling the format, making certain compatibility with the eye kernel of our selecting.
# normal imports
import os, time, functools# torch imports
import torch
from torch.utils.knowledge import Dataset, DataLoader
import torch.nn as nn
# timm imports
from timm.fashions.vision_transformer import VisionTransformer
from timm.layers import Mlp
IMG_SIZE = 224
BATCH_SIZE = 128
# Outline ViT settings
NUM_HEADS = 16
HEAD_DIM = 64
DEPTH = 24
PATCH_SIZE = 16
SEQ_LEN = (IMG_SIZE // PATCH_SIZE)**2 # 196
class MyAttentionBlock(nn.Module):
def __init__(
self,
attn_fn,
format = None,
dim: int = 768,
num_heads: int = 12,
**kwargs
) -> None:
tremendous().__init__()
self.attn_fn = attn_fn
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim)
self.qkv = nn.Linear(dim, dim * 3, bias=False)
self.proj = nn.Linear(dim, dim)
self.mlp = Mlp(
in_features=dim,
hidden_features=dim * 4,
)
permute = (2, 0, 3, 1, 4)
self.permute_attn = functools.partial(torch.transpose,dim0=1,dim1=2)
if format == 'bshd':
permute = (2, 0, 1, 3, 4)
self.permute_attn = nn.Identification()
self.permute_qkv = functools.partial(torch.permute,dims=permute)
def ahead(self, x_in: torch.Tensor) -> torch.Tensor:
x = self.norm1(x_in)
B, N, C = x.form
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim)
# permute tensor based mostly on the desired format
qkv = self.permute_qkv(qkv)
q, ok, v = qkv.unbind(0)
# use the eye operate specified by the person
x = self.attn_fn(q, ok, v)
# permute output in response to the desired format
x = self.permute_attn(x).reshape(B, N, C)
x = self.proj(x)
x = x + x_in
x = x + self.mlp(self.norm2(x))
return x
We outline a randomly generated dataset which we’ll use to feed to our mannequin throughout coaching.
# Use random knowledge
class FakeDataset(Dataset):
def __len__(self):
return 1000000def __getitem__(self, index):
rand_image = torch.randn([3, IMG_SIZE, IMG_SIZE],
dtype=torch.float32)
label = torch.tensor(knowledge=index % 1000, dtype=torch.int64)
return rand_image, label
Subsequent, we outline our ViT coaching operate. Whereas our instance focuses on demonstrating a coaching workload, it’s essential to emphasise that optimizing the eye layer is equally, if no more, vital throughout mannequin inference.
The coaching operate we outline accepts the personalized Transformer block and a flag that controls using torch.compile.
def train_fn(block_fn, compile):
torch.random.manual_seed(0)
system = torch.system("cuda:0")
torch.set_float32_matmul_precision("excessive")# Create dataset and dataloader
train_set = FakeDataset()
train_loader = DataLoader(
train_set, batch_size=BATCH_SIZE,
num_workers=12, pin_memory=True, drop_last=True)
mannequin = VisionTransformer(
img_size=IMG_SIZE,
patch_size=PATCH_SIZE,
embed_dim=NUM_HEADS*HEAD_DIM,
depth=DEPTH,
num_heads=NUM_HEADS,
class_token=False,
global_pool="avg",
block_fn=block_fn
).to(system)
if compile:
mannequin = torch.compile(mannequin)
# Outline loss and optimizer
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(mannequin.parameters())
mannequin.practice()
t0 = time.perf_counter()
summ = 0
depend = 0
for step, knowledge in enumerate(train_loader):
# Copy knowledge to GPU
inputs = knowledge[0].to(system=system, non_blocking=True)
label = knowledge[1].to(system=system, non_blocking=True)
with torch.amp.autocast('cuda', enabled=True, dtype=torch.bfloat16):
outputs = mannequin(inputs)
loss = criterion(outputs, label)
optimizer.zero_grad(set_to_none=True)
loss.backward()
optimizer.step()
# Seize step time
batch_time = time.perf_counter() - t0
if step > 20: # Skip first steps
summ += batch_time
depend += 1
t0 = time.perf_counter()
if step > 100:
break
print(f'common step time: {summ / depend}')
# outline compiled and uncompiled variants of our practice operate
practice = functools.partial(train_fn, compile=False)
train_compile = functools.partial(train_fn, compile=True)
Within the code block under we outline a PyTorch-native consideration operate and use it to coach our ViT mannequin:
def attn_fn(q, ok, v):
scale = HEAD_DIM ** -0.5
q = q * scale
attn = q @ ok.transpose(-2, -1)
attn = attn.softmax(dim=-1)
x = attn @ v
return xblock_fn = functools.partial(MyAttentionBlock, attn_fn=attn_fn)
print('Default Consideration')
practice(block_fn)
print('Compiled Default Consideration')
train_compile(block_fn)
We ran this on an NVIDIA H100 with CUDA 12.4 and PyTorch 2.5.1. The uncompiled variant resulted in a mean step time of 370 milliseconds (ms), whereas the compiled variant improved to 242 ms. We are going to use these outcomes as a baseline for comparability as we take into account different options for performing the eye computation.
One of many best methods to spice up the efficiency of our consideration layers in PyTorch is to make use of the scaled_dot_product_attention (SDPA) operate. Presently in beta, PyTorch SDPA consolidates a number of kernel-level optimizations and dynamically selects probably the most environment friendly one based mostly on the enter’s properties. Supported backends (as of now) embody: FlashAttention-2, Reminiscence-Environment friendly Consideration, a C++-based Math Consideration, and CuDNN. These backends fuse collectively high-level operations whereas using GPU-level optimizations for growing compute effectivity and reminiscence utilization.
SDPA is repeatedly evolving, with new and improved backend implementations being launched often. Staying updated with the most recent PyTorch releases is vital to leveraging the newest efficiency enhancements. For instance, PyTorch 2.5 launched an up to date CuDNN backend that includes a specialised SDPA primitive particularly tailor-made for coaching on NVIDIA Hopper structure GPUs.
Within the code block under, we iterate by way of the record of supported backends and assess the runtime efficiency of coaching with each. We use a helper operate, set_sdpa_backend, for programming the SDPA backend:
from torch.nn.purposeful import scaled_dot_product_attention as sdpadef set_sdpa_backend(backend):
torch.backends.cuda.enable_flash_sdp(False)
torch.backends.cuda.enable_mem_efficient_sdp(False)
torch.backends.cuda.enable_math_sdp(False)
torch.backends.cuda.enable_cudnn_sdp(False)
if backend in ['flash_sdp','all']:
torch.backends.cuda.enable_flash_sdp(True)
if backend in ['mem_efficient_sdp','all']:
torch.backends.cuda.enable_mem_efficient_sdp(True)
if backend in ['math_sdp','all']:
torch.backends.cuda.enable_math_sdp(True)
if backend in ['cudnn_sdp','all']:
torch.backends.cuda.enable_cudnn_sdp(True)
for backend in ['flash_sdp', 'mem_efficient_sdp',
'math_sdp', 'cudnn_sdp']:
set_sdpa_backend(backend)
block_fn = functools.partial(MyAttentionBlock,
attn_fn=sdpa)
print(f'PyTorch SDPA - {backend}')
practice(block_fn)
print(f'Compiled PyTorch SDPA - {backend}')
train_compile(block_fn)
We summarize our interim leads to the desk under
Whereas the selection of SDPA backend has a noticeable impression on efficiency when working in keen mode, the optimizations carried out by mannequin compilation seem to overshadow the variations between the eye kernels. As soon as once more, we warning in opposition to deriving any conclusions from these outcomes because the efficiency impression of various consideration capabilities can fluctuate considerably relying on the precise mannequin and use case.
Whereas PyTorch SDPA is a good place to begin, utilizing third-party consideration kernels may also help speed up your ML workloads additional. These options typically include added flexibility, providing a wider vary of configuration choices for consideration. Some may additionally embody optimizations tailor-made for particular {hardware} accelerators or newer GPU architectures.
On this part, we’ll discover a number of the third-party consideration kernels accessible and consider their potential impression on runtime efficiency.
FlashAttention-3
Whereas Pytorch SDPA helps a FlashAttention backend, extra superior FlashAttention implementations might be discovered within the flash-attn library. Right here we’ll discover the FlashAttention-3 beta launch which boasts a pace of as much as 2x in comparison with FlashAttention-2. Given the early stage in its growth, FlashAttention-3 can solely be put in immediately from the GitHub repository and its use is restricted to sure head dimensions. Moreover, it doesn’t but help mannequin compilation. Within the following code block, we configure our transformer block to make use of flash-attn-3 whereas setting the eye enter format to “bshd” (batch, sequence, head, depth) to fulfill the expectations of the library.
# flash consideration 3
from flash_attn_interface import flash_attn_func as fa3
attn_fn = lambda q,ok,v: fa3(q,ok,v)[0]
block_fn = functools.partial(MyAttentionBlock,
attn_fn=attn_fn,
format='bshd')print(f'Flash Consideration 3')
practice(block_fn)
The resultant step time was 240 ms, making it 5% sooner than the SDPA flash-attn.
Transformer Engine
Transformer Engine (TE) is a specialised library designed to speed up Transformer fashions on NVIDIA GPUs. TE is up to date often with optimizations that leverage the capabilities of the most recent NVIDIA {hardware} and software program choices, giving customers entry to specialised kernels lengthy earlier than they’re built-in into general-purpose frameworks equivalent to PyTorch.
Within the code block under we use DotProductAttention from TE model 1.11.0. Much like PyTorch SDPA, TE helps a lot of backends that are managed through surroundings variables. Right here we show using the NVTE_FUSED_ATTN backend.
def set_te_backend(backend):
# have to be utilized earlier than first use of
# transformer_engine.pytorch.consideration
os.environ["NVTE_FLASH_ATTN"] = '0'
os.environ["NVTE_FUSED_ATTN"] = '0'
os.environ["NVTE_UNFUSED_ATTN"] = '0'
if backend == 'flash':
os.environ["NVTE_FLASH_ATTN"] = '1'
if backend == 'fused':
os.environ["NVTE_FUSED_ATTN"] = '1'
if backend == 'unfused':
os.environ["NVTE_UNFUSED_ATTN"] = '1'from transformer_engine.pytorch.consideration import DotProductAttention
set_te_backend('fused')
attn_fn = DotProductAttention(NUM_HEADS, HEAD_DIM, NUM_HEADS,
qkv_format='bshd',
# disable masking (default is causal masks)
attn_mask_type='no_mask')
block_fn = functools.partial(MyAttentionBlock,
attn_fn=attn_fn,
format='bshd')
print(f'Transformer Engine Consideration')
practice(block_fn)
print(f'Compiled Transformer Engine Consideration')
train_compile(block_fn)
TE consideration resulted in common step instances of 243 ms and 204 ms for the keen and compiled mannequin variants, correspondingly.
XFormer Consideration
Underlying the memory-efficient backend of PyTorch SDPA is an consideration kernel offered by the xFormers library. As soon as once more, we will go to the supply to learn from the most recent kernel optimizations and from the total set of API capabilities. Within the following code block we use the memory_efficient_attention operator from xFormers model 0.0.28.
# xformer reminiscence environment friendly consideration
from xformers.ops import memory_efficient_attention as mea
block_fn = functools.partial(MyAttentionBlock,
attn_fn=mea,
format='bshd')print(f'xFormer Consideration ')
practice(block_fn)
print(f'Compiled xFormer Consideration ')
train_compile(block_fn)
This keen mannequin variant resulted in a mean step time of 246 ms, making it 10.5% sooner than the SDPA reminiscence environment friendly kernel. The compiled variant resulted in a step time of 203 ms.
Outcomes
The desk under summarizes our experiments:
The winner for the keen mannequin was flash-attn-3 with a mean step time that’s 54% sooner than our baseline mannequin. This interprets to an analogous 54% discount in coaching prices. In compiled mode, the efficiency throughout the optimized kernels was kind of equal, with the quickest implementations attaining 202 ms, representing a 20% enchancment in comparison with the baseline experiment.
As talked about above, the exact impression financial savings is enormously depending on the mannequin definition. To evaluate this variability, we reran the experiments utilizing modified settings that elevated the eye sequence size to 3136 tokens.
IMG_SIZE = 224
BATCH_SIZE = 8# Outline ViT settings
NUM_HEADS = 12
HEAD_DIM = 64
DEPTH = 6
PATCH_SIZE = 4
SEQ_LEN = (IMG_SIZE // PATCH_SIZE)**2 # 3136
The outcomes are summarized within the desk under:
Our rapid statement is that when the sequence size is bigger the efficiency impression of the eye kernels is way extra pronounced. As soon as once more, flash-attn-3 got here out in entrance for the keen execution mode — this time with a ~5x improve in efficiency in comparison with the PyTorch-native operate. For the compiled mannequin we see that the TE kernel broke away from the pack with an general greatest step-time of 53 ms.
To this point, we’ve centered on the usual consideration operate. Nonetheless, generally we might wish to use a variant of the everyday consideration computation through which we both masks out a number of the values of intermediate tensors or apply some operation on them. All these modifications might intervene with our skill to make use of the optimized consideration blocks we coated above. On this part we focus on a number of the methods to deal with this:
Leverage Superior Kernel APIs
Many optimized consideration kernels present in depth APIs with controls for customizing the eye computation. Earlier than implementing a brand new resolution, discover these APIs to find out in the event that they already help your required performance.
Implement a customized kernel:
If the present APIs don’t meet your wants, you might take into account creating your personal customized consideration implementation. In earlier posts (e.g., right here) we mentioned a number of the professionals and cons of customized kernel growth. Attaining optimum efficiency might be extraordinarily tough. For those who do go down this path, one strategy may be to begin with an present (optimum) kernel and apply minimal modifications to combine the specified change.
Use FlexAttention:
A latest addition to PyTorch, FlexAttention empowers customers to implement all kinds of consideration variants with no need to compromise on efficiency. Denoting the results of the dot product of the question and key tokens by rating, flex_attention permits for programming both a score_mod operate or a block_mask masks that’s routinely utilized to the rating tensor. See the documentation in addition to the accompanying attention-gym repository for examples of the forms of operations that the API permits.
FlexAttention works by compiling the score_mod operator into the eye operator, thereby making a single fused kernel. It additionally leverages the sparsity of block_masks to keep away from pointless computations. The benchmarks reported within the FlexAttention documentation present appreciable efficiency positive aspects for quite a lot of use instances.
Let’s see each the score_mod and block_mask in motion.
Rating Mod Instance — Comfortable-Capping with Tanh
Comfortable-capping is a standard approach used to regulate the logit sizes (e.g., see right here). The next code block extends our PyTorch-native consideration kernel with soft-capping:
def softcap_attn(q, ok, v):
scale = HEAD_DIM ** -0.5
q = q * scale
attn = q @ ok.transpose(-2, -1)
# apply soft-capping
attn = 30 * torch.tanh(attn/30)
attn = attn.softmax(dim=-1)
x = attn @ v
return x
Within the code block under we practice our mannequin, first with our PyTorch-native kernel, after which with the optimized Flex Consideration API. These experiments had been run with the 3136-length sequence settings.
# flex consideration imports
from torch.nn.consideration.flex_attention import (
create_block_mask,
create_mask,
flex_attention
)
compiled_flex = torch.compile(flex_attention)# score_mod definition
def tanh_softcap(rating, b, h, q_idx, kv_idx):
return 30 * torch.tanh(rating/30)
block_fn = functools.partial(MyAttentionBlock, attn_fn=softcap_attn)
print(f'Consideration with Softcap')
practice(block_fn)
print(f'Compiled Consideration with Softcap')
train_compile(block_fn)
flex_fn = functools.partial(flex_attention, score_mod=tanh_softcap)
compiled_flex_fn = functools.partial(compiled_flex, score_mod=tanh_softcap)
block_fn = functools.partial(MyAttentionBlock,
attn_fn=flex_fn)
compiled_block_fn = functools.partial(MyAttentionBlock,
attn_fn=compiled_flex_fn)
print(f'Flex Consideration with Softcap')
practice(compiled_block_fn)
print(f'Compiled Flex Consideration with Softcap')
train_compile(block_fn)
The outcomes of the experiments are captured within the desk under:
The impression of the Flash Consideration kernel is clearly evident, delivering efficiency boosts of roughly 3.5x in keen mode and 1.5x in compiled mode.
Masks Mod Instance — Neighborhood Masking
We assess the mask_mod performance by making use of a sparse masks to our consideration rating. Recall that every token in our sequence represents a patch in our 2D enter picture. We modify our kernel so that every token solely attends to different tokens that our inside a 5×5 window within the corresponding 2-D token array.
# convert the token id to a second index
def seq_indx_to_2d(idx):
n_row_patches = IMG_SIZE // PATCH_SIZE
r_ind = idx // n_row_patches
c_ind = idx % n_row_patches
return r_ind, c_ind# solely attend to tokens in a 5x5 surrounding window in our 2D token array
def mask_mod(b, h, q_idx, kv_idx):
q_r, q_c = seq_indx_to_2d(q_idx)
kv_r, kv_c = seq_indx_to_2d(kv_idx)
return torch.logical_and(torch.abs(q_r-kv_r)<5, torch.abs(q_c-kv_c)<5)
As a baseline for our experiment, we use PyTorch SDPA which incorporates help for passing in an consideration masks. The next block consists of the masked SDPA experiment adopted by the Flex Consideration implementation:
# materialize the masks to make use of in SDPA
masks = create_mask(mask_mod, 1, 1, SEQ_LEN, SEQ_LEN, system='cuda')set_sdpa_backend('all')
masked_sdpa = functools.partial(sdpa, attn_mask=masks)
block_fn = functools.partial(MyAttentionBlock,
attn_fn=masked_sdpa)
print(f'Masked SDPA Consideration')
practice(block_fn)
print(f'Compiled Masked SDPA Consideration')
train_compile(block_fn)
block_mask = create_block_mask(mask_mod, None, None, SEQ_LEN, SEQ_LEN)
flex_fn = functools.partial(flex_attention, block_mask=block_mask)
compiled_flex_fn = functools.partial(compiled_flex, block_mask=block_mask)
block_fn = functools.partial(MyAttentionBlock,
attn_fn=flex_fn)
compiled_block_fn = functools.partial(MyAttentionBlock,
attn_fn=compiled_flex_fn)
print(f'Masked Flex Consideration')
practice(compiled_block_fn)
print(f'Compiled Masked Flex Consideration')
train_compile(block_fn)
The outcomes of the experiments are captured under:
As soon as once more, Flex Consideration affords a substantial efficiency increase, amounting to 2.19x in keen mode and a couple of.59x in compiled mode.
Flex Consideration Limitations
Though we now have succeeded in demonstrating the ability and potential of Flex Consideration, there are just a few limitations that ought to be famous:
- Restricted Scope of Modifications: With Flex Consideration you may (as of the time of this writing) solely modify the eye rating (the results of the dot product between the question and key tokens). It doesn’t help modifications at different levels of the eye computation.
- Dependency on torcch.compile: Given the reliance on torch.compile, nice care have to be taken to keep away from extreme recompilations which may enormously degrade runtime efficiency. For example, whereas the help for Doc Masking very compelling, it’s going to solely carry out as anticipated if the sum of the lengths of all the paperwork stays fastened.
- No Assist for Trainable Parameters in score_mod: On the time of this writing, Flex Consideration doesn’t help a score_mod implementation that features trainable parameters. For instance, whereas the documentation highlights help for relative place encodings, these are generally applied with trainable parameters (moderately than fastened values) which can’t presently be accommodated.
Within the face of those limitations, we will return to one of many different optimization alternatives mentioned above.
Because the reliance on transformer architectures and a spotlight layers in ML fashions will increase, so does the necessity for instruments and strategies for optimizing these parts. On this put up, we now have explored a lot of consideration kernel variants, every with its personal distinctive properties, capabilities, and limitations. Importantly, one dimension doesn’t match all — completely different fashions and use instances will warrant using completely different kernels and completely different optimization methods. This underscores the significance of getting all kinds instruments and strategies for optimizing consideration layers.
In a future put up, we hope to additional discover consideration layer optimization by specializing in making use of a number of the instruments we mentioned to deal with the problem of dealing with variable-sized enter sequences. Keep tuned…