| name | sharding-stats |
| description | Investigate and explain TorchRec planner sharding statistics output, especially how HBM storage is computed per table and per rank. Use when the user asks about sharding stats, storage breakdown, or memory estimation. |
| allowed-tools | Read, Grep, Bash, Task |
| argument-hint | ["table name","sharding type","or question about stats"] |
Sharding Stats Investigation Guide
Investigate and explain TorchRec planner sharding statistics, especially HBM storage computation, for: $ARGUMENTS
Instructions
You are analyzing the output of EmbeddingStats (in torchrec/distributed/planner/stats.py). This skill covers how to read the stats table and how each number is computed from source code.
Key Source Files
torchrec/distributed/planner/stats.py — Generates the bordered stats table output (EmbeddingStats class)
torchrec/distributed/planner/shard_estimators.py — Core estimation logic:
EmbeddingStorageEstimator — orchestrates storage estimation
calculate_shard_storages() — assembles final Storage per shard
_calculate_shard_io_sizes() — dispatches to sharding-type-specific I/O calculations
_calculate_rw_shard_io_sizes(), _calculate_tw_shard_io_sizes(), etc.
_calculate_storage_specific_sizes() — tensor + optimizer + cache aux
_calculate_tensor_sizes() — proportional tensor size per shard
_calculate_optimizer_sizes() — optimizer state multiplier
calculate_pipeline_io_cost() — pipeline type I/O multipliers
torchrec/distributed/planner/storage_reservations.py — Dense storage and KJT storage reservation
torchrec/distributed/planner/types.py — Perf, Storage, ShardingOption dataclasses
torchrec/distributed/planner/utils.py — bytes_to_gb (1 GB = 2^30 bytes), bytes_to_mb
torchrec/distributed/embedding_types.py — Sharder storage_usage() implementations
Stats Table Structure
The stats output has these sections:
- Per-Rank Summary: Rank, HBM (GB), DDR (GB), Perf (ms), Input (MB), Output (MB), Shards
- Parameter Info Table: Per-table details (FQN, Sharding, Compute Kernel, Perf, Storage, etc.)
- Batch Size & Compute Kernels: Global batch size, kernel counts and storage
- Imbalance Statistics: Total Variation, Total Distance, Chi Divergence, KL Divergence
- Peak Memory Estimation: Top-tier HBM pressure per rank
- Storage Reservation: Reserved, Planning, Dense, KJT storage
HBM Storage Computation — Step by Step
Per-Rank HBM Formula
used_hbm[rank] = sparse_hbm[rank] + dense_storage.hbm + kjt_storage.hbm
sparse_hbm[rank] = sum of shard.storage.hbm for every embedding shard placed on that rank
dense_storage = non-embedding model parameters (from HeuristicalStorageReservation)
kjt_storage = KeyedJaggedTensor input buffers
Per-Shard HBM Formula (the core calculation)
Each shard's HBM is computed in calculate_shard_storages():
shard.storage.hbm = hbm_specific_size + pipeline_io_cost
Where:
hbm_specific_size = tensor_size + optimizer_size + cache_aux_size
Step 1: Base Tensor Storage (sharder.storage_usage())
The sharder determines the raw tensor bytes:
- EmbeddingBagCollectionSharder:
hbm_storage = num_embeddings × emb_dim × element_size
- EmbeddingCollectionSharder:
hbm_storage = num_embeddings × emb_dim × element_size + num_embeddings × 4
- The extra
shape[0] × 4 bytes is metadata overhead for sequence embeddings
- For UVM caching kernels: tensor goes to DDR, HBM gets
ddr_storage × caching_ratio
Step 2: Per-Shard Tensor Size (_calculate_tensor_sizes())
tensor_size = ceil(hbm_storage × prod(shard_size) / prod(full_shape))
For RW sharding with world_size shards: shard_size = [num_embeddings / world_size, emb_dim]
Step 3: Optimizer Size (_calculate_optimizer_sizes())
optimizer_size = ceil(tensor_size × optimizer_multiplier)
| Optimizer | Multiplier |
|---|
| SGD | 0 |
| Adam | 2 |
| RowWiseAdagrad | 1 / emb_dim |
| Default/unknown | 1 |
| None (inference) | 0 |
Step 4: Cache Auxiliary State (_calculate_cache_aux_state_sizes())
Only applies to UVM caching (fused_uvm_caching kernel). For fused kernel: 0.
Step 5: I/O Sizes (sharding-type-specific)
Computed by _calculate_shard_io_sizes() → dispatches to type-specific functions.
Constants:
input_data_type_size = 8 (BIGINT_DTYPE, int64 indices)
output_data_type_size = tensor.element_size() (or output_dtype if specified)
For RW sharding (_calculate_rw_shard_io_sizes()):
batch_inputs = sum(input_length_i × num_poolings_i × batch_size_i) / world_size
batch_outputs = batch_inputs
= sum(num_poolings_i × batch_size_i)
input_size = ceil(batch_inputs × world_size × input_data_type_size)
output_size = ceil(batch_outputs × world_size × shard_dim × output_data_type_size)
For TW sharding (_calculate_tw_shard_io_sizes()):
batch_inputs = sum(input_length_i × num_poolings_i × batch_size_i)
input_size = ceil(batch_inputs × world_size × input_data_type_size)
output_size = ceil(batch_outputs × world_size × emb_dim × output_data_type_size)
For CW sharding (_calculate_cw_shard_io_sizes()):
output_size = ceil(batch_outputs × world_size × shard_sizes[i][1] × output_data_type_size)
Critical insight for sequence (non-pooled) embeddings:
When is_pooled=False, batch_outputs = batch_inputs, which means output includes one full embedding vector per input index. With large input_lengths, the output buffer can be enormous — often 90%+ of total storage.
Step 6: Pipeline I/O Cost (calculate_pipeline_io_cost())
output_contribution = output_size if count_ephemeral_storage_cost else 0
| Pipeline Type | Formula |
|---|
NONE (catch-all) | input_size + output_size |
TRAIN_SPARSE_DIST | 2 × input_size + output_contribution |
TRAIN_PREFETCH_SPARSE_DIST | 3 × input_size + (1 + 6/max_pass) × prefetch_size + output_contribution |
| Inference | 0 |
prefetch_size = input_size if table is cached, else 0
count_ephemeral_storage_cost defaults to False
Step 7: Final Per-Shard Storage
shard.storage.hbm = hbm_specific_size + pipeline_io_cost
Step 8: Total Storage (shown in Parameter Info table)
The "Storage (HBM, DDR)" column in the parameter info table shows:
total_storage = sum(shard.storage for shard in sharding_option.shards)
This is the sum across ALL shards (all ranks), not per-rank.
Dense Storage Reservation
From HeuristicalStorageReservation:
dense_storage.hbm = (total_model_params - embedding_params) × multiplier + buffers
- Training multiplier = 6.0 (1× params + 2× optimizer state + 3× DDP gradient buffers)
- Inference multiplier = 1.0
KJT Storage Reservation
kjt_storage.hbm = total_kjt_size × kjt_multiplier
- Training multiplier = 20 (pipelined batches)
- Inference multiplier = 1
Investigation Checklist
When analyzing a table's storage, determine these parameters:
- Tensor dtype: float32 (4 bytes) or float16/bfloat16 (2 bytes)? Cross-check with per-rank Output (MB) column.
- Sharder type: EmbeddingBagCollectionSharder vs EmbeddingCollectionSharder (EC adds
shape[0] × 4)
- is_pooled: pooled (EmbeddingBag) vs sequence (Embedding) — check "Output" column
- Optimizer: check
_optimizer_classes attribute on tensor; RowWiseAdagrad is common for RecSys
- Pipeline type: NONE, TRAIN_SPARSE_DIST, or TRAIN_PREFETCH_SPARSE_DIST
- count_ephemeral_storage_cost: defaults to False
- Compute kernel: fused, fused_uvm_caching, quant, etc.
- Caching ratio: only matters for UVM caching kernels
Worked Example: RW Sequence Embedding
Given: hash_size=80M, emb_dim=128, dtype=fp16, 4 features, sum(input_lengths)=6066, batch_size=2560, world_size=96, RW sharding, fused kernel, RowWiseAdagrad, PipelineType.NONE
- Tensor storage:
80M × 128 × 2 = 20,480,000,000 bytes (19.07 GB)
- Per-shard tensor:
ceil(20,480,000,000 × 833,333 / 80,000,000) = 213,333,248 bytes
- Optimizer:
ceil(213,333,248 / 128) = 1,666,666 bytes (~1.6 MB)
- hbm_specific:
213,333,248 + 1,666,666 = 214,999,914 bytes (~205 MB)
- I/O (RW, non-pooled):
batch_inputs = 2560 × 6066 / 96 = 161,760
input_size = ceil(161,760 × 96 × 8) = 124,231,680 (~118 MB)
output_size = ceil(161,760 × 96 × 128 × 2) = 3,975,413,760 (~3.70 GB)
- Pipeline (NONE):
124,231,680 + 3,975,413,760 = 4,099,645,440 (~3.82 GB)
- Per-shard total:
214,999,914 + 4,099,645,440 = 4,314,645,354 (~4.02 GB)
- Total (96 shards):
~385.8 GB
Key insight: Output buffer is ~92% of total — sequence embeddings with large input_lengths dominate storage.
| Component | Per Shard | Total (96 shards) | % |
|---|
| Embedding weights | 203 MB | 19.07 GB | 4.9% |
| Optimizer (RowWiseAdagrad) | 1.6 MB | 0.15 GB | 0.04% |
| Input buffer (int64) | 118 MB | 11.1 GB | 2.9% |
| Output buffer (fp16) | 3.70 GB | 355.5 GB | 92.1% |
| Total | ~4.02 GB | ~385.8 GB | |
Example Stats Output
See sharding_stats_example.txt for a full example.