Skip to content

6.7.3 Training Monitoring and Diagnosis

  • Classify underfitting, overfitting, and unstable training from curves.
  • Inspect prediction distribution and gradient norm.
  • Use a repeatable troubleshooting order.
  • Decide one next experiment from evidence.
  • Know what to save in every training run.

Training curve diagnosis chart

The first question is not “which model should I switch to?” It is:

what symptom is visible in the training evidence?
SymptomLikely directionFirst check
train and val both badunderfittinglearning rate, model capacity, data quality
train improves but val worsensoverfittingregularization, data split, augmentation
loss jumps up and downinstabilitylearning rate, batch size, gradients
predictions mostly one classcollapse or data issuelabels, class balance, output layer
metrics suddenly changepipeline bug or distribution shiftdata loader, preprocessing, validation split

Training diagnosis dashboard troubleshooting map

histories = {
"underfit_case": ([1.20, 1.08, 0.99, 0.94], [1.25, 1.13, 1.04, 1.02]),
"overfit_case": ([0.90, 0.55, 0.31, 0.18], [0.92, 0.63, 0.68, 0.82]),
"unstable_case": ([0.80, 1.65, 0.72, 1.48], [0.85, 1.70, 0.79, 1.55]),
}
def diagnose(train, val):
train_drop = train[0] - train[-1]
val_best = min(val)
if max(train) - min(train) > 0.8:
return "possible_lr_too_high_or_unstable_batches"
if train[-1] > 0.8 and val[-1] > 0.8:
return "possible_underfitting"
if train_drop > 0.3 and val[-1] > val_best + 0.1:
return "possible_overfitting"
return "need_more_signals"
print("curve_diagnosis")
for name, (train, val) in histories.items():
print(name, "->", diagnose(train, val))

Expected output:

Terminal window
curve_diagnosis
underfit_case -> possible_underfitting
overfit_case -> possible_overfitting
unstable_case -> possible_lr_too_high_or_unstable_batches

This code is not a replacement for judgment. It teaches the first habit: classify the visible symptom before changing the system.

Lab 2: Check Gradients and Prediction Distribution

Section titled “Lab 2: Check Gradients and Prediction Distribution”

Loss alone is not enough. A model can have a reasonable loss while predicting the same class for every sample.

import torch
from torch import nn
torch.manual_seed(5)
X = torch.randn(12, 3)
y = torch.tensor([0, 1, 0, 1, 1, 0, 0, 1, 0, 1, 1, 0])
model = nn.Sequential(nn.Linear(3, 4), nn.ReLU(), nn.Linear(4, 2))
loss_fn = nn.CrossEntropyLoss()
logits = model(X)
loss = loss_fn(logits, y)
loss.backward()
grad_norm = 0.0
for p in model.parameters():
if p.grad is not None:
grad_norm += p.grad.pow(2).sum().item()
grad_norm = grad_norm**0.5
preds = logits.argmax(dim=1)
counts = torch.bincount(preds, minlength=2)
confidence = torch.softmax(logits, dim=1).max(dim=1).values.mean().item()
print("training_signals")
print("loss:", round(loss.item(), 3))
print("grad_norm:", round(grad_norm, 3))
print("pred_counts:", counts.tolist())
print("avg_confidence:", round(confidence, 3))

Expected output:

Terminal window
training_signals
loss: 0.687
grad_norm: 0.445
pred_counts: [0, 12]
avg_confidence: 0.69

Training diagnosis signal result map

The important signal is pred_counts: [0, 12]. This initial model predicts class 1 for every sample. During real training, if this pattern persists, check class imbalance, labels, output layer shape, and loss setup.

Use this order before changing the architecture:

  1. Curves: train/val loss and metrics.
  2. Predictions: class counts, confidence, best and worst examples.
  3. Gradients: norm, NaN/Inf, exploding or near-zero updates.
  4. Data: labels, leakage, split, preprocessing, augmentation.
  5. Hyperparameters: learning rate, batch size, regularization.
  6. Model: capacity, architecture, initialization.

This order is deliberately boring. That is why it works.

ArtifactWhy save it
train/val curvesdiagnose trend and overfitting
config and seedreproduce the run
best checkpointcompare without retraining
prediction samplesinspect failures directly
gradient statisticscatch instability early
data split versiondetect leakage or drift

Every diagnosis should leave a symptom-to-action note:

Curve Pattern
underfit, overfit, unstable, collapse, or unclear
Prediction Signal
class counts and confidence
Gradient Signal
norm plus NaN/Inf check
Data Check
labels, split, leakage, preprocessing
Chosen Action
one targeted next experiment
Success Rule
what metric or artifact will prove the fix worked
DiagnosisFirst action
possible underfittingraise LR within reason, train longer, increase capacity, inspect labels
possible overfittingearly stopping, stronger regularization, more data, augmentation
unstable traininglower LR, increase batch, add gradient clipping
prediction collapsecheck class balance, target encoding, output shape, loss function
data pipeline issueprint sample batches, verify preprocessing and split
MistakeFix
only reading final accuracysave full curves and best epoch
changing model before checking datainspect sample batches and labels first
ignoring prediction distributionprint class counts or output summaries
assuming low train loss means successcompare validation and failure cases
making multiple fixes at oncechoose one action and verify the result
  1. Add a good_case history where train and val both improve.
  2. Modify Lab 2 so the model has 3 classes. What changes in torch.bincount?
  3. Add a check that reports has_nan_grad.
  4. Write one next action for each diagnosis in Lab 1.
  5. Save a CSV-style log with epoch,train_loss,val_loss,val_acc.
Reference implementation and walkthrough
  1. A good_case should show train and validation loss both decreasing, with validation accuracy improving or staying stable. It is the reference pattern for healthy training.
  2. With 3 classes, torch.bincount(preds, minlength=3) should report three bins. The classifier output and labels must also use three classes.
  3. has_nan_grad should scan parameter gradients after backward(). If any gradient is non-finite, stop and inspect learning rate, loss, input scale, and labels.
  4. Underfitting needs capacity, time, or learning-rate checks; overfitting needs regularization or more data; instability needs smaller LR or clipping; collapse needs label/output/loss checks.
  5. A CSV log should let you reconstruct the curve later. At minimum, each row needs epoch, train loss, validation loss, and validation accuracy.
  • Symptoms are not root causes.
  • Curves are the first diagnostic surface.
  • Predictions and gradients reveal failures that loss can hide.
  • Data checks come before architecture changes.
  • A good diagnosis ends with one targeted next experiment.