com um clique
docstring
// Write docstrings for TorchRec functions and methods following PyTorch conventions. Use when writing or updating docstrings in TorchRec code.
// Write docstrings for TorchRec functions and methods following PyTorch conventions. Use when writing or updating docstrings in TorchRec code.
Generate tests for TorchRec source files with correct patterns (unit, distributed, hypothesis), proper BUCK targets, and test utilities. Use when asked to generate tests, add test coverage, or write tests for a module.
Review TorchRec pull requests and diffs for distributed correctness, sharding safety, backward compatibility, and test coverage. Use when reviewing PRs, diffs, or when asked to review code changes.
Investigate and explain TorchRec planner sharding statistics output, especially how HBM storage is computed per table and per rank. Use when the user asks about sharding stats, storage breakdown, or memory estimation.
Guide users through creating Agent Skills for Claude Code. Use when the user wants to create, write, author, or design a new Skill for TorchRec, or needs help with SKILL.md files.
Interview user in-depth to create a detailed spec with strict implementation details and tradeoff analysis
Find and remove tech debt (redundant/duplicated code), run linters, and ensure code quality in recent changes
| name | docstring |
| description | Write docstrings for TorchRec functions and methods following PyTorch conventions. Use when writing or updating docstrings in TorchRec code. |
This skill describes how to write docstrings for functions and methods in the TorchRec project, following PyTorch conventions.
r"""...""") for all docstrings to avoid issues with LaTeX/math backslashesStart with the function signature showing all parameters:
r"""function_name(param1, param2, *, kwarg1=default1, kwarg2=default2) -> ReturnType
Notes:
* separator)Provide a one-line description of what the function does:
r"""apply_optimizer_in_backward(optimizer_class, params, optimizer_kwargs) -> None
Applies optimizer to parameters in backward pass for memory efficiency.
Use Sphinx math directives for mathematical expressions:
.. math::
\text{output} = \text{input} \cdot \text{weight}^T
Or inline math: :math:\x^2``
Link to related classes and functions using Sphinx roles:
:class:\~torchrec.modules.EmbeddingBagCollection`` - Link to a class:func:\torchrec.distributed.sharding.shard`` - Link to a function:meth:\~Module.forward`` - Link to a method:attr:\attribute_name`` - Reference an attribute~ prefix shows only the last componentExample:
See :class:`~torchrec.distributed.DistributedModelParallel` for details.
Use admonitions for important information:
.. note::
This function requires CUDA to be available.
.. warning::
This API is experimental and may change without notice.
Document all parameters with type annotations and descriptions:
Args:
module (nn.Module): Module to be sharded across devices.
device (torch.device, optional): Device to place the module. Default: ``None``
sharders (List[ModuleSharder], optional): List of sharders to use for sharding.
Default: ``None``
plan (ShardingPlan, optional): Explicit sharding plan. If not provided, will be
generated automatically. Default: ``None``
Formatting rules:
(Type), (Type, optional) for optional parametersvalue" at the end``None``Document the return value:
Returns:
ShardedModule: The sharded module ready for distributed training.
The module will have its parameters distributed according to
the sharding plan.
Document exceptions that may be raised:
Raises:
ValueError: If the sharding plan is invalid for the given module.
RuntimeError: If CUDA is not available when GPU sharding is requested.
Always include examples when possible:
Examples::
>>> import torchrec
>>> from torchrec.modules import EmbeddingBagCollection
>>> ebc = EmbeddingBagCollection(
... tables=[
... EmbeddingBagConfig(
... name="product",
... embedding_dim=64,
... num_embeddings=1000,
... feature_names=["product_id"],
... ),
... ],
... )
>>> # Shard the module
>>> sharded_ebc = shard(ebc, plan=plan)
Formatting rules:
Examples:: with double colon>>> prompt for Python code# when helpfulArgs:
tables (List[EmbeddingBagConfig]): List of embedding table configurations.
Each config specifies the table name, embedding dimension, number of
embeddings, and feature names.
device (Optional[torch.device]): Device to place embeddings. Default: ``None``
Args:
sharding_type (ShardingType): How to shard the embedding table. Options are:
- ``TABLE_WISE``: Each table on a single device
- ``ROW_WISE``: Rows distributed across devices
- ``COLUMN_WISE``: Columns distributed across devices
- ``TABLE_ROW_WISE``: Combination of table and row sharding
Args:
kjt (KeyedJaggedTensor): Sparse features in KeyedJaggedTensor format.
Contains keys (feature names), values (embedding indices), and
lengths/offsets for variable-length sequences.
def shard_modules(
module: nn.Module,
plan: ShardingPlan,
env: ShardingEnv,
device: Optional[torch.device] = None,
) -> nn.Module:
r"""
Shard a module's embedding tables according to a sharding plan.
This function takes a module containing embedding tables and distributes
them across multiple devices according to the provided sharding plan.
It supports various sharding strategies including table-wise, row-wise,
and column-wise sharding.
Args:
module (nn.Module): The module containing embedding tables to shard.
plan (ShardingPlan): The sharding plan specifying how each table
should be distributed.
env (ShardingEnv): The sharding environment containing process group
information and device topology.
device (torch.device, optional): Target device for local shards.
Default: ``None`` (uses current device)
Returns:
nn.Module: The sharded module with distributed embedding tables.
Raises:
ValueError: If the plan references tables not present in the module.
RuntimeError: If the sharding environment is not properly initialized.
.. note::
This function modifies the module in-place for efficiency.
.. warning::
This is an experimental API and may change in future releases.
Examples::
>>> from torchrec.distributed import shard_modules
>>> from torchrec.distributed.planner import EmbeddingShardingPlanner
>>>
>>> # Create a sharding plan
>>> planner = EmbeddingShardingPlanner()
>>> plan = planner.plan(module, sharders)
>>>
>>> # Shard the module
>>> sharded_module = shard_modules(module, plan, env)
"""
# implementation
When writing a TorchRec docstring, ensure:
r"""):func:, :class:, :meth:)