PoissonGLM
- class spikeometric.models.PoissonGLM(alpha: float, beta: float, T: int, tau: float, dt: float, r: float, b: float, rng=None)[source]
Bases:
spikeometric.models.base_model.BaseModelThe Poisson GLM model from section S.7 of the paper “Systematic errors in connectivity inferred from activity in strongly coupled recurrent circuits”.
It is a Poisson Generalized Linear Model model that passes the input to each neuron through an exponential non-linearity and samples a spike count from a Poisson distribution.
More specifically, we have the following equations:
- \[g_i(t+1) = r \: \sum_{t' = 0}^{T-1} \sum_{j\in \mathcal{N}(i)} (W_0)_{j, i} X_j(t-t')c(t') + b_i + \mathcal{E}_i(t+1)\]
- \[\mu_i(t+1) = \frac{\Delta t}{\alpha}\: e^{\beta g_i(t+1)}\]
- \[X_i(t+1) \sim \text{Pois}(\mu_i(t+1))\]
The first equation is implemented in the
input()method and gives the input to neuron \(i\) at time \(t+1\) as a convolution of the spike history of the neighboring neurons with a coupling filter \(c(t) = e^{- \Delta t \frac{t}{\tau}}\), weighted by the connectivity matrix \(W_0\), and scaled by the recurrent scaling factor \(r\). There is also a uniform background input \(b_i\) and an external stimulus \(\mathcal{E}_i(t+1)\).The second equation is implemented in the
non_linearity()method and gives the mean spike count of neuron \(i\) at time \(t+1\) as a function of the input \(g_i(t+1)\).Finally, the third equation is implemented in the
emit_spikes()method and samples the spike count of neuron \(i\) at time \(t+1\) from a Poisson distribution with mean \(\mu_i(t+1)\).- Parameters
alpha (float) – The \(\alpha\) parameter of the model. (tunable)
beta (float) – The \(\beta\) parameter of the model. (tunable)
T (int) – The number of time steps to consider back in time.
tau (float) – The time constant of the exponential coupling filter.
dt (float) – The time step of the simulation in milliseconds.
r (float) – The scaling of the recurrent connections. (tunable)
b (float) – The strength of the uniform background input. (tunable)
rng (torch.Generator, optional) – The random number generator to use for sampling the spikes. If not provided, a new one will be created.
- input(edge_index: torch.Tensor, W: torch.Tensor, state: torch.Tensor, t=- 1) torch.Tensor[source]
The input to the network at time t+1.
\[g_i(t+1) = r \: \sum_{\tau = 0}^{T-1} \sum_{j\in \mathcal{N}(i)} (W_0)_{j, i} X_j(t-\tau)c(\tau) + b_i + \mathcal{E}_i(t+1)\]- Parameters
edge_index (torch.Tensor[int]) – The edge index of the network.
W (torch.Tensor[float]) – The weights of the network.
state (torch.Tensor[int]) – The state of the network at time t.
t (int) – The time step of the simulation.
- Returns
The input to the network at time t+1.
- Return type
torch.Tensor
- non_linearity(input: torch.Tensor) torch.Tensor[source]
The exponential non-linearity of the model. Calculates an expected spike count from the input.
\[\mu_i(t+1) = \frac{\Delta t}{\alpha}\: e^{\beta g_i(t+1)}\]- Parameters
input (torch.Tensor[float]) – The input to the network at time t+1.
- Returns
The expected spike count of the network at time t+1.
- Return type
torch.Tensor
- emit_spikes(rates: torch.Tensor) torch.Tensor[source]
Samples the spikes from a Poisson distribution with rate \(\mu_i(t+1)\).
\[X_i(t+1) \sim \text{Pois}(\mu_i(t+1))\]- Parameters
rates (torch.Tensor[float]) – The expected spike count of the network at time t+1.
- Returns
The state of the network at time t+1.
- Return type
torch.Tensor
- connectivity_filter(W0: torch.Tensor, edge_index: torch.Tensor) torch.Tensor[source]
The connectivity filter of the network is a tensor that contains the synaptic weights between two neurons \(i\) and \(j\) at time step \(t\) after a spike event. This is computed by filtering the initial synaptic weights \(W_0\) with the exponetial coupling kernel \(c\):
\[W_{i,j}(t) = (W_0)_{i,j} \: c(t) = (W_0)_{i,j} e^{- \Delta t \frac{t}{\tau}}\]Spikes that are emited more than \(T\) time steps ago have no effect on the input.
- Parameters
W0 (torch.Tensor[float]) – The initial synaptic weights of the network.
edge_index (torch.Tensor[int]) – The edge index of the network.
- Returns
W (torch.Tensor[float]) – The connectivity filter of the network.
edge_index (torch.Tensor[int]) – The edge index of the network.