بنقرة واحدة
add-pallas-kernel
Add, modify, or autotune a TPU/GPU Pallas kernel.
التثبيت باستخدام Codex أو Claude انسخ هذا Prompt والصقه في Codex أو Claude أو مساعد آخر ليراجع صفحة Skill ويثبّتها لك.
القائمة
Add, modify, or autotune a TPU/GPU Pallas kernel.
التثبيت باستخدام Codex أو Claude انسخ هذا Prompt والصقه في Codex أو Claude أو مساعد آخر ليراجع صفحة Skill ويثبّتها لك.
استنادا إلى تصنيف SOC المهني
Lint, run the pre-PR checks, commit, push, and author or update the branch's pull request in the required plain-text format. Use when committing, pushing, or creating/updating a PR.
Modify or upstream a Grug/Grugformer experiment variant.
Run a perf gate on a PR that touches lib/zephyr internals.
Curate the experiment report index at docs/reports/index.md.
Triage a failed canary ferry run (CI-invoked).
Refresh Marin TPU-vLLM forks from a tpu-inference release/LKG pair, update exact SHA pins, run TPU smokes, and open the Marin PR.
| name | add-pallas-kernel |
| description | Add, modify, or autotune a TPU/GPU Pallas kernel. |
This is a specialization of .agents/skills/run-research/SKILL.md.
Use run-research for the generic research lifecycle (branching, issue/logbook cadence, snapshot/tag discipline, reporting). This skill adds kernel-specific standards for numerics and gradient safety, backend/fallback API design, TPU/GPU performance diagnosis, and block-size autotuning.
.agents/skills/run-research/SKILL.md first.run-research; keep this file focused on kernel-specific constraints.For a kernel K, produce:
Use the research logbook and issue workflow from run-research for experiment history and milestone updates.
Tokamax-style decomposition is preferred for maintainability:
reference.py: readable vanilla JAX oracle.xla.py: default implementation (often same math as reference).pallas_tpu.py: TPU Pallas implementation.pallas_gpu.py: optional GPU Pallas implementation.api.py: stable user-facing entrypoint with implementation= override and fallback order.Reference template: lib/levanter/src/levanter/kernels/pallas/template_kernel.py
Prefer one true batched kernel:
Expose tile choices via a dataclass with explicit defaults:
@dataclass(frozen=True, slots=True)
class BlockSizes:
b_block_size: int = 1024
h_block_size: int = 512
v_block_size: int = 2048
@classmethod
def get_default(cls) -> "BlockSizes":
return cls()
Rules:
block_size arg exists, map it clearly to the new config and raise on conflicting inputs.implementation="pallas_tpu"), fail fast on unsupported backend/shape.api.py.Prefer a canonical kernel input shape and make callers normalize to it:
Use an existing in-repo implementation, pseudocode, a PyTorch reference, or an Optax/JAX baseline. The baseline must be obvious and stable, not clever. If the naive baseline would materialize huge intermediates, use a streaming/blockwise baseline with identical math.
Minimum checks: value parity over a shape/dtype grid, gradient parity on small shapes, backend numerics on CPU and accelerator backends as applicable. Report pointwise deviation metrics (max/mean absolute diff), not only allclose. Use explicit shape/dtype annotations for public APIs and references (e.g. jaxtyping) where available.
For in-tree kernels, add/extend tests under lib/levanter/tests/kernels/. Compare the default implementation against the reference on small CPU shapes and accelerator-aligned shapes for fast paths.
Add cost_estimate= to each pl.pallas_call:
pl.estimate_cost on a body-equivalent JAX function (not a kernel body with pl.program_id).from levanter.kernels.pallas.cost_estimate_utils import with_io_bytes_accessed
def _cost_estimate(
q: jax.Array,
k: jax.Array,
v: jax.Array,
*,
kernel_inputs_specs,
kernel_outputs_specs,
) -> pl.CostEstimate | None:
body_cost = pl.estimate_cost(reference_impl, q, k, v)
return with_io_bytes_accessed(
body_cost,
kernel_inputs_specs=kernel_inputs_specs,
kernel_outputs_specs=kernel_outputs_specs,
)
Use the execution environment guidance and cadence from run-research; this section adds kernel-specific constraints. For kernel-specific profiling capture/compare guidance, see docs/reference/profiling.md.
Key iteration loop: profile -> hypothesis -> change -> tests -> microbench -> profile
Always report: compile-including timing (time-to-first-step), steady-state timing, and exact hardware type and shape/dtype grid.
Keep tuning explicit and reviewable.
(bucket, config) pair and capture timing + failures.(tpu_type, dtype, shape_bucket[, invariants]).infer_block_sizes(...) helper, and default fallback to BlockSizes.get_default().Do not key tuned tables by every exact shape; keep buckets stable and reviewable.
Support three levels of fallback, similar to the fused softmax cross-entropy kernel:
(device, dtype, shape bucket), validate/sanitize for backend constraints, fall back to default/safe entries when no exact tuned match exists.If Mosaic reports errors like Expected matmul acc to be 32-bit:
preferred_element_type=jnp.float32 in lax.dot_general for the kernel path, orjax_default_matmul_precision=highest in benchmark scripts.Prefer explicit kernel-side preferred_element_type for deterministic behavior.
Set LIBTPU_INIT_ARGS by TPU generation during microbench/tuning:
v5p / v5e: --xla_tpu_scoped_vmem_limit_kib=50000v6e: --xla_tpu_scoped_vmem_limit_kib=98304v4: no special scoped-VMEM overrideCapture compiler diagnostics on serious benchmark/tuning runs: HLO dumps via --xla-dump-dir, compiler logs via --compiler-log-path, and explicit XLA_FLAGS and LIBTPU_INIT_ARGS recorded with results.
Useful scripts:
lib/levanter/scripts/bench/bench_fused_cross_entropy_loss_pallas.pylib/levanter/scripts/tune/tune_fused_cross_entropy_loss_block_sizes.pyWhen performance is unclear, run dump-first comparisons on one fixed shape: XLA/reference path, full Pallas path, decomposition variant(s) (temporary toggles). Use separate dump dirs per variant (hlo_*, llo_*, mosaic_*) and compare throughput, fusion/custom-call placement, schedule bundle counts, and pressure signals (heavy vrot/vsel, spills, vreg pressure).
Prefer structural fixes before broad tile sweeps when decomposition variants indicate stage-structure issues. For the full LLO workflow (flags, artifact layout, comparison checklist, replication loop), see docs/reference/llo.md.
run-research workflow.lib/levanter/src/levanter/kernels/pallas/template_kernel.pylib/levanter/tests/kernels/test_pallas_template_kernel.pylib/levanter/src/levanter/kernels/pallas/fused_cross_entropy_lossTokamax kernels are useful references for API and kernel structure comparisons.
.venv/lib/python3.11/site-packages/tokamax/_src/opsabsl.flags before accessing Tokamax modules that depend on flags.docs/reference/llo.mddocs/reference/profiling.md