Reminiscence-Environment friendly Mannequin Weight Loading in PyTorch

I lately got here throughout a publish by Sebastian that caught my consideration, and I needed to dive deeper into its content material. As fashions develop bigger and extra complicated, effectively managing reminiscence throughout mannequin loading turns into more and more vital, particularly when working with restricted GPU or CPU assets. In his publish, Sebastian covers sensible ideas for loading bigger pre-trained or fine-tuned fashions in constrained reminiscence environments, which is especially related when working with PyTorch.

This information emphasizes the best way to deal with conditions the place fashions are saved utilizing torch.save(mannequin.state_dict(), "mannequin.pth") and later must be loaded for continued pre-training or additional fine-tuning. Whereas the examples concentrate on a big language mannequin (LLM), Sebastian’s strategies are broadly relevant to any PyTorch mannequin. Moreover, they supply useful insights into memory-efficient mannequin weight loadingy in PyTorch, serving to optimize reminiscence utilization through the loading course of.

Overview

  • Environment friendly reminiscence administration is essential for loading massive neural networks in PyTorch, particularly on methods with restricted GPU or CPU assets.
  • As a substitute of loading the complete mannequin directly, you may load weights incrementally.Usually, calling mannequin.to(system) strikes all of the mannequin’s parameters to the system (like a GPU), which might eat important reminiscence.
  • PyTorch launched the “meta” system, which permits for the creation of tensors with out utilizing reminiscence.
  • By using the meta system, you may load weights immediately into GPU reminiscence, bypassing the CPU and optimizing reminiscence utilization.

Preliminary Setup: Setting Verify

Earlier than diving into the specifics, let’s make sure that the mandatory packages and variations can be found. Right here’s a snippet that checks for the model of PyTorch and different helpful instruments.

from importlib.metadata import model

pkgs = [
    "torch",
]
for p in pkgs:
    print(f"{p} model: {model(p)}")

Benchmark Utilities for Reminiscence Monitoring

Step one is to arrange a utility to trace GPU reminiscence (VRAM). Monitoring reminiscence utilization helps in understanding how completely different strategies influence reminiscence load throughout mannequin loading and inference. Later, we can even observe the system’s RAM (CPU reminiscence).

Right here’s the utility code for GPU reminiscence monitoring:


import gc
import time
import torch

def start_memory_tracking():
    """Initialize GPU reminiscence monitoring."""
    if torch.cuda.is_available():
        torch.cuda.reset_peak_memory_stats()
    else:
        print("This pocket book is meant for CUDA GPUs however CUDA just isn't obtainable.")

def print_memory_usage():
    max_gpu_memory = torch.cuda.max_memory_allocated() / (1024 ** 3)  # Convert bytes to GB
    print(f"Most GPU reminiscence allotted: {max_gpu_memory:.1f} GB")

def cleanup():
    gc.accumulate()
    torch.cuda.empty_cache()
    time.sleep(3)  # Permit time for reminiscence to clear
    torch.cuda.reset_peak_memory_stats()
    max_memory_allocated = torch.cuda.max_memory_allocated() / (1024 ** 3)
    print(f"Most GPU reminiscence allotted: {max_memory_allocated:.1f} GB")

These features assist observe GPU reminiscence utilization earlier than, throughout, and after mannequin operations. The cleanup() perform is particularly helpful for clearing unused reminiscence to keep away from working out of VRAM.

Mannequin Setup

Subsequent, we arrange the mannequin. For demonstration, we are going to use the “GPT-2 massive” mannequin (although you may modify the mannequin measurement to fit your reminiscence constraints). By altering the configuration, the mannequin measurement can vary from “gpt2-small” (124M parameters) to “gpt2-xl” (1558M parameters).

Right here’s the configuration:

from previous_chapters import GPTModel

BASE_CONFIG = {
    "vocab_size": 50257,     # Vocabulary measurement
    "context_length": 1024,  # Context size
    "drop_rate": 0.0,        # Dropout price
    "qkv_bias": True         # Question-key-value bias
}

