一键导入
profile-training
Profile JAX training and analyze hotspots. Use when profiling or optimizing training throughput.
用 Codex 或 Claude 帮你安装 复制这段 Prompt,粘贴到 Codex、Claude 或其他助手里,让它检查 Skill 页面并帮你完成安装。
菜单
Profile JAX training and analyze hotspots. Use when profiling or optimizing training throughput.
用 Codex 或 Claude 帮你安装 复制这段 Prompt,粘贴到 Codex、Claude 或其他助手里,让它检查 Skill 页面并帮你完成安装。
基于 SOC 职业分类
Lint, run the pre-PR checks, commit, push, and author or update the branch's pull request in the required plain-text format. Use when committing, pushing, or creating/updating a PR.
Modify or upstream a Grug/Grugformer experiment variant.
Run a perf gate on a PR that touches lib/zephyr internals.
Curate the experiment report index at docs/reports/index.md.
Triage a failed canary ferry run (CI-invoked).
Refresh Marin TPU-vLLM forks from a tpu-inference release/LKG pair, update exact SHA pins, run TPU smokes, and open the Marin PR.
| name | profile-training |
| description | Profile JAX training and analyze hotspots. Use when profiling or optimizing training throughput. |
Turn a Levanter profile directory into a deterministic, agent-consumable summary and a concrete optimization workflow:
profile_summary.v1,Ingestion sources:
plugins/profile/<timestamp>/*.xplane.pb*.xplane.pb files via --xplane-filexprof package is available: step overview timing, kernel stats,
collective breakdowns, xprof bottleneck statements.plugins/profile/<timestamp>/perfetto_trace.json.gzplugins/profile/<timestamp>/*.trace.json.gzPrefer XPlane protobuf for new work. Perfetto trace JSON commonly hits the trace
event cap; XPlane contains the uncapped timeline events needed for named-scope
regions, pre-op gaps, gap context, process/thread metadata, and xprof aggregate
tables. Use --trace-file only for a specific Perfetto JSON trace or an older
profile with no XPlane protobuf.
Use Levanter profiler flags so profiles land under
<trainer.log_dir>/<run_id>/profiler:
uv run ... \
--trainer.profiler true \
--trainer.profiler_start_step 5 \
--trainer.profiler_num_steps 50 \
--trainer.profiler_perfetto_link false
For profiles where xprof/HLO protobuf tables matter, enable JAX profile options through the Levanter profiler config:
uv run ... \
--trainer.profiler true \
--trainer.profiler_start_step 5 \
--trainer.profiler_num_steps 50 \
--trainer.profiler.profile_options.host_tracer_level 1 \
--trainer.profiler.profile_options.python_tracer_level 0 \
--trainer.profiler.profile_options.device_tracer_level 0 \
--trainer.profiler.profile_options.enable_hlo_proto true
Keep the profiler window short when enabling HLO protobuf collection — it enlarges artifacts and can increase profile upload/finalization time.
For better profile readability, use haliax.jax_utils.named_call and
jax.named_scope liberally in model code; these names flow into trace
annotations and make region-level summaries far more actionable.
Reference:
lib/levanter/docs/Performance-Guide.md.agents/skills/add-pallas-kernel/Pick a download location for pulled profile artifacts: /tmp for
ephemeral/local, scratch/ for an in-repo working area.
# /tmp (ephemeral)
uv run python lib/marin/tools/profile_summary.py summarize \
--run-target marin-community/marin/<run_id> \
--download-root /tmp/marin-profiles \
--breakdown-mode exclusive_global \
--output /tmp/profile_summary.json
# in-repo scratch (kept with your workspace)
mkdir -p scratch/profiles
uv run python lib/marin/tools/profile_summary.py summarize \
--run-target marin-community/marin/<run_id> \
--download-root scratch/profiles \
--breakdown-mode exclusive_global \
--output scratch/profile_summary.json
uv run python lib/marin/tools/profile_summary.py summarize \
--artifact marin-community/marin/run-grug-125m-profile-apples-pallas_tpu-20260217-225239-055ab2-profiler:v0 \
--download-root /tmp/marin-profiles \
--output /tmp/profile_summary.json
uv run python lib/marin/tools/profile_summary.py summarize \
--run-target marin-community/marin/grug-125m-profile-apples-pallas_tpu-20260217-225239-055ab2 \
--download-root /tmp/marin-profiles \
--output /tmp/profile_summary.json
--run-target accepts: a bare run id (requires --entity and --project),
entity/project/run_id, or a full W&B run URL. The profiler directory is
resolved from trainer.log_dir in the run config.
uv run python lib/marin/tools/profile_summary.py summarize \
--profile-dir /path/to/profiler_dir \
--output /tmp/profile_summary.json
If the directory contains *.xplane.pb, --profile-dir uses the XPlane path
automatically. When both *.xplane.pb and Perfetto trace JSON are present,
--profile-dir reads the XPlane protobuf by default (Perfetto exports are often
capped). Use --trace-file to force a specific Perfetto JSON file.
uv run python lib/marin/tools/profile_summary.py summarize \
--trace-file /path/to/perfetto_trace.json.gz \
--output /tmp/profile_summary.json
Direct XPlane timeline parsing uses protobuf and does not require
TensorFlow-generated xplane_pb2 modules. If xprof is installed, ingestion
also exports compact xprof table JSON and augments the timeline summary with
aggregate step, kernel, collective, and bottleneck evidence.
uv run --with xprof --with protobuf python lib/marin/tools/profile_summary.py summarize \
--xplane-file /path/to/profile.xplane.pb \
--xplane-output-dir /tmp/profile_xprof_tables \
--xplane-count-trace-events \
--output /tmp/profile_summary.json
Without --xplane-output-dir the command still parses XPlane timeline events
directly. Add --with xprof for xprof aggregate table augmentation; add
--xplane-output-dir to preserve the exported table JSON (this flag requires
the optional xprof package).
XPlane summaries expose hierarchical named-scope regions, pre-op gaps, gap region context, process/thread/timeline event metadata, step timing (when step markers or xprof overview rows exist), xprof bottleneck statements, kernel stats, collective breakdowns, and optimization candidates.
Summary version tag: profile_summary.v1
Generate a deterministic markdown root-cause report:
uv run python lib/marin/tools/profile_summary.py report \
--summary /tmp/profile_summary.json \
--output /tmp/profile_report.md
Trace quality checks are surfaced in trace_overview:
suspected_truncation: true when event counts match a known export cap.quality_warnings: warnings to treat hotspot/gap attribution with caution.Top ops:
uv run python lib/marin/tools/profile_summary.py query \
--summary /tmp/profile_summary.json \
--question "What are the top 10 ops by exclusive time?"
Compute vs comm and collective bottlenecks:
uv run python lib/marin/tools/profile_summary.py query \
--summary /tmp/profile_summary.json \
--question "Is comm or compute dominating? Which collective is worst?"
Specific pre-op gap lookup:
uv run python lib/marin/tools/profile_summary.py query \
--summary /tmp/profile_summary.json \
--question "gap before _linear_softmax_cross_entropy_loss_bwd_pallas_mosaic_tpu_combined.1"
Pre-op gap attribution is marker-aware:
gap_before_ops[].payload_op: op where useful work starts after the idle period.gap_before_ops[].marker_op: first op observed after the gap (often
lightweight setup like iota.*).Hierarchical semantic regions (derived from tf_op paths when available):
uv run python lib/marin/tools/profile_summary.py query \
--summary /tmp/profile_summary.json \
--question "show hierarchical regions"
Contextualize a noisy op:
uv run python lib/marin/tools/profile_summary.py query \
--summary /tmp/profile_summary.json \
--question "show context for op copy.564"
Suggested optimizations from evidence:
uv run python lib/marin/tools/profile_summary.py query \
--summary /tmp/profile_summary.json \
--question "What should we try next?"
Use a strict workflow:
before.json.after.json.uv run python lib/marin/tools/profile_summary.py compare \
--before /tmp/profile_before.json \
--after /tmp/profile_after.json \
--strict-provenance
uv run python lib/marin/tools/profile_summary.py track \
--before /tmp/profile_before.json \
--after /tmp/profile_after.json \
--label "pallas-kernel-attempt-3" \
--history /tmp/profile_regression_history.jsonl
uv run python lib/marin/tools/profile_summary.py history \
--history /tmp/profile_regression_history.jsonl
uv run python lib/marin/tools/profile_summary.py bundle \
--before-run-target marin-community/marin/<baseline_run_id> \
--after-run-target marin-community/marin/<candidate_run_id> \
--output-dir /tmp/profile_bundle \
--history /tmp/profile_regression_history.jsonl
uv run python lib/marin/tools/profile_summary.py publish \
--summary /tmp/profile_summary.json \
--report /tmp/profile_report.md \
--alias latest
The comparison reports: steady-state step-time delta, step class deltas (light/heavy when detected), compute/comm/host/stall share deltas, semantic family deltas with workload-normalized metrics, provenance checks (trace hash/run identity), and regressed/improved ops by exclusive duration.
MVP is successful when:
profile_summary.v1,