

激活值检查点(Activation Checkpointing)是一种用于减少深度神经网络训练过程中显存占用的技术。在标准的反向传播中,所有中间层的激活值(即前向传播的输出)都需要保存在显存中,以便计算梯度时使用。对于深层网络或大批次训练,这些激活值会消耗大量显存。
激活值检查点的核心思想是:不在前向传播时保存所有中间激活值,而只保存部分关键层的输出(称为“检查点”);在反向传播需要某段中间激活时,临时从最近的检查点重新执行前向计算来恢复,用时间换空间。
这种方法显著降低了显存需求(通常可节省30%~70%),代价是增加了少量计算开销(因为部分前向过程需重复执行)。它特别适用于训练非常深的模型(如Transformer、ResNet-152等)或在有限显存设备上进行大模型训练。PyTorch通过 torch.utils.checkpoint.checkpoint 提供了对该技术的原生支持。
首先加载训练数据和标签:
x,y=sklearn.datasets.load_digits(return_X_y=True)
x=torch.tensor(x/16).float().cuda() # FP32
y=torch.tensor(y).long().cuda()
print(x.shape,x.dtype)
print(y.shape,y.dtype)bash然后定义模型:
class MLP(torch.nn.Module):
def __init__(self,input_size,hidden_sizes,output_size):
super(MLP,self).__init__()
self.fc_first=torch.nn.Linear(input_size, hidden_sizes[0])
fc_middle=[]
for i in range(1,len(hidden_sizes)-1):
fc_middle.append(torch.nn.Linear(hidden_sizes[i-1],hidden_sizes[i]))
fc_middle.append(torch.nn.ReLU())
self.fc_middle=torch.nn.Sequential(*fc_middle)
self.fc_final=torch.nn.Linear(hidden_sizes[-1], output_size)
def forward(self,x,checkpoint=False):
out=self.fc_first(x)
out=torch.relu(out)
if checkpoint:
out=torch.utils.checkpoint.checkpoint(lambda x:self.fc_middle(x),out,use_reentrant=False)
else:
out=self.fc_middle(out)
out=self.fc_final(out)
return outbash注意,在forward方法中:
if checkpoint:
out=torch.utils.checkpoint.checkpoint(lambda x:self.fc_middle(x),out,use_reentrant=False)
else:
out=self.fc_middle(out)bash这段代码实现了选择性地对模型中间层(fc_middle)启用激活值检查点(Activation Checkpointing),以在训练时节省显存。
具体来说:
当 checkpoint=False(默认情况)时,直接执行 out = self.fc_middle(out),即按常规方式完成前向传播,所有中间激活值都会被保存在显存中,供后续反向传播使用。
当 checkpoint=True 时,不直接计算 fc_middle 的输出并保留其所有中间结果,而是通过 torch.utils.checkpoint.checkpoint 包装该计算过程。PyTorch 会不在前向传播中保存 fc_middle 内部各层的激活值,而只保留输入 out;在反向传播需要这些中间激活时,PyTorch 会临时重新运行 fc_middle 的前向计算(从保存的输入开始)来重建所需激活值。
接下来开始训练,指定checkpoint=True:
model=MLP(input_size=64,hidden_sizes=[256,512,512,128],output_size=10).cuda()
loss_fn=torch.nn.CrossEntropyLoss()
optimizer=torch.optim.SGD(model.parameters(),lr=0.01)
iter=0
while True:
optimizer.zero_grad()
torch.cuda.reset_peak_memory_stats() # 👈 重置峰值统计
out = model(x, checkpoint=True)
loss = loss_fn(out, y)
loss.backward()
optimizer.step()
peak_mem = torch.cuda.max_memory_allocated() # 👈 获取本次迭代峰值
iter += 1
if iter % 10000 == 0:
print(f'iter={iter} loss={loss.item()} peak_cuda_mem={peak_mem} Bytes')
if loss.item() <= 1e-3:
breakbash输出结果(消耗内存约43MB):
iter=10000 loss=0.020305516198277473 peak_cuda_mem=45377536 Bytes
iter=20000 loss=0.003331078216433525 peak_cuda_mem=45377536 Bytes
iter=30000 loss=0.0013336377451196313 peak_cuda_mem=45377536 Bytesbash指定checkpoint=False时:
model=MLP(input_size=64,hidden_sizes=[256,512,512,128],output_size=10).cuda()
loss_fn=torch.nn.CrossEntropyLoss()
optimizer=torch.optim.SGD(model.parameters(),lr=0.01)
iter=0
while True:
optimizer.zero_grad()
torch.cuda.reset_peak_memory_stats() # 👈 重置峰值统计
out = model(x, checkpoint=False)
loss = loss_fn(out, y)
loss.backward()
optimizer.step()
peak_mem = torch.cuda.max_memory_allocated() # 👈 获取本次迭代峰值
iter += 1
if iter % 10000 == 0:
print(f'iter={iter} loss={loss.item()} peak_cuda_mem={peak_mem} Bytes')
if loss.item() <= 1e-3:
breakbash输出结果(消耗内存约48MB):
iter=10000 loss=0.019843708723783493 peak_cuda_mem=50066432 Bytes
iter=20000 loss=0.0028206356801092625 peak_cuda_mem=50066432 Bytes
iter=30000 loss=0.0011666719801723957 peak_cuda_mem=50066432 Bytesbash