Image description

Introduction

This is the official documentation for SPARKS, a Sequential Predictive Autoencoder for the Representation of spiKing Signals. SPARKS combines a variational autoencoder with a novel Hebbian attention layer and predictive learning rule to obtain consistent latent embeddings that capture most of the variance from high dimensional neuronal responses.

About this library

sparks is built in python (v >= 3.9). using Pytorch. We provide high-level training/testing function to get quickly started, and users that are already familiar with Pytorch should be able to easily make full use of SPARKS. The following instructions were tested on a Linux server (Rocky Linux 8.7 Green Obsidian) with CUDA 12.2 and MacOS (Sonoma 14.5 with Apple M2 Max).

Getting Started

Installation

SPARKS can be installed via pip or conda.

Installation with pip

  1. Create a new virtual environment by running
    $ virtualenv .env && source .env/bin/activate

  2. Depending on your systems requirements, you may want to install Pytorch manually, for instance to install the correct CUDA version (please see more details here) .

  3. Install using pip:
    $ pip install sparks-ai

Installation with conda (recommended)

  1. Install Anaconda or Miniconda (please see more details here).

  2. Create a new conda environment by running
    $ conda create -n sparks python==3.9
    $ conda activate sparks
  3. Depending on your systems requirements, you may want to install Pytorch manually, for instance to install the correct CUDA version (please see more details here) .

  4. Install using pip:
    $ pip install sparks-ai

Important: By default, sparks is installed without the dependencies required to run scripts using the monkey reaching and Allen datasets (Chowdury et al., 2020, de Vries et al., 2020). To install the library with all dependencies, please run pip install sparks[scripts]. However, due to compatibility issues with the allen_sdk library, we recommend installing it first from this link.

Minimum example

sparks follows standard PyTorch syntax, and scripts will often look like this:


train_dl = MyTrainDataLoader()  # the train/test functions expect a list of dls, one for each session
test_dl = MyTrainDataLoader()
encoding_network = HebbianTransformerEncoder(n_neurons_per_sess=n_neurons,  # number of neurons in the recording(s)
                                             embed_dim=embed_dim,  #  attention embedding dimension
                                             latent_dim=latent_dim,  # latent dimension
                                             tau_s_per_sess=tau_s,  # STDP decay(s) in seconds, one per attention head
                                             dt_per_sess=dt,  # sampling period
                                             n_layers=n_layers,  # number of conventional attention layers
                                             n_heads=n_heads,  # number of attention heads,  
                                             ).to(device)
n_inputs_decoder = latent_dim * tau_p
decoding_network = mlp(in_dim=n_inputs_decoder,
                       output_dim_per_session=input_size).to(device)

optimizer = torch.optim.Adam(list(encoding_network.parameters())
                             + list(decoding_network.parameters()), lr=lr)
loss_fn = torch.nn.BCEWithLogitsLoss()

for epoch in tqdm.tqdm(range(args.n_epochs)):
    train(encoder=encoding_network,
                   decoder=decoding_network,
                   train_dls=[train_dl],
                   loss_fn=loss_fn,
                   optimizer=optimizer,
                   latent_dim=latent_dim,
                   tau_p=tau_p,  # past time-steps window
                   tau_f=tau_f,  # future time-steps window
                   beta=beta,  # KLD regularisation strength
                   device=device)

    if (epoch + 1) % test_period == 0:
        test_loss, encoder_outputs, decoder_outputs = test(encoding_network,
                                                           decoding_network,
                                                           test_dl,
                                                           latent_dim=latent_dim,
                                                           tau_p=tau_p,
                                                           tau_f=tau_f,
                                                           loss_fn=loss_fn,
                                                           device=device,)

Data

Support

sparks supports both electrophysiology and calcium imaging data. Make sure to select the correct data type with the data_type argument during the encoder initialisation:


my_ephys_encoder = HebbianTransformerEncoder(n_neurons_per_sess=n_neurons,  # number of neurons in the recording(s)
                                             embed_dim=embed_dim,  #  attention embedding dimension
                                             latent_dim=latent_dim,  # latent dimension
                                             tau_s_per_sess=tau_s,  # STDP decay(s) in seconds, one per attention head
                                             dt_per_sess=dt,  # sampling period
                                             n_layers=n_layers,  # number of conventional attention layers
                                             n_heads=n_heads,  # number of attention heads,
                                             data_type='ephys'  # by default  
                                             ).to(device)


