| name | ff-new-algorithm |
| description | Complete workflow for adding a new RL training algorithm. Covers paradigm selection, TrainingArguments subclass, trainer implementation, registry, example config, and verification. Trigger: 'add algorithm', 'new trainer', 'new training method', 'implement algorithm'. |
New RL Algorithm Integration
Authoritative reference: guidance/algorithms.md
Prerequisites
Determine your algorithm's characteristics:
- Paradigm: Coupled or Decoupled? (
constraints.md #7)
- Dynamics: Which SDE/ODE formulation? (
Flow-SDE, Dance-SDE, CPS, ODE)
- Advantage: How are advantages computed from rewards? (Most algorithms can delegate to
AdvantageProcessor)
- Loss: What is the policy optimization objective?
Phase 1: Design
- Study existing implementations:
- Coupled example:
trainers/grpo.py (GRPO)
- Decoupled example:
trainers/nft.py (DiffusionNFT) or trainers/awm.py (AWM)
- Identify what's shared vs unique (
constraints.md #11):
- Shared: Data loading, reward computation,
AdvantageProcessor, adapter interface, checkpoint logic
- Unique:
start() method, loss function, algorithm-specific hyperparameters
- Per-epoch hook order:
sample() → prepare_feedback() → optimize() (see guidance/workflow.md)
Phase 2: Configuration
Step 1 — Define Algorithm-Specific Arguments
Create a new file src/flow_factory/hparams/training_args/my_algo.py:
from __future__ import annotations
from dataclasses import dataclass, field
from ._base import TrainingArguments
@dataclass
class MyAlgoTrainingArguments(TrainingArguments):
"""Training arguments specific to MyAlgo."""
my_specific_param: float = field(
default=0.1,
metadata={"help": "Description of param."},
)
another_param: int = field(
default=10,
metadata={"help": "Description of param."},
)
If the algorithm uses a different CFG guidance_scale at optimize time than at sampling/rollout time (e.g., kl_cfg for a reference-model branch), override get_preprocess_guidance_scale() so the data preprocessing stage encodes negative prompts:
def get_preprocess_guidance_scale(self) -> float:
"""Ensure negative prompts are encoded when optimize-time CFG needs them."""
return max(self.guidance_scale, self.my_optimize_cfg)
See topics/adapter_conventions.md "Classifier-Free Guidance (CFG) Convention" for the full two-stage CFG contract.
Step 2 — Register in Argument Resolver
Update three files in src/flow_factory/hparams/training_args/:
a) Add import + registry entry in _registry.py:
from .my_algo import MyAlgoTrainingArguments
_TRAINING_ARGS_REGISTRY: Dict[str, Type[TrainingArguments]] = {
...
'my_algo': MyAlgoTrainingArguments,
}
b) Add re-export in __init__.py:
from .my_algo import MyAlgoTrainingArguments
c) Add re-export in src/flow_factory/hparams/__init__.py:
from .training_args import MyAlgoTrainingArguments
Phase 3: Trainer Implementation
Step 3 — Create Trainer Class
from .abc import BaseTrainer
from .registry import register_trainer
@register_trainer('my_algo')
class MyAlgoTrainer(BaseTrainer):
"""My custom RL algorithm trainer."""
def start(self):
"""Main training loop — implements the 6-stage pipeline."""
while self.should_continue_training():
if self.log_args.save_freq > 0 and self.epoch % self.log_args.save_freq == 0:
self.save_checkpoint(save_dir, epoch=self.epoch)
if self.eval_args.eval_freq > 0 and self.epoch % self.eval_args.eval_freq == 0:
self.evaluate()
samples = self.sample()
self.prepare_feedback(samples)
self.optimize(samples)
self.adapter.ema_step(step=self.epoch)
self.epoch += 1
def evaluate(self):
"""Evaluation loop — reuse pattern from GRPO/NFT."""
pass
def sample(self):
"""Stages 2-3: K-repeat sampling + trajectory generation."""
pass
def prepare_feedback(self, samples):
"""Stages 4-5: Reward buffer finalize and advantages (no policy gradients)."""
rewards = self.reward_buffer.finalize(store_to_samples=True, split='all')
self.compute_advantages(samples, rewards, store_to_samples=True)
adv_metrics = self.advantage_processor.pop_advantage_metrics()
if adv_metrics:
self.log_data(adv_metrics, step=self.step)
def optimize(self, samples):
"""Stage 6: Policy update."""
pass
Note: AdvantageProcessor is auto-instantiated in BaseTrainer._init_reward_model().
All trainers delegate via self.advantage_processor.compute_advantages() — see architecture.md "Advantage Computation".
Step 4 — Register in Trainer Registry
Add to _TRAINER_REGISTRY in src/flow_factory/trainers/registry.py:
'my_algo': 'flow_factory.trainers.my_algo.MyAlgoTrainer',
Phase 4: Configuration & Examples
Create example config examples/my_algo/lora/flux1/default.yaml:
model:
model_type: "flux1"
model_path: "black-forest-labs/FLUX.1-dev"
finetune_type: "lora"
target_components: ["transformer"]
train:
trainer_type: "my_algo"
my_specific_param: 0.1
learning_rate: 1e-6
group_size: 4
num_inference_steps: 28
scheduler:
dynamics_type: "ODE"
data:
dataset: "path/to/dataset"
rewards:
reward_model: "PickScore"
batch_size: 16
Phase 5: Verification
Common Pitfalls
- Not subclassing
TrainingArguments — algorithm-specific params won't be parsed from YAML
- Forgetting
_registry.py + __init__.py updates — falls back to base TrainingArguments, losing custom params
- Using ODE with coupled paradigm — no log-probabilities available, silent incorrect gradients
- Not calling
self.should_continue_training() — infinite loop if max_epochs is set
- Duplicating
_initialization() logic — already called in BaseTrainer.__init__; don't re-prepare modules
- Reimplementing advantage gather/scatter — use
self.advantage_processor.compute_advantages() instead; it handles both sampler topologies automatically
- Extending
GRPOTrainer unnecessarily — unless your algorithm extends GRPO's PPO-clipped loss, extend BaseTrainer directly (as NFT and AWM do)
- Optimizer-time CFG without
get_preprocess_guidance_scale() — if your algorithm calls adapter.forward(guidance_scale=X) where X > 1.0 but training_args.guidance_scale ≤ 1.0, negative prompts won't be encoded at preprocessing time and CFG silently falls back to no-CFG. Override get_preprocess_guidance_scale() in your TrainingArguments subclass to return max(guidance_scale, your_optimize_cfg). See DGPO's kl_cfg for a real example.