Skip to content

Fully Async Training Pipeline

Overview

The Fully Async training pipeline is a high-throughput RLHF/RL training mode designed to maximize GPU utilization. Unlike the Colocate (synchronous) mode, Fully Async deploys training (Actor), inference (Rollout), forward computation (ActorFwd / Reference), and advantage calculation (Advantages) on separate GPU clusters. Services exchange data through TransferQueue and synchronize weights asynchronously via the Distributed Checkpoint Service (DCS).

Design Comparison

DimensionColocate (Synchronous)Fully Async
GPU SharingActor and Rollout share the same GPUsActor, Rollout, ActorFwd/Reference each have dedicated GPUs
Execution ModelSerial: Rollout completes → switch to Train → weight updateFully parallel: Rollout, Train, forward computation run simultaneously
Weight SyncIn-process tensor copy (colocated)Cross-node NCCL broadcast via DCS (Checkpoint Engine)
Data FlowAlso via TransferQueue, but synchronous: Rollout writes the full batch before Actor readsVia TransferQueue + StreamingDataLoader with async streaming (production and consumption overlap)
Stalenessmax_staleness=0 (strict On-Policy)Configurable max_staleness (allows some Off-Policy)
Rolesactor, critic, rolloutactor, critic, rollout, advantages, reference, actor_fwd

TIP

Both modes use TransferQueue as the data transport layer. In Colocate mode, Rollout and Actor time-share the same GPUs — Rollout writes a full batch to TransferQueue, then yields GPUs for Actor to train. In Fully Async mode, services run on independent GPUs in parallel, enabling concurrent data production and consumption.

Key Advantages

  1. Eliminate GPU idle time — Rollout and Training run simultaneously; Rollout engines continue generating data during training
  2. Flexible resource allocation — Training and inference can use different numbers of GPUs, adapting to heterogeneous hardware
  3. Controllable On/Off-Policy degree — The max_staleness parameter precisely controls data freshness
  4. Pipelined weight updates — DCS enables weight distribution to overlap with training computation

Architecture

System Diagram

┌───────────────────────────────────────────────────────────────────────────┐
│                        Controller (Orchestrator)                          │
│                     relax/core/controller.py                              │
│                                                                           │
│    ┌──────────┐  ┌──────────┐  ┌──────────┐  ┌────────────┐  ┌─────────┐  │
│    │ Rollout  │  │  Actor   │  │ ActorFwd │  │ Reference  │  │  Adv    │  │
│    │ Service  │  │ Service  │  │ Service  │  │  Service   │  │ Service │  │
│    └──┬───────┘  └──┬───────┘  └──┬───────┘  └──┬─────────┘  └──┬──────┘  │
└───────┼─────────────┼─────────────┼─────────────┼────────────────┼────────┘
        │             │             │             │                │
        ▼             ▼             ▼             ▼                ▼
┌───────────────────────────────────────────────────────────────────────────┐
│                      TransferQueue (Data Plane)                           │
│                                                                           │
│  ┌────────────────┐       ┌──────────────────────────────────┐            │
│  │ TQ Controller  │◄──────┤  SimpleStorageUnit × N           │            │
│  │ (Metadata Mgr) │       │  (Partitioned Data Storage)      │            │
│  └────────────────┘       └──────────────────────────────────┘            │
│                                    ▲                                      │
│                                    │                                      │
│                ┌───────────────────┼────────────────────┐                 │
│                │ StreamingDataset / StreamingDataLoader │                 │
│                │ (relax/utils/data/stream_dataloader.py)│                 │
│                └────────────────────────────────────────┘                 │
└───────────────────────────────────────────────────────────────────────────┘
        │             │             │             │
        ▼             ▼             ▼             ▼
