Skip to content

OOM 排查指南

Relax RL 训练中 CUDA 显存不足 (OOM) 错误的诊断与解决实践指南。本文提到的所有参数均可在配置参考手册中查阅。


诊断 OOM

步骤 1:确定 OOM 发生阶段

Relax 中的 OOM 错误通常发生在以下阶段:

阶段表现常见原因
模型初始化启动时 OOM模型对于可用 GPU 来说太大
前向传播(训练)train_actor 期间 OOM--max-tokens-per-gpu 过高或未启用重计算
Log probs 计算train_log_probs 期间 OOM长序列消耗过多激活显存
权重同步update_weights 期间 OOM权重缓冲区对于剩余 GPU 显存来说太大
NCCL 通信all-reduce/all-gather 内部 OOM通信缓冲区显存不足

步骤 2:捕获内存快照

Relax 提供内置的内存分析工具来捕获详细的分配信息。

PyTorch 内存快照

记录内存分配历史并自动导出快照:

bash
python3 relax/entrypoints/train.py \
    --memory-snapshot-dir /path/to/snapshots \
    --memory-snapshot-num-steps 3 \
    --memory-recorder torch \
    # ... 其他参数

这会从一开始就记录内存分配历史,并在指定步数后导出快照。如果发生 OOM,会在故障点自动导出快照。

使用 PyTorch Memory Visualizer 可视化快照:

