OOM Troubleshooting
A practical guide to diagnosing and resolving CUDA Out-Of-Memory (OOM) errors during Relax RL training. All parameters mentioned here are documented in the Configuration Reference.
Diagnosing OOM
Step 1: Determine Where OOM Occurs
OOM errors in Relax typically happen in one of these phases:
| Phase | Symptom | Typical Cause |
|---|---|---|
| Model initialization | OOM during startup | Model too large for available GPUs |
| Forward pass (training) | OOM during train_actor | --max-tokens-per-gpu too high or recomputation not enabled |
| Log probs computation | OOM during train_log_probs | Long sequences consuming excessive activation memory |
| Weight synchronization | OOM during update_weights | Weight buffer too large for remaining GPU memory |
| NCCL communication | OOM inside all-reduce/all-gather | Insufficient memory for communication buffers |
Step 2: Capture Memory Snapshots
Relax provides built-in memory profiling tools to capture detailed allocation information.
PyTorch Memory Snapshot
Record memory history and dump snapshots automatically:
python3 relax/entrypoints/train.py \
--memory-snapshot-dir /path/to/snapshots \
--memory-snapshot-num-steps 3 \
--memory-recorder torch \
# ... other argsThis records memory allocation history from the start and dumps a snapshot after the specified number of steps. If an OOM occurs, a snapshot is automatically dumped at the point of failure.
Visualize the snapshot using PyTorch Memory Visualizer:
python -m torch.utils.viz._memory_viz trace_plot /path/to/snapshots/*.pickle -o memory_trace.htmlMemray Profiler
For CPU+GPU memory profiling, use the memray recorder:
python3 relax/entrypoints/train.py \
--memory-snapshot-dir /path/to/snapshots \
--memory-snapshot-num-steps 3 \
--memory-recorder memray \
# ... other argsWARNING
--memory-snapshot-num-steps is required when using memray.
Step 3: Enable NCCL Communication Memory Check
When OOM occurs inside NCCL collective operations, standard stack traces may not show how much memory was available. The --enable-cuda-memory-check flag adds memory monitoring around every low-level NCCL call:
python3 relax/entrypoints/train.py \
--enable-cuda-memory-check \
# ... other argsWhen enabled:
- Before each NCCL call (all-reduce, all-gather, broadcast, etc.), available GPU memory is checked.
- If free memory is below 5 GB,
torch.cuda.empty_cache()is called automatically to reclaim fragmented memory. - If the NCCL call fails, memory information is attached to the exception for diagnosis.
TIP
--enable-cuda-memory-check introduces approximately 20% training performance degradation. Use it during debugging only, not for production training.
Step 4: Use Profiler Memory Tracking
For finer-grained analysis, enable memory tracking in the PyTorch Profiler:
python3 relax/entrypoints/train.py \
--profile-target train_overall \
--profile-step-start 2 \
--profile-step-end 4 \
--profile-with-memory \
--use-tensorboard \
--tb-project-name /path/to/tb_logs \
# ... other argsThe --profile-with-memory flag records CUDA memory allocations/deallocations in the profiler trace, visible in the TensorBoard Memory view.
Common Solutions
1. Enable Activation Recomputation
The most effective way to reduce training memory. Trades compute time for memory:
--recompute-granularity full \
--recompute-method uniform \
--recompute-num-layers 1This is used in all standard Relax training scripts and is highly recommended.
2. Reduce Max Tokens Per GPU
Lower --max-tokens-per-gpu to reduce the token count packed into each micro-batch:
# Before (OOM)
--max-tokens-per-gpu 12288
# After (reduced)
--max-tokens-per-gpu 8192For log probs computation, set a separate budget if needed:
--log-probs-max-tokens-per-gpu 81923. Enable Dynamic Batching to Prevent OOM
With a fixed --micro-batch-size, a batch of unusually long sequences can exceed GPU memory. Dynamic batching caps total tokens per micro-batch, keeping memory usage predictable and preventing OOM from variable-length inputs:
--use-dynamic-batch-size \
--max-tokens-per-gpu 8192Start with a conservative --max-tokens-per-gpu and increase gradually. This replaces --micro-batch-size — when dynamic batching is enabled, micro-batch size is determined automatically based on the token budget.
4. Use Fixed Micro-Batch Size
If you are already using dynamic batching and still experiencing OOM, the --max-tokens-per-gpu value may be too high. Alternatively, you can switch to a fixed micro-batch size with shorter sequences:
# Remove --use-dynamic-batch-size and set explicitly
--micro-batch-size 15. Enable Optimizer CPU Offload
Move optimizer states (Adam moments) to CPU memory. Critical for large models (30B+):
--optimizer-cpu-offloadFor better performance with CPU offload, overlap data transfers:
--optimizer-cpu-offload \
--overlap-cpu-optimizer-d2h-h2d6. Recompute Loss Function
Save memory by recomputing the loss function instead of caching intermediate results:
--recompute-loss-function7. Chunk Log Probs Computation
Split log probs computation into smaller chunks to reduce peak memory:
--log-probs-chunk-size 4A value of -1 (default) computes all at once. Smaller values use less memory but take longer.
8. Reduce Weight Update Buffer
For MoE models with many parameters, the weight update buffer can consume significant memory. Reduce the buffer size:
# Default is 512 MB
--update-weight-buffer-size 268435456 # 256 MB9. Disable Weights Backuper
The weights backuper keeps a copy of model weights in host memory for recovery. Disabling it saves host memory:
--disable-weights-backuperWARNING
Disabling the weights backuper means automatic weight recovery is unavailable if training fails.
10. Adjust Training Memory Margin
Relax reserves memory to prevent fragmentation. Adjust the margin:
# Default is 1 GB (1073741824 bytes)
--train-memory-margin-bytes 536870912 # 512 MB11. Tune SGLang Memory Fraction
In colocate mode, SGLang and training share GPU memory. Reduce SGLang's allocation to leave more room for training:
# Default varies; typical values
--sglang-mem-fraction-static 0.7 # down from 0.8OOM in Specific Phases
Training Forward Pass OOM
Symptoms: OOM during train_actor step.
Checklist:
- Enable recomputation:
--recompute-granularity full --recompute-method uniform --recompute-num-layers 1 - Enable
--use-dynamic-batch-sizewith a conservative--max-tokens-per-gpu - Lower
--max-tokens-per-gpu - Enable
--recompute-loss-function - Try
--optimizer-cpu-offload
Log Probs Computation OOM
Symptoms: OOM during train_log_probs step.
Checklist:
- Set a lower
--log-probs-max-tokens-per-gpu(separate from training) - Use
--log-probs-chunk-size 4to chunk the computation - Lower
--max-tokens-per-gpu
Weight Synchronization OOM
Symptoms: OOM during update_weights (weight transfer from training to inference).
Checklist:
- Reduce
--update-weight-buffer-size - In colocate mode, reduce
--sglang-mem-fraction-static - Try
--disable-weights-backuper
NCCL Communication OOM
Symptoms: OOM inside NCCL calls with opaque stack traces.
Checklist:
- Enable
--enable-cuda-memory-checkto get detailed memory info at failure point - Capture memory snapshot with
--memory-snapshot-dirand--memory-recorder torch - Reduce
--train-memory-margin-bytesif memory is over-reserved - Increase
--train-memory-margin-bytesif fragmentation is the issue
Quick Reference
| Goal | Parameter |
|---|---|
| Reduce activation memory | --recompute-granularity full --recompute-method uniform --recompute-num-layers 1 |
| Cap per-batch tokens (OOM prevention) | --use-dynamic-batch-size --max-tokens-per-gpu <value> |
| Reduce per-batch memory | --max-tokens-per-gpu <lower value> |
| Move optimizer to CPU | --optimizer-cpu-offload |
| Recompute loss | --recompute-loss-function |
| Chunk log probs | --log-probs-chunk-size 4 |
| Reduce weight sync memory | --update-weight-buffer-size <lower value> |
| Save host memory | --disable-weights-backuper |
| Debug NCCL OOM | --enable-cuda-memory-check |
| Capture memory snapshot | --memory-snapshot-dir <path> --memory-recorder torch |
| Profile memory usage | --profile-with-memory |
| Adjust memory margin | --train-memory-margin-bytes <bytes> |
| SGLang memory (colocate) | --sglang-mem-fraction-static <fraction> |
Next Steps
- Performance Tuning — maximize throughput after resolving OOM
- Configuration Reference — full parameter list
- Debugging Guide — isolating training and inference issues
