Skip to main content

6.1.7 权重初始化

本节定位

初始化决定神经网络一开始能不能带着可用信号进入训练。你通常先用 PyTorch 默认初始化,但也要知道训练异常时怎样检查 Xavier、He、全零、过小和过大的初始化。

学习目标

  • 解释为什么全零权重会破坏学习。
  • 知道 Tanh/Sigmoid 常配 Xavier,ReLU 类激活常配 He。
  • 在训练前运行一次信号探针。
  • 在一个小型分类任务上比较初始化选择。
  • 遇到早期训练不稳定时,按顺序排查,而不是随机改参数。

先看图

先别急着背公式,先看初始化要完成什么任务:

权重初始化信号稳定图

从上往下读这张图:

  • 前向信号不能一层层消失;
  • 激活值不能一开始就大面积饱和;
  • 反向梯度还要能传回来;
  • 普通 nn.Linearnn.Conv2d 模型,先用 PyTorch 默认值通常是好选择。

最小概念

神经网络训练基本上是这个循环:

  1. 初始化权重;
  2. 前向传播;
  3. 计算损失;
  4. 反向传播;
  5. 优化器更新权重。

如果第 1 步就坏了,后面几步虽然还能运行,但其实是在一个很差的起点上运行。

常见失败很直观:

坏起点会发生什么你会看到什么
全零神经元一直一样loss 不下降
太小信号随深度衰减深层输出接近 0
太大激活饱和或爆炸初始 loss 很大,梯度不稳
初始化和激活不匹配非线性后的尺度不合适训练慢或很脆

两个术语要先认识:

  • fan_in:进入一层的输入特征数。
  • fan_out:离开一层的输出特征数。

初始化公式会用它们来控制每一层的尺度。

Xavier 和 He 一张表记住

第一次学不用先死背公式,先记搭配:

激活函数常用选择PyTorch 函数原因
Tanh / SigmoidXavier,也叫 Glorotnn.init.xavier_normal_尽量平衡输入和输出方差
ReLU / Leaky ReLUHe,也叫 Kaimingnn.init.kaiming_normal_补偿 ReLU 把很多值变成 0
普通 PyTorch 模型但不确定PyTorch 默认值不写手动初始化适合作为第一版 baseline
实用规则

普通新项目不要一开始就手动初始化所有层。先用 PyTorch 默认值,确认学习率和数据流程没问题;如果信号或梯度明显异常,再检查初始化。

实验准备

可以在 Notebook 单元格里运行,也可以保存成 weight_init_lab.py

如果缺包,先安装:

pip install torch scikit-learn

实验 1:训练前检查信号

这个实验把随机数据送进 8 层网络,打印第一层和最后一层的激活统计。目标不是看准确率,而是看信号能不能穿过深层网络。

import torch
import torch.nn as nn

torch.manual_seed(7)


def build_probe(activation):
layers = []
in_features = 32
for _ in range(8):
layer = nn.Linear(in_features, 128)
layers.append(layer)
layers.append(activation())
in_features = 128
return nn.Sequential(*layers)


def apply_init(model, strategy):
for module in model:
if isinstance(module, nn.Linear):
if strategy == "tiny":
nn.init.normal_(module.weight, 0.0, 0.01)
elif strategy == "large":
nn.init.normal_(module.weight, 0.0, 1.0)
elif strategy == "xavier":
nn.init.xavier_normal_(module.weight)
elif strategy == "he":
nn.init.kaiming_normal_(module.weight, mode="fan_in", nonlinearity="relu")
nn.init.zeros_(module.bias)


def probe(strategy, activation_cls):
model = build_probe(activation_cls)
apply_init(model, strategy)
x = torch.randn(512, 32)
stats = []

for layer in model:
x = layer(x)
if isinstance(layer, activation_cls):
stats.append(
{
"mean": x.mean().item(),
"std": x.std().item(),
"zero_ratio": (x == 0).float().mean().item(),
"saturated_ratio": (x.abs() > 0.98).float().mean().item(),
}
)

return stats[0], stats[-1]


print("signal_probe")
for label, strategy, activation in [
("tiny + ReLU", "tiny", nn.ReLU),
("large + Tanh", "large", nn.Tanh),
("Xavier + Tanh", "xavier", nn.Tanh),
("He + ReLU", "he", nn.ReLU),
]:
first, last = probe(strategy, activation)
print(
f"{label:14s} "
f"first_std={first['std']:.4f} "
f"last_std={last['std']:.4f} "
f"last_zero={last['zero_ratio']:.2f} "
f"last_saturated={last['saturated_ratio']:.2f}"
)

预期输出:

signal_probe
tiny + ReLU first_std=0.0337 last_std=0.0000 last_zero=0.52 last_saturated=0.00
large + Tanh first_std=0.9273 last_std=0.9633 last_zero=0.00 last_saturated=0.84
Xavier + Tanh first_std=0.4872 last_std=0.2276 last_zero=0.00 last_saturated=0.00
He + ReLU first_std=0.8304 last_std=0.6937 last_zero=0.49 last_saturated=0.19

