全异步训练流水线
概述
全异步训练流水线(Fully Async) 是一种以最大化 GPU 利用率为目标的高吞吐 RLHF/RL 训练模式。与 Colocate(同步)模式不同,Fully Async 将 训练(Actor)、推理(Rollout)、前向计算(ActorFwd / Reference) 和 优势计算(Advantages) 部署在独立的 GPU 集群上,服务间通过 TransferQueue 交换数据,通过 DCS(Distributed Checkpoint Service)异步同步权重。
模式对比
| 维度 | Colocate(同步) | Fully Async(全异步) |
|---|---|---|
| GPU 共享 | Actor 与 Rollout 共享同一组 GPU | Actor、Rollout、ActorFwd/Reference 各自拥有独立 GPU |
| 执行模型 | 串行:Rollout 完成 → 切换到 Train → 更新权重 | 全并行:Rollout、Train、前向计算同时进行 |
| 权重同步 | 进程内 tensor 拷贝(同机器) | 跨节点 NCCL broadcast,通过 DCS(Checkpoint Engine) |
| 数据流 | 同样走 TransferQueue,但同步:Rollout 写完整批数据后 Actor 才读取 | 走 TransferQueue + StreamingDataLoader 异步流式传输(生产消费重叠) |
| Staleness | max_staleness=0(严格 On-Policy) | 可配置 max_staleness(允许一定程度 Off-Policy) |
| 角色 | actor, critic, rollout | actor, critic, rollout, advantages, reference, actor_fwd |
TIP
两种模式都使用 TransferQueue 作为数据传输层。Colocate 模式下 Rollout 和 Actor 分时复用同一组 GPU——Rollout 写完整批数据到 TransferQueue 后释放 GPU,Actor 接管 GPU 进行训练。Fully Async 模式下各服务运行在独立 GPU 上,数据的生产和消费可以并行进行。
核心优势
- 消除 GPU 空闲时间 — Rollout 和 Training 同时运行,推理引擎在训练期间持续生成数据
- 灵活的资源配比 — 训练和推理可以使用不同数量的 GPU,适应异构硬件
- 可控的 On/Off-Policy 程度 —
max_staleness参数精确控制数据新鲜度 - 流水线化的权重更新 — DCS 使权重分发与训练计算重叠
系统架构
架构图
┌───────────────────────────────────────────────────────────────────────────┐
│ Controller (Orchestrator) │
│ relax/core/controller.py │
│ │
│ ┌──────────┐ ┌──────────┐ ┌──────────┐ ┌────────────┐ ┌─────────┐ │
│ │ Rollout │ │ Actor │ │ ActorFwd │ │ Reference │ │ Adv │ │
│ │ Service │ │ Service │ │ Service │ │ Service │ │ Service │ │
│ └──┬───────┘ └──┬───────┘ └──┬───────┘ └──┬─────────┘ └──┬──────┘ │
└───────┼─────────────┼─────────────┼─────────────┼────────────────┼────────┘
│ │ │ │ │
▼ ▼ ▼ ▼ ▼
┌───────────────────────────────────────────────────────────────────────────┐
│ TransferQueue (Data Plane) │
│ │
│ ┌────────────────┐ ┌──────────────────────────────────┐ │
│ │ TQ Controller │◄──────┤ SimpleStorageUnit × N │ │
│ │ (Metadata Mgr) │ │ (Partitioned Data Storage) │ │
│ └────────────────┘ └──────────────────────────────────┘ │
│ ▲ │
│ │ │
│ ┌───────────────────┼────────────────────┐ │
│ │ StreamingDataset / StreamingDataLoader │ │
│ │ (relax/utils/data/stream_dataloader.py)│ │
│ └────────────────────────────────────────┘ │
└───────────────────────────────────────────────────────────────────────────┘
│ │ │ │
▼ ▼ ▼ ▼
┌───────────────────────────────────────────────────────────────────────────┐
│ Distributed Checkpoint Service (DCS) │
│ │
│ ┌──────────────┐ ┌──────────────────────────────────┐ │
│ │ Coordinator │◄───┤ CheckpointEngineClient × N │ │
│ │ (HTTP REST) │ │ (Per-rank weight send/recv) │ │
│ └──────────────┘ └──────────────────────────────────┘ │
│ │
│ ┌───────────────────────────────────────────────┐ │
│ │ DeviceDirectBackend (NCCL/GLOO) │ │
│ │ - Actor → Rollout: weight broadcast to SGLang│ │
│ │ - Actor → ActorFwd/Ref: PP-aware broadcast │ │
│ └───────────────────────────────────────────────┘ │
└───────────────────────────────────────────────────────────────────────────┘服务角色
Fully Async 模式下系统部署 6 个角色(由 relax/core/registry.py 中的 ROLES StrEnum 定义):
class ROLES(StrEnum):
actor: str = "actor" # 策略模型训练
critic: str = "critic" # 价值模型训练(可选)
rollout: str = "rollout" # SGLang 推理引擎,生成样本
advantages: str = "advantages" # 优势和回报计算
reference: str = "reference" # 参考模型前向(KL 散度)
actor_fwd: str = "actor_fwd" # 当前策略前向(log prob)角色选择逻辑(relax/core/registry.py):
def process_role(config):
if config.fully_async:
return ROLES # 全部 6 个角色
else:
return ROLES_COLOCATE # 仅 actor, critic, rollout数据流:TransferQueue 上的 StreamingDataLoader
两种模式下的 TransferQueue
Colocate 和 Fully Async 模式都使用 TransferQueue 进行数据传输。核心区别在于生产和消费的时序关系:
Colocate 模式(串行):
Rollout 完整写入分区 train_N ── 全部就绪 ──► Actor 一次性读取 train_N
(同组 GPU 分时复用;Rollout offload 后 Actor 唤醒训练)
(ref log prob、advantages 在 Actor 的 train_actor() 内串行计算)
Fully Async 模式(流式并行):
Rollout 增量写入分区 train_N ──► Actor 通过 StreamingDataLoader 消费
Rollout 可同时开始 train_N+1 ──► ActorFwd/Reference/Advantages 并行消费 train_N
(不同 GPU 集群完全并行;ref log prob、adv 独立计算并写回 TQ)分区机制:
- 分区 ID 格式:
train_{rollout_id},例如train_0、train_1、train_2 - 生产者(Rollout):完成一次 rollout 后将数据写入
train_{rollout_id} - 消费者(Actor/ActorFwd/Reference/Advantages):从对应分区读取数据,通过
task_name追踪消费进度 - 分区清理:Actor 训练完成后调用
async_clear_partition()清理分区
存储容量与 max_staleness:
# relax/core/controller.py
total_storage_size = (
self.config.rollout_batch_size
* (self.config.max_staleness + 1)
* self.config.n_samples_per_prompt
)TransferQueue 必须能同时缓存 max_staleness + 1 个 rollout batch 的数据。例如 max_staleness=2、rollout_batch_size=8、n_samples_per_prompt=8 时,需要 8 × 3 × 8 = 192 个样本的存储空间。
Task names 用于追踪不同消费者的消费进度:
| 消费者 | task_name | 消费的数据字段 |
|---|---|---|
| Actor | actor_train(StreamDataLoader)/ train(旧版) | tokens, loss_masks, log_probs, ref_log_probs, advantages, returns 等 |
| ActorFwd | actor_log_probs | tokens, total_lengths, response_lengths, loss_masks, rollout_log_probs |
| Reference | ref_log_probs | tokens, total_lengths, response_lengths, loss_masks, rollout_log_probs |
| Advantages | compute_advantages_and_returns | rollout_log_probs, log_probs, ref_log_probs, rewards 等 |
StreamingDataLoader 与 StreamingDataset
在 Fully Async 模式下,Actor 使用 StreamingDataLoader 进行流式数据消费。与 Colocate 模式下 Actor 需要等待 Rollout 完全生成一个 batch 后再读取不同,StreamingDataLoader 可以在 TransferQueue 中的数据被增量写入的同时进行消费。这是实现"训练和推理并行"的核心机制。
StreamingDataset
# TransferQueue (installed from https://github.com/redai-infra/TransferQueue)
class StreamingDataset(IterableDataset):
"""流式数据集,从 TransferQueue 动态获取数据"""
def __init__(self, config, batch_size, micro_batch_size, data_fields,
partition_id, task_name, dp_rank, fetch_batch_fn, process_batch_fn):
self.buffer = [] # 已拉取批次的缓存
self.batch_index = 0 # 当前消费位置
def __iter__(self):
while not consumed:
if self.batch_index <= len(self.buffer) - 1:
# 从缓存中读取(支持多次遍历)
yield from self.process_batch_fn(...)
else:
# 从 TransferQueue 拉取新数据
batch_data, batch_meta = self.fetch_batch_fn(...)
if batch_data is not None:
self.buffer.append((batch_data, batch_meta))核心特性:
- 按需拉取:每次拉取
global_batch_size / num_iters_per_train_update大小的数据 - 缓存复用:
buffer支持对同一批数据进行多次遍历(例如多轮训练) - 分区切换:
step(partition_id)清空缓存并切换到新的 rollout 数据分区
数据拉取函数(fetch_batch_fn)
Fully Async 模式使用定制的 get_data_from_transfer_queue() 函数(relax/utils/data/stream_dataloader.py):
# broadcast_pp 是 fully_async 的反义
fetch_batch_fn = partial(get_data_from_transfer_queue,
broadcast_pp=not getattr(args, "fully_async", False))广播策略差异:
| 模式 | broadcast_pp | 数据拉取节点 | 广播范围 |
|---|---|---|---|
| Colocate | True | tp_rank==0 && pp_rank==0 | TP 组 + PP 组 |
| Fully Async | False | tp_rank==0(每个 PP stage 独立) | 仅 TP 组(各 PP stage 独立拉取) |
- Colocate 模式:Rollout 已经将完整 batch 写入 TransferQueue。Actor 在同一组 GPU 上启动,PP rank 0 从 TQ 拉取数据并广播到其他 PP stage。所有数据一次性就绪用于训练。
- Fully Async 模式:每个 PP stage 位于不同 rank,各自独立从 TransferQueue 拉取数据,避免跨 PP stage 通信开销。由于数据可能仍在增量写入,StreamingDataLoader 会在数据未就绪时自动重试。
create_stream_dataloader
# relax/utils/data/stream_dataloader.py
def create_stream_dataloader(args, rollout_id, task_name, data_fields, dp_rank):
dataset = StreamingDataset(
config=args.tq_config,
batch_size=args.micro_batch_size * args.n_samples_per_prompt,
micro_batch_size=args.micro_batch_size,
data_fields=data_fields,
partition_id=f"train_{rollout_id}",
task_name=task_name,
dp_rank=dp_rank,
fetch_batch_fn=fetch_batch_fn,
process_batch_fn=split_dict,
)
dataloader = StreamingDataLoader(dataset)
# 计算每个 rollout 的训练步数
num_steps_per_rollout = (args.rollout_batch_size * args.n_samples_per_prompt
// args.global_batch_size)
num_microbatches = [
args.global_batch_size // dp_world_size // args.micro_batch_size
for _ in range(num_steps_per_rollout)
]
return [dataloader for _ in range(vpp_size)], num_microbatches异步权重同步:分布式 Checkpoint 服务(DCS)
DCS 在 Fully Async 中的作用
Actor 完成训练后,需要将权重分发到:
- Rollout(SGLang 引擎) — 更新推理引擎权重
- ActorFwd — 更新前向模型以计算当前策略 log prob
- Reference — 更新参考模型(依据
ref_update_interval)
DCS 架构
┌──────────────────────┐
│ DCS Coordinator │
│ (Ray Serve HTTP) │
│ │
│ - Node Registration │
│ - Topology Discovery │
│ - Weight Meta Buffer │
│ - Group Rank Assign │
└──────────┬───────────┘
│
┌──────────────────────┼──────────────────────┐
│ │ │
┌─────────▼──────────┐ ┌─────────▼──────────┐ ┌─────────▼──────────┐
│ CheckpointEngine │ │ CheckpointEngine │ │ CheckpointEngine │
│ Client (Actor) │ │ Client (ActorFwd) │ │ Client (Reference) │
│ │ │ │ │ │
│ DeviceDirectBackend│ │ DeviceDirectBackend│ │ DeviceDirectBackend│
│ (NCCL broadcast) │ │ (NCCL recv) │ │ (NCCL recv) │
└────────────────────┘ └────────────────────┘ └────────────────────┘权重更新流程
Actor → Rollout
# relax/backends/megatron/actor.py
def update_weights_fully_async(self, rollout_id, rollout_only=False, actor_fwd_only=False):
dist.barrier(group=get_gloo_group())
if not rollout_only:
run(self.checkpoint_engine_client.init_process_groups_for_actor_fwd_ref(rollout_id))
run(self.checkpoint_engine_client.update_weights_for_rollout(rollout_only, actor_fwd_only))update_weights_for_rollout 内部流程(DeviceDirectBackend):
- 暂停 Rollout 推理:HTTP 请求 SGLang 引擎
/pause_generation - 刷新 KV Cache:HTTP 请求
/flush_cache - 分发权重:
- 非专家参数:
all_gatherTP 分片 → 完整参数,PP 源 rank 广播到 Rollout(HF 格式)和 ActorFwd/Reference(原始格式) - 专家参数:额外 EP
all_gather,然后同上
- 非专家参数:
- 恢复 Rollout 推理:HTTP 请求
/continue_generation
Actor → ActorFwd/Reference
ActorFwd 和 Reference 通过 DCS 的 PP 感知通信组接收权重:
- 每个 Actor PP stage 创建独立的 NCCL process group(
update_actor_pp_{pp_rank}) - ActorFwd/Reference rank 加入这些 group 接收对应 PP stage 的权重
- 接收端轮询 Coordinator 获取权重元数据,分配空 tensor,然后通过
dist.broadcast接收
max_staleness:On-Policy 与 Off-Policy 的控制
概念
Staleness(陈旧度) 衡量训练使用的 rollout 数据与当前模型权重之间的版本差距。
- Staleness = 0:训练数据必须来自当前模型版本
- Staleness = N:训练数据可来自当前或前 N 个版本的模型
--max-staleness 2 # 允许 Rollout 最多领先 Actor 2 步对训练的影响
max_staleness = 0(On-Policy):
Rollout step 0 → Actor 训练 step 0 → Rollout step 1 → Actor 训练 step 1 → ...
(Rollout 必须等待 Actor 消费完当前数据才能继续)
max_staleness = 2(部分 Off-Policy):
Rollout: step 0 → step 1 → step 2 → [等待] → step 3 → step 4 → step 5 → [等待] → ...
Actor: ........................step 0 → step 1 → step 2 → ...............step 3 → ...
(Rollout 最多领先 2 步;超限时暂停等待 Actor 追赶)实现机制
# relax/components/rollout.py
def satisfy_staleness(partition_list, current_rollout_id, max_staleness):
"""检查当前 rollout 是否在允许的 staleness 范围内"""
if not partition_list:
return True
oldest_step = min(int(p.split("_")[-1]) for p in partition_list)
return current_rollout_id + 1 - oldest_step <= max_staleness当 TransferQueue 中有 max_staleness 个或更多未消费的分区时,Rollout 将暂停等待 Actor 消费。
不同 max_staleness 值的效果
max_staleness | 训练语义 | 吞吐量 | 稳定性 | 典型场景 |
|---|---|---|---|---|
| 0 | 严格 On-Policy | 低 | 最高 | 调试、小模型 |
| 1 | 接近 On-Policy | 中等 | 高 | 生产环境、中等模型 |
| 2-4 | 轻度 Off-Policy | 高 | 中等 | 大模型、推理较慢 |
| >4 | 显著 Off-Policy | 最高 | 需验证 | 极致吞吐场景 |
TIP
生产环境建议 max_staleness=1~2,兼顾吞吐量和训练稳定性。搭配 --eps-clip 和 --eps-clip-high 裁剪参数可缓解 Off-Policy 带来的训练不稳定问题。
训练循环
Actor 训练循环
# relax/components/actor.py
def _background_run(self):
while True:
if self._stop_event.is_set():
break
with self._lock:
local_step = self.step
if local_step >= self.config.num_rollout:
break
self._execute_training()
run(self.data_system_client.async_clear_partition(f"train_{local_step}"))
with self._lock:
self.step += 1
def _execute_training(self):
if self.step < self.config.num_critic_only_steps:
return # 跳过仅训练 Critic 的阶段
if self.config.fully_async:
ray.get(self.actor_model.train_fully_async(self.step))
self._maybe_save_model()
else:
ray.get(self.actor_model.async_train(self.step))ActorFwd 与 Reference 工作流
- 从 TransferQueue 分批拉取数据(
_get_data_from_transfer_queue) - 执行前向计算(
forward_only)获取 log prob - 将结果写回 TransferQueue(
_put_data_to_transfer_queue) - 消费完全部数据后,调用
recv_weight_fully_async()接收新权重
Advantages 服务
Advantages 服务等待 ref_log_probs 和 log_probs 都在 TransferQueue 中就绪后,计算优势和回报并写回。依赖关系由 TransferQueue 的 get_meta 自动处理——当所需字段未就绪时会阻塞等待。
数据流时序
Time ──────────────────────────────────────────────────────────────────────►
Rollout: ┌──generate(step=N)──┐ ┌──generate(step=N+1)────┐ ...
│ SGLang inference │ │ (if staleness allows) │
│ + reward scoring │ │ │
└─────────┬──────────┘ └────────────────────────┘
│
▼ Write to TransferQueue (partition=train_N)
│ Fields: tokens, loss_masks, rollout_log_probs,
│ rewards, total_lengths, response_lengths, ...
│
┌───────────────┼──────────────────────┐
│ │ │
▼ ▼ ▼
ActorFwd: Reference: Advantages:
read train_N read train_N wait for log_probs
compute compute and ref_log_probs
log_probs ref_log_probs │
write to TQ write to TQ │
│ │ │
└───────────────┼──────────────────────┘
│ All forward results ready
▼
Advantages Service:
read rollout_log_probs + log_probs + ref_log_probs + rewards
compute advantages + returns
write to TransferQueue
│
▼
Actor (Training):
StreamingDataLoader streams data
→ Megatron forward + backward + optimizer step
→ DCS distributes new weights to Rollout, ActorFwd, Reference
→ Clear partition train_N
┌───────────────┼──────────────────────┐
│ │ │
▼ ▼ ▼
Rollout: ActorFwd: Reference:
update weights recv_weight recv_weight (if needed)
resume inference (NCCL broadcast) (NCCL broadcast)配置参数
CLI 参数
| 参数 | 默认值 | 说明 |
|---|---|---|
--fully-async | false | 启用全异步训练流水线 |
--max-staleness | 0 | 最大允许 staleness(0=On-Policy,>0=部分 Off-Policy) |
--num-data-storage-units | 1 | TransferQueue SimpleStorageUnit actor 数量 |
--num-iters-per-train-update | 1 | 每个 global batch 的训练迭代次数 |
--checkpoint-engine-backend | nccl | DCS 通信后端(nccl 或 gloo) |
--polling-mode | true | TransferQueue Controller 使用轮询模式获取元数据 |
--ref-update-interval | None | 参考模型更新周期(None=不更新) |
--resource | - | JSON 格式角色资源分配,如 '{"actor": [1, 2], "rollout": [1, 4], ...}' |
配置示例
# 8 GPU 全异步(来自 scripts/training/text/run-qwen3-4B-8xgpu-async.sh)
ray job submit -- python3 relax/entrypoints/train.py \
--resource '{"actor": [1, 2], "rollout": [1, 4], "reference": [1, 1], "actor_fwd": [1, 1], "advantages": [1, 0]}' \
--max-staleness 2 \
--num-data-storage-units 1 \
--num-iters-per-train-update 8 \
--fully-async \
--use-health-check \
...资源分配详解:
- Actor:1 副本 × 2 GPU(TP=2 训练)
- Rollout:1 副本 × 4 GPU(4 个 SGLang 引擎)
- Reference:1 副本 × 1 GPU(单 GPU 前向)
- ActorFwd:1 副本 × 1 GPU(单 GPU 前向)
- Advantages:1 副本 × 0 GPU(仅 CPU 计算)
容错机制
重启策略
| 故障角色 | 策略 | 原因 |
|---|---|---|
| Actor | 全局重启 | Actor 是核心训练服务,所有其他服务依赖于它 |
| Rollout | 全局重启 | 引擎状态复杂,难以原地恢复 |
| ActorFwd | 全局重启 | 权重通信组状态难以恢复 |
| Reference | 原地重启 | 与 Advantages 类似,可安全重新部署 |
| Advantages | 原地重启 | 无状态服务,可安全重新部署 |
| 任意角色 ≥3 次 | 全局重启 | 系统不稳定,需全量重新初始化 |
权重更新期间的容错
# relax/backends/megatron/actor.py — MegatronTrainRayActor.train_async()
rollout_only, actor_fwd_only = self._check_services_health()
# rollout_only=True:跳过 ActorFwd 权重更新(ActorFwd 不可用)
# actor_fwd_only=True:跳过 Rollout 权重更新(Rollout 不可用)
self.update_weights_fully_async(rollout_id, rollout_only, actor_fwd_only)性能调优
关键调优参数
| 参数 | 推荐值 | 影响 |
|---|---|---|
--max-staleness | 1-2 | 平衡吞吐量与训练稳定性 |
--num-iters-per-train-update | 4-8 | 值越大数据利用率越高,但单步延迟增加 |
--num-data-storage-units | 1-2 | 更多存储单元提升并行数据访问性能 |
GPU 资源分配策略
总 GPU 数:N
├── Actor(训练):~25-30%(需 TP/PP/CP 支持)
├── Rollout(推理):~50-60%(推理吞吐是瓶颈)
├── ActorFwd:~5-10%(单 GPU 通常足够)
├── Reference:~5-10%(单 GPU 通常足够)
└── Advantages:0 GPU(仅 CPU 计算)Colocate 与 Fully Async 对比
Colocate Mode Fully Async Mode
(Same GPUs, time-shared) (Dedicated GPU clusters)
┌──────────────────┐ ┌──────────────────────┐
Time ──► │ Rollout │ │ Rollout ──────────► │
│ (SGLang infer) │ │ (continuous infer) │
│ write TQ train_N │ │ │
├──────────────────┤ │ Actor ──────────► │
│ offload rollout │ │ (StreamDataLoader │
│ wake up actor │ │ streaming + train) │
├──────────────────┤ │ │
│ Actor Train │ │ ActorFwd ────────► │
│ (read TQ train_N)│ │ (compute log prob) │
│ (incl ref/adv) │ │ │
├──────────────────┤ │ Reference ────────► │
│ Weight Update │ │ (compute ref logp) │
│ (tensor copy) │ │ │
├──────────────────┤ │ Advantages ──────► │
│ offload actor │ │ (compute adv/ret) │
│ wake up rollout │ │ │
├──────────────────┤ │ DCS weight sync │
│ Rollout │ │ (overlaps training) │
│ (continue) │ └──────────────────────┘
└──────────────────┘
GPU utilization: ~30-50% GPU utilization: ~70-90%
All operations strictly serial All operations parallel
Data via TransferQueue, no overlap Data via TransferQueue, streaming overlap延伸阅读
- 系统架构 — Relax 整体架构设计
- 分布式 Checkpoint — DCS 详细文档
- 健康检查管理器 — 健康监控与故障恢复
