This is the official documentation for SPARKS, a Sequential Predictive Autoencoder for the Representation of spiKing Signals. SPARKS combines a variational autoencoder with a novel Hebbian attention layer and predictive learning rule to obtain consistent latent embeddings that capture most of the variance from high dimensional neuronal responses.
sparks
is built in python (v >= 3.12). using Pytorch. We provide high-level training/testing function to get quickly started, and users that are already familiar with Pytorch should be able to easily make full use of SPARKS. The following instructions were tested on a Linux server (Rocky Linux 8.7 Green Obsidian) with CUDA 12.2 and MacOS (Sonoma 14.5 with Apple M2 Max).
SPARKS can be installed via pip
or conda
.
$ virtualenv .env && source .env/bin/activate
$ pip install sparks-ai
$ conda create -n sparks python==3.12
$ conda activate sparks
$ pip install sparks-ai
Important: By default, sparks
is installed without the dependencies required to run the demos. To install the library with all dependencies, please run pip install sparks[demos]
. In particular, due to compatibility issues with the allen_sdk
library, please see how to install it from the multisession demo notebook.
sparks
follows standard PyTorch syntax, and scripts will often look like this:
train_dl = MyTrainDataLoader() # the train/test functions expect a list of dls, one for each session
test_dl = MyTrainDataLoader()
sparks = SPARKS(n_neurons_per_session=n_neurons, # number of neurons in the recording(s)
embed_dim=embed_dim, # attention embedding dimension
latent_dim=latent_dim, # latent dimension
tau_p=tau_p, # Number of previous time-steps in the PCC
tau_f=tau_f, # Number of future time-steps in the PCC
output_dim_per_session=output_dim_per_session, # output dimension
).to(device)
optimizer = torch.optim.Adam(list(encoding_network.parameters())
+ list(decoding_network.parameters()), lr=lr)
loss_fn = torch.nn.BCEWithLogitsLoss()
for epoch in tqdm.tqdm(range(args.n_epochs)):
train(sparks=sparks,
train_dls=[train_dl],
loss_fn=loss_fn,
optimizer=optimizer,
beta=beta, # KLD regularisation strength
device=device)
if (epoch + 1) % test_period == 0:
test_loss, encoder_outputs, decoder_outputs = test(sparks,
[test_dl],
loss_fn=loss_fn,
device=device)
The detailed configuration of the encoder is done through the HebbianAttentionConfig
, AttentionConfig
and ProjectionConfig
dataclasses:
from sparks.models.dataclasses import HebbianAttentionConfig, AttentionConfig, ProjectionConfig
my_hebbian_config = HebbianAttentionConfig(tau_s= 1.0 # decay time constant in ms,
dt=0.001 # time-step in ms
n_heads=1 # number of attention heads
data_type='ephys' # Type of data, can be 'ephys' or 'calcium'
sliding=False # Whether to use sliding window attention
window_size=1 # The size of the sliding window
block_size=1 # The size of the block for the attention mechanism in sliding mode
params=field(default_factory=dict) # any other custom parameters
)
my_conv_attn_config = AttentionConfig(n_layers=1, # number of conventional attention layers
n_heads: int = 1,
params=field(default_factory=dict) # any other custom parameters
)
my_custom_projection = ProjectionConfig(custom_head=my_projection) # a torch.nn.Module object
By default, SPARKS uses for decoder an MLP with a single hidden layer, but you can directly pass your favourite decoder to the model. The forward pass of the decoder should take as arguments (latents, session_id) here is an example with a simple linear decoder:
class linear(nn.Module):
def __init__(self,
in_dim: int, # latent_dim x tau_p
output_dim_per_session: Any,
id_per_session: Any = None) -> None:
super(linear, self).__init__()
self.out_layers = nn.ModuleDict({str(sess_id): nn.Linear(in_dim, output_dim)
for sess_id, output_dim in zip(id_per_session, output_dim_per_session)})
def forward(self, x, sess_id: int = 0) -> torch.Tensor:
sess_id = str(sess_id)
return self.out_layers[sess_id](x.flatten(1))
my_decoder = linear(in_dim, output_dim_per_session=[output_dim], id_per_sess=[0])
sparks = SPARKS(n_neurons_per_session=in_dim, # number of neurons in the recording(s)
embed_dim=embed_dim, # attention embedding dimension
latent_dim=latent_dim, # latent dimension
tau_p=tau_p, # Number of previous time-steps in the PCC
tau_f=tau_f, # Number of future time-steps in the PCC
output_dim_per_session=output_dim, # output dimension,
).to(device)
The train
and test
functions rely on the user to provide a list of PyTorch Dataloaders
, one for each session to train/test on.
In the underlying Pytorch Dataset
(or any iterable), examples should be formatted with shape NxT
for the neural signals and CxHxWx...xT
for targets during supervised learning, with N
the number of neurons, T
the number of time-steps, and CxHxWx...
the arbitrary dimensions of the target signal.
Training relies on backpropagation-through-time (BPTT), for which memory requirements scale quadratically with the number of time-steps.
To allow training on long examples, we support an online version of the algorithm that ignores the dependency of the current encoder state on previous time-steps, and has fixed complexity w.r.t. the number of time-steps. This can be used by using the online
argument of the train
function:
train(sparks=sparks,
train_dls=[train_dl],
loss_fn=loss_fn,
optimizer=optimizer,
beta=beta, # KLD regularisation strength
online=True, # False by default
device=device)
When training on multiple sessions, SPARKS trains one Hebbian attention layer per session. To resolve the correct layer at runtime, each is assigned a unique id, which are by default zero-indexed. See the corresponding demo notebook for more details!
Now available in beta version: sparse sliding-window attention for scaling up the number of neurons.
Feel free to look at the implementation in sparks.models.sparse_attention
for details, documentation and demos coming soon.