my_ca_encoder = HebbianTransformerEncoder(n_neurons_per_sess=n_neurons,  # number of neurons in the recording(s)
                                          embed_dim=embed_dim,  #  attention embedding dimension
                                          latent_dim=latent_dim,  # latent dimension
                                          tau_s_per_sess=tau_s,  # STDP decay(s) in seconds, one per attention head
                                          dt_per_sess=dt,  # sampling period
                                          n_layers=n_layers,  # number of conventional attention layers
                                          n_heads=n_heads,  # number of attention heads,
                                          data_type='ca'  
                                          ).to(device)

Processing

The train and test functions rely on the user to provide a list of PyTorch Dataloaders, one for each session to train/test on.

In the underlying Pytorch Dataset (or any iterable), examples should be formatted with shape NxT for the neural signals and CxHxWx...xT for targets during supervised learning, with N the number of neurons, T the number of time-steps, and CxHxWx... the arbitrary dimensions of the target signal.

Training

Online training

Training relies on backpropagation-through-time (BPTT), for which memory requirements scale quadratically with the number of time-steps.

To allow training on long examples, we support an online version of the algorithm that ignores the dependency of the current encoder state on previous time-steps, and has fixed complexity w.r.t. the number of time-steps. This can be used by using the online argument of the train function:


train(encoder=encoding_network,
      decoder=decoding_network,
      train_dls=[train_dl],
      loss_fn=loss_fn,
      optimizer=optimizer,
      latent_dim=latent_dim,
      tau_p=tau_p,  # past time-steps window
      tau_f=tau_f,  # future time-steps window
      beta=beta,  # KLD regularisation strength
      online=True,  # False by default
      device=device) 

Multi session training

When training on multiple sessions, SPARKS trains one Hebbian attention layer per session. To resolve the correct layer at runtime, each is assigned a unique id, which are by default zero-indexed:


train_dls = [MyTrainDl(session) for session in [session_1, session_2, session_3]]
sessions_ids = [111, 222, 333]
n_neurons_per_sess = [10, 20, 30]

encoding_network = HebbianTransformerEncoder(n_neurons_per_sess=n_neurons,  # number of neurons in the recording(s)
                                             embed_dim=embed_dim,  #  attention embedding dimension
                                             latent_dim=latent_dim,  # latent dimension
                                             tau_s_per_sess=tau_s,  # STDP decay(s) in seconds, one per attention head
                                             dt_per_sess=dt,  # sampling period
                                             n_layers=n_layers,  # number of conventional attention layers
                                             n_heads=n_heads,  # number of attention heads,
                                             ).to(device)

print(encoding_network.id_per_sess) # outputs [0, 1, 2]

# the training function will automatically assign ids [0, 1, 2] to each session and resolve them
# against those saved by the model at runtime
test(encoder=encoding_network,
     decoder=decoding_network,
     test_dls=train_dls,
     loss_fn=loss_fn,
     optimizer=optimizer,
     latent_dim=latent_dim,
     tau_p=tau_p,  # past time-steps window
     tau_f=tau_f,  # future time-steps window
     beta=beta,  # KLD regularisation strength
     online=True,  # False by default
     device=device) 

You can also define custom ids for your sessions, and pass them to the test or train function:



test_dls = [MyTestDl(session) for session in [session_1, session_2, session_3]]
sessions_ids = [111, 222, 333]


encoding_network = HebbianTransformerEncoder(n_neurons_per_sess=n_neurons,  # number of neurons in the recording(s)
                                             embed_dim=embed_dim,  #  attention embedding dimension
                                             latent_dim=latent_dim,  # latent dimension
                                             tau_s_per_sess=tau_s,  # STDP decay(s) in seconds, one per attention head
                                             dt_per_sess=dt,  # sampling period
                                             n_layers=n_layers,  # number of conventional attention layers
                                             n_heads=n_heads,  # number of attention heads
                                             id_per_sess=sessions_ids  
                                             ).to(device)

train(encoder=encoding_network,
               decoder=decoding_network,
               train_dls=[train_dl],
               loss_fn=loss_fn,
               optimizer=optimizer,
               latent_dim=latent_dim,
               tau_p=tau_p,  # past time-steps window
               tau_f=tau_f,  # future time-steps window
               beta=beta,  # KLD regularisation strength
               online=True,  # False by default
               sess_ids=sessions_ids,
               device=device) 

Coming soon!

Sparse attention for scaling up the number of neurons.