Single node multi-GPU
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.multiprocessing as mp
from torch.utils.data.distributed import DistributedSampler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed import init_process_group, destroy_process_group
import os
def ddp_setup(rank, world_size):
"""
Args:
rank: Unique identifier of each process
world_size: total number of processes
"""
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = 9999
init_process_group(backend="nccl", rank=rank, world_size=world_size)
Constructing the DDP model
self.model = DDP(model, device_ids=[gpu_id])
Distributing input data
train_data = torch.utils.data.DataLoader(
dataset = train_dataset,
batch_size = 32,
shuffle=False,
sampler=DistributedSampler(train_dataset),
)
- Calling the
set_epoch()
method on theDistributedSampler
at the beginning of each epoch is necessary to make shuffling work properly across multiple epochs. Otherwise, the same ordering will be used in each epoch.
def _run_epoch(self, epoch):
batch_size = len(next(iter(self.train_data))[0])
self.train_data.sampler.set_epoch(epoch)
for source, targets in self.train_data:
self._run_batch(srouce, targets)
Saving model checkpoints
ckp = self.model.module.state_dict()
if self.gpu_id == 0 and epoch % self.save_every == 0:
self._save_checkpoint(epoch)
Torchrun
A single process failure could disrupt the whole distributed training. torchrun
provides fault-tolerance and elastic training
def ddp_setup():
init_process_group(backend="nccl")
class Trainer:
def __init__(
self,
model: torch.nn.Module,
train_data: DataLoader,
optimizer: torch.optim.Optimizer,
save_every: int,
snapshot_path: str,
) -> None:
self.local_rank = int(os.environ["LOCAL_RANK"])
self.global_rank = int(os.environ["RANK"])
self.model = model.to(self.gpu_id)
self.train_data = train_data
self.optimizer = optimizer
self.save_every = save_every
self.epochs_run = 0
if os.path.exists(snapshot_path):
print("Loading snapshot")
self._load_snapshot(snapshot_path)
self.model = DDP(self.model, device_ids=[self.gpu_id])
def _save_snapshot(self, epoch):
snapshot = {}
snapshot["MODEL_STATE"] = self.model.module.state_dict()
snapshot["EPOCHS_RUN"] = epoch
torch.save(snapshot, "snapshot.pt")
print(f"Epoch {epoch} | Training snapshot saved at snapshot.pt")
def _load_snapshot(self,snapshot_path):
snapshot = torch.load(snapshot_path)
self.model.load_state_dict(snapshot["MODEL_STATE"])
self.epochs_run = snapshot["EPOCHS_RUN"]
print(f"Resuming training from snapshot at Epoch {self.epochs_run}")
def train(self, max_epochs: int):
for epoch in range(max_epochs):
self._run_epoch(epoch)
def _run_epoch(self, epoch):
batch_size = len(next(iter(self.train_data))[0])
print(f"[GPU{self.gpu_id}] Epoch {epoch} | Batchsize: {batch_size} | Steps: {len(self.train_data)}")
# self.train_data.sampler.set_epoch(epoch)
for source, targets in self.train_data:
source = source.to(self.gpu_id)
targets = targets.to(self.gpu_id)
self._run_batch(srouce, targets)
def main(save_every: int, total_epochs: int, snapshot_path: str = "snapshot") :
ddp_setup()
dataset, model, optimizer = load_train_objs()
train_data = prepare_dataloader(dataset, batch_size=32)
trainer = Trainer(model, train_data, optimizer, save_every, snapshot_path)
trainer.train(total_epochs)
destroy_process_group()
if __name__ == "__main__":
import sys
total_epochs = int(sys.argv[1])
save_every = int(sys.argv[2])
main(save_every, total_epochs)
torchrun --standalone --nproc_per_node=4 multigpu_torchrun.py 50 10