这样读结果:

  • tiny + ReLU:最后一层标准差几乎变成 0,深层信号已经衰减。
  • large + Tanh:很多值贴近 -1 或 1,Tanh 的梯度会变弱。
  • Xavier + Tanh:信号尺度更可控。
  • He + ReLU:ReLU 本来就会产生很多 0,但信号还能传到深层。

实验 2:训练一个小分类器

现在把同样想法放进训练里比较。这是一个很小的二分类数据集,所以某些坏起点也可能被救回来。真正要观察的是初始 loss,以及全零初始化是否卡死。

import torch
import torch.nn as nn
from sklearn.datasets import make_moons
from sklearn.model_selection import train_test_split

torch.manual_seed(9)

X, y = make_moons(n_samples=600, noise=0.22, random_state=9)
X = torch.tensor(X, dtype=torch.float32)
y = torch.tensor(y, dtype=torch.long)

train_idx, val_idx = train_test_split(
torch.arange(len(X)),
test_size=0.25,
random_state=9,
stratify=y,
)
X_train, y_train = X[train_idx], y[train_idx]
X_val, y_val = X[val_idx], y[val_idx]


class MoonMLP(nn.Module):
def __init__(self):
super().__init__()
self.net = nn.Sequential(
nn.Linear(2, 64),
nn.ReLU(),
nn.Linear(64, 64),
nn.ReLU(),
nn.Linear(64, 2),
)

def forward(self, x):
return self.net(x)


def apply_init(model, strategy):
if strategy == "default":
return

for module in model.modules():
if isinstance(module, nn.Linear):
if strategy == "zeros":
nn.init.zeros_(module.weight)
elif strategy == "tiny":
nn.init.normal_(module.weight, 0.0, 0.01)
elif strategy == "large":
nn.init.normal_(module.weight, 0.0, 1.0)
elif strategy == "xavier":
nn.init.xavier_normal_(module.weight)
elif strategy == "he":
nn.init.kaiming_normal_(module.weight, mode="fan_in", nonlinearity="relu")
nn.init.zeros_(module.bias)


def accuracy(model, X, y):
with torch.no_grad():
preds = model(X).argmax(dim=1)
return (preds == y).float().mean().item()


def train_once(strategy):
torch.manual_seed(9)
model = MoonMLP()
apply_init(model, strategy)

optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
loss_fn = nn.CrossEntropyLoss()
start_loss = loss_fn(model(X_train), y_train).item()

for _ in range(120):
loss = loss_fn(model(X_train), y_train)
optimizer.zero_grad()
loss.backward()
optimizer.step()

end_loss = loss_fn(model(X_train), y_train).item()
return start_loss, end_loss, accuracy(model, X_val, y_val)


print("training_probe")
for strategy in ["default", "zeros", "tiny", "large", "xavier", "he"]:
start, end, acc = train_once(strategy)
print(f"{strategy:8s} start_loss={start:.3f} end_loss={end:.3f} val_acc={acc:.3f}")

预期输出:

training_probe
default start_loss=0.671 end_loss=0.047 val_acc=0.973
zeros start_loss=0.693 end_loss=0.693 val_acc=0.500
tiny start_loss=0.693 end_loss=0.067 val_acc=0.973
large start_loss=20.040 end_loss=0.068 val_acc=0.980
xavier start_loss=0.696 end_loss=0.046 val_acc=0.980
he start_loss=0.924 end_loss=0.053 val_acc=0.980

权重初始化实验结果图

重点看三件事:

  • zeros 会卡住,因为隐藏神经元从一开始就是彼此的复制品。
  • large 初始 loss 很大,即使这个小模型后来能恢复,也是一种警告。
  • defaultxavierhe 在这里都能工作,这也说明默认值适合做第一版 baseline。

排错清单

如果训练前几轮就明显不对,按这个顺序查:

  1. 数据 shape 对吗?
  2. 目标 dtype 对吗?CrossEntropyLoss 需要 torch.long 类型的类别标签。
  3. 学习率是不是太高?
  4. 激活值是不是大部分为 0、饱和、naninf
  5. 初始化和激活函数是否匹配?

不要靠猜,先用小探针:

with torch.no_grad():
sample = X_train[:32]
out = model(sample)
print(out.mean().item(), out.std().item(), torch.isfinite(out).all().item())

如果输出不是有限数,或者几乎每个值都一样,就把初始化、输入缩放和学习率一起检查。

练习

  1. 把信号探针里的网络深度从 8 改成 20。哪种初始化最先失败?
  2. MoonMLP 里的 ReLU 改成 Tanh。Xavier 会不会更有竞争力?
  3. 把 Adam 改成 lr=0.1 的 SGD。哪种初始化更脆?

小结

  • 初始化是前向信号和反向梯度的起跑条件。
  • 全零权重会破坏对称性,不要用于隐藏层。
  • Xavier 适合 Tanh/Sigmoid;He 适合 ReLU 类激活。
  • PyTorch 默认值通常是第一步的正确选择,但训练异常时要会用信号探针检查。