This is the official documentation for SPARKS, a Sequential Predictive Autoencoder for the Representation of spiKing Signals. SPARKS combines a variational autoencoder with a novel Hebbian attention layer and predictive learning rule to obtain consistent latent embeddings that capture most of the variance from high dimensional neuronal responses.
sparks
is built in python (v >= 3.9). using Pytorch. We provide high-level training/testing function to get quickly started, and users that are already familiar with Pytorch should be able to easily make full use of SPARKS. The following instructions were tested on a Linux server (Rocky Linux 8.7 Green Obsidian) with CUDA 12.2 and MacOS (Sonoma 14.5 with Apple M2 Max).
SPARKS can be installed via pip
or conda
.
$ virtualenv .env && source .env/bin/activate
$ pip install sparks-ai
$ conda create -n sparks python==3.9
$ conda activate sparks
$ pip install sparks-ai
Important: By default, sparks
is installed without the dependencies required to run scripts using the monkey reaching and Allen datasets (Chowdury et al., 2020, de Vries et al., 2020). To install the library with all dependencies, please run pip install sparks[scripts]
. However, due to compatibility issues with the allen_sdk
library, we recommend installing it first from this link.
sparks
follows standard PyTorch syntax, and scripts will often look like this:
train_dl = MyTrainDataLoader() # the train/test functions expect a list of dls, one for each session
test_dl = MyTrainDataLoader()
encoding_network = HebbianTransformerEncoder(n_neurons_per_sess=n_neurons, # number of neurons in the recording(s)
embed_dim=embed_dim, # attention embedding dimension
latent_dim=latent_dim, # latent dimension
tau_s_per_sess=tau_s, # STDP decay(s) in seconds, one per attention head
dt_per_sess=dt, # sampling period
n_layers=n_layers, # number of conventional attention layers
n_heads=n_heads, # number of attention heads,
).to(device)
n_inputs_decoder = latent_dim * tau_p
decoding_network = mlp(in_dim=n_inputs_decoder,
output_dim_per_session=input_size).to(device)
optimizer = torch.optim.Adam(list(encoding_network.parameters())
+ list(decoding_network.parameters()), lr=lr)
loss_fn = torch.nn.BCEWithLogitsLoss()
for epoch in tqdm.tqdm(range(args.n_epochs)):
train(encoder=encoding_network,
decoder=decoding_network,
train_dls=[train_dl],
loss_fn=loss_fn,
optimizer=optimizer,
latent_dim=latent_dim,
tau_p=tau_p, # past time-steps window
tau_f=tau_f, # future time-steps window
beta=beta, # KLD regularisation strength
device=device)
if (epoch + 1) % test_period == 0:
test_loss, encoder_outputs, decoder_outputs = test(encoding_network,
decoding_network,
test_dl,
latent_dim=latent_dim,
tau_p=tau_p,
tau_f=tau_f,
loss_fn=loss_fn,
device=device,)
sparks
supports both electrophysiology and calcium imaging data. Make sure to select the correct data type with the data_type
argument during the encoder initialisation:
my_ephys_encoder = HebbianTransformerEncoder(n_neurons_per_sess=n_neurons, # number of neurons in the recording(s)
embed_dim=embed_dim, # attention embedding dimension
latent_dim=latent_dim, # latent dimension
tau_s_per_sess=tau_s, # STDP decay(s) in seconds, one per attention head
dt_per_sess=dt, # sampling period
n_layers=n_layers, # number of conventional attention layers
n_heads=n_heads, # number of attention heads,
data_type='ephys' # by default
).to(device)
my_ca_encoder = HebbianTransformerEncoder(n_neurons_per_sess=n_neurons, # number of neurons in the recording(s)
embed_dim=embed_dim, # attention embedding dimension
latent_dim=latent_dim, # latent dimension
tau_s_per_sess=tau_s, # STDP decay(s) in seconds, one per attention head
dt_per_sess=dt, # sampling period
n_layers=n_layers, # number of conventional attention layers
n_heads=n_heads, # number of attention heads,
data_type='ca'
).to(device)
The train
and test
functions rely on the user to provide a list of PyTorch Dataloaders
, one for each session to train/test on.
In the underlying Pytorch Dataset
(or any iterable), examples should be formatted with shape NxT
for the neural signals and CxHxWx...xT
for targets during supervised learning, with N
the number of neurons, T
the number of time-steps, and CxHxWx...
the arbitrary dimensions of the target signal.
Training relies on backpropagation-through-time (BPTT), for which memory requirements scale quadratically with the number of time-steps.
To allow training on long examples, we support an online version of the algorithm that ignores the dependency of the current encoder state on previous time-steps, and has fixed complexity w.r.t. the number of time-steps. This can be used by using the online
argument of the train
function:
train(encoder=encoding_network,
decoder=decoding_network,
train_dls=[train_dl],
loss_fn=loss_fn,
optimizer=optimizer,
latent_dim=latent_dim,
tau_p=tau_p, # past time-steps window
tau_f=tau_f, # future time-steps window
beta=beta, # KLD regularisation strength
online=True, # False by default
device=device)
When training on multiple sessions, SPARKS trains one Hebbian attention layer per session. To resolve the correct layer at runtime, each is assigned a unique id, which are by default zero-indexed:
train_dls = [MyTrainDl(session) for session in [session_1, session_2, session_3]]
sessions_ids = [111, 222, 333]
n_neurons_per_sess = [10, 20, 30]
encoding_network = HebbianTransformerEncoder(n_neurons_per_sess=n_neurons, # number of neurons in the recording(s)
embed_dim=embed_dim, # attention embedding dimension
latent_dim=latent_dim, # latent dimension
tau_s_per_sess=tau_s, # STDP decay(s) in seconds, one per attention head
dt_per_sess=dt, # sampling period
n_layers=n_layers, # number of conventional attention layers
n_heads=n_heads, # number of attention heads,
).to(device)
print(encoding_network.id_per_sess) # outputs [0, 1, 2]
# the training function will automatically assign ids [0, 1, 2] to each session and resolve them
# against those saved by the model at runtime
test(encoder=encoding_network,
decoder=decoding_network,
test_dls=train_dls,
loss_fn=loss_fn,
optimizer=optimizer,
latent_dim=latent_dim,
tau_p=tau_p, # past time-steps window
tau_f=tau_f, # future time-steps window
beta=beta, # KLD regularisation strength
online=True, # False by default
device=device)
You can also define custom ids for your sessions, and pass them to the test or train function:
test_dls = [MyTestDl(session) for session in [session_1, session_2, session_3]]
sessions_ids = [111, 222, 333]
encoding_network = HebbianTransformerEncoder(n_neurons_per_sess=n_neurons, # number of neurons in the recording(s)
embed_dim=embed_dim, # attention embedding dimension
latent_dim=latent_dim, # latent dimension
tau_s_per_sess=tau_s, # STDP decay(s) in seconds, one per attention head
dt_per_sess=dt, # sampling period
n_layers=n_layers, # number of conventional attention layers
n_heads=n_heads, # number of attention heads
id_per_sess=sessions_ids
).to(device)
train(encoder=encoding_network,
decoder=decoding_network,
train_dls=[train_dl],
loss_fn=loss_fn,
optimizer=optimizer,
latent_dim=latent_dim,
tau_p=tau_p, # past time-steps window
tau_f=tau_f, # future time-steps window
beta=beta, # KLD regularisation strength
online=True, # False by default
sess_ids=sessions_ids,
device=device)
Sparse attention for scaling up the number of neurons.