You’re coaching your newest AI mannequin, anxiously watching because the loss steadily decreases when immediately — increase! Your logs are flooded with NaNs (Not a Quantity) — your mannequin is irreparably corrupted and also you’re left watching your display in despair. To make issues worse, the NaNs don’t seem constantly. Typically your mannequin trains simply positive; different instances, it fails inexplicably. Typically it’ll crash instantly, typically after many days of coaching.
NaNs in Deep Studying workloads are amongst essentially the most irritating points to come across. And since they typically seem sporadically — triggered by a particular mixture of mannequin state, enter information, and stochastic elements — they are often extremely troublesome to breed and debug.
Given the appreciable value of coaching AI fashions and the potential waste brought on by NaN failures, it is strongly recommended to have devoted instruments for capturing and analyzing NaN occurrences. In a earlier put up, we mentioned the problem of debugging NaNs in a TensorFlow coaching workload. We proposed an environment friendly scheme for capturing and reproducing NaNs and shared a pattern TensorFlow implementation. On this put up, we undertake and show an identical mechanism for debugging NaNs in PyTorch workloads. The overall scheme is as follows:
On every coaching step:
- Save a replica of the coaching enter batch.
- Verify the gradients for NaN values. If any seem, save a checkpoint with the present mannequin weights earlier than the mannequin is corrupted. Additionally, save the enter batch and, if mandatory, the stochastic state. Discontinue the coaching job.
- Reproduce and debug the NaN prevalence by loading the saved experiment state.
Though this scheme might be simply carried out in native PyTorch, we’ll take the chance to show a few of the conveniences of PyTorch Lightning — a strong open-source framework designed to streamline the event of machine studying (ML) fashions. Constructed on PyTorch, Lightning abstracts away lots of the boiler-plate parts of an ML experiment, reminiscent of coaching loops, information distribution, logging, and extra, enabling builders to give attention to the core logic of their fashions.
To implement our NaN capturing scheme, we’ll use Lightning’s callback interface — a devoted construction that permits inserting customized logic at particular factors throughout the movement of execution.
Importantly, please don’t view our selection of Lightning or some other instrument or method that we point out as an endorsement of its use. The code that we’ll share is meant for demonstrative functions — please don’t depend on its correctness or optimality.
Many because of Rom Maltser for his contributions to this put up.
NaNCapture Callback
To implement our NaN capturing answer, we create a NaNCapture Lightning callback. The constructor receives a listing path for storing/loading checkpoints and units up the NaNCapture state. We additionally outline utilities for checking for NaNs, storing checkpoints, and halting the coaching job.
import os
import torch
from copy import deepcopy
import lightning.pytorch as pl
class NaNCapture(pl.Callback):
def __init__(self, dirpath: str):
# path to checkpoint
self.dirpath = dirpath
# replace to True when Nan is recognized
self.nan_captured = False
# shops a replica of the final batch
self.last_batch = None
self.batch_idx = None
@staticmethod
def contains_nan(tensor):
return torch.isnan(tensor).any().merchandise()
# alternatively test for finite
# return not torch.isfinite(tensor).merchandise()
@staticmethod
def halt_training(coach):
coach.should_stop = True
# talk cease command to all different ranks
coach.technique.reduce_boolean_decision(coach.should_stop,
all=False)
def save_ckpt(self, coach):
os.makedirs(self.dirpath, exist_ok=True)
# embrace coach.global_rank to keep away from battle
filename = f"nan_checkpoint_rank_{coach.global_rank}.ckpt"
full_path = os.path.be a part of(self.dirpath, filename)
print(f"saving ckpt to {full_path}")
coach.save_checkpoint(full_path, False)
Callback Operate: on_train_batch_start
We start by implementing the on_train_batch_start hook to retailer a replica of every enter batch. In case of a NaN occasion, this batch can be saved within the checkpoint.
Callback Operate: on_before_optimizer_step
Subsequent we implement the on_before_optimizer_step hook. Right here, we test for NaN entries in the entire gradient tensors. If discovered, we retailer a checkpoint with the uncorrupted mannequin weights and halt the coaching.
Python"> def on_before_optimizer_step(self, coach, pl_module, optimizer):
if not self.nan_captured:
# Verify if gradients comprise NaN
grads = [p.grad.view(-1) for p in pl_module.parameters()
if p.grad is not None]
all_grads = torch.cat(grads)
if self.contains_nan(all_grads):
print("nan discovered")
self.save_ckpt(coach)
self.halt_training(coach)
Capturing the Coaching State
To allow reproducibility, we embrace the NaNCapture state within the checkpoint by appending it to the coaching state dictionary. Lightning gives devoted utilities for saving and loading a callback state:
def state_dict(self):
d = {"nan_captured": self.nan_captured}
if self.nan_captured:
d["last_batch"] = self.last_batch
return d
def load_state_dict(self, state_dict):
self.nan_captured = state_dict.get("nan_captured", False)
if self.nan_captured:
self.last_batch = state_dict["last_batch"]
Reproducing the NaN Prevalence
Now we have described how our NaNCapture callback can be utilized to retailer the coaching state that resulted in a NaN, however how will we reload this state with a view to reproduce the problem and debug it? To perform this, we leverage Lightning’s devoted information loading class, LightningDataModule.
DataModule Operate: on_before_batch_transfer
Within the code block beneath, we lengthen the LightningDataModule class to permit injecting a set coaching enter batch. That is achieved by overriding the on_before_batch_transfer hook, as proven beneath:
from lightning.pytorch import LightningDataModule
class InjectableDataModule(LightningDataModule):
def __init__(self):
tremendous().__init__()
self.cached_batch = None
def set_custom_batch(self, batch):
self.cached_batch = batch
def on_before_batch_transfer(self, batch, dataloader_idx):
if self.cached_batch:
return self.cached_batch
return batch
Callback Operate: on_train_start
The ultimate step is modifying the on_train_start hook of our NaNCapture callback to inject the saved coaching batch into the LightningDataModule.
def on_train_start(self, coach, pl_module):
if self.nan_captured:
datamodule = coach.datamodule
datamodule.set_custom_batch(self.last_batch)
Within the subsequent part we’ll show the end-to-end answer utilizing a toy instance.
Toy Instance
To check our new callback, we create a resnet50-based picture classification mannequin with a loss perform intentionally designed to set off NaN occurrences.
As a substitute of utilizing the usual CrossEntropy loss, we compute binary_cross_entropy_with_logits for every class independently and divide the end result by the variety of samples belonging to that class. Inevitably, we’ll encounter a batch wherein a number of courses are lacking, resulting in a divide-by-zero operation, leading to NaN values and corrupting the mannequin.
The implementation beneath follows Lightning’s introductory tutorial.
import lightning.pytorch as pl
import torch
import torchvision
import torch.nn.useful as F
num_classes = 20
# outline a lightning module
class ResnetModel(pl.LightningModule):
def __init__(self):
"""Initializes a brand new occasion of the MNISTModel class."""
tremendous().__init__()
self.mannequin = torchvision.fashions.resnet50(num_classes=num_classes)
def ahead(self, x):
return self.mannequin(x)
def training_step(self, batch, batch_nb):
x, y = batch
outputs = self(x)
# uncomment for default loss
# return F.cross_entropy(outputs, y)
# calculate binary_cross_entropy for every class individually
losses = []
for c in vary(num_classes):
depend = torch.count_nonzero(y==c)
masked = torch.the place(y==c, 1., 0.)
loss = F.binary_cross_entropy_with_logits(
outputs[..., c],
masked,
discount='sum'
)
mean_loss = loss/depend # might lead to NaN
losses.append(mean_loss)
total_loss = torch.stack(losses).imply()
return total_loss
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=0.02)
We outline an artificial dataset and encapsulate it in our InjectableDataModule
class:
import os
import random
from torch.utils.information import Dataset, DataLoader
batch_size = 128
num_steps = 800
# A dataset with random photographs and labels
class FakeDataset(Dataset):
def __len__(self):
return batch_size*num_steps
def __getitem__(self, index):
rand_image = torch.randn([3, 224, 224], dtype=torch.float32)
label = torch.tensor(random.randint(0, num_classes-1),
dtype=torch.int64)
return rand_image, label
# outline a lightning datamodule
class FakeDataModule(InjectableDataModule):
def train_dataloader(self):
dataset = FakeDataset()
return DataLoader(
dataset,
batch_size=batch_size,
num_workers=os.cpu_count(),
pin_memory=True
)
Lastly, we initialize a Lightning Coach with our NaNCapture callback and name coach.match with our Lightning module and Lightning DataModule.
import time
if __name__ == "__main__":
# Initialize a lightning module
lit_module = ResnetModel()
# Initialize a DataModule
mnist_data = FakeDataModule()
# Practice the mannequin
ckpt_dir = "./ckpt_dir"
coach = pl.Coach(
max_epochs=1,
callbacks=[NaNCapture(ckpt_dir)]
)
ckpt_path = None
# test is nan ckpt exists
if os.path.isdir(ckpt_dir):
# test if nan ckpt exists
if os.path.isdir(ckpt_dir):
dir_contents = [os.path.join(ckpt_dir, f)
for f in os.listdir(ckpt_dir)]
ckpts = [f for f in dir_contents
if os.path.isfile(f) and f.endswith('.ckpt')]
if ckpts:
ckpt_path = ckpts[0]
t0 = time.perf_counter()
coach.match(lit_module, mnist_data, ckpt_path=ckpt_path)
print(f"whole runtime: {time.perf_counter() - t0}")
After various coaching steps, a NaN occasion will happen. At this level a checkpoint is saved with the complete coaching state and the coaching is halted.
When the script is run once more the precise state that brought on the NaN can be reloaded permitting us to simply reproduce the problem and debug its root trigger.
Efficiency Overhead
To evaluate the influence of our NaNCapture callback on runtime efficiency, we modified our experiment to make use of CrossEntropyLoss (to keep away from NaNs) and measured the typical throughput when operating with and with out NaNCapture callback. The experiments have been carried out on an NVIDIA L40S GPU, with a PyTorch 2.5.1 Docker picture.

