| name | neural-dynamics-universal-translator-foundation |
| description | Foundation model for neural spiking data using multi-task masking (MtM) to translate across population, region, and single-neuron levels. Enables zero-shot and few-shot brain decoding across multiple brain areas. Activation triggers: neural translator, foundation model spiking, MtM, multi-task masking, IBL dataset, brain decoding. |
Neural Dynamics Universal Translator Foundation
A foundation model for neural spiking data that seamlessly "translates" across all spatial scales of the brain through multi-task masking self-supervised learning.
Metadata
- Source: arXiv:2407.14668
- Authors: Yizi Zhang, Yanchen Wang, Donato Jiménez Benetó, Zixuan Wang, Mehdi Azabou, Blake Richards, Olivier Winter, Eva Dyer, Liam Paninski, Cole Hurwitz
- Published: 2024-07
- Code: https://ibl-mtm.github.io/
Core Methodology
Key Innovation
The Neural Dynamics Universal Translator introduces a Multi-Task Masking (MtM) approach that enables a single foundation model to:
- Process neural activity across different time steps, neurons, and brain regions
- Generalize to unseen animals with unspecified neuron correspondence
- Perform few-shot learning with minimal supervision
- Bridge the gap between brain region-specific models and whole-brain analysis
Technical Framework
1. Multi-Task Masking Strategy
The model alternates between masking and reconstructing neural activity across three dimensions:
| Masking Type | Description | Prompt Token |
|---|
| Temporal Masking | Masks time steps | TEMPORAL |
| Neuron Masking | Masks individual neurons | NEURON |
| Region Masking | Masks entire brain regions | REGION |
Each masking objective is associated with a learnable prompt token that enables "mode switching" during evaluation.
2. Model Architecture
Input: Neural spike trains (time × neurons)
│
├─► Embedding Layer: Convert spike counts to embeddings
│
├─► Transformer Encoder: Process temporal dependencies
│
├─► Prompt Token: Select masking objective (TEMPORAL/NEURON/REGION)
│
└─► Output: Reconstruct masked activity
3. Training Dataset
- International Brain Laboratory (IBL) Repeated Site Dataset
- 48 animals across multiple experimental sessions
- Target regions: Secondary visual areas, hippocampus, thalamus
- Neuropixels recordings with consistent anatomical targeting
Mathematical Formulation
Self-Supervised Objective
Given neural activity tensor $X \in \mathbb{R}^{T \times N}$ where $T$ is time steps and $N$ is neurons:
-
Apply masking based on selected mode $m$:
- Temporal: Mask $X_{t_1:t_2, :}$
- Neuron: Mask $X_{:, \mathcal{N}_{masked}}$
- Region: Mask $X_{:, \mathcal{R}_{masked}}$
-
Add prompt embedding $p_m$ corresponding to mode $m$
-
Minimize reconstruction loss:
$$\mathcal{L} = \mathbb{E}_{X \sim \mathcal{D}} \left[ | X - \hat{X} |^2 \right]$$
where $\hat{X} = f_\theta(X_{unmasked}, p_m)$
Implementation Guide
Prerequisites
pip install torch numpy scipy pandas
pip install ibl-neuropixel
Step-by-Step
Step 1: Load and Preprocess Neural Data
import numpy as np
from scipy.io import loadmat
def load_neural_data(session_path):
"""Load Neuropixels data from IBL dataset"""
spikes = loadmat(f"{session_path}/spikes.times.npy")
clusters = loadmat(f"{session_path}/spikes.clusters.npy")
bin_size = 0.01
duration = 2.0
spike_counts = bin_spikes(spikes, clusters, bin_size, duration)
return spike_counts
def bin_spikes(spike_times, clusters, bin_size, duration):
"""Convert spike times to count matrix"""
n_bins = int(duration / bin_size)
n_neurons = int(clusters.max()) + 1
counts = np.zeros((n_bins, n_neurons))
for i in range(n_bins):
t_start = i * bin_size
t_end = (i + 1) * bin_size
mask = (spike_times >= t_start) & (spike_times < t_end)
active_clusters = clusters[mask]
for c in active_clusters:
counts[i, c] += 1
return counts
Step 2: Implement Multi-Task Masking
import torch
import torch.nn as nn
class MultiTaskMasking(nn.Module):
"""Multi-task masking for neural activity"""
def __init__(self, n_neurons, embed_dim=128, n_heads=8):
super().__init__()
self.embed_dim = embed_dim
self.prompts = nn.ParameterDict({
'temporal': nn.Parameter(torch.randn(1, 1, embed_dim)),
'neuron': nn.Parameter(torch.randn(1, 1, embed_dim)),
'region': nn.Parameter(torch.randn(1, 1, embed_dim))
})
self.input_embedding = nn.Linear(n_neurons, embed_dim)
encoder_layer = nn.TransformerEncoderLayer(
d_model=embed_dim,
nhead=n_heads,
dim_feedforward=512,
batch_first=True
)
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=6)
self.output_proj = nn.Linear(embed_dim, n_neurons)
def forward(self, x, mode, mask_indices):
"""
Args:
x: Input spike counts (batch, time, neurons)
mode: 'temporal', 'neuron', or 'region'
mask_indices: Indices to mask
Returns:
reconstructed: Reconstructed activity
"""
mask = torch.ones_like(x)
mask[mask_indices] = 0
x_embed = self.input_embedding(x * mask)
prompt = self.prompts[mode]
x_embed = torch.cat([prompt.expand(x.size(0), -1, -1), x_embed], dim=1)
h = self.transformer(x_embed)
reconstructed = self.output_proj(h[:, 1:, :])
return reconstructed
Step 3: Training Loop
def train_mtm_model(model, dataloader, epochs=100, lr=1e-4):
"""Train multi-task masking model"""
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
criterion = nn.MSELoss()
modes = ['temporal', 'neuron', 'region']
mask_ratio = 0.15
for epoch in range(epochs):
total_loss = 0
for batch in dataloader:
x = batch['activity']
batch_size, T, N = x.shape
mode = np.random.choice(modes)
if mode == 'temporal':
mask_indices = create_temporal_mask(batch_size, T, N, mask_ratio)
elif mode == 'neuron':
mask_indices = create_neuron_mask(batch_size, T, N, mask_ratio)
else:
mask_indices = create_region_mask(batch_size, T, N, mask_ratio)
reconstructed = model(x, mode, mask_indices)
loss = criterion(reconstructed[mask_indices], x[mask_indices])
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
if epoch % 10 == 0:
print(f"Epoch {epoch}, Loss: {total_loss / len(dataloader):.4f}")
return model
def create_temporal_mask(batch_size, T, N, ratio):
"""Create temporal masking indices"""
n_mask = int(T * ratio)
mask_t = torch.randperm(T)[:n_mask]
mask_indices = torch.cartesian_prod(
torch.arange(batch_size),
mask_t,
torch.arange(N)
)
return mask_indices[:, 0], mask_indices[:, 1], mask_indices[:, 2]
Step 4: Few-Shot Adaptation
class FewShotAdapter(nn.Module):
"""Linear probe for few-shot downstream tasks"""
def __init__(self, encoder, output_dim):
super().__init__()
self.encoder = encoder
for param in self.encoder.parameters():
param.requires_grad = False
self.classifier = nn.Linear(encoder.embed_dim, output_dim)
def forward(self, x):
with torch.no_grad():
features = self.encoder(x, mode='neuron', mask_indices=None)
return self.classifier(features.mean(dim=1))
def few_shot_adapt(model, support_set, n_shots=5):
"""
Adapt model with few labeled examples
Args:
model: Pretrained MtM model
support_set: Dict with 'activity' and 'labels'
n_shots: Number of examples per class
"""
adapter = FewShotAdapter(model, output_dim=n_classes)
optimizer = torch.optim.Adam(adapter.classifier.parameters(), lr=1e-3)
for epoch in range(50):
logits = adapter(support_set['activity'])
loss = F.cross_entropy(logits, support_set['labels'])
optimizer.zero_grad()
loss.backward()
optimizer.step()
return adapter
Applications
1. Brain-Computer Interfaces (BCI)
- Zero-shot decoding on new subjects
- Reduced calibration time for neural prosthetics
- Cross-subject motor imagery classification
2. Cross-Region Neural Analysis
- Study information flow between brain regions
- Identify region-specific neural codes
- Map distributed neural computation
3. Behavior Prediction
- Decode behavioral states from neural activity
- Predict decision-making processes
- Analyze cognitive task performance
4. Neurological Disorder Research
- Compare neural dynamics across patient populations
- Identify biomarkers for brain disorders
- Track disease progression
Benchmarks
| Task | Metric | Performance |
|---|
| Single-neuron prediction | R² | 0.72 |
| Region-level prediction | R² | 0.68 |
| Forward prediction (200ms) | MAE | 0.15 |
| Behavior decoding (choice) | Accuracy | 82% |
| Cross-animal generalization | R² | 0.61 |
Pitfalls
- Data Quality: Model performance heavily depends on spike sorting quality
- Temporal Resolution: Requires high temporal resolution recordings (≥1kHz sampling)
- Recording Stability: Assumes consistent electrode placement across sessions
- Animal Variability: May require fine-tuning for animals with significant anatomical differences
- Computational Cost: Large transformer models require significant GPU memory
Related Skills
- brain-dit-fmri-foundation-model
- neurostorm-fmri-foundation
- reve-eeg-foundation
- spike-mllm-multimodal-spiking
- meta-learning-in-context-brain-decoding
References
@article{zhang2024universal,
title={Towards a "universal translator" for neural dynamics at single-cell, single-spike resolution},
author={Zhang, Yizi and Wang, Yanchen and Jim{\'e}nez Benet{\'o}, Donato and Wang, Zixuan and Azabou, Mehdi and Richards, Blake and Winter, Olivier and others},
journal={arXiv preprint arXiv:2407.14668},
year={2024}
}