bash
python -m torch.utils.viz._memory_viz trace_plot /path/to/snapshots/*.pickle -o memory_trace.html

Memray 分析器

如需 CPU+GPU 内存分析,使用 memray 记录器:

bash
python3 relax/entrypoints/train.py \
    --memory-snapshot-dir /path/to/snapshots \
    --memory-snapshot-num-steps 3 \
    --memory-recorder memray \
    # ... 其他参数

WARNING

使用 memray 时必须设置 --memory-snapshot-num-steps

步骤 3:启用 NCCL 通信内存检查

当 OOM 发生在 NCCL 集合通信操作内部时,标准调用栈可能无法显示可用显存量。--enable-cuda-memory-check 标志在每个底层 NCCL 调用周围添加内存监控:

bash
python3 relax/entrypoints/train.py \
    --enable-cuda-memory-check \
    # ... 其他参数

启用后:

  • 每次 NCCL 调用(all-reduce、all-gather、broadcast 等)前检查可用 GPU 显存。
  • 如果可用显存低于 5 GB,自动调用 torch.cuda.empty_cache() 回收碎片化显存。
  • 如果 NCCL 调用失败,内存信息会附加到异常中用于诊断。

TIP

--enable-cuda-memory-check 会导致约 20% 的训练性能劣化。建议仅在调试时使用,不用于生产训练。

步骤 4:使用 Profiler 内存追踪

进行更细粒度的分析时,在 PyTorch Profiler 中启用内存追踪:

bash
python3 relax/entrypoints/train.py \
    --profile-target train_overall \
    --profile-step-start 2 \
    --profile-step-end 4 \
    --profile-with-memory \
    --use-tensorboard \
    --tb-project-name /path/to/tb_logs \
    # ... 其他参数

--profile-with-memory 标志在 profiler trace 中记录 CUDA 显存分配/释放,可在 TensorBoard 的 Memory 视图中查看。


常见解决方案

1. 启用激活重计算

减少训练显存最有效的方法。以计算时间换取显存:

bash
--recompute-granularity full \
--recompute-method uniform \
--recompute-num-layers 1

所有标准的 Relax 训练脚本都使用此配置,强烈推荐启用。

2. 降低 Max Tokens Per GPU

降低 --max-tokens-per-gpu 以减少每个 micro-batch 中打包的 Token 数量:

bash
# 之前(OOM)
--max-tokens-per-gpu 12288

# 调整后(降低)
--max-tokens-per-gpu 8192

如需单独控制 log probs 计算的内存预算:

bash
--log-probs-max-tokens-per-gpu 8192

3. 启用动态批处理防止 OOM

使用固定的 --micro-batch-size 时,一批异常长的序列可能超出 GPU 显存。动态批处理通过限制每个 micro-batch 的总 Token 数,使显存使用可预测,从而防止因变长输入导致的 OOM:

bash
--use-dynamic-batch-size \
--max-tokens-per-gpu 8192

从保守的 --max-tokens-per-gpu 值开始,逐步增大。该参数替代 --micro-batch-size——启用动态批处理后,micro-batch 大小会根据 Token 预算自动确定。

4. 使用固定 Micro-Batch Size

如果已经在使用动态批处理仍然遇到 OOM,可能是 --max-tokens-per-gpu 设置过高。也可以切换到固定的 micro-batch size 并使用较短的序列:

bash
# 去掉 --use-dynamic-batch-size 并显式设置
--micro-batch-size 1

5. 启用优化器 CPU Offload

将优化器状态(Adam 动量)移到 CPU 内存。对大模型(30B+)至关重要:

bash
--optimizer-cpu-offload

如需更好的 CPU offload 性能,重叠数据传输:

bash
--optimizer-cpu-offload \
--overlap-cpu-optimizer-d2h-h2d

6. 重计算损失函数

通过重计算损失函数而非缓存中间结果来节省显存:

bash
--recompute-loss-function

7. 分块计算 Log Probs

将 log probs 计算拆分为更小的块以降低峰值显存:

bash
--log-probs-chunk-size 4

值为 -1(默认)时一次性全部计算。更小的值使用更少的显存但耗时更长。

8. 减小权重更新缓冲区

对于参数量大的 MoE 模型,权重更新缓冲区可能消耗大量显存。减小缓冲区大小:

bash
# 默认是 512 MB
--update-weight-buffer-size 268435456  # 256 MB

9. 禁用权重备份

权重备份器在主机内存中保留一份模型权重的副本用于恢复。禁用它可以节省主机内存:

bash
--disable-weights-backuper

WARNING

禁用权重备份器意味着训练失败时无法自动恢复权重。

10. 调整训练内存预留

Relax 预留内存以防止碎片化。调整预留量:

bash
# 默认是 1 GB (1073741824 字节)
--train-memory-margin-bytes 536870912  # 512 MB

11. 调整 SGLang 显存比例

在 colocate 模式下,SGLang 和训练共享 GPU 显存。降低 SGLang 的分配比例为训练留出更多空间:

bash
# 默认因场景而异,典型值
--sglang-mem-fraction-static 0.7  # 从 0.8 下调

特定阶段的 OOM

训练前向传播 OOM

症状train_actor 步骤期间 OOM。

排查清单

  1. 启用重计算:--recompute-granularity full --recompute-method uniform --recompute-num-layers 1
  2. 启用 --use-dynamic-batch-size 并设置保守的 --max-tokens-per-gpu
  3. 降低 --max-tokens-per-gpu
  4. 启用 --recompute-loss-function
  5. 尝试 --optimizer-cpu-offload

Log Probs 计算 OOM

症状train_log_probs 步骤期间 OOM。

排查清单

  1. 设置更低的 --log-probs-max-tokens-per-gpu(与训练分开控制)
  2. 使用 --log-probs-chunk-size 4 分块计算
  3. 降低 --max-tokens-per-gpu

权重同步 OOM

症状update_weights(训练到推理的权重传输)期间 OOM。

排查清单

  1. 减小 --update-weight-buffer-size
  2. 在 colocate 模式下降低 --sglang-mem-fraction-static
  3. 尝试 --disable-weights-backuper

NCCL 通信 OOM

症状:NCCL 调用内部 OOM,调用栈不透明。

排查清单

  1. 启用 --enable-cuda-memory-check 获取故障点的详细内存信息
  2. 使用 --memory-snapshot-dir--memory-recorder torch 捕获内存快照
  3. 如果内存预留过多,减小 --train-memory-margin-bytes
  4. 如果碎片化是问题所在,增大 --train-memory-margin-bytes

快速参考

目标参数
减少激活显存--recompute-granularity full --recompute-method uniform --recompute-num-layers 1
限制每批次 Token 数(防止 OOM)--use-dynamic-batch-size --max-tokens-per-gpu <值>
减少每批次显存--max-tokens-per-gpu <更小的值>
将优化器移到 CPU--optimizer-cpu-offload
重计算损失函数--recompute-loss-function
分块计算 log probs--log-probs-chunk-size 4
减少权重同步显存--update-weight-buffer-size <更小的值>
节省主机内存--disable-weights-backuper
调试 NCCL OOM--enable-cuda-memory-check
捕获内存快照--memory-snapshot-dir <路径> --memory-recorder torch
分析显存使用--profile-with-memory
调整内存预留--train-memory-margin-bytes <字节数>
SGLang 显存(colocate)--sglang-mem-fraction-static <比例>

下一步

基于 Apache 2.0 许可发布