BernoulliGLM

class spikeometric.models.BernoulliGLM(theta: float, dt: float, coupling_window: int, alpha: float, abs_ref_scale: int, abs_ref_strength: float, rel_ref_scale: int, rel_ref_strength: int, beta: float, r: float, rng=None)[source]

Bases: spikeometric.models.base_model.BaseModel

The Bernoulli GLM from “Inferring causal connectivity from pairwise recordings and optogenetics”.

This is a Generalized Linear Model with a logit link function and a Bernoulli distributed response. Intuitively, it passes the input to each neuron through a sigmoid nonlinearity to get a probability of firing and samples spikes from the resulting Bernoulli distribution.

More formally, the model can be broken into three steps, each of which is implemented as a separate method in this class:

  1. \[g_i(t+1) = \sum_{\tau=0}^{T-1} \left(X_i(t-\tau)ref(\tau) + r\sum_{j \in \mathcal{N}(i)} (W_0)_{j, i} X_j(t-\tau) c(\tau)\right) + \mathcal{E}_i(t+1)\]
  2. \[p_i(t+1) = \sigma(g_i(t+1) - \theta) \Delta t\]
  3. \[X_i(t+1) \sim \text{Bernoulli}(p_i(t+1))\]

The first equation is implemented in the input() method and gives us the input to the neuron \(i\) at time \(t+1\) as a sum of the refractory, synaptic and external inputs. The refractory input is calculated by convolving the spike history of the neuron itself with a refractory filter \(ref\), the synaptic input is obtained by convolving the spike history of the neuron’s neighbors with the coupling filter \(c\), weighted by the synaptic weights \(W_0\), and the exteral input is given by evaluating an external input function \(\mathcal{E}\) at time \(t+1\).

The second equation is implemented in non_linearity() which computes the probability that the neuron \(i\) spikes at time \(t+1\) by passing its input \(g_i(t+1)\) through a sigmoid nonlinearity with threshold \(\theta\). The probability is then scaled by the time step size \(\Delta t\) to get the probability of spiking in an interval of length \(\Delta t\).

Finally, the third equation is implemented in emit_spikes() which samples the spike of the neuron \(i\) at time \(t+1\) from the Bernoulli distribution with probability \(p_i(t+1)\).

Parameters
  • theta (float) – The threshold activation \(\theta\) above which the neurons spike with probability > 0.5. (tunable)

  • dt (float) – The time step size \(\Delta t\) in milliseconds.

  • coupling_window (int) – Length of the coupling window \(c_w\) in time steps

  • alpha (float) – The decay rate \(\alpha\) of the negative activation during the relative refractory period (tunable)

  • abs_ref_scale (int) – The absolute refractory period of the neurons \(A_{ref}\) in time steps

  • abs_ref_strength (float) – The large negative activation \(abs\) added to the neurons during the absolute refractory period

  • rel_ref_scale (int) – The relative refractory period of the neurons \(R_{ref}\) in time steps

  • rel_ref_strength (float) – The negative activation \(rel\) added to the neurons during the relative refractory period (tunable)

  • beta (float) – The decay rate \(\beta\) of the weights. (tunable)

  • r (float) – The scaling of the recurrent connections. (tunable)

  • rng (torch.Generator) – The random number generator for sampling from the Bernoulli distribution.

input(edge_index: torch.Tensor, W: torch.Tensor, state: torch.Tensor, t=- 1) torch.Tensor[source]

Computes the input at time step t+1 by adding together the synaptic input from neighboring neurons and the stimulus input.

\[g_i(t+1) = \sum_{\tau=0}^{T-1} \left(X_i(t-\tau)ref(\tau) + \sum_{j \in \mathcal{N}(i)} (W_0)_{j, i} X_j(t-\tau) c(\tau)\right) + \mathcal{E}_i(t+1)\]
Parameters
  • edge_index (torch.Tensor [2, n_edges]) – The connectivity of the network

  • W (torch.Tensor [n_edges, T]) – The weights of the edges

  • state (torch.Tensor [n_neurons, T]) – The state of the neurons

  • t (int) – The current time step

