mit einem Klick
add-reward
// Guide for adding a new reward function to AReaL. Use when user wants to create a reward function.
// Guide for adding a new reward function to AReaL. Use when user wants to create a reward function.
AReaL commit message conventions. MUST load on every git commit -- provides Conventional Commits format with scope inference from file paths.
Upgrade focused runtime dependencies in AReaL. First validates and updates per-package API checklists for structural completeness, then updates pyproject files, resolves conflicts, locks, updates the Dockerfile, and audits API compatibility against the checklists.
AReaL commit message conventions. MUST load on every git commit -- provides Conventional Commits format with scope inference from file paths.
AReaL commit message conventions. MUST load on every git commit -- provides Conventional Commits format with scope inference from file paths.
Read-only pull request review workflow with risk analysis, targeted checklists, and Codex subagent consultation.
Guide for adding a new model to the Archon engine. Use when user wants to add support for a new HuggingFace model architecture in ArchonEngine.
| name | add-reward |
| description | Guide for adding a new reward function to AReaL. Use when user wants to create a reward function. |
Add a new reward function to AReaL.
This skill is triggered when:
Create areal/reward/<name>.py:
from typing import Any
from areal.utils import logging
logger = logging.getLogger("MyReward")
def <name>_reward_fn(
prompt: str,
completions: str,
prompt_ids,
completion_ids,
answer: str | None = None,
**kwargs: Any,
) -> float:
"""Compute reward for a single completion.
Args:
prompt: Prompt string
completions: Completion string (model output)
prompt_ids: Tokenized prompt IDs
completion_ids: Tokenized completion IDs
answer: Ground truth answer from dataset (optional)
**kwargs: Additional data from dataset
Returns:
Reward value (float), typically 0.0 or 1.0
"""
try:
# Extract answer from completion
extracted = _extract_answer(completions)
# Compare with ground truth
if answer is not None and extracted == str(answer):
return 1.0
return 0.0
except Exception:
logger.warning("Exception in reward computation", exc_info=True)
return 0.0
def _extract_answer(completion: str) -> str:
"""Extract the answer from a completion string.
Implement your extraction logic here.
"""
# Example: Extract content from \boxed{}
import re
match = re.search(r"\\boxed\{([^}]+)\}", completion)
if match:
return match.group(1).strip()
return completion.strip()
Update areal/reward/__init__.py:
# Add to VALID_REWARD_FN
VALID_REWARD_FN = [
# ... existing reward functions
"<name>",
]
# Add to get_reward_fn function
def get_reward_fn(name: str, **kwargs):
# ... existing code
elif name == "<name>":
from areal.reward.<name> import <name>_reward_fn
return <name>_reward_fn
If your reward function uses blocking operations (e.g., API calls, model inference), the
workflow will wrap it with AsyncRewardWrapper:
# In your workflow
from areal.reward import AsyncRewardWrapper
self.reward_fn = AsyncRewardWrapper(reward_fn)
# Then call it asynchronously
rewards = await self.reward_fn(prompt, completions, **data)
Create tests/test_<name>_reward.py:
import pytest
from areal.reward.<name> import <name>_reward_fn
def test_reward_correct_answer():
reward = <name>_reward_fn(
prompt="What is 2+2?",
completions="The answer is \\boxed{4}",
prompt_ids=None,
completion_ids=None,
answer="4",
)
assert reward == 1.0
def test_reward_wrong_answer():
reward = <name>_reward_fn(
prompt="What is 2+2?",
completions="The answer is \\boxed{5}",
prompt_ids=None,
completion_ids=None,
answer="4",
)
assert reward == 0.0
| Reward | File | Description |
|---|---|---|
| GSM8K | areal/reward/gsm8k.py | Math answer verification |
| Geometry3K | areal/reward/geometry3k.py | Geometry answer verification |
| CLEVR | areal/reward/clevr_count_70k.py | Counting verification |
| MathVerify | areal/reward/math_verify.py | General math verification |
All reward functions must follow this signature:
def reward_fn(
prompt: str, # Input prompt string
completions: str, # Model completion string
prompt_ids, # Tokenized prompt
completion_ids, # Tokenized completion
**kwargs: Any, # Additional data from dataset (e.g., answer)
) -> float: # Reward value (typically 0.0 or 1.0)
Note: The reward function is called once per sample. Batching is handled by
AsyncRewardWrapper in the workflow.
AsyncRewardWrapper if neededareal.utils.logging, not printAsyncRewardWrapper