con un clic
add-pallas-kernel
Add, modify, or autotune a TPU/GPU Pallas kernel.
Instalar con Codex o Claude Copia este prompt, pégalo en Codex, Claude u otro asistente, y deja que revise la página de la skill y la instale por ti.
Menú
Add, modify, or autotune a TPU/GPU Pallas kernel.
Instalar con Codex o Claude Copia este prompt, pégalo en Codex, Claude u otro asistente, y deja que revise la página de la skill y la instale por ti.
Basado en la clasificación ocupacional SOC
Lint, run the pre-PR checks, commit, push, and author or update the branch's pull request in the required plain-text format. Use when committing, pushing, or creating/updating a PR.
Modify or upstream a Grug/Grugformer experiment variant.
Run a perf gate on a PR that touches lib/zephyr internals.
Curate the experiment report index at docs/reports/index.md.
Triage a failed canary ferry run (CI-invoked).
Refresh Marin TPU-vLLM forks from a tpu-inference release/LKG pair, update exact SHA pins, run TPU smokes, and open the Marin PR.
| name | add-pallas-kernel |
| description | Add, modify, or autotune a TPU/GPU Pallas kernel. |
This is a specialization of .agents/skills/run-research/SKILL.md.
Use run-research for the generic research lifecycle (branching, issue/logbook cadence, snapshot/tag discipline, reporting). This skill adds kernel-specific standards for numerics and gradient safety, backend/fallback API design, TPU/GPU performance diagnosis, and block-size autotuning.
.agents/skills/run-research/SKILL.md first.run-research; keep this file focused on kernel-specific constraints.For a kernel K, produce:
Use the research logbook and issue workflow from run-research for experiment history and milestone updates.
Tokamax-style decomposition is preferred for maintainability:
reference.py: readable vanilla JAX oracle.xla.py: default implementation (often same math as reference).pallas_tpu.py: TPU Pallas implementation.pallas_gpu.py: optional GPU Pallas implementation.api.py: stable user-facing entrypoint with implementation= override and fallback order.Reference template: lib/levanter/src/levanter/kernels/pallas/template_kernel.py
Prefer one true batched kernel:
Expose tile choices via a dataclass with explicit defaults:
@dataclass(frozen=True, slots=True)
class BlockSizes:
b_block_size: int = 1024
h_block_size: int = 512
v_block_size: int = 2048
@classmethod
def get_default(cls) -> "BlockSizes":
return cls()
Rules:
block_size arg exists, map it clearly to the new config and raise on conflicting inputs.implementation="pallas_tpu"), fail fast on unsupported backend/shape.api.py.Prefer a canonical kernel input shape and make callers normalize to it:
Use an existing in-repo implementation, pseudocode, a PyTorch reference, or an Optax/JAX baseline. The baseline must be obvious and stable, not clever. If the naive baseline would materialize huge intermediates, use a streaming/blockwise baseline with identical math.
Minimum checks: value parity over a shape/dtype grid, gradient parity on small shapes, backend numerics on CPU and accelerator backends as applicable. Report pointwise deviation metrics (max/mean absolute diff), not only allclose. Use explicit shape/dtype annotations for public APIs and references (e.g. jaxtyping) where available.
For in-tree kernels, add/extend tests under lib/levanter/tests/kernels/. Compare the default implementation against the reference on small CPU shapes and accelerator-aligned shapes for fast paths.
Add cost_estimate= to each pl.pallas_call:
pl.estimate_cost on a body-equivalent JAX function (not a kernel body with pl.program_id).from levanter.kernels.pallas.cost_estimate_utils import with_io_bytes_accessed
def _cost_estimate(
q: jax.Array,
k: jax.Array,
v: jax.Array,
*,
kernel_inputs_specs,
kernel_outputs_specs,
) -> pl.CostEstimate | None:
body_cost = pl.estimate_cost(reference_impl, q, k, v)
return with_io_bytes_accessed(
body_cost,
kernel_inputs_specs=kernel_inputs_specs,
kernel_outputs_specs=kernel_outputs_specs,
)
Use the execution environment guidance and cadence from run-research; this section adds kernel-specific constraints. For kernel-specific profiling capture/compare guidance, see docs/reference/profiling.md.
Key iteration loop: profile -> hypothesis -> change -> tests -> microbench -> profile
Always report: compile-including timing (time-to-first-step), steady-state timing, and exact hardware type and shape/dtype grid.
Keep tuning explicit and reviewable.
(bucket, config) pair and capture timing + failures.(tpu_type, dtype, shape_bucket[, invariants]).infer_block_sizes(...) helper, and default fallback to BlockSizes.get_default().Do not key tuned tables by every exact shape; keep buckets stable and reviewable.
Support three levels of fallback, similar to the fused softmax cross-entropy kernel:
(device, dtype, shape bucket), validate/sanitize for backend constraints, fall back to default/safe entries when no exact tuned match exists.If Mosaic reports errors like Expected matmul acc to be 32-bit:
preferred_element_type=jnp.float32 in lax.dot_general for the kernel path, orjax_default_matmul_precision=highest in benchmark scripts.Prefer explicit kernel-side preferred_element_type for deterministic behavior.
Set LIBTPU_INIT_ARGS by TPU generation during microbench/tuning:
v5p / v5e: --xla_tpu_scoped_vmem_limit_kib=50000v6e: --xla_tpu_scoped_vmem_limit_kib=98304v4: no special scoped-VMEM overrideCapture compiler diagnostics on serious benchmark/tuning runs: HLO dumps via --xla-dump-dir, compiler logs via --compiler-log-path, and explicit XLA_FLAGS and LIBTPU_INIT_ARGS recorded with results.
Useful scripts:
lib/levanter/scripts/bench/bench_fused_cross_entropy_loss_pallas.pylib/levanter/scripts/tune/tune_fused_cross_entropy_loss_block_sizes.pyWhen performance is unclear, run dump-first comparisons on one fixed shape: XLA/reference path, full Pallas path, decomposition variant(s) (temporary toggles). Use separate dump dirs per variant (hlo_*, llo_*, mosaic_*) and compare throughput, fusion/custom-call placement, schedule bundle counts, and pressure signals (heavy vrot/vsel, spills, vreg pressure).
Prefer structural fixes before broad tile sweeps when decomposition variants indicate stage-structure issues. For the full LLO workflow (flags, artifact layout, comparison checklist, replication loop), see docs/reference/llo.md.
run-research workflow.lib/levanter/src/levanter/kernels/pallas/template_kernel.pylib/levanter/tests/kernels/test_pallas_template_kernel.pylib/levanter/src/levanter/kernels/pallas/fused_cross_entropy_lossTokamax kernels are useful references for API and kernel structure comparisons.
.venv/lib/python3.11/site-packages/tokamax/_src/opsabsl.flags before accessing Tokamax modules that depend on flags.docs/reference/llo.mddocs/reference/profiling.md