model_configs = {
    "gpt2-small (124M)": {"emb_dim": 768, "n_layers": 12, "n_heads": 12},
    "gpt2-medium (355M)": {"emb_dim": 1024, "n_layers": 24, "n_heads": 16},
    "gpt2-large (774M)": {"emb_dim": 1280, "n_layers": 36, "n_heads": 20},
    "gpt2-xl (1558M)": {"emb_dim": 1600, "n_layers": 48, "n_heads": 25},
}

CHOOSE_MODEL = "gpt2-xl (1558M)"
BASE_CONFIG.replace(model_configs[CHOOSE_MODEL])

This configuration permits flexibility in selecting fashions based mostly on obtainable reminiscence assets. For decrease reminiscence consumption, choosing a smaller variant (like gpt2-small) is advisable.

As soon as the mannequin configuration is ready up, the following steps will dive into loading, managing, and optimizing the mannequin weights for environment friendly reminiscence utilization.

Monitoring GPU Reminiscence Throughout Mannequin Loading

Let’s now put the GPU reminiscence monitoring utilities into motion. First, we initialize reminiscence monitoring and cargo the mannequin to watch reminiscence consumption. The code under tracks GPU reminiscence utilization as we load and run a GPT mannequin.

start_memory_tracking()

mannequin = GPTModel(BASE_CONFIG)
system = torch.system("cuda")
mannequin.to(system)

print_memory_usage()
# Output: Most GPU reminiscence allotted: 6.4 GB

This exhibits that loading and putting the mannequin onto the GPU consumes round 6.4 GB of VRAM, which is typical for bigger fashions like GPT-2. Nonetheless, that is simply the preliminary setup.

Operating the Mannequin

To confirm that all the things works appropriately, we go a easy enter tensor to the mannequin. Though we aren’t monitoring reminiscence throughout this step, it’s important to test that the mannequin operates as anticipated.

# Check if the mannequin works (no want to trace reminiscence right here)
test_input = torch.tensor([[1, 2, 3]]).to(system)
mannequin.eval()

with torch.no_grad():
    mannequin(test_input)

Saving the Mannequin

Now, think about we’re pretraining the mannequin (or finetuning it). For this instance, we skip the precise pretraining course of and immediately save the initialized mannequin. The next code saves the mannequin’s weights utilizing torch.save().

# Coaching code would go right here...

mannequin.practice()
torch.save(mannequin.state_dict(), "mannequin.pth")

Reminiscence Cleanup

After saving the mannequin, it’s vital to unlock GPU reminiscence to make sure environment friendly useful resource administration in subsequent operations. By deleting the mannequin and the check enter tensor, after which working our cleanup() perform, we clear up VRAM.

del mannequin, test_input
cleanup()
# Output: Most GPU reminiscence allotted: 0.0 GB

At this level, the GPU reminiscence utilization is reset to zero, as anticipated.

Loading Pretrained Mannequin Weights

The subsequent step includes reloading the saved mannequin weights to proceed coaching or finetuning. Nonetheless, loading pretrained weights requires extra GPU reminiscence than initializing a recent mannequin as a result of the mannequin’s weights are loaded twice: as soon as when loading the mannequin itself, and once more when loading the weights into reminiscence.

# Begin monitoring reminiscence
start_memory_tracking()

# Recreate the mannequin structure
mannequin = GPTModel(BASE_CONFIG)
mannequin.to(system)

# Load the saved state_dict
mannequin.load_state_dict(
    torch.load("mannequin.pth", map_location=system, weights_only=True)
)
mannequin.to(system)
mannequin.eval()

print_memory_usage()
# Output: Most GPU reminiscence allotted: 12.8 GB

The GPU reminiscence utilization has now doubled in comparison with the preliminary load, peaking at 12.8 GB. This occurs as a result of, for a brief interval, each the unique mannequin and the newly loaded weights are held in reminiscence. Ultimately, the loaded weights are copied into the mannequin, and the non permanent state_dict is discarded. Nonetheless, this reminiscence spike may cause points when working with restricted assets.

Resetting GPU Reminiscence

After loading the mannequin weights and testing it, it’s important to reset the GPU reminiscence as soon as once more. Testing the mannequin ensures it really works as anticipated, and clearing reminiscence is essential for environment friendly useful resource utilization.

