Soup's Blog

Back

模型训练(三)激活值检查点Activation CheckpointBlur image

激活值检查点(Activation Checkpointing)是一种用于减少深度神经网络训练过程中显存占用的技术。在标准的反向传播中,所有中间层的激活值(即前向传播的输出)都需要保存在显存中,以便计算梯度时使用。对于深层网络或大批次训练,这些激活值会消耗大量显存。

激活值检查点的核心思想是:不在前向传播时保存所有中间激活值,而只保存部分关键层的输出(称为“检查点”);在反向传播需要某段中间激活时,临时从最近的检查点重新执行前向计算来恢复,用时间换空间。

这种方法显著降低了显存需求(通常可节省30%~70%),代价是增加了少量计算开销(因为部分前向过程需重复执行)。它特别适用于训练非常深的模型(如Transformer、ResNet-152等)或在有限显存设备上进行大模型训练。PyTorch通过 torch.utils.checkpoint.checkpoint 提供了对该技术的原生支持。

首先加载训练数据和标签:

x,y=sklearn.datasets.load_digits(return_X_y=True)
x=torch.tensor(x/16).float().cuda() # FP32
y=torch.tensor(y).long().cuda()
print(x.shape,x.dtype)
print(y.shape,y.dtype)
bash

然后定义模型:

注意,在forward方法中:

if checkpoint:
    out=torch.utils.checkpoint.checkpoint(lambda x:self.fc_middle(x),out,use_reentrant=False)
else:
    out=self.fc_middle(out)
bash

这段代码实现了选择性地对模型中间层(fc_middle)启用激活值检查点(Activation Checkpointing),以在训练时节省显存。

具体来说:

checkpoint=False(默认情况)时,直接执行 out = self.fc_middle(out),即按常规方式完成前向传播,所有中间激活值都会被保存在显存中,供后续反向传播使用。 当 checkpoint=True 时,不直接计算 fc_middle 的输出并保留其所有中间结果,而是通过 torch.utils.checkpoint.checkpoint 包装该计算过程。PyTorch 会不在前向传播中保存 fc_middle 内部各层的激活值,而只保留输入 out;在反向传播需要这些中间激活时,PyTorch 会临时重新运行 fc_middle 的前向计算(从保存的输入开始)来重建所需激活值。

接下来开始训练,指定checkpoint=True

输出结果(消耗内存约43MB):

iter=10000 loss=0.020305516198277473 peak_cuda_mem=45377536 Bytes
iter=20000 loss=0.003331078216433525 peak_cuda_mem=45377536 Bytes
iter=30000 loss=0.0013336377451196313 peak_cuda_mem=45377536 Bytes
bash

指定checkpoint=False时:

输出结果(消耗内存约48MB):

iter=10000 loss=0.019843708723783493 peak_cuda_mem=50066432 Bytes
iter=20000 loss=0.0028206356801092625 peak_cuda_mem=50066432 Bytes
iter=30000 loss=0.0011666719801723957 peak_cuda_mem=50066432 Bytes
bash
模型训练(三)激活值检查点Activation Checkpoint
http://www.soupcola.top/blog/distri_trainning/distri_trainning-3
Author Soup Cola
Published at 2026年2月6日