Pytorch Distributed Data Parallel (DDP)


Single node multi-GPU

import torch
import torch.nn as nn
import torch.nn.functional as F

import torch.multiprocessing as mp
from torch.utils.data.distributed import DistributedSampler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed import init_process_group, destroy_process_group
import os

def ddp_setup(rank, world_size):
    """
    Args:
    	rank: Unique identifier of each process
    	world_size: total number of processes
    """
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = 9999
    init_process_group(backend="nccl", rank=rank, world_size=world_size)

Constructing the DDP model

self.model = DDP(model, device_ids=[gpu_id])

Distributing input data

train_data = torch.utils.data.DataLoader(
	dataset = train_dataset,
    batch_size = 32,
    shuffle=False,
    sampler=DistributedSampler(train_dataset),
)
  • Calling the set_epoch() method on the DistributedSampler at the beginning of each epoch is necessary to make shuffling work properly across multiple epochs. Otherwise, the same ordering will be used in each epoch.
def _run_epoch(self, epoch):
    batch_size = len(next(iter(self.train_data))[0])
    self.train_data.sampler.set_epoch(epoch)
    for source, targets in self.train_data:
        self._run_batch(srouce, targets)

Saving model checkpoints

ckp = self.model.module.state_dict()
if self.gpu_id == 0 and epoch % self.save_every == 0:
    self._save_checkpoint(epoch)

Torchrun

A single process failure could disrupt the whole distributed training. torchrun provides fault-tolerance and elastic training

def ddp_setup():
    init_process_group(backend="nccl")
    
class Trainer:
    def __init__(
        self,
    	model: torch.nn.Module,
        train_data: DataLoader,
        optimizer: torch.optim.Optimizer,
        save_every: int,
        snapshot_path: str,
    ) -> None:
        self.local_rank = int(os.environ["LOCAL_RANK"])
        self.global_rank = int(os.environ["RANK"])
        self.model = model.to(self.gpu_id)
        self.train_data = train_data
        self.optimizer = optimizer
        self.save_every = save_every
        self.epochs_run = 0
        if os.path.exists(snapshot_path):
            print("Loading snapshot")
            self._load_snapshot(snapshot_path)
            
        self.model = DDP(self.model, device_ids=[self.gpu_id])
    
    def _save_snapshot(self, epoch):
        snapshot = {}
        snapshot["MODEL_STATE"] = self.model.module.state_dict()
        snapshot["EPOCHS_RUN"] = epoch
        torch.save(snapshot, "snapshot.pt")
        print(f"Epoch {epoch} | Training snapshot saved at snapshot.pt")
    
    def _load_snapshot(self,snapshot_path):
        snapshot = torch.load(snapshot_path)
        self.model.load_state_dict(snapshot["MODEL_STATE"])
        self.epochs_run = snapshot["EPOCHS_RUN"]
        print(f"Resuming training from snapshot at Epoch {self.epochs_run}")
        
    def train(self, max_epochs: int):
        for epoch in range(max_epochs):
            self._run_epoch(epoch)
    
    def _run_epoch(self, epoch):
        batch_size = len(next(iter(self.train_data))[0])
        print(f"[GPU{self.gpu_id}] Epoch {epoch} | Batchsize: {batch_size} | Steps: {len(self.train_data)}")
    	# self.train_data.sampler.set_epoch(epoch)
    	for source, targets in self.train_data:
            source = source.to(self.gpu_id)
            targets = targets.to(self.gpu_id)
        	self._run_batch(srouce, targets)

def main(save_every: int, total_epochs: int, snapshot_path: str = "snapshot") :
    ddp_setup()
    dataset, model, optimizer = load_train_objs()
    train_data = prepare_dataloader(dataset, batch_size=32)
    trainer = Trainer(model, train_data, optimizer, save_every, snapshot_path)
    trainer.train(total_epochs)
    destroy_process_group()
            
if __name__ == "__main__":
    import sys
    total_epochs = int(sys.argv[1])
    save_every = int(sys.argv[2])
    main(save_every, total_epochs)
torchrun --standalone --nproc_per_node=4 multigpu_torchrun.py 50 10

Author: Wulilichao
Reprint policy: All articles in this blog are used except for special statements CC BY 4.0 reprint policy. If reproduced, please indicate source Wulilichao !
  TOC