BaseModel

class spikeometric.models.BaseModel[source]

Bases: torch_geometric.nn.conv.message_passing.MessagePassing

Base class for all spiking neural networks.

Extends the MessagePassing class from torch_geometric by adding stimulus support and a forward method that calculates the spikes of the network at time t using the following steps:

  1. input(): calculates the input to each neuron

  2. non_linearity(): applies a non-linearity to the input to get the neuron’s response

  3. emit_spikes(): Determines the spikes of the network at time t from the response

These methods are overriden by the child classes to implement different models.

To simulate the network, a default simulate method is provided, but can be overriden by the child classes to implement different simulation methods if needed.

If the models has any tunable parameters, they can be tuned to match a desired firing rate using the tune method. For other target functions, the tune method can be overriden by the child classes.

There are also methods for saving and loading the model.

property tunable_parameters: dict

Returns a list of the tunable parameters

connectivity_filter(W0: torch.Tensor, edge_index: torch.Tensor) torch.Tensor[source]

Returns the connectivity filter of the network. The connectivity filter determines the time dependency of the weights of the network. The default connectivity filter is just the initial synaptic weights, which means that the spikes only affect the neurons for one time step.

Parameters
  • W0 (torch.Tensor) – The initial synaptic weights of the network [n_edges, T]

  • edge_index (torch.Tensor) – The connectivity of the network [2, n_edges]

Returns

W – The connectivity filter of the network [n_edges, T]

Return type

torch.Tensor

input(edge_index: torch.Tensor, W: torch.Tensor, state: torch.Tensor, **kwargs) torch.Tensor[source]
emit_spikes(inputs: torch.Tensor) torch.Tensor[source]
non_linearity(inputs: torch.Tensor) torch.Tensor[source]
synaptic_input(edge_index: torch.Tensor, W: torch.Tensor, state: torch.Tensor, **kwargs) torch.Tensor[source]

Calculates the synaptic input to each neuron torch_geometric’s message passing framework. The propagate method fist calls the message method to compute the message along each edge and then aggregates the messages using the aggregation method (sum in this case). The result is then passed to the update method to compute the new state of the neurons. We only override the message method, and use the default aggregation and update methods.

\[I_i(t) = \sum_{j \in \mathcal{N}_i} \mathbf{W}_{ij} \cdot \mathbf{x}_j(t)\]
Parameters
  • edge_index (torch.Tensor) – The connectivity of the network [2, n_edges]

  • W (torch.Tensor) – The weights of the edges [n_edges, T]

  • state (torch.Tensor) – The state of the neurons [n_neurons, T]

  • **kwargs – Additional arguments

Returns

synaptic_input – The synaptic input to each neuron [n_neurons, 1]

Return type

torch.Tensor

message(state_j: torch.Tensor, W: torch.Tensor) torch.Tensor[source]

Calculates the message from the j-th neuron to the i-th neuron. This method is called by the propagate method of torch_geometric’s MessagePassing class.

\[m_{ij} = \mathbf{W}_{ij} \cdot \mathbf{x}_j(t)\]
Parameters
  • state_j (torch.Tensor) – The state of the j-th neuron [n_edges, T]

  • W (torch.Tensor) – The weights of the edges [n_edges, T]

Returns

message – The message from the j-th neuron to the i-th neuron [n_edges, 1]

Return type

torch.Tensor

stimulus_input(t: int, **kwargs) torch.Tensor[source]

Calculates the stimulus input to the network at time t.

Parameters

t (int) – The current time step

Returns

stimulus_input – The stimulus input to the network [n_neurons]

Return type

torch.Tensor

stimulus_filter(stimulus: torch.Tensor, **kwargs) torch.Tensor[source]

Filters the stimulus to the network. The default stimulus filter is just the stimulus itself.

Parameters
  • stimulus (torch.Tensor) – The stimulus to the network

  • **kwargs – Additional arguments

Returns

stimulus – The filtered stimulus to the network [n_neurons]

Return type

torch.Tensor

add_stimulus(stimulus: Callable)[source]

Adds a stimulus to the network

forward(edge_index: torch.Tensor, W: torch.Tensor, state: torch.Tensor, **kwargs) torch.Tensor[source]

Calculates the new state of the network at time t+1 from the state at time t.

Parameters
  • edge_index (torch.Tensor) – The connectivity of the network [2, n_edges]

  • W (torch.Tensor) – The edge weights of the connectivity filter [n_edges, T]

  • state (torch.Tensor) – The state of the network from time t - T to time t [n_neurons, T]

Returns

spikes – The spikes of the network at time t+1 [n_neurons]

Return type

torch.Tensor

simulate(data, n_steps: int, verbose: bool = True, equilibration_steps: int = 100, store_as_dtype: torch.dtype = torch.int32, **kwargs)[source]

Simulates the network for n_steps time steps given the connectivity. Returns the state of the network at each time step.

Parameters
  • data (torch_geometric.data.Data) – The data containing the connectivity.

  • n_steps (int) – The number of time steps to simulate

  • verbose (bool) – If True, a progress bar is shown

  • equilibration (int) – The number of time steps to simulate before the we start recording the state of the network.

  • store_as_dtype (torch.dtype) – The dtype to store the state of the network as.

Returns

x – The state of the network at each time step. The state is a binary tensor where 1 means that the neuron is active.

Return type

torch.Tensor[n_neurons, n_steps]

tune(data: torch_geometric.data.data.Data, firing_rate: float, tunable_parameters: Union[str, List[str]] = 'all', lr: float = 0.1, n_steps: int = 100, n_epochs: int = 100, verbose: bool = True)[source]

Tunes the model parameters to match a firing rate.

Parameters
  • data (torch_geometric.data.Data) – The training data containing the connectivity.

  • firing_rate (torch.Tensor) – The target firing rate of the network

  • tunable_parameters (list or str) – The list of parameters to tune, can be “all”, “stimulus”, “model” or a list of parameter names

  • lr (float) – The learning rate

  • n_steps (int) – The number of time steps to simulate for each epoch

  • n_epochs (int) – The number of epochs

  • verbose (bool) – If True, a progress bar is shown

set_tunable(parameters: Union[str, List[str]])[source]

Sets requires_grad to True for the parameters to be tuned

save(path: str)[source]

Saves the model to the path

load(path: str)[source]

Loads the model from the path

to(device: Union[str, torch.device])[source]

Moves the model to the device, including the random number generator

equilibrate(edge_index: torch.Tensor, W: torch.Tensor, inital_state: torch.Tensor, n_steps=100, store_as_dtype: torch.dtype = torch.int32) torch.Tensor[source]

Equilibrate the network to a given connectivity matrix.

Parameters
  • edge_index (torch.Tensor) – The connectivity of the network

  • W (torch.Tensor) – The connectivity filter

  • inital_state (torch.Tensor) – The initial state of the network

  • n_steps (int) – The number of time steps to equilibrate for

  • store_as_dtype (torch.dtype) – The dtype to store the state of the network as

Returns

x – The state of the network at each time step

Return type

torch.Tensor