# Check if the mannequin works (no want to trace reminiscence right here)
test_input = torch.tensor([[1, 2, 3]]).to(system)
mannequin.eval()

with torch.no_grad():
    mannequin(test_input)

del mannequin, test_input
cleanup()
# Output: Most GPU reminiscence allotted: 0.0 GB

This reset brings GPU reminiscence utilization again to zero, making certain a clear state for future operations.

Loading Weights Sequentially

One efficient workaround for the issue of double reminiscence utilization when loading mannequin weights is sequential loading. As a substitute of loading each the mannequin and weights concurrently into GPU reminiscence, we are able to load the mannequin first, maintain the weights in CPU reminiscence, after which copy every parameter one after the other to the GPU. This technique considerably reduces the height reminiscence utilization.

Right here’s the best way to implement sequential weight loading:

Step-by-Step Breakdown:

  1. Load the Mannequin onto the GPU: First, we load the mannequin structure into GPU reminiscence, as traditional.
  2. Load the Weights onto the CPU: The mannequin weights are loaded onto CPU reminiscence, avoiding the preliminary reminiscence spike brought on by shifting each the mannequin and the weights to the GPU.
  3. Copy Weights Parameter by Parameter: Every weight is then copied sequentially from the CPU to GPU, that means that at no level do we have now each the mannequin and the complete state_dict in GPU reminiscence.

The code under demonstrates this method:

start_memory_tracking()

# Load the mannequin into GPU reminiscence
mannequin = GPTModel(BASE_CONFIG).to(system)

# Load the mannequin's saved state_dict onto the CPU
state_dict = torch.load("mannequin.pth", map_location="cpu", weights_only=True)

print_memory_usage()
# Output: Most GPU reminiscence allotted: 6.4 GB

# Copy every parameter to GPU reminiscence one after the other
with torch.no_grad():
    for title, param in mannequin.named_parameters():
        if title in state_dict:
            param.copy_(state_dict[name].to(system))
        else:
            print(f"Warning: {title} not present in state_dict.")

print_memory_usage()
# Output: Most GPU reminiscence allotted: 6.7 GB

Reminiscence Comparability:

  • Initially, the mannequin alone occupies
  • As we copy every parameter sequentially, the reminiscence will increase barely to

Nonetheless, it is a a lot smaller peak in comparison with the 12.8 GB required when loading all the things directly. By sequentially loading the weights, we keep away from having each the complete mannequin and the complete set of weights in GPU reminiscence concurrently.

Mannequin Testing & Reminiscence Reset:

After copying the weights, we check the mannequin to make sure all the things works as anticipated. Lastly, we reset the GPU reminiscence to clear any lingering objects, simply as we did in earlier steps.

# Check if the mannequin works (no want to trace reminiscence right here)
test_input = torch.tensor([[1, 2, 3]]).to(system)
mannequin.eval()

with torch.no_grad():
    mannequin(test_input)

# Clear up GPU reminiscence
del mannequin, test_input, state_dict, param
cleanup()
# Output: Most GPU reminiscence allotted: 0.0 GB

Loading the Mannequin with Low CPU Reminiscence

Within the earlier part, we decreased GPU reminiscence utilization by loading mannequin weights into CPU reminiscence first after which sequentially copying them into the GPU. However what if the machine has restricted CPU reminiscence and bigger GPU reminiscence? To sort out this, we are able to use PyTorch’s “meta” system method, which is right for machines with constrained CPU assets.

Meta Gadget: A Good Tradeoff

The “meta” system is a particular system kind in PyTorch that creates “meta” tensors. These tensors symbolize the form and kind of the information with out allocating reminiscence for the information itself. This enables us to outline fashions with out consuming CPU or GPU reminiscence till crucial.

Utilizing the meta system, we are able to first initialize the mannequin with none reminiscence allocation, after which load the mannequin weights immediately into GPU reminiscence, bypassing the CPU.

Monitoring CPU Reminiscence Utilization

Earlier than we dive into the meta system method, we are going to outline a utility perform to trace CPU reminiscence utilization:


