PyTorch Core Procedure in Model Training
1. Define the Model (nn.Module)
In PyTorch, models are defined by subclassing torch.nn.Module.
Core ideas:
- Parameters are registered automatically.
- Layers are defined in
__init__ - Forward computation is defined in
forward()
import torch.nn as nn
class MyModel(nn.Module):
def __init__(self):
super().__init__()
self.fc = nn.Linear(128, 10)
def forward(self, x):
return self.fc(x)
Key concepts:
model.parameters()exposes trainable tensorsmodel.to(device)moves model to CPU/GPU
2. Define Loss Function (Criterion)
The loss function measures prediction error.
Common examples:
nn.CrossEntropyLoss()→ classificationnn.MSELoss()→ regressionnn.BCEWithLogitsLoss()→ binary classification
criterion = nn.CrossEntropyLoss()
Role:
- Produces a scalar loss
- Defines training objective
3. Define Optimizer
Optimizer updates parameters using gradients.
Common optimizers:
torch.optim.SGDtorch.optim.Adamtorch.optim.AdamW(standard for Transformers)
import torch.optim as optim
optimizer = optim.AdamW(model.parameters(), lr=1e-3)
4. Data Pipeline (Dataset + DataLoader)
Encapsulates batching, shuffling, multiprocessing.
from torch.utils.data import DataLoader
train_loader = DataLoader(
dataset,
batch_size=32,
shuffle=True,
num_workers=4
)
Key roles:
- Mini-batch training
- Randomization
- Efficient I/O
5. Training Loop (Core Procedure)
This is the heart of PyTorch training.
Canonical Pattern
for epoch in range(num_epochs):
model.train() # Enable training mode
for x, y in train_loader:
x, y = x.to(device), y.to(device)
# (1) Forward pass
logits = model(x)
# (2) Compute loss
loss = criterion(logits, y)
# (3) Zero gradients
optimizer.zero_grad()
# (4) Backward pass (autograd)
loss.backward()
# (5) Parameter update
optimizer.step()
What Each Step Really Does
| Step | Operation | What Happens Internally |
|---|---|---|
| Forward | model(x) |
Builds computation graph |
| Loss | criterion(...) |
Adds loss node |
| Zero grad | zero_grad() |
Clears accumulated grads |
| Backward | loss.backward() |
Reverse-mode autodiff |
| Step | optimizer.step() |
Updates parameters |
6. Autograd: Gradient Computation
PyTorch uses dynamic computation graphs.
Key mechanics:
- Tensors with
requires_grad=Truetracked - Graph built during forward
.backward()traverses graph in reverse
Example:
loss.backward()
print(model.fc.weight.grad) # ∂loss / ∂weight
Important:
- Gradients accumulate by default
- Must call
zero_grad()
7. Training vs Evaluation Mode
Some layers behave differently:
- Dropout
- BatchNorm
model.train() # Enable dropout, batchnorm updates
model.eval() # Disable dropout, use running stats
8. Validation Loop (No Gradients)
Validation should not build graphs:
model.eval()
with torch.no_grad():
for x, y in val_loader:
logits = model(x)
val_loss = criterion(logits, y)
Why:
- Saves memory
- Faster
- Prevents accidental gradient updates
9. Learning Rate Scheduling (Optional but Common)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer, T_max=100
)
scheduler.step()
Used for:
- Stabilizing training
- Better convergence
10. Checkpointing
torch.save({
"model": model.state_dict(),
"optimizer": optimizer.state_dict(),
}, "checkpoint.pt")
Load:
ckpt = torch.load("checkpoint.pt")
model.load_state_dict(ckpt["model"])
optimizer.load_state_dict(ckpt["optimizer"])
Minimal Mental Model (Interview-Style)
PyTorch training = Forward → Loss → Backward → Step
Formally:
[ \theta \leftarrow \theta - \eta \nabla_\theta \mathcal{L}(f_\theta(x), y) ]
Common Pitfalls (Important)
| Mistake | Consequence |
|---|---|
Forget zero_grad() |
Gradient accumulation bugs |
Forget model.train() |
Dropout/BN wrong behavior |
No torch.no_grad() in val |
Memory leak |
| Wrong loss for logits | Training instability |
Page Source