PyTorch Core Procedure in Model Training

1. Define the Model (nn.Module)

In PyTorch, models are defined by subclassing torch.nn.Module.

Core ideas:

import torch.nn as nn

class MyModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc = nn.Linear(128, 10)

    def forward(self, x):
        return self.fc(x)

Key concepts:


2. Define Loss Function (Criterion)

The loss function measures prediction error.

Common examples:

criterion = nn.CrossEntropyLoss()

Role:


3. Define Optimizer

Optimizer updates parameters using gradients.

Common optimizers:

import torch.optim as optim

optimizer = optim.AdamW(model.parameters(), lr=1e-3)

4. Data Pipeline (Dataset + DataLoader)

Encapsulates batching, shuffling, multiprocessing.

from torch.utils.data import DataLoader

train_loader = DataLoader(
    dataset,
    batch_size=32,
    shuffle=True,
    num_workers=4
)

Key roles:


5. Training Loop (Core Procedure)

This is the heart of PyTorch training.

Canonical Pattern

for epoch in range(num_epochs):
    model.train()                # Enable training mode

    for x, y in train_loader:
        x, y = x.to(device), y.to(device)

        # (1) Forward pass
        logits = model(x)

        # (2) Compute loss
        loss = criterion(logits, y)

        # (3) Zero gradients
        optimizer.zero_grad()

        # (4) Backward pass (autograd)
        loss.backward()

        # (5) Parameter update
        optimizer.step()

What Each Step Really Does

Step Operation What Happens Internally
Forward model(x) Builds computation graph
Loss criterion(...) Adds loss node
Zero grad zero_grad() Clears accumulated grads
Backward loss.backward() Reverse-mode autodiff
Step optimizer.step() Updates parameters

6. Autograd: Gradient Computation

PyTorch uses dynamic computation graphs.

Key mechanics:

Example:

loss.backward()

print(model.fc.weight.grad)  # ∂loss / ∂weight

Important:


7. Training vs Evaluation Mode

Some layers behave differently:

model.train()   # Enable dropout, batchnorm updates
model.eval()    # Disable dropout, use running stats

8. Validation Loop (No Gradients)

Validation should not build graphs:

model.eval()
with torch.no_grad():
    for x, y in val_loader:
        logits = model(x)
        val_loss = criterion(logits, y)

Why:


9. Learning Rate Scheduling (Optional but Common)

scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    optimizer, T_max=100
)

scheduler.step()

Used for:


10. Checkpointing

torch.save({
    "model": model.state_dict(),
    "optimizer": optimizer.state_dict(),
}, "checkpoint.pt")

Load:

ckpt = torch.load("checkpoint.pt")
model.load_state_dict(ckpt["model"])
optimizer.load_state_dict(ckpt["optimizer"])

Minimal Mental Model (Interview-Style)

PyTorch training = Forward → Loss → Backward → Step

Formally:

[ \theta \leftarrow \theta - \eta \nabla_\theta \mathcal{L}(f_\theta(x), y) ]


Common Pitfalls (Important)

Mistake Consequence
Forget zero_grad() Gradient accumulation bugs
Forget model.train() Dropout/BN wrong behavior
No torch.no_grad() in val Memory leak
Wrong loss for logits Training instability


Page Source