import os
import psutil
from threading import Thread

def memory_usage_in_gb(func, *args, **kwargs):
    course of = psutil.Course of(os.getpid())
    baseline_mem = course of.memory_info().rss / 1024 ** 3  # in GB
    mem_usage = []
    performed = False

    def monitor_memory():
        whereas not performed:
            mem_usage.append(course of.memory_info().rss / 1024 ** 3)  # Convert to GB
            time.sleep(0.1)

    t = Thread(goal=monitor_memory)
    t.begin()

    func(*args, **kwargs)
    performed = True
    t.be a part of()

    peak_mem_usage_gb = max(mem_usage) - baseline_mem
    return peak_mem_usage_gb

Now that we are able to measure CPU reminiscence utilization, let’s observe the reminiscence used through the sequential weight loading method from the earlier part:

def load_sequentially():
    start_memory_tracking()

    mannequin = GPTModel(BASE_CONFIG).to(system)
    state_dict = torch.load("mannequin.pth", map_location="cpu", weights_only=True)

    print_memory_usage()

    # Sequentially copy weights to the mannequin's parameters
    with torch.no_grad():
        for title, param in mannequin.named_parameters():
            if title in state_dict:
                param.copy_(state_dict[name].to(system))

    print_memory_usage()

peak_memory_used = memory_usage_in_gb(load_sequentially)
print(f"-> Most CPU reminiscence allotted: {peak_memory_used:.1f} GB")

This method outputs:

  • Most GPU reminiscence allotted: 6.7 GB
  • Most CPU reminiscence allotted: 6.3 GB

Meta Gadget Method

To additional scale back CPU reminiscence utilization, we are able to use the meta system to load the mannequin with out allocating reminiscence till we really need it. Right here’s the implementation:

def load_sequentially_with_meta():
    start_memory_tracking()

    with torch.system("meta"):
        mannequin = GPTModel(BASE_CONFIG)

    mannequin = mannequin.to_empty(system=system)
    state_dict = torch.load("mannequin.pth", map_location=system, weights_only=True)

    print_memory_usage()

    # Sequentially copy weights to the mannequin's parameters
    with torch.no_grad():
        for title, param in mannequin.named_parameters():
            if title in state_dict:
                param.copy_(state_dict[name])

    print_memory_usage()

peak_memory_used = memory_usage_in_gb(load_sequentially_with_meta)
print(f"-> Most CPU reminiscence allotted: {peak_memory_used:.1f} GB")

Reminiscence Utilization with Meta Gadget:

  • Most GPU reminiscence allotted: 12.8 GB
  • Most CPU reminiscence allotted: 1.3 GB

Through the use of the meta system and immediately loading the mannequin weights into GPU reminiscence, we drastically scale back CPU reminiscence consumption from 6.3 GB to only 1.3 GB.

Comparability with Baseline

Lastly, let’s evaluate this technique with the straightforward PyTorch weight loading technique, the place no meta system or sequential loading is used:

def baseline():
    start_memory_tracking()

    mannequin = GPTModel(BASE_CONFIG)
    mannequin.to(system)
    mannequin.load_state_dict(torch.load("mannequin.pth", map_location=system, weights_only=True))
    mannequin.to(system)
    mannequin.eval()

    print_memory_usage()

peak_memory_used = memory_usage_in_gb(baseline)
print(f"-> Most CPU reminiscence allotted: {peak_memory_used:.1f} GB")

For this method:

  • Most GPU reminiscence allotted: 12.8 GB
  • Most CPU reminiscence allotted: 4.4 GB

Utilizing mmap=True for Environment friendly Mannequin Loading

For extra superior customers of PyTorch, there’s an alternate method to dealing with reminiscence constraints when loading massive fashions—utilizing the mmap=True setting in torch.load(). This setting leverages memory-mapped file I/O, which permits the mannequin to learn information immediately from disk with out absolutely loading it into RAM. That is notably helpful on methods with restricted CPU reminiscence, because it minimizes the reminiscence footprint throughout mannequin loading.

What’s mmap=True?