Returns

synaptic_input

Return type

torch.Tensor [n_neurons, 1]

non_linearity(input: torch.Tensor) torch.Tensor[source]

Computes the probability that a neuron spikes given its input

\[p_i(t+1) = \sigma(g_i(t+1) - \theta)\]
Parameters

input (torch.Tensor [n_neurons, 1]) – The synaptic input to the neurons

Returns

probabilities – The probability that a neuron spikes

Return type

torch.Tensor [n_neurons, 1]

emit_spikes(probabilities: torch.Tensor) torch.Tensor[source]

Emits spikes from the neurons given their probabilities of spiking

\[P(X_i(t+1) = 1) = p_i(t+1)\]
Parameters

probabilites (torch.Tensor [n_neurons, 1]) – The probability that a neuron spikes

Returns

spikes – The spikes emitted by the neurons (1 if the neuron spikes, 0 otherwise)

Return type

torch.Tensor [n_neurons, 1]

connectivity_filter(W0: torch.Tensor, edge_index: torch.Tensor) torch.Tensor[source]

The connectivity filter constructs a tensor holding the weights of the edges in the network. This is done by filtering the initial coupling weights \(W_0\) with the coupling filter \(c\) and using a refractory filter \(ref\) as self-edge weights to emulate the refractory period.

For the coupling edges, we are given an initial weight \((W_0)_{i,j}\) for each edge. This tells us how strong the connection between neurons \(i\) and \(j\) is immediately after a spike event. We then use an exponential decay as our coupling filter \(c\) to model the decay of the connection strength over the next \(c_w\) time steps. This period is called the coupling window. Formally, at time step \(t\) after a spike event, the weight of an edge between neurons \(i\) and \(j\) is given by

\[\begin{split}W_{i, j}(t) = \begin{cases} (W_0)_{i,j} \: e^{-\beta t \Delta t} & \text{if } t < c_w \\ 0 & \text{if } c_w \leq t \end{cases}\end{split}\]

The self-edges are used to implement the absolute and relative refractory periods. A neuron enters an absolute refractory period after it spikes, during which it cannot spike again. The absolute refractory period is modeled by setting the weight of the self-edge to \(a\) for \(A_{ref}\) time steps. After this, the neuron enters the relative refractory period. During this period, the neuron can spike again but the probability of doing so is reduced. This is modeled by weighting spike events by to \(r e^{-\alpha t \Delta t}\) for the next \(R_{ref}\) time steps.

That is, the refractory filter \(ref\) is given by

\[\begin{split}ref(t) = \begin{cases} abs & \text{if } t < A_{ref} \\ rel e^{-\alpha t \Delta t} & \text{if } A_{ref} \leq t < A_{ref} + R_{ref} \\ 0 & \text{if } A_{ref} + R_{ref} \leq t \end{cases}\end{split}\]

And we set W_{i, i}(t) = ref(t) for all neurons \(i\).

All of this information can be represented by a tensor \(W\) of shape \(N\times N\times T\), where W[i, j, t] is the weight of the edge from neuron \(i\) to neuron \(j\) at time step \(t\) after a spike event.

Now, we remove all the zero weights from \(W\) and flatten the tensor to get a tensor of shape \(E\times T\), where \(E\) is the number of edges in the network. Then, W[i, t] is the weight of the \(i\)-th edge at time step \(t\) after a spike event, and we can use the edge_index tensor to tell us which edge corresponds to which neuron pair.

A final remark: the weights are returned flipped in time to make the convolution operation more efficient. That is, W[i, T-t] is the weight of the edge at time step \(t\) after a spike event.

Parameters
  • W0 (torch.Tensor [n_edges,]) – The initial weights of the edges

  • edge_index (torch.Tensor [2, n_edges]) – The edge index

Returns

  • W (torch.Tensor [n_edges, T]) – The connectivity filter

  • edge_index (torch.Tensor [2, n_edges]) – The edge index (with self edges added)