| name | torch-geometric |
| description | Graph Neural Networks (PyG). Node/graph classification, link prediction, GCN, GAT, GraphSAGE, heterogeneous graphs, molecular property prediction, for geometric deep learning. |
PyTorch Geometric (PyG)
Routing Boundary
Use this skill only for PyTorch Geometric, torch_geometric, PyG, graph neural networks, GCN/GAT, graph classification, node classification, link prediction, and heterogeneous graph learning. Do not use it for generic neural networks, CNN/image classification, graph visualization, or molecule-only tasks unless PyG or graph neural network modeling is explicit.
Overview
PyTorch Geometric is a library built on PyTorch for developing and training Graph Neural Networks (GNNs). Apply this skill for deep learning on graphs and irregular structures, including mini-batch processing, multi-GPU training, and geometric deep learning applications.
Naming Compatibility
Use torch-geometric as the canonical skill ID. Treat torch_geometric,
PyG, and pytorch geometric as API or keyword spellings that route to this
same skill, not as separate expert roles.
When to Use This Skill
This skill should be used when working with:
- Graph-based machine learning: Node classification, graph classification, link prediction
- Molecular property prediction: Drug discovery, chemical property prediction
- Social network analysis: Community detection, influence prediction
- Citation networks: Paper classification, recommendation systems
- 3D geometric data: Point clouds, meshes, molecular structures
- Heterogeneous graphs: Multi-type nodes and edges (e.g., knowledge graphs)
- Large-scale graph learning: Neighbor sampling, distributed training
Quick Start
Installation
uv pip install torch_geometric
For additional dependencies (sparse operations, clustering):
uv pip install pyg_lib torch_scatter torch_sparse torch_cluster torch_spline_conv -f https://data.pyg.org/whl/torch-${TORCH}+${CUDA}.html
Basic Graph Creation
import torch
from torch_geometric.data import Data
edge_index = torch.tensor([[0, 1, 1, 2],
[1, 0, 2, 1]], dtype=torch.long)
x = torch.tensor([[-1], [0], [1]], dtype=torch.float)
data = Data(x=x, edge_index=edge_index)
print(f"Nodes: {data.num_nodes}, Edges: {data.num_edges}")
Loading a Benchmark Dataset
from torch_geometric.datasets import Planetoid
dataset = Planetoid(root='/tmp/Cora', name='Cora')
data = dataset[0]
print(f"Dataset: {dataset}")
print(f"Nodes: {data.num_nodes}, Edges: {data.num_edges}")
print(f"Features: {data.num_node_features}, Classes: {dataset.num_classes}")
Core Concepts
Data Structure
PyG represents graphs using the torch_geometric.data.Data class with these key attributes:
data.x: Node feature matrix [num_nodes, num_node_features]
data.edge_index: Graph connectivity in COO format [2, num_edges]
data.edge_attr: Edge feature matrix [num_edges, num_edge_features] (optional)
data.y: Target labels for nodes or graphs
data.pos: Node spatial positions [num_nodes, num_dimensions] (optional)
- Custom attributes: Can add any attribute (e.g.,
data.train_mask, data.batch)
Important: These attributes are not mandatory—extend Data objects with custom attributes as needed.
Edge Index Format
Edges are stored in COO (coordinate) format as a [2, num_edges] tensor:
- First row: source node indices
- Second row: target node indices
edge_index = torch.tensor([[0, 1, 1, 2],
[1, 0, 2, 1]], dtype=torch.long)
Mini-Batch Processing
PyG handles batching by creating block-diagonal adjacency matrices, concatenating multiple graphs into one large disconnected graph:
- Adjacency matrices are stacked diagonally
- Node features are concatenated along the node dimension
- A
batch vector maps each node to its source graph
- No padding needed—computationally efficient
from torch_geometric.loader import DataLoader
loader = DataLoader(dataset, batch_size=32, shuffle=True)
for batch in loader:
print(f"Batch size: {batch.num_graphs}")
print(f"Total nodes: {batch.num_nodes}")
Building Graph Neural Networks
Message Passing Paradigm
GNNs in PyG follow a neighborhood aggregation scheme:
- Transform node features
- Propagate messages along edges
- Aggregate messages from neighbors
- Update node representations
Using Pre-Built Layers
PyG provides 40+ convolutional layers. Common ones include:
GCNConv (Graph Convolutional Network):
from torch_geometric.nn import GCNConv
import torch.nn.functional as F
class GCN(torch.nn.Module):
def __init__(self, num_features, num_classes):
super().__init__()
self.conv1 = GCNConv(num_features, 16)
self.conv2 = GCNConv(16, num_classes)
def forward(self, data):
x, edge_index = data.x, data.edge_index
x = self.conv1(x, edge_index)
x = F.relu(x)
x = F.dropout(x, training=self.training)
x = self.conv2(x, edge_index)
return F.log_softmax(x, dim=1)
GATConv (Graph Attention Network):
from torch_geometric.nn import GATConv
class GAT(torch.nn.Module):
def __init__(self, num_features, num_classes):
super().__init__()
self.conv1 = GATConv(num_features, 8, heads=8, dropout=0.6)
self.conv2 = GATConv(8 * 8, num_classes, heads=1, concat=False, dropout=0.6)
def forward(self, data):
x, edge_index = data.x, data.edge_index
x = F.dropout(x, p=0.6, training=self.training)
x = F.elu(self.conv1(x, edge_index))
x = F.dropout(x, p=0.6, training=self.training)
x = self.conv2(x, edge_index)
return F.log_softmax(x, dim=1)
GraphSAGE:
from torch_geometric.nn import SAGEConv
class GraphSAGE(torch.nn.Module):
def __init__(self, num_features, num_classes):
super().__init__()
self.conv1 = SAGEConv(num_features, 64)
self.conv2 = SAGEConv(64, num_classes)
def forward(self, data):
x, edge_index = data.x, data.edge_index
x = self.conv1(x, edge_index)
x = F.relu(x)
x = F.dropout(x, training=self.training)
x = self.conv2(x, edge_index)
return F.log_softmax(x, dim=1)
Custom Message Passing Layers
For custom layers, inherit from MessagePassing:
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree
class CustomConv(MessagePassing):
def __init__(self, in_channels, out_channels):
super().__init__(aggr='add')
self.lin = torch.nn.Linear(in_channels, out_channels)
def forward(self, x, edge_index):
edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))
x = self.lin(x)
row, col = edge_index
deg = degree(col, x.size(0), dtype=x.dtype)
deg_inv_sqrt = deg.pow(-0.5)
norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]
return self.propagate(edge_index, x=x, norm=norm)
def message(self, x_j, norm):
return norm.view(-1, 1) * x_j
Key methods:
forward(): Main entry point
message(): Constructs messages from source to target nodes
aggregate(): Aggregates messages (usually don't override—set aggr parameter)
update(): Updates node embeddings after aggregation
Variable naming convention: Appending _i or _j to tensor names automatically maps them to target or source nodes.
Working with Datasets
Loading Built-in Datasets
PyG provides extensive benchmark datasets:
from torch_geometric.datasets import Planetoid
dataset = Planetoid(root='/tmp/Cora', name='Cora')
from torch_geometric.datasets import TUDataset
dataset = TUDataset(root='/tmp/ENZYMES', name='ENZYMES')
from torch_geometric.datasets import QM9
dataset = QM9(root='/tmp/QM9')
from torch_geometric.datasets import Reddit
dataset = Reddit(root='/tmp/Reddit')
Check references/datasets_reference.md for a comprehensive list.
Creating Custom Datasets
For datasets that fit in memory, inherit from InMemoryDataset:
from torch_geometric.data import InMemoryDataset, Data
import torch
class MyOwnDataset(InMemoryDataset):
def __init__(self, root, transform=None, pre_transform=None):
super().__init__(root, transform, pre_transform)
self.load(self.processed_paths[0])
@property
def raw_file_names(self):
return ['my_data.csv']
@property
def processed_file_names(self):
return ['data.pt']
def download(self):
pass
def process(self):
data_list = []
edge_index = torch.tensor([[0, 1], [1, 0]], dtype=torch.long)
x = torch.randn(2, 16)
y = torch.tensor([0], dtype=torch.long)
data = Data(x=x, edge_index=edge_index, y=y)
data_list.append(data)
if self.pre_filter is not None:
data_list = [d for d in data_list if self.pre_filter(d)]
if self.pre_transform is not None:
data_list = [self.pre_transform(d) for d in data_list]
self.save(data_list, self.processed_paths[0])
For large datasets that don't fit in memory, inherit from Dataset and implement len() and get(idx).
Loading Graphs from CSV
import pandas as pd
import torch
from torch_geometric.data import HeteroData
nodes_df = pd.read_csv('nodes.csv')
x = torch.tensor(nodes_df[['feat1', 'feat2']].values, dtype=torch.float)
edges_df = pd.read_csv('edges.csv')
edge_index = torch.tensor([edges_df['source'].values,
edges_df['target'].values], dtype=torch.long)
data = Data(x=x, edge_index=edge_index)
Training Workflows
Node Classification (Single Graph)
import torch
import torch.nn.functional as F
from torch_geometric.datasets import Planetoid
dataset = Planetoid(root='/tmp/Cora', name='Cora')
data = dataset[0]
model = GCN(dataset.num_features, dataset.num_classes)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
model.train()
for epoch in range(200):
optimizer.zero_grad()
out = model(data)
loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])
loss.backward()
optimizer.step()
if epoch % 10 == 0:
print(f'Epoch {epoch}, Loss: {loss.item():.4f}')
model.eval()
pred = model(data).argmax(dim=1)
correct = (pred[data.test_mask] == data.y[data.test_mask]).sum()
acc = int(correct) / int(data.test_mask.sum())
print(f'Test Accuracy: {acc:.4f}')
Graph Classification (Multiple Graphs)
from torch_geometric.datasets import TUDataset
from torch_geometric.loader import DataLoader
from torch_geometric.nn import global_mean_pool
class GraphClassifier(torch.nn.Module):
def __init__(self, num_features, num_classes):
super().__init__()
self.conv1 = GCNConv(num_features, 64)
self.conv2 = GCNConv(64, 64)
self.lin = torch.nn.Linear(64, num_classes)
def forward(self, data):
x, edge_index, batch = data.x, data.edge_index, data.batch
x = self.conv1(x, edge_index)
x = F.relu(x)
x = self.conv2(x, edge_index)
x = F.relu(x)
x = global_mean_pool(x, batch)
x = self.lin(x)
return F.log_softmax(x, dim=1)
dataset = TUDataset(root='/tmp/ENZYMES', name='ENZYMES')
loader = DataLoader(dataset, batch_size=32, shuffle=True)
model = GraphClassifier(dataset.num_features, dataset.num_classes)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
model.train()
for epoch in range(100):
total_loss = 0
for batch in loader:
optimizer.zero_grad()
out = model(batch)
loss = F.nll_loss(out, batch.y)
loss.backward()
optimizer.step()
total_loss += loss.item()
if epoch % 10 == 0:
print(f'Epoch {epoch}, Loss: {total_loss / len(loader):.4f}')
Large-Scale Graphs with Neighbor Sampling
For large graphs, use NeighborLoader to sample subgraphs:
from torch_geometric.loader import NeighborLoader
train_loader = NeighborLoader(
data,
num_neighbors=[25, 10],
batch_size=128,
input_nodes=data.train_mask,
)
model.train()
for batch in train_loader:
optimizer.zero_grad()
out = model(batch)
loss = F.nll_loss(out[:batch.batch_size], batch.y[:batch.batch_size])
loss.backward()
optimizer.step()
Important:
- Output subgraphs are directed
- Node indices are relabeled (0 to batch.num_nodes - 1)
- Only use seed node predictions for loss computation
- Sampling beyond 2-3 hops is generally not feasible
Advanced Features
Heterogeneous Graphs
For graphs with multiple node and edge types, use HeteroData:
from torch_geometric.data import HeteroData
data = HeteroData()
data['paper'].x = torch.randn(100, 128)
data['author'].x = torch.randn(200, 64)
data['author', 'writes', 'paper'].edge_index = torch.randint(0, 200, (2, 500))
data['paper', 'cites', 'paper'].edge_index = torch.randint(0, 100, (2, 300))
print(data)
Convert homogeneous models to heterogeneous:
from torch_geometric.nn import to_hetero
model = GNN(...)
model = to_hetero(model, data.metadata(), aggr='sum')
out = model(data.x_dict, data.edge_index_dict)
Or use HeteroConv for custom edge-type-specific operations:
from torch_geometric.nn import HeteroConv, GCNConv, SAGEConv
class HeteroGNN(torch.nn.Module):
def __init__(self, metadata):
super().__init__()
self.conv1 = HeteroConv({
('paper', 'cites', 'paper'): GCNConv(-1, 64),
('author', 'writes', 'paper'): SAGEConv((-1, -1), 64),
}, aggr='sum')
self.conv2 = HeteroConv({
('paper', 'cites', 'paper'): GCNConv(64, 32),
('author', 'writes', 'paper'): SAGEConv((64, 64), 32),
}, aggr='sum')
def forward(self, x_dict, edge_index_dict):
x_dict = self.conv1(x_dict, edge_index_dict)
x_dict = {key: F.relu(x) for key, x in x_dict.items()}
x_dict = self.conv2(x_dict, edge_index_dict)
return x_dict
Transforms
Apply transforms to modify graph structure or features:
from torch_geometric.transforms import NormalizeFeatures, AddSelfLoops, Compose
transform = NormalizeFeatures()
dataset = Planetoid(root='/tmp/Cora', name='Cora', transform=transform)
transform = Compose([
AddSelfLoops(),
NormalizeFeatures(),
])
dataset = Planetoid(root='/tmp/Cora', name='Cora', transform=transform)
Common transforms:
- Structure:
ToUndirected, AddSelfLoops, RemoveSelfLoops, KNNGraph, RadiusGraph
- Features:
NormalizeFeatures, NormalizeScale, Center
- Sampling:
RandomNodeSplit, RandomLinkSplit
- Positional Encoding:
AddLaplacianEigenvectorPE, AddRandomWalkPE
See references/transforms_reference.md for the full list.
Model Explainability
PyG provides explainability tools to understand model predictions:
from torch_geometric.explain import Explainer, GNNExplainer
explainer = Explainer(
model=model,
algorithm=GNNExplainer(epochs=200),
explanation_type='model',
node_mask_type='attributes',
edge_mask_type='object',
model_config=dict(
mode='multiclass_classification',
task_level='node',
return_type='log_probs',
),
)
node_idx = 10
explanation = explainer(data.x, data.edge_index, index=node_idx)
print(f'Node {node_idx} explanation:')
print(f'Important edges: {explanation.edge_mask.topk(5).indices}')
print(f'Important features: {explanation.node_mask[node_idx].topk(5).indices}')
Pooling Operations
For hierarchical graph representations:
from torch_geometric.nn import TopKPooling, global_mean_pool
class HierarchicalGNN(torch.nn.Module):
def __init__(self, num_features, num_classes):
super().__init__()
self.conv1 = GCNConv(num_features, 64)
self.pool1 = TopKPooling(64, ratio=0.8)
self.conv2 = GCNConv(64, 64)
self.pool2 = TopKPooling(64, ratio=0.8)
self.lin = torch.nn.Linear(64, num_classes)
def forward(self, data):
x, edge_index, batch = data.x, data.edge_index, data.batch
x = F.relu(self.conv1(x, edge_index))
x, edge_index, _, batch, _, _ = self.pool1(x, edge_index, None, batch)
x = F.relu(self.conv2(x, edge_index))
x, edge_index, _, batch, _, _ = self.pool2(x, edge_index, None, batch)
x = global_mean_pool(x, batch)
x = self.lin(x)
return F.log_softmax(x, dim=1)
Common Patterns and Best Practices
Check Graph Properties
from torch_geometric.utils import is_undirected
print(f"Is undirected: {is_undirected(data.edge_index)}")
from torch_geometric.utils import connected_components
print(f"Connected components: {connected_components(data.edge_index)}")
from torch_geometric.utils import contains_self_loops
print(f"Has self-loops: {contains_self_loops(data.edge_index)}")
GPU Training
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
data = data.to(device)
for batch in loader:
batch = batch.to(device)
Save and Load Models
torch.save(model.state_dict(), 'model.pth')
model = GCN(num_features, num_classes)
model.load_state_dict(torch.load('model.pth'))
model.eval()
Layer Capabilities
When choosing layers, consider these capabilities:
- SparseTensor: Supports efficient sparse matrix operations
- edge_weight: Handles one-dimensional edge weights
- edge_attr: Processes multi-dimensional edge features
- Bipartite: Works with bipartite graphs (different source/target dimensions)
- Lazy: Enables initialization without specifying input dimensions
See the GNN cheatsheet at references/layer_capabilities.md.
Resources
Bundled References
This skill includes detailed reference documentation:
references/layers_reference.md: Complete listing of all 40+ GNN layers with descriptions and capabilities
references/datasets_reference.md: Comprehensive dataset catalog organized by category
references/transforms_reference.md: All available transforms and their use cases
references/api_patterns.md: Common API patterns and coding examples
Scripts
Utility scripts are provided in scripts/:
scripts/visualize_graph.py: Visualize graph structure using networkx and matplotlib
scripts/create_gnn_template.py: Generate boilerplate code for common GNN architectures
scripts/benchmark_model.py: Benchmark model performance on standard datasets
Execute scripts directly or read them for implementation patterns.
Official Resources