Image description

Introduction

This is the official documentation for SPARKS, a Sequential Predictive Autoencoder for the Representation of spiKing Signals. SPARKS combines a variational autoencoder with a novel Hebbian attention layer and predictive learning rule to obtain consistent latent embeddings that capture most of the variance from high dimensional neuronal responses.

About this library

sparks is built in python (v >= 3.12). using Pytorch. We provide high-level training/testing function to get quickly started, and users that are already familiar with Pytorch should be able to easily make full use of SPARKS. The following instructions were tested on a Linux server (Rocky Linux 8.7 Green Obsidian) with CUDA 12.2 and MacOS (Sonoma 14.5 with Apple M2 Max).

Getting Started

Installation

SPARKS can be installed via pip or conda.

Installation with pip

  1. Create a new virtual environment by running
    $ virtualenv .env && source .env/bin/activate

  2. Depending on your systems requirements, you may want to install Pytorch manually, for instance to install the correct CUDA version (please see more details here) .

  3. Install using pip:
    $ pip install sparks-ai

Installation with conda (recommended)

  1. Install Anaconda or Miniconda (please see more details here).

  2. Create a new conda environment by running
    $ conda create -n sparks python==3.12
    $ conda activate sparks
  3. Depending on your systems requirements, you may want to install Pytorch manually, for instance to install the correct CUDA version (please see more details here) .

  4. Install using pip:
    $ pip install sparks-ai

Important: By default, sparks is installed without the dependencies required to run the demos. To install the library with all dependencies, please run pip install sparks[demos]. In particular, due to compatibility issues with the allen_sdk library, please see how to install it from the multisession demo notebook.

Minimum example

sparks follows standard PyTorch syntax, and scripts will often look like this:


train_dl = MyTrainDataLoader()  # the train/test functions expect a list of dls, one for each session
test_dl = MyTrainDataLoader()
sparks = SPARKS(n_neurons_per_session=n_neurons,  # number of neurons in the recording(s)
                embed_dim=embed_dim,  #  attention embedding dimension
                latent_dim=latent_dim,  # latent dimension
                tau_p=tau_p,  # Number of previous time-steps in the PCC
                tau_f=tau_f,  # Number of future time-steps in the PCC
                output_dim_per_session=output_dim_per_session,  # output dimension
                ).to(device)

optimizer = torch.optim.Adam(list(encoding_network.parameters())
                             + list(decoding_network.parameters()), lr=lr)
loss_fn = torch.nn.BCEWithLogitsLoss()

for epoch in tqdm.tqdm(range(args.n_epochs)):
    train(sparks=sparks,
          train_dls=[train_dl],
          loss_fn=loss_fn,
          optimizer=optimizer,
          beta=beta,  # KLD regularisation strength
          device=device)

    if (epoch + 1) % test_period == 0:
        test_loss, encoder_outputs, decoder_outputs = test(sparks,
                                                           [test_dl],
                                                           loss_fn=loss_fn,
                                                           device=device)

Configuration

Encoder configuration

The detailed configuration of the encoder is done through the HebbianAttentionConfig, AttentionConfig and ProjectionConfig dataclasses:


from sparks.models.dataclasses import HebbianAttentionConfig, AttentionConfig, ProjectionConfig
my_hebbian_config = HebbianAttentionConfig(tau_s= 1.0  # decay time constant in ms,
                                           dt=0.001      #  time-step in ms
                                           n_heads=1 # number of attention heads
                                           data_type='ephys' # Type of data, can be 'ephys' or 'calcium'
                                           sliding=False # Whether to use sliding window attention
                                           window_size=1 # The size of the sliding window
                                           block_size=1 # The size of the block for the attention mechanism in sliding mode
                                           params=field(default_factory=dict) # any other custom parameters
                                           )
my_conv_attn_config = AttentionConfig(n_layers=1, # number of conventional attention layers
                                      n_heads: int = 1,
                                      params=field(default_factory=dict) # any other custom parameters
                                      )

my_custom_projection = ProjectionConfig(custom_head=my_projection) # a torch.nn.Module object

Decoder configuration

By default, SPARKS uses for decoder an MLP with a single hidden layer, but you can directly pass your favourite decoder to the model. The forward pass of the decoder should take as arguments (latents, session_id) here is an example with a simple linear decoder:


class linear(nn.Module):
    def __init__(self, 
                 in_dim: int, # latent_dim x tau_p
                 output_dim_per_session: Any,
                 id_per_session: Any = None) -> None:

        super(linear, self).__init__()

        self.out_layers = nn.ModuleDict({str(sess_id): nn.Linear(in_dim, output_dim)
                                        for sess_id, output_dim in zip(id_per_session, output_dim_per_session)})

    def forward(self, x, sess_id: int = 0) -> torch.Tensor:
        sess_id = str(sess_id)
        return self.out_layers[sess_id](x.flatten(1))

my_decoder = linear(in_dim, output_dim_per_session=[output_dim], id_per_sess=[0])
sparks = SPARKS(n_neurons_per_session=in_dim,  # number of neurons in the recording(s)
                embed_dim=embed_dim,  #  attention embedding dimension
                latent_dim=latent_dim,  # latent dimension
                tau_p=tau_p,  # Number of previous time-steps in the PCC
                tau_f=tau_f,  # Number of future time-steps in the PCC
                output_dim_per_session=output_dim,  # output dimension,
                ).to(device)

Data

Processing

The train and test functions rely on the user to provide a list of PyTorch Dataloaders, one for each session to train/test on.

In the underlying Pytorch Dataset (or any iterable), examples should be formatted with shape NxT for the neural signals and CxHxWx...xT for targets during supervised learning, with N the number of neurons, T the number of time-steps, and CxHxWx... the arbitrary dimensions of the target signal.

Training

Online training

Training relies on backpropagation-through-time (BPTT), for which memory requirements scale quadratically with the number of time-steps.

To allow training on long examples, we support an online version of the algorithm that ignores the dependency of the current encoder state on previous time-steps, and has fixed complexity w.r.t. the number of time-steps. This can be used by using the online argument of the train function:


train(sparks=sparks,
      train_dls=[train_dl],
      loss_fn=loss_fn,
      optimizer=optimizer,
      beta=beta,  # KLD regularisation strength
      online=True,  # False by default
      device=device) 

Multi session training

When training on multiple sessions, SPARKS trains one Hebbian attention layer per session. To resolve the correct layer at runtime, each is assigned a unique id, which are by default zero-indexed. See the corresponding demo notebook for more details!

Sparse attention

Now available in beta version: sparse sliding-window attention for scaling up the number of neurons. Feel free to look at the implementation in sparks.models.sparse_attention for details, documentation and demos coming soon.