┌───────────────────────────────────────────────────────────────────────────┐
│              Distributed Checkpoint Service (DCS)                         │
│                                                                           │
│  ┌──────────────┐     ┌──────────────────────────────────┐                │
│  │  Coordinator │◄───┤  CheckpointEngineClient × N      │                 │
│  │  (HTTP REST) │    │  (Per-rank weight send/recv)     │                 │
│  └──────────────┘     └──────────────────────────────────┘                │
│                                                                           │
│  ┌───────────────────────────────────────────────┐                        │
│  │  DeviceDirectBackend (NCCL/GLOO)              │                        │
│  │  - Actor → Rollout: weight broadcast to SGLang│                        │
│  │  - Actor → ActorFwd/Ref: PP-aware broadcast   │                        │
│  └───────────────────────────────────────────────┘                        │
└───────────────────────────────────────────────────────────────────────────┘

Service Roles

In Fully Async mode, the system deploys 6 roles (defined by the ROLES StrEnum in relax/core/registry.py):

python
class ROLES(StrEnum):
    actor: str = "actor"           # Policy model training
    critic: str = "critic"         # Value model training (optional)
    rollout: str = "rollout"       # SGLang inference engine, generates samples
    advantages: str = "advantages" # Advantage and return computation
    reference: str = "reference"   # Reference model forward (KL divergence)
    actor_fwd: str = "actor_fwd"   # Current policy forward (log prob)

Role selection logic (relax/core/registry.py):

python
def process_role(config):
    if config.fully_async:
        return ROLES           # All 6 roles
    else:
        return ROLES_COLOCATE  # Only actor, critic, rollout

Data Flow: StreamingDataLoader on TransferQueue

TransferQueue in Both Modes

Both Colocate and Fully Async modes use TransferQueue for data transfer. The key difference is the timing relationship between production and consumption:

Colocate mode (serial):
  Rollout fully writes partition train_N ── all ready ──► Actor reads train_N at once
  (Same GPUs time-shared; Rollout offloads then Actor wakes up to train)
  (ref log prob, advantages computed inside Actor's train_actor() serially)

Fully Async mode (streaming parallel):
  Rollout writes partition train_N incrementally ──► Actor consumes via StreamingDataLoader
  Rollout can start train_N+1 simultaneously    ──► ActorFwd/Reference/Advantages consume train_N in parallel
  (Different GPU clusters run fully in parallel; ref log prob, adv computed independently and written back to TQ)

Partition mechanism:

  • Partition ID format: train_{rollout_id}, e.g. train_0, train_1, train_2
  • Producer (Rollout): writes data to train_{rollout_id} after completing a rollout
  • Consumers (Actor/ActorFwd/Reference/Advantages): read from the corresponding partition, tracked by task_name
  • Partition cleanup: Actor calls async_clear_partition() after training completes

Storage capacity and max_staleness:

python
# relax/core/controller.py
total_storage_size = (
    self.config.rollout_batch_size
    * (self.config.max_staleness + 1)
    * self.config.n_samples_per_prompt
)

TransferQueue must be able to buffer max_staleness + 1 rollout batches simultaneously. For example, with max_staleness=2, rollout_batch_size=8, n_samples_per_prompt=8, this requires 8 × 3 × 8 = 192 sample slots.

Task names track consumption progress for different consumers:

Consumertask_nameData Fields Consumed
Actoractor_train (StreamDataLoader) / train (legacy)tokens, loss_masks, log_probs, ref_log_probs, advantages, returns, etc.
ActorFwdactor_log_probstokens, total_lengths, response_lengths, loss_masks, rollout_log_probs
Referenceref_log_probstokens, total_lengths, response_lengths, loss_masks, rollout_log_probs
Advantagescompute_advantages_and_returnsrollout_log_probs, log_probs, ref_log_probs, rewards, etc.

StreamingDataLoader and StreamingDataset

In Fully Async mode, Actor uses StreamingDataLoader for streaming data consumption. Unlike Colocate mode where Actor waits for Rollout to fully generate a batch before reading, StreamingDataLoader can consume data as it is being incrementally written to TransferQueue. This is the core mechanism enabling "training and inference in parallel".

StreamingDataset

python
# TransferQueue (installed from https://github.com/redai-infra/TransferQueue)
class StreamingDataset(IterableDataset):
    """Streaming dataset that dynamically fetches data from TransferQueue"""

    def __init__(self, config, batch_size, micro_batch_size, data_fields,
                 partition_id, task_name, dp_rank, fetch_batch_fn, process_batch_fn):
        self.buffer = []       # Cache for fetched batches
        self.batch_index = 0   # Current consumption position

    def __iter__(self):
        while not consumed:
            if self.batch_index <= len(self.buffer) - 1:
                # Read from cache (supports multi-pass training)
                yield from self.process_batch_fn(...)
            else:
                # Fetch new data from TransferQueue
                batch_data, batch_meta = self.fetch_batch_fn(...)
                if batch_data is not None:
                    self.buffer.append((batch_data, batch_meta))

Key features:

  • On-demand fetching: fetches one global_batch_size / num_iters_per_train_update batch at a time
  • Buffer reuse: buffer supports iterating over the same batch multiple times (e.g. multi-epoch training)
  • Partition switching: step(partition_id) clears the buffer and switches to a new rollout data partition

Fetch Function (fetch_batch_fn)

Fully Async mode uses a customized get_data_from_transfer_queue() function (relax/utils/data/stream_dataloader.py):

python
# broadcast_pp is the inverse of fully_async
fetch_batch_fn = partial(get_data_from_transfer_queue,
                         broadcast_pp=not getattr(args, "fully_async", False))

Broadcast strategy differences:

Modebroadcast_ppData Fetch NodeBroadcast Scope
ColocateTruetp_rank==0 && pp_rank==0TP group + PP group
Fully AsyncFalsetp_rank==0 (each PP stage independently)TP group only (each PP stage fetches independently)
  • Colocate mode: Rollout has already written the full batch to TransferQueue. Actor starts on the same GPUs, PP rank 0 fetches data from TQ and broadcasts to other PP stages. All data is available at once for training.
  • Fully Async mode: Each PP stage is on a separate rank and fetches data from TransferQueue independently, avoiding cross-PP-stage communication overhead. Since data may still be written incrementally, StreamingDataLoader automatically retries when data is not yet ready.

create_stream_dataloader

python
# relax/utils/data/stream_dataloader.py
def create_stream_dataloader(args, rollout_id, task_name, data_fields, dp_rank):
    dataset = StreamingDataset(
        config=args.tq_config,
        batch_size=args.micro_batch_size * args.n_samples_per_prompt,
        micro_batch_size=args.micro_batch_size,
        data_fields=data_fields,
        partition_id=f"train_{rollout_id}",
        task_name=task_name,
        dp_rank=dp_rank,
        fetch_batch_fn=fetch_batch_fn,
        process_batch_fn=split_dict,
    )
    dataloader = StreamingDataLoader(dataset)

    # Compute training steps per rollout
    num_steps_per_rollout = (args.rollout_batch_size * args.n_samples_per_prompt
                            // args.global_batch_size)
    num_microbatches = [
        args.global_batch_size // dp_world_size // args.micro_batch_size
        for _ in range(num_steps_per_rollout)
    ]
    return [dataloader for _ in range(vpp_size)], num_microbatches

Async Weight Sync: Distributed Checkpoint Service (DCS)

DCS Role in Fully Async

After Actor completes training, weights must be distributed to:

  1. Rollout (SGLang engines) — update inference engine weights
  2. ActorFwd — update the forward model for current policy log prob computation
  3. Reference — update the reference model (per ref_update_interval)

DCS Architecture

                          ┌──────────────────────┐
                          │   DCS Coordinator    │
                          │   (Ray Serve HTTP)   │
                          │                      │
                          │ - Node Registration  │
                          │ - Topology Discovery │
                          │ - Weight Meta Buffer │
                          │ - Group Rank Assign  │
                          └──────────┬───────────┘

              ┌──────────────────────┼──────────────────────┐
              │                      │                      │
    ┌─────────▼──────────┐ ┌─────────▼──────────┐ ┌─────────▼──────────┐
    │ CheckpointEngine   │ │ CheckpointEngine   │ │ CheckpointEngine   │
    │ Client (Actor)     │ │ Client (ActorFwd)  │ │ Client (Reference) │
    │                    │ │                    │ │                    │
    │ DeviceDirectBackend│ │ DeviceDirectBackend│ │ DeviceDirectBackend│
    │ (NCCL broadcast)   │ │ (NCCL recv)        │ │ (NCCL recv)        │
    └────────────────────┘ └────────────────────┘ └────────────────────┘

Weight Update Flow

Actor → Rollout

python
# relax/backends/megatron/actor.py
def update_weights_fully_async(self, rollout_id, rollout_only=False, actor_fwd_only=False):
    dist.barrier(group=get_gloo_group())
    if not rollout_only:
        run(self.checkpoint_engine_client.init_process_groups_for_actor_fwd_ref(rollout_id))
    run(self.checkpoint_engine_client.update_weights_for_rollout(rollout_only, actor_fwd_only))

Internal flow of update_weights_for_rollout (DeviceDirectBackend):

  1. Pause Rollout inference: HTTP request to SGLang engine /pause_generation
  2. Flush KV Cache: HTTP request /flush_cache
  3. Distribute weights:
    • Non-expert parameters: all_gather TP shards → full parameters, then PP source rank broadcasts to Rollout (HF format) and ActorFwd/Reference (raw format)
    • Expert parameters: additional EP all_gather, then same as above
  4. Resume Rollout inference: HTTP request /continue_generation

Actor → ActorFwd/Reference

ActorFwd and Reference receive weights via DCS PP-aware communication groups:

  • Each Actor PP stage creates an independent NCCL process group (update_actor_pp_{pp_rank})
  • ActorFwd/Reference ranks join these groups to receive weights for the corresponding PP stage
  • The receiver polls the Coordinator for weight metadata, allocates empty tensors, then receives via dist.broadcast

max_staleness: On-Policy vs Off-Policy Control

Concept

Staleness measures the version gap between the rollout data used for training and the current model weights.

  • Staleness = 0: training data must come from the current model version
  • Staleness = N: training data can come from current or previous N model versions
bash
--max-staleness 2    # Allow Rollout to be up to 2 steps ahead of Actor

Impact on Training

max_staleness = 0 (On-Policy):
  Rollout step 0 → Actor trains step 0 → Rollout step 1 → Actor trains step 1 → ...
  (Rollout must wait for Actor to consume current data before continuing)

max_staleness = 2 (Partial Off-Policy):
  Rollout: step 0 → step 1 → step 2 → [wait] → step 3 → step 4 → step 5 → [wait] → ...
  Actor:   ........................step 0 → step 1 → step 2 → ...............step 3 → ...
  (Rollout can be up to 2 steps ahead; pauses when exceeding the limit)

Implementation

python
# relax/components/rollout.py
def satisfy_staleness(partition_list, current_rollout_id, max_staleness):
    """Check if the current rollout is within the allowed staleness bound."""
    if not partition_list:
        return True
    oldest_step = min(int(p.split("_")[-1]) for p in partition_list)
    return current_rollout_id + 1 - oldest_step <= max_staleness

If there are max_staleness or more unconsumed partitions in TransferQueue, Rollout pauses and waits for Actor to catch up.

Effect of Different max_staleness Values

max_stalenessTraining SemanticsThroughputStabilityTypical Scenario
0Strict On-PolicyLowHighestDebugging, small models
1Near On-PolicyMediumHighProduction, medium models
2-4Mild Off-PolicyHighMediumLarge models, slow inference
>4Significant Off-PolicyHighestNeeds validationExtreme throughput priority

TIP

For production, max_staleness=1~2 is recommended to balance throughput and training stability. Combine with --eps-clip and --eps-clip-high clipping parameters to mitigate Off-Policy instability.


Training Loop

Actor Training Loop

python
# relax/components/actor.py
def _background_run(self):
    while True:
        if self._stop_event.is_set():
            break
        with self._lock:
            local_step = self.step
        if local_step >= self.config.num_rollout:
            break
        self._execute_training()
        run(self.data_system_client.async_clear_partition(f"train_{local_step}"))
        with self._lock:
            self.step += 1

def _execute_training(self):
    if self.step < self.config.num_critic_only_steps:
        return  # Skip critic-only warmup phase
    if self.config.fully_async:
        ray.get(self.actor_model.train_fully_async(self.step))
        self._maybe_save_model()
    else:
        ray.get(self.actor_model.async_train(self.step))

ActorFwd and Reference Workflow

  1. Fetch data in batches from TransferQueue (_get_data_from_transfer_queue)
  2. Execute forward computation (forward_only) to get log probs
  3. Write results back to TransferQueue (_put_data_to_transfer_queue)
  4. After all data is consumed, call recv_weight_fully_async() to receive new weights

Advantages Service

The Advantages service waits for both ref_log_probs and log_probs to be ready in TransferQueue, then computes advantages and returns and writes them back. The dependency is handled automatically by TransferQueue's get_meta — it blocks until the required fields are available.


Data Flow Timeline

Time ──────────────────────────────────────────────────────────────────────►

Rollout:  ┌──generate(step=N)──┐     ┌──generate(step=N+1)────┐    ...
          │ SGLang inference   │     │  (if staleness allows) │
          │ + reward scoring   │     │                        │
          └─────────┬──────────┘     └────────────────────────┘

                    ▼ Write to TransferQueue (partition=train_N)
                    │ Fields: tokens, loss_masks, rollout_log_probs,
                    │         rewards, total_lengths, response_lengths, ...

    ┌───────────────┼──────────────────────┐
    │               │                      │
    ▼               ▼                      ▼
  ActorFwd:      Reference:            Advantages:
read train_N    read train_N        wait for log_probs
  compute        compute            and ref_log_probs
 log_probs     ref_log_probs               │
 write to TQ    write to TQ                │
    │               │                      │
    └───────────────┼──────────────────────┘
                    │ All forward results ready

              Advantages Service:
                read rollout_log_probs + log_probs + ref_log_probs + rewards
                compute advantages + returns
                write to TransferQueue


              Actor (Training):
                StreamingDataLoader streams data
                 → Megatron forward + backward + optimizer step
                 → DCS distributes new weights to Rollout, ActorFwd, Reference
                 → Clear partition train_N

    ┌───────────────┼──────────────────────┐
    │               │                      │
    ▼               ▼                      ▼
 Rollout:         ActorFwd:             Reference:
 update weights   recv_weight            recv_weight (if needed)
 resume inference (NCCL broadcast)      (NCCL broadcast)

Configuration

CLI Parameters

ParameterDefaultDescription
--fully-asyncfalseEnable the Fully Async training pipeline
--max-staleness0Maximum allowed staleness (0=On-Policy, >0=partial Off-Policy)
--num-data-storage-units1Number of TransferQueue SimpleStorageUnit actors
--num-iters-per-train-update1Number of training iterations per global batch
--checkpoint-engine-backendncclDCS communication backend (nccl or gloo)
--polling-modetrueTransferQueue Controller uses polling for metadata
--ref-update-intervalNoneReference model update period (None=no update)
--resource-JSON role resource allocation, e.g. '{"actor": [1, 2], "rollout": [1, 4], ...}'

Example Configuration

bash
# 8 GPU Fully Async (from scripts/training/text/run-qwen3-4B-8xgpu-async.sh)
ray job submit -- python3 relax/entrypoints/train.py \
    --resource '{"actor": [1, 2], "rollout": [1, 4], "reference": [1, 1], "actor_fwd": [1, 1], "advantages": [1, 0]}' \
    --max-staleness 2 \
    --num-data-storage-units 1 \
    --num-iters-per-train-update 8 \
    --fully-async \
    --use-health-check \
    ...

Resource allocation breakdown:

  • Actor: 1 replica × 2 GPU (TP=2 training)
  • Rollout: 1 replica × 4 GPU (4 SGLang engines)
  • Reference: 1 replica × 1 GPU (single-GPU forward)
  • ActorFwd: 1 replica × 1 GPU (single-GPU forward)
  • Advantages: 1 replica × 0 GPU (CPU-only computation)

Fault Tolerance

Restart Strategy

Failed RoleStrategyReason
ActorGlobal RestartActor is the core training service; all others depend on it
RolloutGlobal RestartComplex engine state, difficult to recover in-place
ActorFwdGlobal RestartWeight communication group state is hard to recover
ReferenceIn-place RestartSimilar to Advantages, safe to redeploy
AdvantagesIn-place RestartStateless service, safe to redeploy
Any role ≥3 timesGlobal RestartSystem unstable, full re-initialization needed

Fault Tolerance During Weight Update

python
# relax/backends/megatron/actor.py — MegatronTrainRayActor.train_async()
rollout_only, actor_fwd_only = self._check_services_health()
# rollout_only=True: skip ActorFwd weight update (ActorFwd unavailable)
# actor_fwd_only=True: skip Rollout weight update (Rollout unavailable)
self.update_weights_fully_async(rollout_id, rollout_only, actor_fwd_only)

Performance Tuning

Key Tuning Parameters

ParameterRecommendedImpact
--max-staleness1-2Balance throughput vs training stability
--num-iters-per-train-update4-8Larger values improve data utilization but increase per-step latency
--num-data-storage-units1-2More storage units improve parallel data access

GPU Resource Allocation Strategy

Total GPUs: N
├── Actor (training): ~25-30% (needs TP/PP/CP support)
├── Rollout (inference): ~50-60% (inference throughput is the bottleneck)
├── ActorFwd: ~5-10% (single GPU usually sufficient)
├── Reference: ~5-10% (single GPU usually sufficient)
└── Advantages: 0 GPU (CPU-only computation)

Colocate vs Fully Async Comparison

                Colocate Mode                           Fully Async Mode
          (Same GPUs, time-shared)                 (Dedicated GPU clusters)
            ┌──────────────────┐                     ┌──────────────────────┐
  Time ──►  │   Rollout        │                     │  Rollout ──────────► │
            │ (SGLang infer)   │                     │  (continuous infer)  │
            │ write TQ train_N │                     │                      │
            ├──────────────────┤                     │  Actor  ──────────►  │
            │ offload rollout  │                     │  (StreamDataLoader   │
            │ wake up actor    │                     │   streaming + train) │
            ├──────────────────┤                     │                      │
            │   Actor Train    │                     │  ActorFwd ────────►  │
            │ (read TQ train_N)│                     │  (compute log prob)  │
            │ (incl ref/adv)   │                     │                      │
            ├──────────────────┤                     │  Reference ────────► │
            │   Weight Update  │                     │  (compute ref logp)  │
            │ (tensor copy)    │                     │                      │
            ├──────────────────┤                     │  Advantages ──────►  │
            │ offload actor    │                     │  (compute adv/ret)   │
            │ wake up rollout  │                     │                      │
            ├──────────────────┤                     │  DCS weight sync     │
            │   Rollout        │                     │  (overlaps training) │
            │   (continue)     │                     └──────────────────────┘
            └──────────────────┘
         GPU utilization: ~30-50%                    GPU utilization: ~70-90%
         All operations strictly serial              All operations parallel
         Data via TransferQueue, no overlap           Data via TransferQueue, streaming overlap

Next Steps

Released under the Apache 2.0 License.