Soup's Blog

Back

模型训练(一)分布式训练之DDPBlur image

AMP(Automatic Mixed Precision,自动混合精度训练)是一种在深度学习训练中加速计算并节省显存的技术,同时几乎不损失模型精度。

AMP(自动混合精度训练)的工作流程如下: 在训练过程中,AMP 通过 autocast 自动将模型的前向计算切换到半精度(FP16),以加速运算并减少显存占用;同时,为防止 FP16 表示范围有限导致梯度下溢(变为零),它使用 GradScaler 对损失值进行放大(如乘以1024),使反向传播产生的梯度落在 FP16 的有效范围内;随后,这些缩放后的梯度被转换回单精度(FP32),并与优化器中维护的 FP32 主权重结合,在去除缩放因子后完成参数更新。整个过程由框架自动管理哪些操作使用 FP16、哪些必须保留 FP32(如 BatchNorm 或 softmax),从而在几乎不损失模型精度的前提下,显著提升训练速度并降低显存消耗。

流程图如下: 在这里插入图片描述

首先加载手写数字数据集(Digits Dataset)并将其转换为 PyTorch 张量,同时放到 GPU 上进行后续深度学习训练或推理。

x,y=sklearn.datasets.load_digits(return_X_y=True)
x=torch.tensor(x/16).float().cuda() # FP32
y=torch.tensor(y).long().cuda()
print(x.shape,x.dtype)
print(y.shape,y.dtype)
# torch.Size([1797, 64]) torch.float32
# torch.Size([1797]) torch.int64
bash

其中x/16是进行归一化操作,原始像素值范围是 0~16(因为 8×8 图像来自 sklearn 的预处理版本,最大值为 16)

定义一个网络:

class MLP(torch.nn.Module):
    def __init__(self,input_size,hidden_size,output_size):
        super(MLP,self).__init__()
        self.fc1=torch.nn.Linear(input_size, hidden_size)
        self.fc2=torch.nn.Linear(hidden_size, output_size)

    def forward(self,x):
        out=self.fc1(x)
        out=torch.relu(out)
        out=self.fc2(out)
        return out
bash

然后进行模型定义、损失函数、优化器和自动混合精度(AMP)训练组件的初始化:

model=MLP(input_size=64,hidden_size=256,output_size=10).cuda()
loss_fn=torch.nn.CrossEntropyLoss()
optimizer=torch.optim.SGD(model.parameters(),lr=0.01)
scaler=torch.amp.GradScaler()
bash

其中scaler = torch.amp.GradScaler()创建一个 梯度缩放器(GradScaler),用于 自动混合精度(AMP)训练。在 FP16(半精度)训练时,防止梯度下溢(underflow → 变成 0)。 tips:下溢是指一个数太小了,小到当前数据类型无法表示,结果被强制变成 0。由于链式法则,所以可能会导致梯度越来越小。

接下来在模型的某一层(model.fc1)上注册一个前向传播钩子(forward hook),用于在第一次前向计算时打印该层的输入、输出和权重的形状与数据类型,便于调试模型的数据流和精度(如 FP16/FP32):

print_once=False
def debug_forward(module,input,output):
    global print_once
    if not print_once:
        print_once=True
        print(f'{module}\ninput_shape={input[0].shape} input_dtype={input[0].dtype}\noutput_shape={output.shape} output_dtype={output.dtype}\nweight_shape={module.weight.shape} weight_dtype={module.weight.dtype}')
    
model.fc1.register_forward_hook(debug_forward)
bash

打印的内容如下:

Linear(in_features=64, out_features=256, bias=True)
input_shape=torch.Size([1797, 64]) input_dtype=torch.float32
output_shape=torch.Size([1797, 256]) output_dtype=torch.float16
weight_shape=torch.Size([256, 64]) weight_dtype=torch.float32
bash

它揭示了模型在混合精度训练(AMP)环境下某一层(fc1)的实际运行状态。由output_dtype可知在训练过程中,将float32自动转换成float16进行训练了。模型参数weight_shape始终以 FP32 形式存储(这是 AMP 的标准做法,称为 “master weights”)。

最后开始训练:

iter=0
while True:
    optimizer.zero_grad()
    with torch.amp.autocast(device_type='cuda',dtype=torch.float16): # FP16 Mix
        out=model(x)
        loss=loss_fn(out,y)
    scaler.scale(loss).backward()
    scaler.step(optimizer)
    scaler.update()
    iter+=1
    if iter%100000==0:
        print(f'iter={iter} loss={loss.item()} cuda_mem={torch.cuda.memory_allocated()}Bytes')
    if loss.item()<=1e-3:
        break
bash

其中核心是:

with torch.amp.autocast(device_type='cuda', dtype=torch.float16):  # FP16 Mix
    out = model(x)
    loss = loss_fn(out, y)
bash

在这个上下文中,PyTorch 会自动为支持的操作选择 FP16 或 FP32 精度:大多数计算(如 Linear, Conv, MatMul)→ 使用 FP16(更快、更省显存);数值敏感操作(如 Softmax, Log, BatchNorm)→ 自动回退到 FP32(保精度)。

再进行梯度缩放:

scaler.scale(loss).backward()
bash

tips:在 FP16 中,梯度可能太小而下溢(underflow → 变成 0)。解决方案:先把 loss 放大(如 ×1024),使得反向传播时梯度也放大,落在 FP16 可表示范围内

scaler.scale(loss)会返回一个 scaled loss(FP16),调用 .backward() 时,计算的是 放大后的梯度。

最后参数更新:

scaler.step(optimizer)
scaler.update()
bash

其中scaler.step(optimizer)先检查梯度是否包含 NaN/InfFP16 容易溢出)。如果正常,则自动将梯度从 FP16 转回 FP32,并除以缩放因子,再调用 optimizer.step()。如果检测到异常(如梯度过大导致 Inf),则跳过本次更新(避免破坏模型)。 scaler.update()动态调整下一次的缩放因子(scale)。如果连续几次都无异常 → 尝试增大 scale(更激进)。如果出现 Inf/NaN → 减小 scale(更保守)。

混合精度训练:

Linear(in_features=64, out_features=256, bias=True)
input_shape=torch.Size([1797, 64]) input_dtype=torch.float32
output_shape=torch.Size([1797, 256]) output_dtype=torch.float16
weight_shape=torch.Size([256, 64]) weight_dtype=torch.float32
iter=100000 loss=0.00813820119947195 cuda_mem=17707008Bytes
iter=200000 loss=0.0030544002074748278 cuda_mem=17707008Bytes
iter=300000 loss=0.0017446494894102216 cuda_mem=17707008Bytes
iter=400000 loss=0.0011826277477666736 cuda_mem=17707008Bytes
bash

不用混合精度训练:

Linear(in_features=64, out_features=256, bias=True)
input_shape=torch.Size([1797, 64]) input_dtype=torch.float32
output_shape=torch.Size([1797, 256]) output_dtype=torch.float32
weight_shape=torch.Size([256, 64]) weight_dtype=torch.float32
iter=100000 loss=0.008329853415489197 cuda_mem=17742848Bytes
iter=200000 loss=0.003189701121300459 cuda_mem=17742848Bytes
iter=300000 loss=0.0018277550116181374 cuda_mem=17742848Bytes
iter=400000 loss=0.001241791993379593 cuda_mem=17742848Bytes
bash

可以看到,使用AMP混合精度训练所占用的显存明显更低!

模型训练(一)分布式训练之DDP
http://www.soupcola.top/blog/distri_trainning/distri_trainning-2
Author Soup Cola
Published at 2026年2月6日