ワンクリックで
profile-training
Profile JAX training and analyze hotspots. Use when profiling or optimizing training throughput.
Codex または Claude でインストール この Prompt をコピーして Codex、Claude、または他のアシスタントに貼り付けると、Skill ページを確認してインストールできます。
メニュー
Profile JAX training and analyze hotspots. Use when profiling or optimizing training throughput.
Codex または Claude でインストール この Prompt をコピーして Codex、Claude、または他のアシスタントに貼り付けると、Skill ページを確認してインストールできます。
SOC 職業分類に基づく
Lint, run the pre-PR checks, commit, push, and author or update the branch's pull request in the required plain-text format. Use when committing, pushing, or creating/updating a PR.
Modify or upstream a Grug/Grugformer experiment variant.
Run a perf gate on a PR that touches lib/zephyr internals.
Curate the experiment report index at docs/reports/index.md.
Triage a failed canary ferry run (CI-invoked).
Refresh Marin TPU-vLLM forks from a tpu-inference release/LKG pair, update exact SHA pins, run TPU smokes, and open the Marin PR.
| name | profile-training |
| description | Profile JAX training and analyze hotspots. Use when profiling or optimizing training throughput. |
Turn a Levanter profile directory into a deterministic, agent-consumable summary and a concrete optimization workflow:
profile_summary.v1,Ingestion sources:
plugins/profile/<timestamp>/*.xplane.pb*.xplane.pb files via --xplane-filexprof package is available: step overview timing, kernel stats,
collective breakdowns, xprof bottleneck statements.plugins/profile/<timestamp>/perfetto_trace.json.gzplugins/profile/<timestamp>/*.trace.json.gzPrefer XPlane protobuf for new work. Perfetto trace JSON commonly hits the trace
event cap; XPlane contains the uncapped timeline events needed for named-scope
regions, pre-op gaps, gap context, process/thread metadata, and xprof aggregate
tables. Use --trace-file only for a specific Perfetto JSON trace or an older
profile with no XPlane protobuf.
Use Levanter profiler flags so profiles land under
<trainer.log_dir>/<run_id>/profiler:
uv run ... \
--trainer.profiler true \
--trainer.profiler_start_step 5 \
--trainer.profiler_num_steps 50 \
--trainer.profiler_perfetto_link false
For profiles where xprof/HLO protobuf tables matter, enable JAX profile options through the Levanter profiler config:
uv run ... \
--trainer.profiler true \
--trainer.profiler_start_step 5 \
--trainer.profiler_num_steps 50 \
--trainer.profiler.profile_options.host_tracer_level 1 \
--trainer.profiler.profile_options.python_tracer_level 0 \
--trainer.profiler.profile_options.device_tracer_level 0 \
--trainer.profiler.profile_options.enable_hlo_proto true
Keep the profiler window short when enabling HLO protobuf collection — it enlarges artifacts and can increase profile upload/finalization time.
For better profile readability, use haliax.jax_utils.named_call and
jax.named_scope liberally in model code; these names flow into trace
annotations and make region-level summaries far more actionable.
Reference:
lib/levanter/docs/Performance-Guide.md.agents/skills/add-pallas-kernel/Pick a download location for pulled profile artifacts: /tmp for
ephemeral/local, scratch/ for an in-repo working area.
# /tmp (ephemeral)
uv run python lib/marin/tools/profile_summary.py summarize \
--run-target marin-community/marin/<run_id> \
--download-root /tmp/marin-profiles \
--breakdown-mode exclusive_global \
--output /tmp/profile_summary.json
# in-repo scratch (kept with your workspace)
mkdir -p scratch/profiles
uv run python lib/marin/tools/profile_summary.py summarize \
--run-target marin-community/marin/<run_id> \
--download-root scratch/profiles \
--breakdown-mode exclusive_global \
--output scratch/profile_summary.json
uv run python lib/marin/tools/profile_summary.py summarize \
--artifact marin-community/marin/run-grug-125m-profile-apples-pallas_tpu-20260217-225239-055ab2-profiler:v0 \
--download-root /tmp/marin-profiles \
--output /tmp/profile_summary.json
uv run python lib/marin/tools/profile_summary.py summarize \
--run-target marin-community/marin/grug-125m-profile-apples-pallas_tpu-20260217-225239-055ab2 \
--download-root /tmp/marin-profiles \
--output /tmp/profile_summary.json
--run-target accepts: a bare run id (requires --entity and --project),
entity/project/run_id, or a full W&B run URL. The profiler directory is
resolved from trainer.log_dir in the run config.
uv run python lib/marin/tools/profile_summary.py summarize \
--profile-dir /path/to/profiler_dir \
--output /tmp/profile_summary.json
If the directory contains *.xplane.pb, --profile-dir uses the XPlane path
automatically. When both *.xplane.pb and Perfetto trace JSON are present,
--profile-dir reads the XPlane protobuf by default (Perfetto exports are often
capped). Use --trace-file to force a specific Perfetto JSON file.
uv run python lib/marin/tools/profile_summary.py summarize \
--trace-file /path/to/perfetto_trace.json.gz \
--output /tmp/profile_summary.json
Direct XPlane timeline parsing uses protobuf and does not require
TensorFlow-generated xplane_pb2 modules. If xprof is installed, ingestion
also exports compact xprof table JSON and augments the timeline summary with
aggregate step, kernel, collective, and bottleneck evidence.
uv run --with xprof --with protobuf python lib/marin/tools/profile_summary.py summarize \
--xplane-file /path/to/profile.xplane.pb \
--xplane-output-dir /tmp/profile_xprof_tables \
--xplane-count-trace-events \
--output /tmp/profile_summary.json
Without --xplane-output-dir the command still parses XPlane timeline events
directly. Add --with xprof for xprof aggregate table augmentation; add
--xplane-output-dir to preserve the exported table JSON (this flag requires
the optional xprof package).
XPlane summaries expose hierarchical named-scope regions, pre-op gaps, gap region context, process/thread/timeline event metadata, step timing (when step markers or xprof overview rows exist), xprof bottleneck statements, kernel stats, collective breakdowns, and optimization candidates.
Summary version tag: profile_summary.v1
Generate a deterministic markdown root-cause report:
uv run python lib/marin/tools/profile_summary.py report \
--summary /tmp/profile_summary.json \
--output /tmp/profile_report.md
Trace quality checks are surfaced in trace_overview:
suspected_truncation: true when event counts match a known export cap.quality_warnings: warnings to treat hotspot/gap attribution with caution.Top ops:
uv run python lib/marin/tools/profile_summary.py query \
--summary /tmp/profile_summary.json \
--question "What are the top 10 ops by exclusive time?"
Compute vs comm and collective bottlenecks:
uv run python lib/marin/tools/profile_summary.py query \
--summary /tmp/profile_summary.json \
--question "Is comm or compute dominating? Which collective is worst?"
Specific pre-op gap lookup:
uv run python lib/marin/tools/profile_summary.py query \
--summary /tmp/profile_summary.json \
--question "gap before _linear_softmax_cross_entropy_loss_bwd_pallas_mosaic_tpu_combined.1"
Pre-op gap attribution is marker-aware:
gap_before_ops[].payload_op: op where useful work starts after the idle period.gap_before_ops[].marker_op: first op observed after the gap (often
lightweight setup like iota.*).Hierarchical semantic regions (derived from tf_op paths when available):
uv run python lib/marin/tools/profile_summary.py query \
--summary /tmp/profile_summary.json \
--question "show hierarchical regions"
Contextualize a noisy op:
uv run python lib/marin/tools/profile_summary.py query \
--summary /tmp/profile_summary.json \
--question "show context for op copy.564"
Suggested optimizations from evidence:
uv run python lib/marin/tools/profile_summary.py query \
--summary /tmp/profile_summary.json \
--question "What should we try next?"
Use a strict workflow:
before.json.after.json.uv run python lib/marin/tools/profile_summary.py compare \
--before /tmp/profile_before.json \
--after /tmp/profile_after.json \
--strict-provenance
uv run python lib/marin/tools/profile_summary.py track \
--before /tmp/profile_before.json \
--after /tmp/profile_after.json \
--label "pallas-kernel-attempt-3" \
--history /tmp/profile_regression_history.jsonl
uv run python lib/marin/tools/profile_summary.py history \
--history /tmp/profile_regression_history.jsonl
uv run python lib/marin/tools/profile_summary.py bundle \
--before-run-target marin-community/marin/<baseline_run_id> \
--after-run-target marin-community/marin/<candidate_run_id> \
--output-dir /tmp/profile_bundle \
--history /tmp/profile_regression_history.jsonl
uv run python lib/marin/tools/profile_summary.py publish \
--summary /tmp/profile_summary.json \
--report /tmp/profile_report.md \
--alias latest
The comparison reports: steady-state step-time delta, step class deltas (light/heavy when detected), compute/comm/host/stall share deltas, semantic family deltas with workload-normalized metrics, provenance checks (trace hash/run identity), and regressed/improved ops by exclusive duration.
MVP is successful when:
profile_summary.v1,