

AMP(Automatic Mixed Precision,自动混合精度训练)是一种在深度学习训练中加速计算并节省显存的技术,同时几乎不损失模型精度。
AMP(自动混合精度训练)的工作流程如下: 在训练过程中,AMP 通过 autocast 自动将模型的前向计算切换到半精度(FP16),以加速运算并减少显存占用;同时,为防止 FP16 表示范围有限导致梯度下溢(变为零),它使用 GradScaler 对损失值进行放大(如乘以1024),使反向传播产生的梯度落在 FP16 的有效范围内;随后,这些缩放后的梯度被转换回单精度(FP32),并与优化器中维护的 FP32 主权重结合,在去除缩放因子后完成参数更新。整个过程由框架自动管理哪些操作使用 FP16、哪些必须保留 FP32(如 BatchNorm 或 softmax),从而在几乎不损失模型精度的前提下,显著提升训练速度并降低显存消耗。
流程图如下:

首先加载手写数字数据集(Digits Dataset)并将其转换为 PyTorch 张量,同时放到 GPU 上进行后续深度学习训练或推理。
x,y=sklearn.datasets.load_digits(return_X_y=True)
x=torch.tensor(x/16).float().cuda() # FP32
y=torch.tensor(y).long().cuda()
print(x.shape,x.dtype)
print(y.shape,y.dtype)
# torch.Size([1797, 64]) torch.float32
# torch.Size([1797]) torch.int64bash其中x/16是进行归一化操作,原始像素值范围是 0~16(因为 8×8 图像来自 sklearn 的预处理版本,最大值为 16)
定义一个网络:
class MLP(torch.nn.Module):
def __init__(self,input_size,hidden_size,output_size):
super(MLP,self).__init__()
self.fc1=torch.nn.Linear(input_size, hidden_size)
self.fc2=torch.nn.Linear(hidden_size, output_size)
def forward(self,x):
out=self.fc1(x)
out=torch.relu(out)
out=self.fc2(out)
return outbash然后进行模型定义、损失函数、优化器和自动混合精度(AMP)训练组件的初始化:
model=MLP(input_size=64,hidden_size=256,output_size=10).cuda()
loss_fn=torch.nn.CrossEntropyLoss()
optimizer=torch.optim.SGD(model.parameters(),lr=0.01)
scaler=torch.amp.GradScaler()bash其中scaler = torch.amp.GradScaler()创建一个 梯度缩放器(GradScaler),用于 自动混合精度(AMP)训练。在 FP16(半精度)训练时,防止梯度下溢(underflow → 变成 0)。
tips:下溢是指一个数太小了,小到当前数据类型无法表示,结果被强制变成 0。由于链式法则,所以可能会导致梯度越来越小。
接下来在模型的某一层(model.fc1)上注册一个前向传播钩子(forward hook),用于在第一次前向计算时打印该层的输入、输出和权重的形状与数据类型,便于调试模型的数据流和精度(如 FP16/FP32):
print_once=False
def debug_forward(module,input,output):
global print_once
if not print_once:
print_once=True
print(f'{module}\ninput_shape={input[0].shape} input_dtype={input[0].dtype}\noutput_shape={output.shape} output_dtype={output.dtype}\nweight_shape={module.weight.shape} weight_dtype={module.weight.dtype}')
model.fc1.register_forward_hook(debug_forward)bash打印的内容如下:
Linear(in_features=64, out_features=256, bias=True)
input_shape=torch.Size([1797, 64]) input_dtype=torch.float32
output_shape=torch.Size([1797, 256]) output_dtype=torch.float16
weight_shape=torch.Size([256, 64]) weight_dtype=torch.float32bash它揭示了模型在混合精度训练(AMP)环境下某一层(fc1)的实际运行状态。由output_dtype可知在训练过程中,将float32自动转换成float16进行训练了。模型参数weight_shape始终以 FP32 形式存储(这是 AMP 的标准做法,称为 “master weights”)。
最后开始训练:
iter=0
while True:
optimizer.zero_grad()
with torch.amp.autocast(device_type='cuda',dtype=torch.float16): # FP16 Mix
out=model(x)
loss=loss_fn(out,y)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
iter+=1
if iter%100000==0:
print(f'iter={iter} loss={loss.item()} cuda_mem={torch.cuda.memory_allocated()}Bytes')
if loss.item()<=1e-3:
breakbash其中核心是:
with torch.amp.autocast(device_type='cuda', dtype=torch.float16): # FP16 Mix
out = model(x)
loss = loss_fn(out, y)bash在这个上下文中,PyTorch 会自动为支持的操作选择 FP16 或 FP32 精度:大多数计算(如 Linear, Conv, MatMul)→ 使用 FP16(更快、更省显存);数值敏感操作(如 Softmax, Log, BatchNorm)→ 自动回退到 FP32(保精度)。
再进行梯度缩放:
scaler.scale(loss).backward()bashtips:在 FP16 中,梯度可能太小而下溢(underflow → 变成 0)。解决方案:先把 loss 放大(如 ×1024),使得反向传播时梯度也放大,落在 FP16 可表示范围内
scaler.scale(loss)会返回一个 scaled loss(FP16),调用 .backward() 时,计算的是 放大后的梯度。
最后参数更新:
scaler.step(optimizer)
scaler.update()bash其中scaler.step(optimizer)先检查梯度是否包含 NaN/Inf(FP16 容易溢出)。如果正常,则自动将梯度从 FP16 转回 FP32,并除以缩放因子,再调用 optimizer.step()。如果检测到异常(如梯度过大导致 Inf),则跳过本次更新(避免破坏模型)。
scaler.update()动态调整下一次的缩放因子(scale)。如果连续几次都无异常 → 尝试增大 scale(更激进)。如果出现 Inf/NaN → 减小 scale(更保守)。
混合精度训练:
Linear(in_features=64, out_features=256, bias=True)
input_shape=torch.Size([1797, 64]) input_dtype=torch.float32
output_shape=torch.Size([1797, 256]) output_dtype=torch.float16
weight_shape=torch.Size([256, 64]) weight_dtype=torch.float32
iter=100000 loss=0.00813820119947195 cuda_mem=17707008Bytes
iter=200000 loss=0.0030544002074748278 cuda_mem=17707008Bytes
iter=300000 loss=0.0017446494894102216 cuda_mem=17707008Bytes
iter=400000 loss=0.0011826277477666736 cuda_mem=17707008Bytesbash不用混合精度训练:
Linear(in_features=64, out_features=256, bias=True)
input_shape=torch.Size([1797, 64]) input_dtype=torch.float32
output_shape=torch.Size([1797, 256]) output_dtype=torch.float32
weight_shape=torch.Size([256, 64]) weight_dtype=torch.float32
iter=100000 loss=0.008329853415489197 cuda_mem=17742848Bytes
iter=200000 loss=0.003189701121300459 cuda_mem=17742848Bytes
iter=300000 loss=0.0018277550116181374 cuda_mem=17742848Bytes
iter=400000 loss=0.001241791993379593 cuda_mem=17742848Bytesbash可以看到,使用AMP混合精度训练所占用的显存明显更低!