For our toy mannequin, the NaNCapture callback provides a minimal 1.5% overhead to the runtime efficiency — a small value to pay for the dear debugging capabilities it gives.
Naturally, the precise overhead will rely on the specifics of the mannequin and runtime atmosphere.
Easy methods to Deal with Stochasticity
The answer we’ve described henceforth will reach reproducing the coaching state supplied that the mannequin doesn’t embrace any randomness. Nonetheless, introducing stochasticity into the mannequin definition is usually vital for convergence. A typical instance of a stochastic layer is torch.nn.Dropout.
You could discover that your NaN occasion will depend on the exact state of randomness when the failure occurred. Consequently, we wish to improve our NaNCapture callback to seize and restore the random state on the level of failure. The random state is decided by various libraries. Within the code block beneath, we try and seize the complete state of randomness:
import os
import torch
import random
import numpy as np
from copy import deepcopy
import lightning.pytorch as pl
class NaNCapture(pl.Callback):
def __init__(self, dirpath: str):
# path to checkpoint
self.dirpath = dirpath
# replace to True when Nan is recognized
self.nan_captured = False
# shops a replica of the final batch
self.last_batch = None
self.batch_idx = None
# rng state
self.rng_state = {
"torch": None,
"torch_cuda": None,
"numpy": None,
"random": None
}
@staticmethod
def contains_nan(tensor):
return torch.isnan(tensor).any().merchandise()
# alternatively test for finite
# return not torch.isfinite(tensor).merchandise()
@staticmethod
def halt_training(coach):
coach.should_stop = True
coach.technique.reduce_boolean_decision(coach.should_stop,
all=False)
def save_ckpt(self, coach):
os.makedirs(self.dirpath, exist_ok=True)
# embrace coach.global_rank to keep away from battle
filename = f"nan_checkpoint_rank_{coach.global_rank}.ckpt"
full_path = os.path.be a part of(self.dirpath, filename)
print(f"saving ckpt to {full_path}")
coach.save_checkpoint(full_path, False)
def on_train_start(self, coach, pl_module):
if self.nan_captured:
# inject batch
datamodule = coach.datamodule
datamodule.set_custom_batch(self.last_batch)
def on_train_batch_start(self, coach, pl_module, batch, batch_idx):
if self.nan_captured:
# restore random state
torch.random.set_rng_state(self.rng_state["torch"])
torch.cuda.set_rng_state_all(self.rng_state["torch_cuda"])
np.random.set_state(self.rng_state["numpy"])
random.setstate(self.rng_state["random"])
else:
# seize present batch
self.last_batch= deepcopy(batch)
self.batch_idx = batch_idx
# seize present random state
self.rng_state["torch"] = torch.random.get_rng_state()
self.rng_state["torch_cuda"] = torch.cuda.get_rng_state_all()
self.rng_state["numpy"] = np.random.get_state()
self.rng_state["random"] = random.getstate()
def on_before_optimizer_step(self, coach, pl_module, optimizer):
if not self.nan_captured:
# Verify if gradients comprise NaN
grads = [p.grad.view(-1) for p in pl_module.parameters()
if p.grad is not None]
all_grads = torch.cat(grads)
if self.contains_nan(all_grads):
print("nan discovered")
self.save_ckpt(coach)
self.halt_training(coach)
def state_dict(self):
d = {"nan_captured": self.nan_captured}
if self.nan_captured:
d["last_batch"] = self.last_batch
d["rng_state"] = self.rng_state
return d
def load_state_dict(self, state_dict):
self.nan_captured = state_dict.get("nan_captured", False)
if self.nan_captured:
self.last_batch = state_dict["last_batch"]
self.rng_state = state_dict["rng_state"]
Importantly, setting the random state might not assure full reproducibility. The GPU owes its energy to its huge parallelism. In some GPU operations, a number of threads might learn or write concurrently to the identical reminiscence places leading to nondeterminism. PyTorch permits for some management over this by way of its use_deterministic_algorithms, however this will influence the runtime efficiency. Moreover, there’s a chance that the NaN occasion won’t reproduced as soon as this configuration setting is modified. Please see the PyTorch documentation on reproducibility for extra particulars.
Abstract
Encountering NaN failures is likely one of the most discouraging occasions that may occur in machine studying growth. These errors not solely waste priceless computation and growth sources, however typically point out basic points within the mannequin structure or experiment design. Attributable to their sporadic, typically elusive nature, debugging NaN failures could be a nightmare.
This put up launched a proactive method for capturing and reproducing NaN errors utilizing a devoted Lightning callback. The answer we shared is a proposal which might be modified and prolonged in your particular use case.
Whereas this answer might not tackle each potential NaN situation, it considerably reduces debugging time when relevant, probably saving builders numerous hours of frustration and wasted effort.