Reminiscence-mapped I/O (mmap) is a mechanism that permits a file to be learn immediately from disk by mapping it into the digital deal with area. As a substitute of loading the complete mannequin into RAM, PyTorch can load components of the mannequin on demand, successfully lowering reminiscence utilization. This may be notably advantageous when coping with massive pretrained or finetuned fashions, akin to GPT-2 or GPT-3, on machines with restricted assets.

The mmap=True possibility could be added when calling torch.load() to attain this habits.

Instance Implementation of mmap=True

Let’s see how the mmap=True possibility works in apply. Beneath is a pattern implementation the place we load a mannequin utilizing this setting:

def best_practices():
    with torch.system("meta"):
        mannequin = GPTModel(BASE_CONFIG)

    mannequin.load_state_dict(
        torch.load("mannequin.pth", map_location=system, weights_only=True, mmap=True),
        assign=True
    )

    print_memory_usage()

peak_memory_used = memory_usage_in_gb(best_practices)
print(f"-> Most CPU reminiscence allotted: {peak_memory_used:.1f} GB")

Outcomes with mmap=True

  • Most GPU reminiscence allotted: 6.4 GB
  • Most CPU reminiscence allotted: 5.9 GB

Right here, we see that the GPU reminiscence utilization stays environment friendly (6.4 GB), and CPU reminiscence utilization is pretty excessive as a result of the machine has sufficient CPU RAM to assist it. Nonetheless, on a system with restricted CPU RAM, the mmap=True method would use much less reminiscence by avoiding loading the complete mannequin into RAM.

When to Use mmap=True

The mmap=True possibility is particularly useful within the following eventualities:

  • Restricted CPU RAM
  • Disk I/O Pace

Efficiency Concerns

At first look, the mmap=True method may appear much less environment friendly in comparison with the sequential weight loading method. Nonetheless, for machines with restricted CPU reminiscence, mmap=True could be a game-changer, offering an efficient solution to load massive fashions with out overwhelming the CPU’s obtainable reminiscence.

Through the use of mmap=True, you’re balancing disk entry with reminiscence availability, which may also help in environments the place reminiscence is scarce however disk I/O is quick.

Different Strategies for Mannequin Weight Loading

On this pocket book, we’ve centered on easy, built-in strategies for effectively loading mannequin weights in PyTorch, notably when reminiscence (both GPU or CPU) is constrained. The advisable technique for managing restricted CPU reminiscence is the mmap=True method, as defined beforehand.

Nonetheless, when you’re coping with excessive reminiscence limitations or want extra management over the method, there’s one other brute-force method: saving and loading every weight tensor individually.

Saving Mannequin Weights Individually

As a substitute of saving the complete state_dict as a single file, this technique shops every mannequin parameter (tensor) individually. This lets you load every parameter one by one, stopping the necessity to maintain the complete mannequin in reminiscence concurrently.

Right here’s how one can save the mannequin weights individually:

mannequin = GPTModel(BASE_CONFIG)
# Assume `mannequin` is your educated mannequin
state_dict = mannequin.state_dict()

# Create a listing to retailer particular person parameter recordsdata
os.makedirs("model_parameters", exist_ok=True)

# Save every parameter tensor individually
for title, param in state_dict.objects():
    torch.save(param.cpu(), f"model_parameters/{title}.pt")

del mannequin  # Release GPU reminiscence

This breaks the mannequin into particular person elements, saving every tensor to its personal file within the “model_parameters” listing.

Loading Weights Individually

Now, let’s see how one can load these weights one-by-one to keep away from overwhelming reminiscence utilization.

def load_individual_weights():
    start_memory_tracking()

    with torch.system("meta"):
        mannequin = GPTModel(BASE_CONFIG)

    mannequin = mannequin.to_empty(system=system)

    print_memory_usage()
    param_dir = "model_parameters"

    with torch.no_grad():
        for title, param in mannequin.named_parameters():
            weight_path = os.path.be a part of(param_dir, f"{title}.pt")
            if os.path.exists(weight_path):
                param_data = torch.load(weight_path, map_location="cpu", weights_only=True)
                param.copy_(param_data.to(system))  # Transfer tensor to GPU
                del param_data  # Free reminiscence after copying
            else:
                print(f"Warning: {title} not present in {param_dir}.")

    print_memory_usage()

