ワンクリックで
sharding-stats
// Investigate and explain TorchRec planner sharding statistics output, especially how HBM storage is computed per table and per rank. Use when the user asks about sharding stats, storage breakdown, or memory estimation.
// Investigate and explain TorchRec planner sharding statistics output, especially how HBM storage is computed per table and per rank. Use when the user asks about sharding stats, storage breakdown, or memory estimation.
Generate tests for TorchRec source files with correct patterns (unit, distributed, hypothesis), proper BUCK targets, and test utilities. Use when asked to generate tests, add test coverage, or write tests for a module.
Review TorchRec pull requests and diffs for distributed correctness, sharding safety, backward compatibility, and test coverage. Use when reviewing PRs, diffs, or when asked to review code changes.
Write docstrings for TorchRec functions and methods following PyTorch conventions. Use when writing or updating docstrings in TorchRec code.
Guide users through creating Agent Skills for Claude Code. Use when the user wants to create, write, author, or design a new Skill for TorchRec, or needs help with SKILL.md files.
Interview user in-depth to create a detailed spec with strict implementation details and tradeoff analysis
Find and remove tech debt (redundant/duplicated code), run linters, and ensure code quality in recent changes
| name | sharding-stats |
| description | Investigate and explain TorchRec planner sharding statistics output, especially how HBM storage is computed per table and per rank. Use when the user asks about sharding stats, storage breakdown, or memory estimation. |
| allowed-tools | Read, Grep, Bash, Task |
| argument-hint | ["table name","sharding type","or question about stats"] |
Investigate and explain TorchRec planner sharding statistics, especially HBM storage computation, for: $ARGUMENTS
You are analyzing the output of EmbeddingStats (in torchrec/distributed/planner/stats.py). This skill covers how to read the stats table and how each number is computed from source code.
torchrec/distributed/planner/stats.py — Generates the bordered stats table output (EmbeddingStats class)torchrec/distributed/planner/shard_estimators.py — Core estimation logic:
EmbeddingStorageEstimator — orchestrates storage estimationcalculate_shard_storages() — assembles final Storage per shard_calculate_shard_io_sizes() — dispatches to sharding-type-specific I/O calculations_calculate_rw_shard_io_sizes(), _calculate_tw_shard_io_sizes(), etc._calculate_storage_specific_sizes() — tensor + optimizer + cache aux_calculate_tensor_sizes() — proportional tensor size per shard_calculate_optimizer_sizes() — optimizer state multipliercalculate_pipeline_io_cost() — pipeline type I/O multiplierstorchrec/distributed/planner/storage_reservations.py — Dense storage and KJT storage reservationtorchrec/distributed/planner/types.py — Perf, Storage, ShardingOption dataclassestorchrec/distributed/planner/utils.py — bytes_to_gb (1 GB = 2^30 bytes), bytes_to_mbtorchrec/distributed/embedding_types.py — Sharder storage_usage() implementationsThe stats output has these sections:
used_hbm[rank] = sparse_hbm[rank] + dense_storage.hbm + kjt_storage.hbm
sparse_hbm[rank] = sum of shard.storage.hbm for every embedding shard placed on that rankdense_storage = non-embedding model parameters (from HeuristicalStorageReservation)kjt_storage = KeyedJaggedTensor input buffersEach shard's HBM is computed in calculate_shard_storages():
shard.storage.hbm = hbm_specific_size + pipeline_io_cost
Where:
hbm_specific_size = tensor_size + optimizer_size + cache_aux_size
sharder.storage_usage())The sharder determines the raw tensor bytes:
hbm_storage = num_embeddings × emb_dim × element_sizehbm_storage = num_embeddings × emb_dim × element_size + num_embeddings × 4
shape[0] × 4 bytes is metadata overhead for sequence embeddingsddr_storage × caching_ratio_calculate_tensor_sizes())tensor_size = ceil(hbm_storage × prod(shard_size) / prod(full_shape))
For RW sharding with world_size shards: shard_size = [num_embeddings / world_size, emb_dim]
_calculate_optimizer_sizes())optimizer_size = ceil(tensor_size × optimizer_multiplier)
| Optimizer | Multiplier |
|---|---|
| SGD | 0 |
| Adam | 2 |
| RowWiseAdagrad | 1 / emb_dim |
| Default/unknown | 1 |
| None (inference) | 0 |
_calculate_cache_aux_state_sizes())Only applies to UVM caching (fused_uvm_caching kernel). For fused kernel: 0.
Computed by _calculate_shard_io_sizes() → dispatches to type-specific functions.
Constants:
input_data_type_size = 8 (BIGINT_DTYPE, int64 indices)output_data_type_size = tensor.element_size() (or output_dtype if specified)For RW sharding (_calculate_rw_shard_io_sizes()):
batch_inputs = sum(input_length_i × num_poolings_i × batch_size_i) / world_size
batch_outputs = batch_inputs # if non-pooled (sequence)
= sum(num_poolings_i × batch_size_i) # if pooled
input_size = ceil(batch_inputs × world_size × input_data_type_size) # per shard
output_size = ceil(batch_outputs × world_size × shard_dim × output_data_type_size) # per shard
For TW sharding (_calculate_tw_shard_io_sizes()):
batch_inputs = sum(input_length_i × num_poolings_i × batch_size_i) # no division
input_size = ceil(batch_inputs × world_size × input_data_type_size)
output_size = ceil(batch_outputs × world_size × emb_dim × output_data_type_size)
For CW sharding (_calculate_cw_shard_io_sizes()):
# Same as TW but output uses shard_sizes[i][1] (shard column dim) instead of full emb_dim
output_size = ceil(batch_outputs × world_size × shard_sizes[i][1] × output_data_type_size)
Critical insight for sequence (non-pooled) embeddings:
When is_pooled=False, batch_outputs = batch_inputs, which means output includes one full embedding vector per input index. With large input_lengths, the output buffer can be enormous — often 90%+ of total storage.
calculate_pipeline_io_cost())output_contribution = output_size if count_ephemeral_storage_cost else 0
| Pipeline Type | Formula |
|---|---|
NONE (catch-all) | input_size + output_size |
TRAIN_SPARSE_DIST | 2 × input_size + output_contribution |
TRAIN_PREFETCH_SPARSE_DIST | 3 × input_size + (1 + 6/max_pass) × prefetch_size + output_contribution |
| Inference | 0 |
prefetch_size = input_size if table is cached, else 0count_ephemeral_storage_cost defaults to Falseshard.storage.hbm = hbm_specific_size + pipeline_io_cost
The "Storage (HBM, DDR)" column in the parameter info table shows:
total_storage = sum(shard.storage for shard in sharding_option.shards)
This is the sum across ALL shards (all ranks), not per-rank.
From HeuristicalStorageReservation:
dense_storage.hbm = (total_model_params - embedding_params) × multiplier + buffers
kjt_storage.hbm = total_kjt_size × kjt_multiplier
When analyzing a table's storage, determine these parameters:
shape[0] × 4)_optimizer_classes attribute on tensor; RowWiseAdagrad is common for RecSysGiven: hash_size=80M, emb_dim=128, dtype=fp16, 4 features, sum(input_lengths)=6066, batch_size=2560, world_size=96, RW sharding, fused kernel, RowWiseAdagrad, PipelineType.NONE
80M × 128 × 2 = 20,480,000,000 bytes (19.07 GB)ceil(20,480,000,000 × 833,333 / 80,000,000) = 213,333,248 bytesceil(213,333,248 / 128) = 1,666,666 bytes (~1.6 MB)213,333,248 + 1,666,666 = 214,999,914 bytes (~205 MB)batch_inputs = 2560 × 6066 / 96 = 161,760input_size = ceil(161,760 × 96 × 8) = 124,231,680 (~118 MB)output_size = ceil(161,760 × 96 × 128 × 2) = 3,975,413,760 (~3.70 GB)124,231,680 + 3,975,413,760 = 4,099,645,440 (~3.82 GB)214,999,914 + 4,099,645,440 = 4,314,645,354 (~4.02 GB)~385.8 GBKey insight: Output buffer is ~92% of total — sequence embeddings with large input_lengths dominate storage.
| Component | Per Shard | Total (96 shards) | % |
|---|---|---|---|
| Embedding weights | 203 MB | 19.07 GB | 4.9% |
| Optimizer (RowWiseAdagrad) | 1.6 MB | 0.15 GB | 0.04% |
| Input buffer (int64) | 118 MB | 11.1 GB | 2.9% |
| Output buffer (fp16) | 3.70 GB | 355.5 GB | 92.1% |
| Total | ~4.02 GB | ~385.8 GB |
See sharding_stats_example.txt for a full example.