en un clic
parity-testing
// Verify numerical parity between NeMo AutoModel implementations and reference HuggingFace models, including state dict and forward-pass checks.
// Verify numerical parity between NeMo AutoModel implementations and reference HuggingFace models, including state dict and forward-pass checks.
Maintain the NeMo AutoModel Fern docs site under fern/ — add, update, move, or remove pages; manage redirects, slugs, navigation, and version aliases; run validation and previews.
Guide for onboarding new model families into NeMo AutoModel, including architecture discovery, implementation patterns, registration, and validation.
Dev environment setup for NeMo AutoModel — container-based development, uv package management, installation options, environment variables, and common build pitfalls.
CI/CD reference for NeMo AutoModel — pipeline structure, commit and PR workflow, CI failure investigation, and common failure patterns.
Guide for selecting and configuring distributed training strategies in NeMo AutoModel, including FSDP2, Megatron FSDP, DDP, and parallelism settings.
Configure NeMo AutoModel job launches for interactive runs, Slurm clusters, and SkyPilot cloud execution.
| name | parity-testing |
| description | Verify numerical parity between NeMo AutoModel implementations and reference HuggingFace models, including state dict and forward-pass checks. |
| when_to_use | Verifying numerical correctness of a new or modified model against its HuggingFace reference, debugging loss divergence or output mismatches, or validating state dict mappings. |
NeMo AutoModel adds custom model implementations (combined projections, backend switching, kernel patches) on top of HuggingFace transformers. Parity testing verifies that NeMo AutoModel's implementation produces numerically equivalent results to the reference HF implementation.
Key differences that can cause divergence:
from transformers import AutoModelForCausalLM
from nemo.collections.llm import NeMoAutoModelForCausalLM
The HF model is the reference. The NeMo AutoModel is the implementation under test.
# NeMo way
nemo_model = NeMoAutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-1B")
# HF way
hf_model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-1B")
Use seeded random tensors to guarantee reproducibility across runs.
import torch
torch.manual_seed(42)
input_ids = torch.randint(0, 32000, (1, 128))
attention_mask = torch.ones_like(input_ids)
GPU introduces non-determinism from parallel reductions and kernel launch order. Always start parity testing on CPU with float32 to isolate numerical differences caused by model implementation from those caused by hardware.
device = torch.device("cpu")
dtype = torch.float32
hf_model = hf_model.to(device=device, dtype=dtype).eval()
nemo_model = nemo_model.to(device=device, dtype=dtype).eval()
This is the fastest and most fundamental check. If the state dict adapter cannot perfectly round-trip weights, nothing else will work.
hf_model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-3.2-1B", torch_dtype=torch.float32, device_map="cpu"
)
nemo_model = NeMoAutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-3.2-1B", torch_dtype=torch.float32, device_map="cpu"
)
adapter = nemo_model.state_dict_adapter
hf_sd = hf_model.state_dict()
# Convert HF -> custom format -> back to HF format
custom_sd = adapter.from_hf(hf_sd)
roundtrip_sd = adapter.to_hf(custom_sd)
# Check no missing or extra keys
assert set(roundtrip_sd.keys()) == set(hf_sd.keys()), (
f"Key mismatch.\n"
f" Missing from roundtrip: {set(hf_sd.keys()) - set(roundtrip_sd.keys())}\n"
f" Extra in roundtrip: {set(roundtrip_sd.keys()) - set(hf_sd.keys())}"
)
# Check all values are exactly equal (max_diff must be 0.0)
for key in hf_sd:
max_diff = (hf_sd[key] - roundtrip_sd[key]).abs().max().item()
assert max_diff == 0.0, f"Round-trip mismatch for {key}: max_diff={max_diff}"
print("Level 1 PASSED: state dict round-trip is exact.")
What to check:
lm_head.weight aliasing model.embed_tokens.weight) are handled correctly.Test individual components in isolation to localize any divergence.
Components to test: attention, MLP, layer norm / RMSNorm, RoPE, full decoder layer.
# Create a tiny config to keep tests fast
from transformers import AutoConfig
config = AutoConfig.from_pretrained("meta-llama/Llama-3.2-1B")
config.num_hidden_layers = 2
config.hidden_size = 256
config.intermediate_size = 512
config.num_attention_heads = 4
config.num_key_value_heads = 2
hf_model = AutoModelForCausalLM.from_config(config).to(dtype=torch.float32, device="cpu").eval()
nemo_model = NeMoAutoModelForCausalLM.from_config(config).to(dtype=torch.float32, device="cpu").eval()
# Load the same weights into both
hf_sd = hf_model.state_dict()
adapter = nemo_model.state_dict_adapter
nemo_model.load_state_dict(adapter.from_hf(hf_sd))
Forward pass with identical seeded inputs:
torch.manual_seed(42)
input_ids = torch.randint(0, config.vocab_size, (1, 64))
attention_mask = torch.ones_like(input_ids)
with torch.no_grad():
hf_out = hf_model(input_ids, attention_mask=attention_mask)
nemo_out = nemo_model(input_ids, attention_mask=attention_mask)
max_diff, mean_diff, cos_sim = compare_tensors(
hf_out.logits, nemo_out.logits, name="component_parity_logits"
)
assert max_diff < 1e-5, f"Component parity FAILED: max_diff={max_diff}"
print("Level 2 PASSED: component parity within tolerance.")
Strict tolerance: max_diff < 1e-5 for float32 on CPU. This is tight enough to catch weight loading bugs while allowing for minor floating-point operation reordering.
Testing individual components (attention example):
# Extract matching layers
hf_attn = hf_model.model.layers[0].self_attn
nemo_attn = nemo_model.model.layers[0].self_attn
torch.manual_seed(42)
hidden_states = torch.randn(1, 64, config.hidden_size)
position_ids = torch.arange(64).unsqueeze(0)
with torch.no_grad():
hf_attn_out = hf_attn(hidden_states, position_ids=position_ids)
nemo_attn_out = nemo_attn(hidden_states, position_ids=position_ids)
compare_tensors(hf_attn_out[0], nemo_attn_out[0], name="attention_output")
Repeat for MLP, norm, and full decoder layer.
Full model forward pass on GPU with bfloat16, reflecting realistic deployment conditions.
device = torch.device("cuda")
dtype = torch.bfloat16
hf_model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-3.2-1B", torch_dtype=dtype, device_map=device
).eval()
nemo_model = NeMoAutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-3.2-1B", torch_dtype=dtype, device_map=device
).eval()
torch.manual_seed(42)
input_ids = torch.randint(0, 32000, (1, 128), device=device)
attention_mask = torch.ones_like(input_ids)
with torch.no_grad():
hf_out = hf_model(input_ids, attention_mask=attention_mask)
nemo_out = nemo_model(input_ids, attention_mask=attention_mask)
# Verify shapes match
assert hf_out.logits.shape == nemo_out.logits.shape, (
f"Shape mismatch: {hf_out.logits.shape} vs {nemo_out.logits.shape}"
)
max_diff, mean_diff, cos_sim = compare_tensors(
hf_out.logits, nemo_out.logits, name="e2e_logits"
)
# Looser tolerance for bfloat16
assert max_diff < 1e-2, f"E2E parity FAILED: max_diff={max_diff}"
assert cos_sim > 0.9999, f"E2E parity FAILED: cosine_sim={cos_sim}"
print("Level 3 PASSED: E2E forward pass parity within tolerance.")
Tolerances for bfloat16: max_diff < 1e-2, cosine_similarity > 0.9999. bfloat16 has limited mantissa bits, so per-element differences accumulate across layers.
def compare_tensors(a, b, name=""):
"""Compare two tensors and report multiple similarity metrics.
Args:
a: Reference tensor (from HF model).
b: Test tensor (from NeMo AutoModel).
name: Label for the comparison (printed in output).
Returns:
Tuple of (max_diff, mean_diff, cosine_similarity).
"""
max_diff = (a - b).abs().max().item()
mean_diff = (a - b).abs().mean().item()
cos_sim = torch.nn.functional.cosine_similarity(
a.flatten().float(), b.flatten().float(), dim=0
).item()
print(
f"{name}: max_diff={max_diff:.6e}, mean_diff={mean_diff:.6e}, "
f"cosine_sim={cos_sim:.8f}"
)
return max_diff, mean_diff, cos_sim
def compare_state_dicts(sd_a, sd_b, prefix=""):
"""Compare two state dicts key-by-key, reporting per-parameter differences."""
keys_a = set(sd_a.keys())
keys_b = set(sd_b.keys())
missing = keys_a - keys_b
extra = keys_b - keys_a
if missing:
print(f"{prefix}Missing keys: {missing}")
if extra:
print(f"{prefix}Extra keys: {extra}")
shared = keys_a & keys_b
max_diffs = {}
for key in sorted(shared):
diff = (sd_a[key].float() - sd_b[key].float()).abs().max().item()
if diff > 0:
max_diffs[key] = diff
print(f"{prefix}{key}: max_diff={diff:.6e}")
if not max_diffs and not missing and not extra:
print(f"{prefix}All {len(shared)} parameters match exactly.")
return missing, extra, max_diffs
Follow this procedure when a parity test fails.
Run Level 2 component tests. Determine which component (attention, MLP, norm, RoPE, decoder layer) introduces the divergence.
Verify the state dict adapter round-trip (Level 1). If round-trip is not exact, the bug is in the adapter's from_hf() or to_hf() method.
# Quick check: load NeMo model, export its weights back to HF format, compare
nemo_sd = nemo_model.state_dict()
exported_hf_sd = adapter.to_hf(nemo_sd)
compare_state_dicts(hf_sd, exported_hf_sd, prefix="weight_check: ")
Different backends (TE vs SDPA vs flex attention) can produce different results even with identical weights. Force the baseline backend for parity testing:
from nemo.collections.llm import BackendConfig
# Force SDPA attention and torch linear to match HF behavior
nemo_model = NeMoAutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-3.2-1B",
backend_config=BackendConfig(attn="sdpa", linear="torch"),
)
Replace one NeMo component's output with the HF component's output and check if downstream computation matches. This isolates exactly which component introduces divergence.
# Example: inject HF attention output into NeMo decoder layer
with torch.no_grad():
# Get HF attention output
hf_attn_out = hf_model.model.layers[0].self_attn(hidden_states, position_ids=position_ids)
# Manually run NeMo decoder layer but substitute HF attention output
nemo_layer = nemo_model.model.layers[0]
residual = hidden_states
normed = nemo_layer.input_layernorm(hidden_states)
# Use HF attention output instead of NeMo attention output
attn_out = hf_attn_out[0]
hidden_states_after_attn = residual + attn_out
# Continue with NeMo MLP
residual = hidden_states_after_attn
normed = nemo_layer.post_attention_layernorm(hidden_states_after_attn)
mlp_out = nemo_layer.mlp(normed)
final = residual + mlp_out
# If final matches HF decoder layer output, the bug is in NeMo attention.
# If final does NOT match, the bug is in NeMo MLP or norm.
After forward pass parity is confirmed, verify gradients:
hf_model.train()
nemo_model.train()
hf_out = hf_model(input_ids, attention_mask=attention_mask, labels=input_ids)
nemo_out = nemo_model(input_ids, attention_mask=attention_mask, labels=input_ids)
hf_out.loss.backward()
nemo_out.loss.backward()
# Compare gradients for corresponding parameters
for (hf_name, hf_param), (nemo_name, nemo_param) in zip(
hf_model.named_parameters(), nemo_model.named_parameters()
):
if hf_param.grad is not None and nemo_param.grad is not None:
compare_tensors(hf_param.grad, nemo_param.grad, name=f"grad_{hf_name}")
from_pretrained may break after save_pretrained + from_pretrained if the state dict adapter has asymmetries.tie_word_embeddings=True, confirm that lm_head.weight and embed_tokens.weight share the same tensor after loading.These are the key source files relevant to parity testing:
| Component | Path |
|---|---|
| State dict adapter base | components/models/common/combined_projection/state_dict_adapter.py |
| Model registry | _transformers/registry.py |
| AutoModel entry point | _transformers/auto_model.py |
| Kernel patches | _transformers/kernel_patches.py |
| Model capabilities | _transformers/capabilities.py |