from spikeometric.models.sa_model import SAModel
import torch
[docs]class ThresholdSAM(SAModel):
r"""
The generative model used in the paper `"Systematic errors in connectivity inferred from activity in strongly coupled recurrent circuits" <https://www.biorxiv.org/content/10.1101/512053v1.full>`_.
Is is a threshold-based model, where the neurons spike if the synaptic activation is above a threshold.
Since the model is threshold-based, it is not tunable.
It is defined by the following equations:
#. .. math:: g_i(t+1) = r \: \sum_j (W_0)_{j,i}\: s_j(t) + b_i + \mathcal{E}_i(t+1)
#. .. math:: X_i(t+1) = \begin{cases}
1 & \text{if } g_i(t+1) > \theta \\
0 & \text{otherwise}
\end{cases}
where :math:`r` is the scaling of the recurrent connections, :math:`b_i` is the strength of a background input
and :math:`\mathcal{E}_i(t+1)` is an external stimulus. The response to the input is then compared to a threshold :math:`\theta`
and the neurons spike if the activation is above the threshold.
Between each time step, the synaptic activation is decayed by a factor of :math:`(1-\Delta t/\tau)`
if the neuron did not spike and increased by :math:`\Delta t` if it did.
Parameters
----------
r : float
The scaling of the recurrent connections. (tunable)
b : float
The strength of the background input. (tunable)
tau : float
The time constant of the neurons.
dt : float
The time step of the simulation.
sigma : float
The standard deviation of the noise.
rho : float
The sparsity of the noise.
theta : float
The threshold of the neurons (the neurons will spike if the activation is above the threshold).
rng : torch.Generator
The random number generator.
"""
def __init__(self, r: float, b: float, tau: float, dt: float, sigma: float, rho: float, theta: float, rng: torch.Generator):
super().__init__()
# Buffers
self.register_buffer("r", torch.tensor(r, dtype=torch.float))
self.register_buffer("b", torch.tensor(b, dtype=torch.float))
self.register_buffer("tau", torch.tensor(tau, dtype=torch.float))
self.register_buffer("dt", torch.tensor(dt, dtype=torch.float))
self.register_buffer("sigma", torch.tensor(sigma, dtype=torch.float))
self.register_buffer("rho", torch.tensor(rho, dtype=torch.float))
self.register_buffer("theta", torch.tensor(theta, dtype=torch.float))
self.register_buffer("T", torch.tensor(1, dtype=torch.int))
self._rng = rng
[docs] def emit_spikes(self, input: torch.Tensor) -> torch.Tensor:
r"""
Emit spikes from the network. The neuron will spike if the input it receives is above the threshold.
.. math:: X_i(t+1) = \begin{cases}
1 & \text{if } g_i(t+1) > \theta \\
0 & \text{otherwise}
\end{cases}
Parameters
-----------
activation : torch.Tensor [n_neurons]
The activation of the neurons at the current time step.
Returns
--------
spikes : torch.Tensor [n_neurons]
The spikes of the network at the current time step.
"""
return input > self.theta
[docs] def update_activation(self, spikes: torch.Tensor, activation: torch.Tensor) -> torch.Tensor:
r"""
Update the activation of the neurons according to the formula:
.. math::
s_i(t+1) = s_i(t)(1 - \frac{\Delta t}{\tau}) + X_i(t)\Delta t
where :math:`s_i(t)` is the activation of neuron :math:`i` at time :math:`t`, :math:`X_i(t)`
indicates whether the neuron spiked at time :math:`t` or not, and :math:`\tau` is the time constant of the neurons.
Parameters
-----------
spikes : torch.Tensor [n_neurons, 1]
The spikes of the network at the current time step.
activation : torch.Tensor [n_neurons, 1]
The activation of the neurons at the previous time step.
Returns
--------
activation : torch.Tensor [n_neurons, 1]
The activation of the neurons at the current time step.
"""
return activation*(1 - self.dt/self.tau) + spikes*self.dt