원클릭으로
add-archon-model
// Guide for adding a new model to the Archon engine. Use when user wants to add support for a new HuggingFace model architecture in ArchonEngine.
// Guide for adding a new model to the Archon engine. Use when user wants to add support for a new HuggingFace model architecture in ArchonEngine.
AReaL commit message conventions. MUST load on every git commit -- provides Conventional Commits format with scope inference from file paths.
Upgrade focused runtime dependencies in AReaL. First validates and updates per-package API checklists for structural completeness, then updates pyproject files, resolves conflicts, locks, updates the Dockerfile, and audits API compatibility against the checklists.
AReaL commit message conventions. MUST load on every git commit -- provides Conventional Commits format with scope inference from file paths.
AReaL commit message conventions. MUST load on every git commit -- provides Conventional Commits format with scope inference from file paths.
Read-only pull request review workflow with risk analysis, targeted checklists, and Codex subagent consultation.
Guide for adding a new dataset loader to AReaL. Use when user wants to add a new dataset.
| name | add-archon-model |
| description | Guide for adding a new model to the Archon engine. Use when user wants to add support for a new HuggingFace model architecture in ArchonEngine. |
Add support for a new HuggingFace model architecture in the Archon training engine.
This skill is triggered when:
ModelSpec or model type for ArchonBefore starting, ensure:
config.json with model_type)meta-llama/Llama-3-8B)Read the HuggingFace model's source code to extract key architecture information.
Action: Fetch and analyze the model's HuggingFace configuration and modeling files.
Read the model's config.json (via AutoConfig.from_pretrained) to identify:
model_type string (this is the key used for registry lookup)qk_norm, attention_bias, MoE fields)Read the HuggingFace modeling_*.py source to identify:
tie_word_embeddings appear in config?Summarize findings in a checklist like:
Target model: <name>
HF model_type: "<model_type>" (and variants like "<model_type>_moe" if applicable)
Attention: [standard GQA / with QK norm / with bias / sliding window / ...]
FFN: [SwiGLU / GeGLU / standard MLP / ...]
MoE: [no / yes - num_experts, top_k, shared_experts]
RoPE: [standard / YaRN / NTK-aware / ...]
Norm: [RMSNorm / LayerNorm] with [pre-norm / post-norm]
Weight tying: [yes / no]
Choose the closest existing implementation as a starting point:
| Target characteristics | Reference | Why |
|---|---|---|
| Dense-only, standard GQA, no QK norm | qwen2 | Simplest baseline, pure dense |
| Has QK norm, or has MoE support | qwen3 | Supports QK norm + MoE + shared experts |
Action: Copy the reference model directory as the starting point:
areal/experimental/models/archon/<model>/
__init__.py
spec.py
model/
args.py
model.py
rope.py
state_dict_adapter.py
infra/
parallelize.py
args.pyAdapt <Model>ModelArgs to match the target model's HuggingFace config fields.
Key changes from reference:
Update the @dataclass fields to match the target model's hyperparameters:
dim, n_layers, n_heads,
n_kv_heads, vocab_size, head_dim, hidden_dim, norm_eps, rope_theta,
etc.)attention_bias, qk_norm, sliding_window)Update from_hf_config() to correctly map HuggingFace config attributes:
getattr(hf_config, "field_name", default) for optional fieldsCritical: Verify every field mapping against the HF model's config.json. Incorrect
mappings here cause silent errors downstream.
Base class contract (BaseModelArgs):
@dataclass
class <Model>ModelArgs(BaseModelArgs):
# ... model-specific fields ...
@classmethod
def from_hf_config(
cls,
hf_config: PretrainedConfig,
is_critic: bool = False,
**kwargs,
) -> <Model>ModelArgs:
# Map HF config fields to Archon model args
...
model.pyAdapt the model architecture to match the target model.
Key components to adapt:
Normalization (RMSNorm or similar):
elementwise_affine is configurableLayerNorm, implement accordinglyAttention module:
nn.Linear(..., bias=True/False))q_norm/k_norm if the model has them, remove if it doesn'tn_kv_heads < n_heads for grouped-query attentionset_cp_group / _sp_enabled pattern from the referenceFeedForward module:
w2(silu(w1(x)) * w3(x)) -- most common for modern LLMsMoE module replaces FeedForward on designated layersTransformerBlock: Pre-norm (most modern LLMs) vs post-norm
_is_moe_layer() if applicableTop-level Model (<Model>Model(BaseArchonModel)):
tok_embeddings, layers (as ModuleDict), norm, output/scoreinit_weights(): Match initialization scheme from HFinit_buffers(): RoPE cache + MoE buffersforward(): Must follow BaseArchonModel signature:
(tokens, positions, cu_seqlens, max_seqlen, tree_attn_meta=None) -> TensorBase class contract (BaseArchonModel):
class <Model>Model(BaseArchonModel):
def forward(self, tokens, positions, cu_seqlens, max_seqlen, tree_attn_meta=None) -> torch.Tensor: ...
def init_weights(self) -> None: ...
def init_buffers(self, buffer_device) -> None: ...
rope.pyHandle the rotary position embedding variant.
Options:
Standard RoPE (same as qwen2/qwen3): Re-export from qwen2:
from areal.experimental.models.archon.qwen2.model.rope import (
apply_rotary_emb,
precompute_rope_cache,
repeat_kv,
reshape_for_broadcast,
rotate_half,
)
Custom RoPE (YaRN, NTK-aware, etc.): Implement custom precompute_rope_cache()
and apply_rotary_emb() functions. The key difference is usually in how inv_freq
is computed (scaling factors, interpolation, etc.).
state_dict_adapter.pyMap between HuggingFace and Archon weight key names.
This is the most error-prone step. The adapter must correctly handle:
Key name mapping (from_hf_map dict):
model.embed_tokens.weight -> tok_embeddings.weightmodel.layers.{}.self_attn.q_proj.weight ->
layers.{}.attention.wq.weightmodel.layers.{}.mlp.gate_proj.weight -> layers.{}.feed_forward.w1.weightmodel.layers.{}.input_layernorm.weight ->
layers.{}.attention_norm.weightlm_head.weight -> output.weightNone): rotary_emb.inv_freq (computed at runtime)Reverse mapping (to_hf_map): Auto-generated from from_hf_map
MoE expert weights (if applicable): 3D<->2D conversion for expert weights. Copy the MoE handling from qwen3 if the model has MoE.
Weight tying: Skip output.weight during to_hf() if tie_word_embeddings=True
Verification approach: After implementation, the adapter should satisfy:
# Roundtrip: archon -> hf -> archon preserves all keys
hf_sd = adapter.to_hf(archon_sd)
roundtrip_sd = adapter.from_hf(hf_sd)
assert set(roundtrip_sd.keys()) == set(archon_sd.keys())
Base class contract (BaseStateDictAdapter):
class <Model>StateDictAdapter(BaseStateDictAdapter):
def from_hf(self, hf_state_dict) -> dict[str, Any]: ...
def to_hf(self, archon_state_dict) -> dict[str, Any]: ...
def convert_single_to_hf(self, name, tensor) -> list[tuple[str, torch.Tensor]]: ...
parallelize.pyDefine the parallelization strategy for the model.
The parallelize function applies parallelism in this order:
Key adaptations by model architecture:
use_local_output=False (DTensor output for
norm), add SequenceParallel(sequence_dim=2) for q_norm/k_normuse_local_output=Trueapply_moe_ep_tp() and apply_non_moe_tp()Function signature (must match ParallelizeFn protocol):
def parallelize_<model>(
model: nn.Module,
parallel_dims: ArchonParallelDims,
param_dtype: torch.dtype = torch.bfloat16,
reduce_dtype: torch.dtype = torch.float32,
loss_parallel: bool = True,
cpu_offload: bool = False,
reshard_after_forward_policy: str = "default",
ac_config: ActivationCheckpointConfig | None = None,
enable_compile: bool = True,
) -> nn.Module:
spec.py and RegisterAssemble the ModelSpec and register it.
from areal.experimental.models.archon.model_spec import ModelSpec, register_model_spec
from areal.experimental.models.archon.pipeline_parallel import pipeline_llm
from areal.experimental.models.archon.<model>.infra.parallelize import parallelize_<model>
from areal.experimental.models.archon.<model>.model.args import <Model>ModelArgs
from areal.experimental.models.archon.<model>.model.model import <Model>Model
from areal.experimental.models.archon.<model>.model.state_dict_adapter import (
<Model>StateDictAdapter,
)
<MODEL>_SPEC = ModelSpec(
name="<Model>",
model_class=<Model>Model,
model_args_class=<Model>ModelArgs,
state_dict_adapter_class=<Model>StateDictAdapter,
parallelize_fn=parallelize_<model>,
supported_model_types=frozenset({"<model_type>"}), # From HF config.json
pipelining_fn=pipeline_llm,
)
# Auto-register when module is imported
register_model_spec(<MODEL>_SPEC)
__all__ = ["<MODEL>_SPEC"]
Note: supported_model_types should include all HF model_type strings that this
implementation handles (e.g., {"qwen3", "qwen3_moe"} for Qwen3).
__init__.pyAdd the import to areal/experimental/models/archon/__init__.py:
from areal.experimental.models.archon.<model> import spec as <model>_spec # noqa: F401
This triggers auto-registration when the module is imported.
Verification should be done in stages, adapting based on available hardware and the test
patterns in tests/experimental/archon/.
Before writing tests, examine the existing test files to understand current patterns:
tests/experimental/archon/
conftest.py -- Pytest configuration (version checks)
utils.py -- Shared utilities (model loading, comparison)
test_qwen3_args.py -- Args unit tests (CPU-only)
test_state_dict_adapter.py -- State dict roundtrip tests
test_weight_sync.py -- Weight completeness tests (meta device)
test_forward.py -- Forward precision comparison (single GPU)
...
Test stages (write tests appropriate for the model's complexity):
Test from_hf_config() with mock HuggingFace configs:
# Pattern: Create mock PretrainedConfig, verify args mapping
from unittest.mock import MagicMock
def test_args_from_hf_config():
hf_config = MagicMock()
hf_config.hidden_size = 4096
hf_config.num_hidden_layers = 32
# ... set all required fields
args = <Model>ModelArgs.from_hf_config(hf_config)
assert args.dim == 4096
assert args.n_layers == 32
Test key mapping roundtrip:
def test_state_dict_roundtrip():
# Create adapter with mock config
adapter = <Model>StateDictAdapter(mock_config)
# Create fake archon state dict with expected keys
archon_sd = {"tok_embeddings.weight": torch.randn(vocab, dim), ...}
# Roundtrip
hf_sd = adapter.to_hf(archon_sd)
roundtrip = adapter.from_hf(hf_sd)
assert set(roundtrip.keys()) == set(archon_sd.keys())
Verify all model parameters have HF mappings:
def test_weight_completeness():
# Create model on meta device
with torch.device("meta"):
model = <Model>Model(args)
adapter = <Model>StateDictAdapter(hf_config)
# Check every archon param has a HF mapping
for name, _ in model.named_parameters():
hf_pairs = adapter.convert_single_to_hf(name, torch.empty(0))
assert len(hf_pairs) > 0, f"No HF mapping for {name}"
Compare Archon model output against HuggingFace reference:
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Requires CUDA")
def test_forward_matches_hf():
# Load both HF and Archon models
# Run forward on same input
# Compare logits within tolerance
Important: Do NOT hardcode the test categories. Inspect the existing test files in
tests/experimental/archon/ and follow the same patterns, fixtures, and markers. Adapt
test scope to the model's specific features (e.g., add MoE-specific tests only if the
model has MoE).
| Model | Directory | Features |
|---|---|---|
| Qwen2 | areal/experimental/models/archon/qwen2/ | Dense, attention bias, no QK norm |
| Qwen3 | areal/experimental/models/archon/qwen3/ | Dense + MoE, QK norm, no attention bias, shared experts |
| Feature | qwen2 | qwen3 | What to check in target model |
|---|---|---|---|
| Attention bias | Yes | No | attention_bias in HF config |
| QK norm | No | Yes | qk_norm in HF config or QKNorm module in modeling file |
| MoE | No | Yes | num_experts/num_local_experts in HF config |
| Shared experts | No | Yes | num_shared_experts in HF config |
| Decoder sparse step | No | Yes | decoder_sparse_step in HF config |
| Weight tying | Both | Both | tie_word_embeddings in HF config |
| RoPE | Standard | Standard (re-export qwen2) | Check inv_freq formula in HF modeling code |
state_dict_adapter.py (causes silent weight drops)from_hf_config() field mapping (uses wrong HF config attribute name)None keys in from_hf_map (keys to skip like
rotary_emb.inv_freq)use_local_output must match)areal/experimental/models/archon/__init__.pymodel_type variants in supported_model_types frozensetprint instead of areal.utils.logging.getLogger()After completion, verify all files exist and are consistent:
areal/experimental/models/archon/<model>/__init__.pyareal/experimental/models/archon/<model>/spec.py -- ModelSpec + registerareal/experimental/models/archon/<model>/model/args.py -- ModelArgs +
from_hf_configareal/experimental/models/archon/<model>/model/model.py -- Model + Attention +
FFNareal/experimental/models/archon/<model>/model/rope.py -- RoPE (or re-export)areal/experimental/models/archon/<model>/model/state_dict_adapter.py -- Key
mappingareal/experimental/models/archon/<model>/infra/parallelize.py -- Parallel
strategyareal/experimental/models/archon/__init__.py -- Import line addedtests/experimental/archon/test_<model>_*.py -- Tests