Outcomes from Particular person Weight Loading

  • Most GPU reminiscence allotted: 6.4 GB
  • Most CPU reminiscence allotted: 0.3 GB

The reminiscence footprint right here is considerably decreased—each on the GPU and CPU. By loading weights individually, you make sure that no pointless reminiscence is consumed at any stage, making this method perfect for terribly memory-limited environments.

When to Use This Technique

  • Excessive Reminiscence Limitations

When CPU and GPU reminiscence are each extremely constrained, this technique affords exact management, making certain that just one parameter tensor is loaded into reminiscence at any given time.

On machines the place you can’t afford to make use of greater than minimal assets, this brute-force technique offers an answer to make sure you can load even the biggest fashions.

Efficiency Concerns

The trade-off right here is efficiency. Since every tensor is loaded individually, this technique incurs additional disk I/O, which can decelerate the loading course of in comparison with strategies that load the complete mannequin or bigger chunks of information directly.

When working with massive fashions, akin to GPT variants or different deep studying fashions, reminiscence effectivity is essential. Strategies like sequential weight loading, utilizing the meta system, and enabling mmap=True assist scale back reminiscence utilization on each CPU and GPU. These strategies, recognized for memory-efficient mannequin weight loading in PyTorch, are extremely versatile and could be tailored relying on the precise constraints of your {hardware} setting, whether or not you may have restricted CPU RAM, GPU VRAM, or each.

By using these strategies, you may work with massive fashions even on constrained {hardware}, making certain clean mannequin coaching and fine-tuning workflows.

Hope you just like the article! Reminiscence-efficient mannequin weight loading in PyTorch helps save assets. Utilizing reminiscence environment friendly mannequin weight loading in Python can scale back overhead. For a reminiscence environment friendly mannequin weight loading in PyTorch instance, strive utilizing torch.load() with reminiscence mapping to decrease RAM utilization.

Steadily Requested Questions

Q1.What’s the significance of memory-efficient mannequin loading in PyTorch?

As deep studying fashions develop bigger (particularly fashions like GPT-2, GPT-3), effectively loading these fashions turns into important to forestall working out of GPU or CPU reminiscence. Reminiscence-efficient loading lets you work with massive fashions even in constrained environments.

Q2.How can I observe GPU reminiscence utilization throughout mannequin loading in PyTorch?

You need to use the features torch.cuda.reset_peak_memory_stats() and torch.cuda.max_memory_allocated() to trace GPU reminiscence utilization earlier than, throughout, and after loading or coaching fashions. The offered utility features assist monitor reminiscence utilization effectively.

Q3.What’s sequential weight loading in PyTorch, and the way does it assist?

Sequential weight loading includes loading the mannequin structure onto the GPU after which transferring weights one by one from CPU to GPU. This reduces the height reminiscence utilization in comparison with loading each the mannequin and its weights directly, serving to handle restricted GPU reminiscence.

This fall.How do I scale back reminiscence utilization in PyTorch?

Use decrease precision: float16, combined precision.
Optimize tensor operations: keep away from copies, environment friendly shapes, views.
Gradient accumulation: replace weights much less regularly.
Cut back mannequin measurement: prune connections, quantize weights, smaller fashions.
Optimize information loading: information loaders, prefetching, memory-mapped recordsdata.
GPU reminiscence effectivity: monitor utilization, free unused reminiscence, a number of GPUs.
Superior strategies: information distillation, low-rank approximation.

Q5.What’s the “meta” system in PyTorch, and the way does it assist with reminiscence constraints?

The “meta” system lets you initialize fashions with out allocating reminiscence for his or her parameters. That is helpful when you may have restricted CPU reminiscence since you may later load weights immediately into the GPU, bypassing the necessity for big reminiscence allocations on the CPU.

Hello, I’m Janvi, a passionate information science fanatic at present working at Analytics Vidhya. My journey into the world of information started with a deep curiosity about how we are able to extract significant insights from complicated datasets.