BaseModel
- class spikeometric.models.BaseModel[source]
Bases:
torch_geometric.nn.conv.message_passing.MessagePassingBase class for all spiking neural networks.
Extends the MessagePassing class from torch_geometric by adding stimulus support and a forward method that calculates the spikes of the network at time t using the following steps:
input(): calculates the input to each neuronnon_linearity(): applies a non-linearity to the input to get the neuron’s responseemit_spikes(): Determines the spikes of the network at time t from the response
These methods are overriden by the child classes to implement different models.
To simulate the network, a default simulate method is provided, but can be overriden by the child classes to implement different simulation methods if needed.
If the models has any tunable parameters, they can be tuned to match a desired firing rate using the tune method. For other target functions, the tune method can be overriden by the child classes.
There are also methods for saving and loading the model.
- property tunable_parameters: dict
Returns a list of the tunable parameters
- connectivity_filter(W0: torch.Tensor, edge_index: torch.Tensor) torch.Tensor[source]
Returns the connectivity filter of the network. The connectivity filter determines the time dependency of the weights of the network. The default connectivity filter is just the initial synaptic weights, which means that the spikes only affect the neurons for one time step.
- Parameters
W0 (torch.Tensor) – The initial synaptic weights of the network [n_edges, T]
edge_index (torch.Tensor) – The connectivity of the network [2, n_edges]
- Returns
W – The connectivity filter of the network [n_edges, T]
- Return type
torch.Tensor
- input(edge_index: torch.Tensor, W: torch.Tensor, state: torch.Tensor, **kwargs) torch.Tensor[source]
- synaptic_input(edge_index: torch.Tensor, W: torch.Tensor, state: torch.Tensor, **kwargs) torch.Tensor[source]
Calculates the synaptic input to each neuron torch_geometric’s message passing framework. The propagate method fist calls the message method to compute the message along each edge and then aggregates the messages using the aggregation method (sum in this case). The result is then passed to the update method to compute the new state of the neurons. We only override the message method, and use the default aggregation and update methods.
\[I_i(t) = \sum_{j \in \mathcal{N}_i} \mathbf{W}_{ij} \cdot \mathbf{x}_j(t)\]- Parameters
edge_index (torch.Tensor) – The connectivity of the network [2, n_edges]
W (torch.Tensor) – The weights of the edges [n_edges, T]
state (torch.Tensor) – The state of the neurons [n_neurons, T]
**kwargs – Additional arguments
- Returns
synaptic_input – The synaptic input to each neuron [n_neurons, 1]
- Return type
torch.Tensor
- message(state_j: torch.Tensor, W: torch.Tensor) torch.Tensor[source]
Calculates the message from the j-th neuron to the i-th neuron. This method is called by the propagate method of torch_geometric’s MessagePassing class.
\[m_{ij} = \mathbf{W}_{ij} \cdot \mathbf{x}_j(t)\]- Parameters
state_j (torch.Tensor) – The state of the j-th neuron [n_edges, T]
W (torch.Tensor) – The weights of the edges [n_edges, T]
- Returns
message – The message from the j-th neuron to the i-th neuron [n_edges, 1]
- Return type
torch.Tensor
- stimulus_input(t: int, **kwargs) torch.Tensor[source]
Calculates the stimulus input to the network at time t.
- Parameters
t (int) – The current time step
- Returns
stimulus_input – The stimulus input to the network [n_neurons]
- Return type
torch.Tensor
- stimulus_filter(stimulus: torch.Tensor, **kwargs) torch.Tensor[source]
Filters the stimulus to the network. The default stimulus filter is just the stimulus itself.
- Parameters
stimulus (torch.Tensor) – The stimulus to the network
**kwargs – Additional arguments
- Returns
stimulus – The filtered stimulus to the network [n_neurons]
- Return type
torch.Tensor
- forward(edge_index: torch.Tensor, W: torch.Tensor, state: torch.Tensor, **kwargs) torch.Tensor[source]
Calculates the new state of the network at time t+1 from the state at time t.
- Parameters
edge_index (torch.Tensor) – The connectivity of the network [2, n_edges]
W (torch.Tensor) – The edge weights of the connectivity filter [n_edges, T]
state (torch.Tensor) – The state of the network from time t - T to time t [n_neurons, T]
- Returns
spikes – The spikes of the network at time t+1 [n_neurons]
- Return type
torch.Tensor
- simulate(data, n_steps: int, verbose: bool = True, equilibration_steps: int = 100, store_as_dtype: torch.dtype = torch.int32, **kwargs)[source]
Simulates the network for n_steps time steps given the connectivity. Returns the state of the network at each time step.
- Parameters
data (torch_geometric.data.Data) – The data containing the connectivity.
n_steps (int) – The number of time steps to simulate
verbose (bool) – If True, a progress bar is shown
equilibration (int) – The number of time steps to simulate before the we start recording the state of the network.
store_as_dtype (torch.dtype) – The dtype to store the state of the network as.
- Returns
x – The state of the network at each time step. The state is a binary tensor where 1 means that the neuron is active.
- Return type
torch.Tensor[n_neurons, n_steps]
- tune(data: torch_geometric.data.data.Data, firing_rate: float, tunable_parameters: Union[str, List[str]] = 'all', lr: float = 0.1, n_steps: int = 100, n_epochs: int = 100, verbose: bool = True)[source]
Tunes the model parameters to match a firing rate.
- Parameters
data (torch_geometric.data.Data) – The training data containing the connectivity.
firing_rate (torch.Tensor) – The target firing rate of the network
tunable_parameters (list or str) – The list of parameters to tune, can be “all”, “stimulus”, “model” or a list of parameter names
lr (float) – The learning rate
n_steps (int) – The number of time steps to simulate for each epoch
n_epochs (int) – The number of epochs
verbose (bool) – If True, a progress bar is shown
- set_tunable(parameters: Union[str, List[str]])[source]
Sets requires_grad to True for the parameters to be tuned
- to(device: Union[str, torch.device])[source]
Moves the model to the device, including the random number generator
- equilibrate(edge_index: torch.Tensor, W: torch.Tensor, inital_state: torch.Tensor, n_steps=100, store_as_dtype: torch.dtype = torch.int32) torch.Tensor[source]
Equilibrate the network to a given connectivity matrix.
- Parameters
edge_index (torch.Tensor) – The connectivity of the network
W (torch.Tensor) – The connectivity filter
inital_state (torch.Tensor) – The initial state of the network
n_steps (int) – The number of time steps to equilibrate for
store_as_dtype (torch.dtype) – The dtype to store the state of the network as
- Returns
x – The state of the network at each time step
- Return type
torch.Tensor